Medical-Chatbot / api /retrieval.py
BinKhoaLe1812's picture
Upd api
13f8f13 verified
# api/retrieval.py
import os
import re
import time
import requests
import numpy as np
import logging
from typing import List, Dict
from .database import db_manager
from models import summarizer
logger = logging.getLogger("retrieval-bot")
class RetrievalEngine:
def __init__(self):
self.db_manager = db_manager
# Lazy-init reranker to avoid NameError during module import ordering
self._reranker = None
def _get_reranker(self):
"""Initialize the NVIDIA reranker on first use."""
if self._reranker is None:
self._reranker = _NvidiaReranker()
return self._reranker
@staticmethod
def _is_cpg_text(text: str) -> bool:
"""Heuristic to detect Clinical Practice Guideline (CPG) content."""
if not text:
return False
keywords = [
# common CPG indicators
r"\bguideline(s)?\b", r"\bclinical practice\b", r"\brecommend(ation|ed|s)?\b",
r"\bshould\b", r"\bmust\b", r"\bstrongly (recommend|suggest)\b",
r"\bNICE\b", r"\bAHA\b", r"\bACC\b", r"\bWHO\b", r"\bUSPSTF\b", r"\bIDSA\b",
r"\bclass (I|IIa|IIb|III)\b", r"\blevel (A|B|C)\b"
]
text_lc = text.lower()
return any(re.search(p, text_lc, flags=re.IGNORECASE) for p in keywords)
@staticmethod
def _extract_guideline_sentences(text: str) -> str:
"""Extract likely guideline sentences to reduce conversational/noisy content before summarization."""
if not text:
return ""
sentences = re.split(r"(?<=[.!?])\s+", text)
keep_patterns = [
r"\b(recommend|should|must|indicated|contraindicated|preferred|first-line|consider)\b",
r"\b(class\s*(I|IIa|IIb|III)|level\s*(A|B|C))\b",
r"\b(dose|mg|route|frequency)\b",
r"\b(screen|treat|manage|evaluate|monitor)\b"
]
kept = []
for s in sentences:
s_norm = s.strip()
if not s_norm:
continue
if any(re.search(p, s_norm, flags=re.IGNORECASE) for p in keep_patterns):
kept.append(s_norm)
# Fallback: if filtering too aggressive, keep truncated original
if not kept:
return text[:1200]
return " ".join(kept)[:2000]
def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.8) -> list:
"""
Retrieve medical information from FAISS index
Min similarity between query and kb is to be 80%
"""
index = self.db_manager.load_faiss_index()
if index is None:
return [""]
embedding_model = self.db_manager.get_embedding_model()
qa_collection = self.db_manager.get_qa_collection()
# Embed query
query_vec = embedding_model.encode([query], convert_to_numpy=True)
D, I = index.search(query_vec, k=k)
# Filter by cosine threshold
results = []
kept = []
kept_vecs = []
# Smart dedup on cosine threshold between similar candidates
for score, idx in zip(D[0], I[0]):
if score < min_sim:
continue
# List sim docs
doc = qa_collection.find_one({"i": int(idx)})
if not doc:
continue
# Only compare answers
answer = doc.get("Doctor", "").strip()
if not answer:
continue
# Check semantic redundancy among previously kept results
new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0]
is_similar = False
for i, vec in enumerate(kept_vecs):
sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9)
if sim >= 0.9: # High semantic similarity
is_similar = True
# Keep only better match to original query
cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9)
new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9)
if new_sim_to_query > cur_sim_to_query:
kept[i] = answer
kept_vecs[i] = new_vec
break
# Non-similar candidates
if not is_similar:
kept.append(answer)
kept_vecs.append(new_vec)
# If any CPG-like content is present, rerank with NVIDIA NIM reranker and summarize to key guidelines
try:
cpg_candidates = [t for t in kept if self._is_cpg_text(t)]
if cpg_candidates:
logger.info("[Retrieval] CPG content detected; invoking NVIDIA reranker")
reranked = self._get_reranker().rerank(query, cpg_candidates)
# Keep only valid high-scoring items
filtered: List[Dict] = [r for r in reranked if r.get("score", 0) >= 0.3 and r.get("text")]
# Limit to top 3 for prompt efficiency
top_items = filtered[:3]
if top_items:
summarized: List[str] = []
for item in top_items:
guideline_text = self._extract_guideline_sentences(item["text"])
# Summarize to key clinical guidelines only (no conversational content)
concise = summarizer.summarize_text(guideline_text, max_length=300)
if concise:
summarized.append(concise)
# If summarization produced results, replace kept with these
if summarized:
kept = summarized
except Exception as e:
logger.warning(f"[Retrieval] CPG rerank/summarize step skipped due to error: {e}")
return kept if kept else [""]
def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
"""
Retrieve diagnosis information from symptom vectors
"""
self.db_manager.load_symptom_vectors()
embedding_model = self.db_manager.get_embedding_model()
# Embed input
qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
# Similarity compute
sims = self.db_manager.symptom_vectors @ qvec # cosine
sorted_idx = np.argsort(sims)[-top_k:][::-1]
seen_diag = set()
final = [] # Dedup
for i in sorted_idx:
sim = sims[i]
if sim < min_sim:
continue
label = self.db_manager.symptom_docs[i]["prognosis"]
if label not in seen_diag:
final.append(self.db_manager.symptom_docs[i]["answer"])
seen_diag.add(label)
return final
# Global retrieval engine instance
retrieval_engine = RetrievalEngine()
class _NvidiaReranker:
"""Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
def __init__(self):
self.api_key = os.getenv("NVIDIA_URI")
# Use provider doc model identifier
self.model = os.getenv("NVIDIA_RERANK_MODEL", "nv-rerank-qa-mistral-4b:1")
# NIM rerank endpoint (subject to environment); keep configurable
self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
self.timeout_s = 30
def rerank(self, query: str, documents: List[str]) -> List[Dict]:
if not self.api_key:
raise ValueError("NVIDIA_URI not set for reranker")
if not documents:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# Truncate and limit candidates to avoid 4xx
docs = documents[:10]
docs = [d[:2000] for d in docs if isinstance(d, str)]
# Two payload shapes based on provider doc
payloads = [
{
"model": self.model,
"query": {"text": query},
"passages": [{"text": d} for d in docs],
},
{
"model": self.model,
"query": query,
"documents": [{"text": d} for d in docs],
},
]
try:
data = None
for p in payloads:
resp = requests.post(self.base_url, headers=headers, json=p, timeout=self.timeout_s)
if resp.status_code >= 400:
# try next shape
continue
data = resp.json()
break
if data is None:
# last attempt for diagnostics
resp.raise_for_status()
# Expecting a list with scores and indices or texts
results = []
entries = data.get("results") or data.get("data") or []
if isinstance(entries, list) and entries:
for entry in entries:
# Common patterns: {index, score} or {text, score}
idx = entry.get("index")
text = entry.get("text") if entry.get("text") else (documents[idx] if idx is not None and idx < len(documents) else None)
score = entry.get("score", 0)
if text:
results.append({"text": text, "score": float(score)})
else:
# Fallback: if API returns scores aligned to input order
scores = data.get("scores")
if isinstance(scores, list) and len(scores) == len(documents):
for t, s in zip(documents, scores):
results.append({"text": t, "score": float(s)})
# Sort by score desc
results.sort(key=lambda x: x.get("score", 0), reverse=True)
return results
except Exception as e:
logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
# On failure, return original order with neutral scores
return [{"text": d, "score": 0.0} for d in documents]