LiamKhoaLe commited on
Commit
410be5e
·
1 Parent(s): 2a31cee

Enhance UI

Browse files
Files changed (4) hide show
  1. app.py +55 -4
  2. llama_integration.py +122 -0
  3. requirements.txt +3 -0
  4. search.py +134 -0
app.py CHANGED
@@ -13,6 +13,8 @@ from sentence_transformers.util import cos_sim
13
  from memory import MemoryManager
14
  from translation import translate_query
15
  from vlm import process_medical_image
 
 
16
 
17
  # ✅ Enable Logging for Debugging
18
  import logging
@@ -221,7 +223,7 @@ class RAGMedicalChatbot:
221
  self.model_name = model_name
222
  self.retrieve = retrieve_function
223
 
224
- def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "") -> str:
225
  # 0. Translate query if not EN, this help our RAG system
226
  if lang.upper() in {"VI", "ZH"}:
227
  user_query = translate_query(user_query, lang.lower())
@@ -232,6 +234,24 @@ class RAGMedicalChatbot:
232
  knowledge_base = "\n".join(retrieved_info)
233
  ## b. Diagnosis RAG from symptom query
234
  diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # 2. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection
237
  contextual_chunks = memory.get_contextual_chunks(user_id, user_query, lang)
@@ -258,16 +278,46 @@ class RAGMedicalChatbot:
258
  # Symptom-Diagnosis prediction RAG
259
  if diagnosis_guides:
260
  parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides))
 
 
 
 
 
 
261
  parts.append(f"User's question: {user_query}")
262
  parts.append(f"Language to generate answer: {lang}")
263
  prompt = "\n\n".join(parts)
264
  logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history
265
  response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
 
 
 
 
 
266
  # Store exchange + chunking
267
  if user_id:
268
  memory.add_exchange(user_id, user_query, response, lang=lang)
269
  logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response
270
  return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  # ✅ Initialize Chatbot
273
  chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash", retrieve_function=retrieve_medical_info)
@@ -280,23 +330,24 @@ async def chat_endpoint(req: Request):
280
  query_raw = body.get("query")
281
  query = query_raw.strip() if isinstance(query_raw, str) else ""
282
  lang = body.get("lang", "EN")
 
283
  image_base64 = body.get("image_base64", None)
284
  img_desc = body.get("img_desc", "Describe and investigate any clinical findings from this medical image.")
285
  start = time.time()
286
  image_diagnosis = ""
287
  # LLM Only
288
  if not image_base64:
289
- logger.info("[BOT] LLM scenario.")
290
  # LLM+VLM
291
  else:
292
  # If image is present → diagnose first
293
  safe_load = len(image_base64.encode("utf-8"))
294
  if safe_load > 5_000_000: # Img size safe processor
295
  return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."})
296
- logger.info("[BOT] VLM+LLM scenario.")
297
  logger.info(f"[VLM] Process medical image size: {safe_load}, desc: {img_desc}, {lang}.")
298
  image_diagnosis = process_medical_image(image_base64, img_desc, lang)
299
- answer = chatbot.chat(user_id, query, lang, image_diagnosis)
300
  elapsed = time.time() - start
301
  # Final
302
  return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
 
13
  from memory import MemoryManager
14
  from translation import translate_query
15
  from vlm import process_medical_image
16
+ from search import search_web
17
+ from llama_integration import process_search_query
18
 
19
  # ✅ Enable Logging for Debugging
20
  import logging
 
223
  self.model_name = model_name
224
  self.retrieve = retrieve_function
225
 
226
+ def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "", search_mode: bool = False) -> str:
227
  # 0. Translate query if not EN, this help our RAG system
228
  if lang.upper() in {"VI", "ZH"}:
229
  user_query = translate_query(user_query, lang.lower())
 
234
  knowledge_base = "\n".join(retrieved_info)
235
  ## b. Diagnosis RAG from symptom query
236
  diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher
237
+
238
+ # 1.5. Search mode - web search and Llama processing
239
+ search_context = ""
240
+ url_mapping = {}
241
+ if search_mode:
242
+ logger.info("[SEARCH] Starting web search mode")
243
+ try:
244
+ # Search the web
245
+ search_results = search_web(user_query, num_results=5)
246
+ if search_results:
247
+ # Process with Llama
248
+ search_context, url_mapping = process_search_query(user_query, search_results)
249
+ logger.info(f"[SEARCH] Found {len(search_results)} results, processed with Llama")
250
+ else:
251
+ logger.warning("[SEARCH] No search results found")
252
+ except Exception as e:
253
+ logger.error(f"[SEARCH] Search failed: {e}")
254
+ search_context = ""
255
 
