BinKhoaLe1812 commited on
Commit
c82ba47
·
verified ·
1 Parent(s): 0c0f651

Upd reranker

Browse files
Files changed (1) hide show
  1. api/retrieval.py +125 -2
api/retrieval.py CHANGED
@@ -1,15 +1,61 @@
1
  # api/retrieval.py
 
 
 
 
2
  import numpy as np
3
  import logging
 
4
  from .database import db_manager
 
5
 
6
- logger = logging.getLogger("medical-chatbot")
7
 
8
  class RetrievalEngine:
9
  def __init__(self):
10
  self.db_manager = db_manager
 
11
 
12
- def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.9) -> list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
  Retrieve medical information from FAISS index
15
  Min similarity between query and kb is to be 80%
@@ -66,6 +112,30 @@ class RetrievalEngine:
66
  kept.append(answer)
67
  kept_vecs.append(new_vec)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return kept if kept else [""]
70
 
71
  def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
@@ -98,3 +168,56 @@ class RetrievalEngine:
98
 
99
  # Global retrieval engine instance
100
  retrieval_engine = RetrievalEngine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # api/retrieval.py
2
+ import os
3
+ import re
4
+ import time
5
+ import requests
6
  import numpy as np
7
  import logging
8
+ from typing import List, Dict
9
  from .database import db_manager
10
+ from models import summarizer
11
 
12
+ logger = logging.getLogger("retrieval-bot")
13
 
14
  class RetrievalEngine:
15
  def __init__(self):
16
  self.db_manager = db_manager
17
+ self._reranker = _NvidiaReranker()
18
 
19
+ @staticmethod
20
+ def _is_cpg_text(text: str) -> bool:
21
+ """Heuristic to detect Clinical Practice Guideline (CPG) content."""
22
+ if not text:
23
+ return False
24
+ keywords = [
25
+ # common CPG indicators
26
+ r"\bguideline(s)?\b", r"\bclinical practice\b", r"\brecommend(ation|ed|s)?\b",
27
+ r"\bshould\b", r"\bmust\b", r"\bstrongly (recommend|suggest)\b",
28
+ r"\bNICE\b", r"\bAHA\b", r"\bACC\b", r"\bWHO\b", r"\bUSPSTF\b", r"\bIDSA\b",
29
+ r"\bclass (I|IIa|IIb|III)\b", r"\blevel (A|B|C)\b"
30
+ ]
31
+ text_lc = text.lower()
32
+ return any(re.search(p, text_lc, flags=re.IGNORECASE) for p in keywords)
33
+
34
+ @staticmethod
35
+ def _extract_guideline_sentences(text: str) -> str:
36
+ """Extract likely guideline sentences to reduce conversational/noisy content before summarization."""
37
+ if not text:
38
+ return ""
39
+ sentences = re.split(r"(?<=[.!?])\s+", text)
40
+ keep_patterns = [
41
+ r"\b(recommend|should|must|indicated|contraindicated|preferred|first-line|consider)\b",
42
+ r"\b(class\s*(I|IIa|IIb|III)|level\s*(A|B|C))\b",
43
+ r"\b(dose|mg|route|frequency)\b",
44
+ r"\b(screen|treat|manage|evaluate|monitor)\b"
45
+ ]
46
+ kept = []
47
+ for s in sentences:
48
+ s_norm = s.strip()
49
+ if not s_norm:
50
+ continue
51
+ if any(re.search(p, s_norm, flags=re.IGNORECASE) for p in keep_patterns):
52
+ kept.append(s_norm)
53
+ # Fallback: if filtering too aggressive, keep truncated original
54
+ if not kept:
55
+ return text[:1200]
56
+ return " ".join(kept)[:2000]
57
+
58
+ def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.8) -> list:
59
  """
60
  Retrieve medical information from FAISS index
61
  Min similarity between query and kb is to be 80%
 
112
  kept.append(answer)
113
  kept_vecs.append(new_vec)
114
 
115
+ # If any CPG-like content is present, rerank with NVIDIA NIM reranker and summarize to key guidelines
116
+ try:
117
+ cpg_candidates = [t for t in kept if self._is_cpg_text(t)]
118
+ if cpg_candidates:
119
+ logger.info("[Retrieval] CPG content detected; invoking NVIDIA reranker")
120
+ reranked = self._reranker.rerank(query, cpg_candidates)
121
+ # Keep only valid high-scoring items
122
+ filtered: List[Dict] = [r for r in reranked if r.get("score", 0) >= 0.3 and r.get("text")]
123
+ # Limit to top 3 for prompt efficiency
124
+ top_items = filtered[:3]
125
+ if top_items:
126
+ summarized: List[str] = []
127
+ for item in top_items:
128
+ guideline_text = self._extract_guideline_sentences(item["text"])
129
+ # Summarize to key clinical guidelines only (no conversational content)
130
+ concise = summarizer.summarize_text(guideline_text, max_length=300)
131
+ if concise:
132
+ summarized.append(concise)
133
+ # If summarization produced results, replace kept with these
134
+ if summarized:
135
+ kept = summarized
136
+ except Exception as e:
137
+ logger.warning(f"[Retrieval] CPG rerank/summarize step skipped due to error: {e}")
138
+
139
  return kept if kept else [""]
140
 
141
  def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
 
168
 
169
  # Global retrieval engine instance
170
  retrieval_engine = RetrievalEngine()
171
+
172
+
173
+ class _NvidiaReranker:
174
+ """Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
175
+ def __init__(self):
176
+ self.api_key = os.getenv("NVIDIA_URI")
177
+ self.model = "nvidia/rerank-qa-mistral-4b"
178
+ # NIM rerank endpoint (subject to environment); keep configurable
179
+ self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
180
+ self.timeout_s = 30
181
+
182
+ def rerank(self, query: str, documents: List[str]) -> List[Dict]:
183
+ if not self.api_key:
184
+ raise ValueError("NVIDIA_URI not set for reranker")
185
+ if not documents:
186
+ return []
187
+ headers = {
188
+ "Authorization": f"Bearer {self.api_key}",
189
+ "Content-Type": "application/json",
190
+ }
191
+ payload = {
192
+ "model": self.model,
193
+ "query": query,
194
+ "documents": [{"text": d} for d in documents],
195
+ }
196
+ try:
197
+ resp = requests.post(self.base_url, headers=headers, json=payload, timeout=self.timeout_s)
198
+ resp.raise_for_status()
199
+ data = resp.json()
200
+ # Expecting a list with scores and indices or texts
201
+ results = []
202
+ entries = data.get("results") or data.get("data") or []
203
+ if isinstance(entries, list) and entries:
204
+ for entry in entries:
205
+ # Common patterns: {index, score} or {text, score}
206
+ idx = entry.get("index")
207
+ text = entry.get("text") if entry.get("text") else (documents[idx] if idx is not None and idx < len(documents) else None)
208
+ score = entry.get("score", 0)
209
+ if text:
210
+ results.append({"text": text, "score": float(score)})
211
+ else:
212
+ # Fallback: if API returns scores aligned to input order
213
+ scores = data.get("scores")
214
+ if isinstance(scores, list) and len(scores) == len(documents):
215
+ for t, s in zip(documents, scores):
216
+ results.append({"text": t, "score": float(s)})
217
+ # Sort by score desc
218
+ results.sort(key=lambda x: x.get("score", 0), reverse=True)
219
+ return results
220
+ except Exception as e:
221
+ logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
222
+ # On failure, return original order with neutral scores
223
+ return [{"text": d, "score": 0.0} for d in documents]