File size: 10,503 Bytes
b8bf5c8
c82ba47
 
 
 
b8bf5c8
 
c82ba47
2fb34cd
c82ba47
b8bf5c8
c82ba47
b8bf5c8
 
 
 
3b787c4
 
 
 
 
 
 
 
b8bf5c8
c82ba47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8bf5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82ba47
 
 
 
 
3b787c4
c82ba47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8bf5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82ba47
 
 
 
 
 
13f8f13
 
c82ba47
 
 
 
 
 
 
 
 
 
 
 
13f8f13
c82ba47
13f8f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82ba47
13f8f13
 
 
 
 
 
 
 
 
 
 
c82ba47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13f8f13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# api/retrieval.py
import os
import re
import time
import requests
import numpy as np
import logging
from typing import List, Dict
from .database import db_manager
from models import summarizer

logger = logging.getLogger("retrieval-bot")

class RetrievalEngine:
    def __init__(self):
        self.db_manager = db_manager
        # Lazy-init reranker to avoid NameError during module import ordering
        self._reranker = None

    def _get_reranker(self):
        """Initialize the NVIDIA reranker on first use."""
        if self._reranker is None:
            self._reranker = _NvidiaReranker()
        return self._reranker
    
    @staticmethod
    def _is_cpg_text(text: str) -> bool:
        """Heuristic to detect Clinical Practice Guideline (CPG) content."""
        if not text:
            return False
        keywords = [
            # common CPG indicators
            r"\bguideline(s)?\b", r"\bclinical practice\b", r"\brecommend(ation|ed|s)?\b",
            r"\bshould\b", r"\bmust\b", r"\bstrongly (recommend|suggest)\b",
            r"\bNICE\b", r"\bAHA\b", r"\bACC\b", r"\bWHO\b", r"\bUSPSTF\b", r"\bIDSA\b",
            r"\bclass (I|IIa|IIb|III)\b", r"\blevel (A|B|C)\b"
        ]
        text_lc = text.lower()
        return any(re.search(p, text_lc, flags=re.IGNORECASE) for p in keywords)
    
    @staticmethod
    def _extract_guideline_sentences(text: str) -> str:
        """Extract likely guideline sentences to reduce conversational/noisy content before summarization."""
        if not text:
            return ""
        sentences = re.split(r"(?<=[.!?])\s+", text)
        keep_patterns = [
            r"\b(recommend|should|must|indicated|contraindicated|preferred|first-line|consider)\b",
            r"\b(class\s*(I|IIa|IIb|III)|level\s*(A|B|C))\b",
            r"\b(dose|mg|route|frequency)\b",
            r"\b(screen|treat|manage|evaluate|monitor)\b"
        ]
        kept = []
        for s in sentences:
            s_norm = s.strip()
            if not s_norm:
                continue
            if any(re.search(p, s_norm, flags=re.IGNORECASE) for p in keep_patterns):
                kept.append(s_norm)
        # Fallback: if filtering too aggressive, keep truncated original
        if not kept:
            return text[:1200]
        return " ".join(kept)[:2000]
    
    def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.8) -> list:
        """
        Retrieve medical information from FAISS index
        Min similarity between query and kb is to be 80%
        """
        index = self.db_manager.load_faiss_index()
        if index is None:
            return [""]
        
        embedding_model = self.db_manager.get_embedding_model()
        qa_collection = self.db_manager.get_qa_collection()
        
        # Embed query
        query_vec = embedding_model.encode([query], convert_to_numpy=True)
        D, I = index.search(query_vec, k=k)
        
        # Filter by cosine threshold
        results = []
        kept = []
        kept_vecs = []
        
        # Smart dedup on cosine threshold between similar candidates
        for score, idx in zip(D[0], I[0]):
            if score < min_sim:
                continue
            
            # List sim docs
            doc = qa_collection.find_one({"i": int(idx)})
            if not doc:
                continue
            
            # Only compare answers
            answer = doc.get("Doctor", "").strip()
            if not answer:
                continue
            
            # Check semantic redundancy among previously kept results
            new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0]
            is_similar = False
            
            for i, vec in enumerate(kept_vecs):
                sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9)
                if sim >= 0.9:  # High semantic similarity
                    is_similar = True
                    # Keep only better match to original query
                    cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9)
                    new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9)
                    if new_sim_to_query > cur_sim_to_query:
                        kept[i] = answer
                        kept_vecs[i] = new_vec
                    break
            
            # Non-similar candidates
            if not is_similar:
                kept.append(answer)
                kept_vecs.append(new_vec)
        
        # If any CPG-like content is present, rerank with NVIDIA NIM reranker and summarize to key guidelines
        try:
            cpg_candidates = [t for t in kept if self._is_cpg_text(t)]
            if cpg_candidates:
                logger.info("[Retrieval] CPG content detected; invoking NVIDIA reranker")
                reranked = self._get_reranker().rerank(query, cpg_candidates)
                # Keep only valid high-scoring items
                filtered: List[Dict] = [r for r in reranked if r.get("score", 0) >= 0.3 and r.get("text")]
                # Limit to top 3 for prompt efficiency
                top_items = filtered[:3]
                if top_items:
                    summarized: List[str] = []
                    for item in top_items:
                        guideline_text = self._extract_guideline_sentences(item["text"])
                        # Summarize to key clinical guidelines only (no conversational content)
                        concise = summarizer.summarize_text(guideline_text, max_length=300)
                        if concise:
                            summarized.append(concise)
                    # If summarization produced results, replace kept with these
                    if summarized:
                        kept = summarized
        except Exception as e:
            logger.warning(f"[Retrieval] CPG rerank/summarize step skipped due to error: {e}")

        return kept if kept else [""]
    
    def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
        """
        Retrieve diagnosis information from symptom vectors
        """
        self.db_manager.load_symptom_vectors()
        embedding_model = self.db_manager.get_embedding_model()
        
        # Embed input
        qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
        qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
        
        # Similarity compute
        sims = self.db_manager.symptom_vectors @ qvec  # cosine
        sorted_idx = np.argsort(sims)[-top_k:][::-1]
        seen_diag = set()
        final = []  # Dedup
        
        for i in sorted_idx:
            sim = sims[i]
            if sim < min_sim:
                continue
            label = self.db_manager.symptom_docs[i]["prognosis"]
            if label not in seen_diag:
                final.append(self.db_manager.symptom_docs[i]["answer"])
                seen_diag.add(label)
        
        return final