256
  # 2. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection
257
  contextual_chunks = memory.get_contextual_chunks(user_id, user_query, lang)
 
278
  # Symptom-Diagnosis prediction RAG
279
  if diagnosis_guides:
280
  parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides))
281
+
282
+ # 5. Search context with citation instructions
283
+ if search_context:
284
+ parts.append("Additional information from web search:\n" + search_context)
285
+ parts.append("IMPORTANT: When you use information from the web search results above, you MUST add a citation tag <#ID> immediately after the relevant content, where ID is the document number (1, 2, 3, etc.). For example: 'According to recent studies <#1>, this condition affects...'")
286
+
287
  parts.append(f"User's question: {user_query}")
288
  parts.append(f"Language to generate answer: {lang}")
289
  prompt = "\n\n".join(parts)
290
  logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history
291
  response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
292
+
293
+ # 6. Process citations and replace with URLs
294
+ if search_mode and url_mapping:
295
+ response = self._process_citations(response, url_mapping)
296
+
297
  # Store exchange + chunking
298
  if user_id:
299
  memory.add_exchange(user_id, user_query, response, lang=lang)
300
  logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response
301
  return response.strip()
302
+
303
+ def _process_citations(self, response: str, url_mapping: Dict[int, str]) -> str:
304
+ """Replace citation tags with actual URLs"""
305
+ import re
306
+
307
+ # Find all citation tags like <#1>, <#2>, etc.
308
+ citation_pattern = r'<#(\d+)>'
309
+
310
+ def replace_citation(match):
311
+ doc_id = int(match.group(1))
312
+ if doc_id in url_mapping:
313
+ return f'<{url_mapping[doc_id]}>'
314
+ return match.group(0) # Keep original if URL not found
315
+
316
+ # Replace citations with URLs
317
+ processed_response = re.sub(citation_pattern, replace_citation, response)
318
+
319
+ logger.info(f"[CITATION] Processed citations, found {len(re.findall(citation_pattern, response))} citations")
320
+ return processed_response
321
 
322
  # ✅ Initialize Chatbot
323
  chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash", retrieve_function=retrieve_medical_info)
 
330
  query_raw = body.get("query")
331
  query = query_raw.strip() if isinstance(query_raw, str) else ""
332
  lang = body.get("lang", "EN")
333
+ search_mode = body.get("search", False)
334
  image_base64 = body.get("image_base64", None)
335
  img_desc = body.get("img_desc", "Describe and investigate any clinical findings from this medical image.")
336
  start = time.time()
337
  image_diagnosis = ""
338
  # LLM Only
339
  if not image_base64:
340
+ logger.info(f"[BOT] LLM scenario. Search mode: {search_mode}")
341
  # LLM+VLM
342
  else:
343
  # If image is present → diagnose first
344
  safe_load = len(image_base64.encode("utf-8"))
345
  if safe_load > 5_000_000: # Img size safe processor
346
  return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."})
347
+ logger.info(f"[BOT] VLM+LLM scenario. Search mode: {search_mode}")
348
  logger.info(f"[VLM] Process medical image size: {safe_load}, desc: {img_desc}, {lang}.")
349
  image_diagnosis = process_medical_image(image_base64, img_desc, lang)
350
+ answer = chatbot.chat(user_id, query, lang, image_diagnosis, search_mode)
351
  elapsed = time.time() - start
352
  # Final
353
  return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
