# api/retrieval.py import os import re import time import requests import numpy as np import logging from typing import List, Dict from .database import db_manager from models import summarizer logger = logging.getLogger("retrieval-bot") class RetrievalEngine: def __init__(self): self.db_manager = db_manager # Lazy-init reranker to avoid NameError during module import ordering self._reranker = None def _get_reranker(self): """Initialize the NVIDIA reranker on first use.""" if self._reranker is None: self._reranker = _NvidiaReranker() return self._reranker @staticmethod def _is_cpg_text(text: str) -> bool: """Heuristic to detect Clinical Practice Guideline (CPG) content.""" if not text: return False keywords = [ # common CPG indicators r"\bguideline(s)?\b", r"\bclinical practice\b", r"\brecommend(ation|ed|s)?\b", r"\bshould\b", r"\bmust\b", r"\bstrongly (recommend|suggest)\b", r"\bNICE\b", r"\bAHA\b", r"\bACC\b", r"\bWHO\b", r"\bUSPSTF\b", r"\bIDSA\b", r"\bclass (I|IIa|IIb|III)\b", r"\blevel (A|B|C)\b" ] text_lc = text.lower() return any(re.search(p, text_lc, flags=re.IGNORECASE) for p in keywords) @staticmethod def _extract_guideline_sentences(text: str) -> str: """Extract likely guideline sentences to reduce conversational/noisy content before summarization.""" if not text: return "" sentences = re.split(r"(?<=[.!?])\s+", text) keep_patterns = [ r"\b(recommend|should|must|indicated|contraindicated|preferred|first-line|consider)\b", r"\b(class\s*(I|IIa|IIb|III)|level\s*(A|B|C))\b", r"\b(dose|mg|route|frequency)\b", r"\b(screen|treat|manage|evaluate|monitor)\b" ] kept = [] for s in sentences: s_norm = s.strip() if not s_norm: continue if any(re.search(p, s_norm, flags=re.IGNORECASE) for p in keep_patterns): kept.append(s_norm) # Fallback: if filtering too aggressive, keep truncated original if not kept: return text[:1200] return " ".join(kept)[:2000] def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.8) -> list: """ Retrieve medical information from FAISS index Min similarity between query and kb is to be 80% """ index = self.db_manager.load_faiss_index() if index is None: return [""] embedding_model = self.db_manager.get_embedding_model() qa_collection = self.db_manager.get_qa_collection() # Embed query query_vec = embedding_model.encode([query], convert_to_numpy=True) D, I = index.search(query_vec, k=k) # Filter by cosine threshold results = [] kept = [] kept_vecs = [] # Smart dedup on cosine threshold between similar candidates for score, idx in zip(D[0], I[0]): if score < min_sim: continue # List sim docs doc = qa_collection.find_one({"i": int(idx)}) if not doc: continue # Only compare answers answer = doc.get("Doctor", "").strip() if not answer: continue # Check semantic redundancy among previously kept results new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0] is_similar = False for i, vec in enumerate(kept_vecs): sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9) if sim >= 0.9: # High semantic similarity is_similar = True # Keep only better match to original query cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9) new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9) if new_sim_to_query > cur_sim_to_query: kept[i] = answer kept_vecs[i] = new_vec break # Non-similar candidates if not is_similar: kept.append(answer) kept_vecs.append(new_vec) # If any CPG-like content is present, rerank with NVIDIA NIM reranker and summarize to key guidelines try: cpg_candidates = [t for t in kept if self._is_cpg_text(t)] if cpg_candidates: logger.info("[Retrieval] CPG content detected; invoking NVIDIA reranker") reranked = self._get_reranker().rerank(query, cpg_candidates) # Keep only valid high-scoring items filtered: List[Dict] = [r for r in reranked if r.get("score", 0) >= 0.3 and r.get("text")] # Limit to top 3 for prompt efficiency top_items = filtered[:3] if top_items: summarized: List[str] = [] for item in top_items: guideline_text = self._extract_guideline_sentences(item["text"]) # Summarize to key clinical guidelines only (no conversational content) concise = summarizer.summarize_text(guideline_text, max_length=300) if concise: summarized.append(concise) # If summarization produced results, replace kept with these if summarized: kept = summarized except Exception as e: logger.warning(f"[Retrieval] CPG rerank/summarize step skipped due to error: {e}") return kept if kept else [""] def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list: """ Retrieve diagnosis information from symptom vectors """ self.db_manager.load_symptom_vectors() embedding_model = self.db_manager.get_embedding_model() # Embed input qvec = embedding_model.encode(symptom_text, convert_to_numpy=True) qvec = qvec / (np.linalg.norm(qvec) + 1e-9) # Similarity compute sims = self.db_manager.symptom_vectors @ qvec # cosine sorted_idx = np.argsort(sims)[-top_k:][::-1] seen_diag = set() final = [] # Dedup for i in sorted_idx: sim = sims[i] if sim < min_sim: continue label = self.db_manager.symptom_docs[i]["prognosis"] if label not in seen_diag: final.append(self.db_manager.symptom_docs[i]["answer"]) seen_diag.add(label) return final # Global retrieval engine instance retrieval_engine = RetrievalEngine() class _NvidiaReranker: """Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b""" def __init__(self): self.api_key = os.getenv("NVIDIA_URI") # Use provider doc model identifier self.model = os.getenv("NVIDIA_RERANK_MODEL", "nv-rerank-qa-mistral-4b:1") # NIM rerank endpoint (subject to environment); keep configurable self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking") self.timeout_s = 30 def rerank(self, query: str, documents: List[str]) -> List[Dict]: if not self.api_key: raise ValueError("NVIDIA_URI not set for reranker") if not documents: return [] headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "application/json", } # Truncate and limit candidates to avoid 4xx docs = documents[:10] docs = [d[:2000] for d in docs if isinstance(d, str)] # Two payload shapes based on provider doc payloads = [ { "model": self.model, "query": {"text": query}, "passages": [{"text": d} for d in docs], }, { "model": self.model, "query": query, "documents": [{"text": d} for d in docs], }, ] try: data = None for p in payloads: resp = requests.post(self.base_url, headers=headers, json=p, timeout=self.timeout_s) if resp.status_code >= 400: # try next shape continue data = resp.json() break if data is None: # last attempt for diagnostics resp.raise_for_status() # Expecting a list with scores and indices or texts results = [] entries = data.get("results") or data.get("data") or [] if isinstance(entries, list) and entries: for entry in entries: # Common patterns: {index, score} or {text, score} idx = entry.get("index") text = entry.get("text") if entry.get("text") else (documents[idx] if idx is not None and idx < len(documents) else None) score = entry.get("score", 0) if text: results.append({"text": text, "score": float(score)}) else: # Fallback: if API returns scores aligned to input order scores = data.get("scores") if isinstance(scores, list) and len(scores) == len(documents): for t, s in zip(documents, scores): results.append({"text": t, "score": float(s)}) # Sort by score desc results.sort(key=lambda x: x.get("score", 0), reverse=True) return results except Exception as e: logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}") # On failure, return original order with neutral scores return [{"text": d, "score": 0.0} for d in documents]