# Global retrieval engine instance
retrieval_engine = RetrievalEngine()


class _NvidiaReranker:
    """Simple client for NVIDIA NIM reranking: nvidia/rerank-qa-mistral-4b"""
    def __init__(self):
        self.api_key = os.getenv("NVIDIA_URI")
        # Use provider doc model identifier
        self.model = os.getenv("NVIDIA_RERANK_MODEL", "nv-rerank-qa-mistral-4b:1")
        # NIM rerank endpoint (subject to environment); keep configurable
        self.base_url = os.getenv("NVIDIA_RERANK_ENDPOINT", "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking")
        self.timeout_s = 30

    def rerank(self, query: str, documents: List[str]) -> List[Dict]:
        if not self.api_key:
            raise ValueError("NVIDIA_URI not set for reranker")
        if not documents:
            return []
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
            "Accept": "application/json",
        }
        # Truncate and limit candidates to avoid 4xx
        docs = documents[:10]
        docs = [d[:2000] for d in docs if isinstance(d, str)]
        # Two payload shapes based on provider doc
        payloads = [
            {
                "model": self.model,
                "query": {"text": query},
                "passages": [{"text": d} for d in docs],
            },
            {
                "model": self.model,
                "query": query,
                "documents": [{"text": d} for d in docs],
            },
        ]
        try:
            data = None
            for p in payloads:
                resp = requests.post(self.base_url, headers=headers, json=p, timeout=self.timeout_s)
                if resp.status_code >= 400:
                    # try next shape
                    continue
                data = resp.json()
                break
            if data is None:
                # last attempt for diagnostics
                resp.raise_for_status()
            # Expecting a list with scores and indices or texts
            results = []
            entries = data.get("results") or data.get("data") or []
            if isinstance(entries, list) and entries:
                for entry in entries:
                    # Common patterns: {index, score} or {text, score}
                    idx = entry.get("index")
                    text = entry.get("text") if entry.get("text") else (documents[idx] if idx is not None and idx < len(documents) else None)
                    score = entry.get("score", 0)
                    if text:
                        results.append({"text": text, "score": float(score)})
            else:
                # Fallback: if API returns scores aligned to input order
                scores = data.get("scores")
                if isinstance(scores, list) and len(scores) == len(documents):
                    for t, s in zip(documents, scores):
                        results.append({"text": t, "score": float(s)})
            # Sort by score desc
            results.sort(key=lambda x: x.get("score", 0), reverse=True)
            return results
        except Exception as e:
            logger.warning(f"[Reranker] Failed calling NVIDIA reranker: {e}")
            # On failure, return original order with neutral scores
            return [{"text": d, "score": 0.0} for d in documents]