llama_integration.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import logging
5
+ from typing import List, Dict, Tuple
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class NVIDIALLamaClient:
10
+ def __init__(self):
11
+ self.api_key = os.getenv("NVIDIA_URI")
12
+ if not self.api_key:
13
+ raise ValueError("NVIDIA_URI environment variable not set")
14
+
15
+ self.base_url = "https://api.nvcf.nvidia.com/v2/nvcf/chat/completions"
16
+ self.model = "meta/llama-3.1-8b-instruct"
17
+
18
+ def generate_keywords(self, user_query: str) -> List[str]:
19
+ """Use Llama to generate search keywords from user query"""
20
+ try:
21
+ prompt = f"""Given this medical question: "{user_query}"
22
+
23
+ Generate 3-5 specific search keywords that would help find relevant medical information online.
24
+ Focus on medical terms, symptoms, conditions, treatments, or procedures mentioned.
25
+ Return only the keywords separated by commas, no explanations.
26
+
27
+ Keywords:"""
28
+
29
+ response = self._call_llama(prompt)
30
+
31
+ # Extract keywords from response
32
+ keywords = [kw.strip() for kw in response.split(',') if kw.strip()]
33
+ logger.info(f"Generated keywords: {keywords}")
34
+ return keywords[:5] # Limit to 5 keywords
35
+
36
+ except Exception as e:
37
+ logger.error(f"Failed to generate keywords: {e}")
38
+ return [user_query] # Fallback to original query
39
+
40
+ def summarize_documents(self, documents: List[Dict], user_query: str) -> Tuple[str, Dict[int, str]]:
41
+ """Use Llama to summarize documents and return summary with URL mapping"""
42
+ try:
43
+ # Create document summaries
44
+ doc_summaries = []
45
+ url_mapping = {}
46
+
47
+ for doc in documents:
48
+ doc_id = doc['id']
49
+ url_mapping[doc_id] = doc['url']
50
+
51
+ # Create a summary prompt for each document
52
+ summary_prompt = f"""Summarize this medical information in 2-3 sentences, focusing on details relevant to: "{user_query}"
53
+
54
+ Document: {doc['title']}
55
+ Content: {doc['content'][:1000]}...
56
+
57
+ Summary:"""
58
+
59
+ summary = self._call_llama(summary_prompt)
60
+ doc_summaries.append(f"Document {doc_id}: {summary}")
61
+
62
+ # Combine all summaries
63
+ combined_summary = "\n\n".join(doc_summaries)
64
+
65
+ return combined_summary, url_mapping
66
+
67
+ except Exception as e:
68
+ logger.error(f"Failed to summarize documents: {e}")
69
+ return "", {}
70
+
71
+ def _call_llama(self, prompt: str) -> str:
72
+ """Make API call to NVIDIA Llama model"""
73
+ try:
74
+ headers = {
75
+ "Authorization": f"Bearer {self.api_key}",
76
+ "Content-Type": "application/json"
77
+ }
78
+
79
+ payload = {
80
+ "model": self.model,
81
+ "messages": [
82
+ {
83
+ "role": "user",
84
+ "content": prompt
85
+ }
86
+ ],
87
+ "temperature": 0.7,
88
+ "max_tokens": 1000
89
+ }
90
+
91
+ response = requests.post(
92
+ self.base_url,
93
+ headers=headers,
94
+ json=payload,
95
+ timeout=30
96
+ )
97
+
98
+ response.raise_for_status()
99
+ result = response.json()
100
+
101
+ return result['choices'][0]['message']['content'].strip()
102
+
103
+ except Exception as e:
104
+ logger.error(f"Llama API call failed: {e}")
105
+ raise
106
+
107
+ def process_search_query(user_query: str, search_results: List[Dict]) -> Tuple[str, Dict[int, str]]:
108
+ """Process search results using Llama model"""
109
+ try:
110
+ llama_client = NVIDIALLamaClient()
111
+
112
+ # Generate search keywords
113
+ keywords = llama_client.generate_keywords(user_query)
114
+
115
+ # Summarize documents
116
+ summary, url_mapping = llama_client.summarize_documents(search_results, user_query)
117
+
118
+ return summary, url_mapping
119
+
120
+ except Exception as e:
121
+ logger.error(f"Failed to process search query: {e}")
122
+ return "", {}
requirements.txt CHANGED
@@ -21,3 +21,6 @@ uvicorn
21
  fastapi
22
  torch # Reduce model load with half-precision (float16) to reduce RAM usage
23
  psutil # CPU/RAM logger
 
 
 
 
21
  fastapi
22
  torch # Reduce model load with half-precision (float16) to reduce RAM usage
23
  psutil # CPU/RAM logger
