Spaces:
Sleeping
Sleeping
Upd reranker
Browse files- api/retrieval.py +125 -2
api/retrieval.py
CHANGED
|
@@ -1,15 +1,61 @@
|
|
| 1 |
# api/retrieval.py
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import logging
|
|
|
|
| 4 |
from .database import db_manager
|
|
|
|
| 5 |
|
| 6 |
-
logger = logging.getLogger("
|
| 7 |
|
| 8 |
class RetrievalEngine:
|
| 9 |
def __init__(self):
|
| 10 |
self.db_manager = db_manager
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
Retrieve medical information from FAISS index
|
| 15 |
Min similarity between query and kb is to be 80%
|
|
@@ -66,6 +112,30 @@ class RetrievalEngine:
|
|
| 66 |
kept.append(answer)
|
| 67 |
kept_vecs.append(new_vec)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
return kept if kept else [""]
|
| 70 |
|
| 71 |
def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
|
|
@@ -98,3 +168,56 @@ class RetrievalEngine:
|
|
| 98 |
|
| 99 |
# Global retrieval engine instance
|
| 100 |
retrieval_engine = RetrievalEngine()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# api/retrieval.py
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import time
|
| 5 |
+
import requests
|
| 6 |
import numpy as np
|
| 7 |
import logging
|
| 8 |
+
from typing import List, Dict
|
| 9 |
from .database import db_manager
|
| 10 |
+
from models import summarizer
|
| 11 |
|
| 12 |
+
logger = logging.getLogger("retrieval-bot")
|
| 13 |
|
| 14 |
class RetrievalEngine:
|
| 15 |
def __init__(self):
|
| 16 |
self.db_manager = db_manager
|
| 17 |
+
self._reranker = _NvidiaReranker()
|
| 18 |
|
| 19 |
+
@staticmethod
|
| 20 |
+
def _is_cpg_text(text: str) -> bool:
|
| 21 |
+
"""Heuristic to detect Clinical Practice Guideline (CPG) content."""
|
| 22 |
+
if not text:
|
| 23 |
+
return False
|
| 24 |
+
keywords = [
|
| 25 |
+
# common CPG indicators
|
| 26 |
+
r"\bguideline(s)?\b", r"\bclinical practice\b", r"\brecommend(ation|ed|s)?\b",
|
| 27 |
+
r"\bshould\b", r"\bmust\b", r"\bstrongly (recommend|suggest)\b",
|
| 28 |
+
r"\bNICE\b", r"\bAHA\b", r"\bACC\b", r"\bWHO\b", r"\bUSPSTF\b", r"\bIDSA\b",
|
| 29 |
+
r"\bclass (I|IIa|IIb|III)\b", r"\blevel (A|B|C)\b"
|
| 30 |
+
]
|
| 31 |
+
text_lc = text.lower()
|
| 32 |
+
return any(re.search(p, text_lc, flags=re.IGNORECASE) for p in keywords)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def _extract_guideline_sentences(text: str) -> str:
|
| 36 |
+
"""Extract likely guideline sentences to reduce conversational/noisy content before summarization."""
|
| 37 |
+
if not text:
|
| 38 |
+
return ""
|
| 39 |
+
sentences = re.split(r"(?<=[.!?])\s+", text)
|
| 40 |
+
keep_patterns = [
|
| 41 |
+
r"\b(recommend|should|must|indicated|contraindicated|preferred|first-line|consider)\b",
|
| 42 |
+
r"\b(class\s*(I|IIa|IIb|III)|level\s*(A|B|C))\b",
|
| 43 |
+
r"\b(dose|mg|route|frequency)\b",
|
| 44 |
+
r"\b(screen|treat|manage|evaluate|monitor)\b"
|
| 45 |
+
]
|
| 46 |
+
kept = []
|
| 47 |
+
for s in sentences:
|
| 48 |
+
s_norm = s.strip()
|
| 49 |
+
if not s_norm:
|
| 50 |
+
continue
|
| 51 |
+
if any(re.search(p, s_norm, flags=re.IGNORECASE) for p in keep_patterns):
|
| 52 |
+
kept.append(s_norm)
|
| 53 |
+
# Fallback: if filtering too aggressive, keep truncated original
|
| 54 |
+
if not kept:
|
| 55 |
+
return text[:1200]
|
| 56 |
+
return " ".join(kept)[:2000]
|
| 57 |
+
|
| 58 |
+
def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.8) -> list:
|
| 59 |
"""
|
| 60 |
Retrieve medical information from FAISS index
|
| 61 |
Min similarity between query and kb is to be 80%
|
|
|
|
| 112 |
kept.append(answer)
|
| 113 |
kept_vecs.append(new_vec)
|
| 114 |
|
| 115 |
+
# If any CPG-like content is present, rerank with NVIDIA NIM reranker and summarize to key guidelines
|
| 116 |
+
try:
|
| 117 |
+
cpg_candidates = [t for t in kept if self._is_cpg_text(t)]
|
| 118 |
+
if cpg_candidates:
|
| 119 |
+
logger.info("[Retrieval] CPG content detected; invoking NVIDIA reranker")
|
| 120 |
+
reranked = self._reranker.rerank(query, cpg_candidates)
|
| 121 |
+
# Keep only valid high-scoring items
|
| 122 |
+
filtered: List[Dict] = [r for r in reranked if r.get("score", 0) >= 0.3 and r.get("text")]
|
| 123 |
+
# Limit to top 3 for prompt efficiency
|
| 124 |
+
top_items = filtered[:3]
|
| 125 |
+
if top_items:
|
| 126 |
+
summarized: List[str] = []
|
| 127 |
+
for item in top_items:
|
| 128 |
+
guideline_text = self._extract_guideline_sentences(item["text"])
|
| 129 |
+
# Summarize to key clinical guidelines only (no conversational content)
|
| 130 |
+
concise = summarizer.summarize_text(guideline_text, max_length=300)
|
| 131 |
+
if concise:
|
| 132 |
+
summarized.append(concise)
|
| 133 |
+
# If summarization produced results, replace kept with these
|
| 134 |
+
if summarized:
|
| 135 |
+
kept = summarized
|
| 136 |
+
except Exception as e:
|
| 137 |
+
logger.warning(f"[Retrieval] CPG rerank/summarize step skipped due to error: {e}")
|
| 138 |
+
|
| 139 |
return kept if kept else [""]
|
| 140 |
|
| 141 |
def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
|
|
|
|
| 168 |
|
| 169 |
# Global retrieval engine instance
|
| 170 |
retrieval_engine = RetrievalEngine()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class _NvidiaReranker:
|
| 174 |
+
"""Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
|
| 175 |
+
def __init__(self):
|
| 176 |
+
self.api_key = os.getenv("NVIDIA_URI")
|
| 177 |
+
self.model = "nvidia/rerank-qa-mistral-4b"
|
| 178 |
+
# NIM rerank endpoint (subject to environment); keep configurable
|
| 179 |
+
self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
|
| 180 |
+
self.timeout_s = 30
|
| 181 |
+
|
| 182 |
+
def rerank(self, query: str, documents: List[str]) -> List[Dict]:
|
| 183 |
+
if not self.api_key:
|
| 184 |
+
raise ValueError("NVIDIA_URI not set for reranker")
|
| 185 |
+
if not documents:
|
| 186 |
+
return []
|
| 187 |
+
headers = {
|
| 188 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 189 |
+
"Content-Type": "application/json",
|
| 190 |
+
}
|
| 191 |
+
payload = {
|
| 192 |
+
"model": self.model,
|
| 193 |
+
"query": query,
|
| 194 |
+
"documents": [{"text": d} for d in documents],
|
| 195 |
+
}
|
| 196 |
+
try:
|
| 197 |
+
resp = requests.post(self.base_url, headers=headers, json=payload, timeout=self.timeout_s)
|
| 198 |
+
resp.raise_for_status()
|
| 199 |
+
data = resp.json()
|
| 200 |
+
# Expecting a list with scores and indices or texts
|
| 201 |
+
results = []
|
| 202 |
+
entries = data.get("results") or data.get("data") or []
|
| 203 |
+
if isinstance(entries, list) and entries:
|
| 204 |
+
for entry in entries:
|
| 205 |
+
# Common patterns: {index, score} or {text, score}
|
| 206 |
+
idx = entry.get("index")
|
| 207 |
+
text = entry.get("text") if entry.get("text") else (documents[idx] if idx is not None and idx < len(documents) else None)
|
| 208 |
+
score = entry.get("score", 0)
|
| 209 |
+
if text:
|
| 210 |
+
results.append({"text": text, "score": float(score)})
|
| 211 |
+
else:
|
| 212 |
+
# Fallback: if API returns scores aligned to input order
|
| 213 |
+
scores = data.get("scores")
|
| 214 |
+
if isinstance(scores, list) and len(scores) == len(documents):
|
| 215 |
+
for t, s in zip(documents, scores):
|
| 216 |
+
results.append({"text": t, "score": float(s)})
|
| 217 |
+
# Sort by score desc
|
| 218 |
+
results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
| 219 |
+
return results
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
|
| 222 |
+
# On failure, return original order with neutral scores
|
| 223 |
+
return [{"text": d, "score": 0.0} for d in documents]
|