24
+ # **Web Search**
25
+ requests
26
+ beautifulsoup4
search.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import re
4
+ from urllib.parse import urljoin, urlparse
5
+ import time
6
+ import logging
7
+ from typing import List, Dict, Tuple
8
+ import os
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class WebSearcher:
13
+ def __init__(self):
14
+ self.session = requests.Session()
15
+ self.session.headers.update({
16
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
17
+ })
18
+ self.max_results = 10
19
+ self.timeout = 10
20
+
21
+ def search_google(self, query: str, num_results: int = 10) -> List[Dict]:
22
+ """Search Google and return results with URLs and titles"""
23
+ try:
24
+ # Use DuckDuckGo as it's more reliable for scraping
25
+ return self.search_duckduckgo(query, num_results)
26
+ except Exception as e:
27
+ logger.error(f"Google search failed: {e}")
28
+ return []
29
+
30
+ def search_duckduckgo(self, query: str, num_results: int = 10) -> List[Dict]:
31
+ """Search DuckDuckGo and return results"""
32
+ try:
33
+ url = "https://html.duckduckgo.com/html/"
34
+ params = {
35
+ 'q': query,
36
+ 'kl': 'us-en'
37
+ }
38
+
39
+ response = self.session.get(url, params=params, timeout=self.timeout)
40
+ response.raise_for_status()
41
+
42
+ soup = BeautifulSoup(response.content, 'html.parser')
43
+ results = []
44
+
45
+ # Find result links
46
+ result_links = soup.find_all('a', class_='result__a')
47
+
48
+ for link in result_links[:num_results]:
49
+ try:
50
+ href = link.get('href')
51
+ if href and href.startswith('http'):
52
+ title = link.get_text(strip=True)
53
+ if title and href:
54
+ results.append({
55
+ 'url': href,
56
+ 'title': title,
57
+ 'content': '' # Will be filled later
58
+ })
59
+ except Exception as e:
60
+ logger.warning(f"Error parsing result: {e}")
61
+ continue
62
+
63
+ return results
64
+
65
+ except Exception as e:
66
+ logger.error(f"DuckDuckGo search failed: {e}")
67
+ return []
68
+
69
+ def extract_content(self, url: str) -> str:
70
+ """Extract text content from a webpage"""
71
+ try:
72
+ response = self.session.get(url, timeout=self.timeout)
73
+ response.raise_for_status()
74
+
75
+ soup = BeautifulSoup(response.content, 'html.parser')
76
+
77
+ # Remove script and style elements
78
+ for script in soup(["script", "style"]):
79
+ script.decompose()
80
+
81
+ # Get text content
82
+ text = soup.get_text()
83
+
84
+ # Clean up text
85
+ lines = (line.strip() for line in text.splitlines())
86
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
87
+ text = ' '.join(chunk for chunk in chunks if chunk)
88
+
89
+ # Limit content length
90
+ if len(text) > 2000:
91
+ text = text[:2000] + "..."
92
+
93
+ return text
94
+
95
+ except Exception as e:
96
+ logger.warning(f"Failed to extract content from {url}: {e}")
97
+ return ""
98
+
99
+ def search_and_extract(self, query: str, num_results: int = 5) -> List[Dict]:
100
+ """Search for query and extract content from top results"""
101
+ logger.info(f"Searching for: {query}")
102
+
103
+ # Get search results
104
+ search_results = self.search_duckduckgo(query, num_results)
105
+
106
+ # Extract content from each result
107
+ enriched_results = []
108
+ for i, result in enumerate(search_results):
109
+ try:
110
+ logger.info(f"Extracting content from {result['url']}")
111
+ content = self.extract_content(result['url'])
112
+
113
+ if content:
114
+ enriched_results.append({
115
+ 'id': i + 1,
116
+ 'url': result['url'],
117
+ 'title': result['title'],
118
+ 'content': content
119
+ })
120
+
121
+ # Add delay to be respectful
122
+ time.sleep(1)
123
+
124
+ except Exception as e:
125
+ logger.warning(f"Failed to process {result['url']}: {e}")
126
+ continue
127
+
128
+ logger.info(f"Successfully processed {len(enriched_results)} results")
129
+ return enriched_results
130
+
131
+ def search_web(query: str, num_results: int = 5) -> List[Dict]:
132
+ """Main function to search the web and return enriched results"""
133
+ searcher = WebSearcher()
134
+ return searcher.search_and_extract(query, num_results)