LiamKhoaLe commited on
Commit
d74506f
·
1 Parent(s): b3797f0
Files changed (2) hide show
  1. README.md +82 -1
  2. app.py +406 -102
README.md CHANGED
@@ -8,7 +8,88 @@ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: 'Medical searcher for web-sources retrieval'
12
  ---
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  > Introduction: A medical app for MCP-1st-Birthday hackathon, integrate MCP searcher and document RAG
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: 'MedicalMCP RAG & Search with MedSwin'
12
  ---
13
 
14
+ # 🩺 MedLLM Agent
15
+
16
+ **Advanced Medical AI Assistant** powered by fine-tuned MedSwin models with comprehensive knowledge retrieval capabilities.
17
+
18
+ ## ✨ Key Features
19
+
20
+ ### 📄 **Document RAG (Retrieval-Augmented Generation)**
21
+ - Upload medical documents (PDF/TXT) and get answers based on your uploaded content
22
+ - Hierarchical document indexing with auto-merging retrieval
23
+ - Mitigates hallucination by grounding responses in your documents
24
+ - Toggle RAG on/off - when disabled, provides concise clinical answers without document context
25
+
26
+ ### 🌐 **Web Search Integration (MCP Protocol)**
27
+ - Fetch knowledge from reliable online medical resources
28
+ - Automatic summarization of web search results using Llama-8B
29
+ - Enriches context for medical specialist models
30
+ - Combines document RAG + web sources for comprehensive answers
31
+
32
+ ### 🧠 **MedSwin Medical Specialist Models**
33
+ - **MedSwin SFT** (default) - Supervised Fine-Tuned model
34
+ - **MedSwin KD** - Knowledge Distillation model
35
+ - **MedSwin TA** - Task-Aware merged model
36
+ - Models download on-demand for efficient resource usage
37
+ - Fine-tuned on MedAlpaca-7B for medical domain expertise
38
+
39
+ ### 🌍 **Multi-Language Support**
40
+ - Automatic language detection
41
+ - Non-English queries automatically translated to English
42
+ - Medical model processes in English
43
+ - Responses translated back to original language
44
+ - Powered by Llama-3.1-8B-Instruct for translation
45
+
46
+ ### ⚙️ **Advanced Configuration**
47
+ - Customizable generation parameters (temperature, top-p, top-k)
48
+ - Adjustable retrieval settings (top-k, merge threshold)
49
+ - Increased max tokens to prevent early stopping
50
+ - Custom EOS handling for medical models
51
+ - Dynamic system prompts based on RAG status
52
+
53
+ ## 🚀 Usage
54
+
55
+ 1. **Upload Documents**: Drag and drop PDF or text files containing medical information
56
+ 2. **Configure Settings**:
57
+ - Enable/disable Document RAG
58
+ - Enable/disable Web Search (MCP)
59
+ - Select medical model (MedSwin SFT/KD/TA)
60
+ 3. **Ask Questions**: Type your medical question in any language
61
+ 4. **Get Answers**: Receive comprehensive answers based on:
62
+ - Your uploaded documents (if RAG enabled)
63
+ - Web sources (if web search enabled)
64
+ - Medical model's training knowledge
65
+
66
+ ## 🔧 Technical Details
67
+
68
+ - **Medical Models**: MedSwin/MedSwin-7B-SFT, MedSwin-7B-KD, MedSwin-Merged-TA-SFT-0.7
69
+ - **Translation Model**: meta-llama/Meta-Llama-3.1-8B-Instruct
70
+ - **Embedding Model**: sentence-transformers/all-MiniLM-L6-v2
71
+ - **RAG Framework**: LlamaIndex with hierarchical node parsing
72
+ - **Web Search**: DuckDuckGo with content extraction and summarization
73
+
74
+ ## 📋 Requirements
75
+
76
+ See `requirements.txt` for full dependency list. Key dependencies:
77
+ - transformers, torch
78
+ - llama-index
79
+ - langdetect
80
+ - duckduckgo-search
81
+ - gradio, spaces
82
+
83
+ ## 🎯 Use Cases
84
+
85
+ - Medical document Q&A
86
+ - Clinical information retrieval
87
+ - Medical research assistance
88
+ - Multi-language medical consultations
89
+ - Evidence-based medical answers
90
+
91
+ ---
92
+
93
+ **Note**: This system is designed to assist with medical information retrieval. Always consult qualified healthcare professionals for medical decisions.
94
+
95
  > Introduction: A medical app for MCP-1st-Birthday hackathon, integrate MCP searcher and document RAG
app.py CHANGED
@@ -31,24 +31,38 @@ from llama_index.core.storage.docstore import SimpleDocumentStore
31
  from llama_index.llms.huggingface import HuggingFaceLLM
32
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
33
  from tqdm import tqdm
 
 
 
 
34
 
35
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
38
  hf_logging.set_verbosity_error()
39
 
40
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
 
 
 
 
 
 
 
41
  EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
42
  HF_TOKEN = os.environ.get("HF_TOKEN")
43
  if not HF_TOKEN:
44
  raise ValueError("HF_TOKEN not found in environment variables")
45
 
46
  # Custom UI
47
- TITLE = "<h1><center>Multi-Document RAG with LLama 3.1-8B Model</center></h1>"
48
  DESCRIPTION = """
49
  <center>
 
 
 
 
50
  <p>Upload PDF or text files to get started!</p>
51
- <p>After asking question wait for RAG system to get relevant nodes and pass to LLM</p>
52
  </center>
53
  """
54
  CSS = """
@@ -107,6 +121,22 @@ CSS = """
107
  display: flex;
108
  align-items: center;
109
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  @media (min-width: 768px) {
111
  .main-container {
112
  display: flex;
@@ -124,34 +154,195 @@ CSS = """
124
  }
125
  """
126
 
127
- global_model = None
128
- global_tokenizer = None
 
 
 
129
  global_file_info = {}
130
 
131
- def initialize_model_and_tokenizer():
132
- global global_model, global_tokenizer
133
- if global_model is None or global_tokenizer is None:
134
- logger.info("Initializing model and tokenizer...")
135
- global_tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
136
- global_model = AutoModelForCausalLM.from_pretrained(
137
- MODEL,
 
138
  device_map="auto",
139
  trust_remote_code=True,
140
  token=HF_TOKEN,
141
  torch_dtype=torch.float16
142
  )
143
- logger.info("Model and tokenizer initialized successfully")
144
 
145
- def get_llm(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
146
- global global_model, global_tokenizer
147
- if global_model is None or global_tokenizer is None:
148
- initialize_model_and_tokenizer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  return HuggingFaceLLM(
151
  context_window=4096,
152
  max_new_tokens=max_new_tokens,
153
- tokenizer=global_tokenizer,
154
- model=global_model,
155
  generate_kwargs={
156
  "do_sample": True,
157
  "temperature": temperature,
@@ -174,7 +365,7 @@ def extract_text_from_document(file):
174
  else:
175
  return None, 0, ValueError(f"Unsupported file format: {file_extension}")
176
 
177
- @spaces.GPU()
178
  def create_or_update_index(files, request: gr.Request):
179
  global global_file_info
180
 
@@ -185,7 +376,7 @@ def create_or_update_index(files, request: gr.Request):
185
  user_id = request.session_hash
186
  save_dir = f"./{user_id}_index"
187
  # Initialize LlamaIndex modules
188
- llm = get_llm()
189
  embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
190
  Settings.llm = llm
191
  Settings.embed_model = embed_model
@@ -234,13 +425,6 @@ def create_or_update_index(files, request: gr.Request):
234
  new_leaf_nodes = get_leaf_nodes(new_nodes)
235
  new_root_nodes = get_root_nodes(new_nodes)
236
  logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
237
- node_ancestry = {}
238
- for node in new_nodes:
239
- if hasattr(node, 'metadata') and 'file_name' in node.metadata:
240
- file_origin = node.metadata['file_name']
241
- if file_origin not in node_ancestry:
242
- node_ancestry[file_origin] = 0
243
- node_ancestry[file_origin] += 1
244
 
245
  if os.path.exists(save_dir):
246
  logger.info(f"Loading existing index from {save_dir}")
@@ -288,7 +472,7 @@ def create_or_update_index(files, request: gr.Request):
288
  output_container += "</div>"
289
  return f"Successfully indexed {len(files)} files.", output_container
290
 
291
- @spaces.GPU()
292
  def stream_chat(
293
  message: str,
294
  history: list,
@@ -300,75 +484,125 @@ def stream_chat(
300
  penalty: float,
301
  retriever_k: int,
302
  merge_threshold: float,
 
 
 
303
  request: gr.Request
304
  ):
305
  if not request:
306
  yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}]
307
  return
 
 
 
 
 
 
 
 
 
 
 
308
  user_id = request.session_hash
309
  index_dir = f"./{user_id}_index"
310
- if not os.path.exists(index_dir):
311
- yield history + [{"role": "assistant", "content": "Please upload documents first."}]
312
- return
313
-
314
- max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 1024
315
- temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.9
316
- top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
317
- top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
318
- penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
319
- retriever_k = int(retriever_k) if isinstance(retriever_k, (int, float)) else 15
320
- merge_threshold = float(merge_threshold) if isinstance(merge_threshold, (int, float)) else 0.5
321
- llm = get_llm(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k)
322
- embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
323
- Settings.llm = llm
324
- Settings.embed_model = embed_model
325
- storage_context = StorageContext.from_defaults(persist_dir=index_dir)
326
- index = load_index_from_storage(storage_context, settings=Settings)
327
- base_retriever = index.as_retriever(similarity_top_k=retriever_k)
328
- auto_merging_retriever = AutoMergingRetriever(
329
- base_retriever,
330
- storage_context=storage_context,
331
- simple_ratio_thresh=merge_threshold,
332
- verbose=True
333
- )
334
- logger.info(f"Query: {message}")
335
- retrieval_start = time.time()
336
- base_nodes = base_retriever.retrieve(message)
337
- logger.info(f"Retrieved {len(base_nodes)} base nodes in {time.time() - retrieval_start:.2f}s")
338
- base_file_sources = {}
339
- for node in base_nodes:
340
- if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
341
- file_name = node.node.metadata['file_name']
342
- if file_name not in base_file_sources:
343
- base_file_sources[file_name] = 0
344
- base_file_sources[file_name] += 1
345
- logger.info(f"Base retrieval file distribution: {base_file_sources}")
346
- merging_start = time.time()
347
- merged_nodes = auto_merging_retriever.retrieve(message)
348
- logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - merging_start:.2f}s")
349
- merged_file_sources = {}
350
- for node in merged_nodes:
351
- if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
352
- file_name = node.node.metadata['file_name']
353
- if file_name not in merged_file_sources:
354
- merged_file_sources[file_name] = 0
355
- merged_file_sources[file_name] += 1
356
- logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
357
- context = "\n\n".join([n.node.text for n in merged_nodes])
358
  source_info = ""
359
- if merged_file_sources:
360
- source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
361
- formatted_system_prompt = f"{system_prompt}\n\nDocument Context:\n{context}{source_info}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  messages = [{"role": "system", "content": formatted_system_prompt}]
363
  for entry in history:
364
  messages.append(entry)
365
  messages.append({"role": "user", "content": message})
366
- prompt = global_tokenizer.apply_chat_template(
 
 
 
 
 
 
 
 
 
 
367
  messages,
368
  tokenize=False,
369
  add_generation_prompt=True
370
  )
 
 
 
 
371
  stop_event = threading.Event()
 
372
  class StopOnEvent(StoppingCriteria):
373
  def __init__(self, stop_event):
374
  super().__init__()
@@ -376,13 +610,42 @@ def stream_chat(
376
 
377
  def __call__(self, input_ids, scores, **kwargs):
378
  return self.stop_event.is_set()
379
- stopping_criteria = StoppingCriteriaList([StopOnEvent(stop_event)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  streamer = TextIteratorStreamer(
381
- global_tokenizer,
382
  skip_prompt=True,
383
  skip_special_tokens=True
384
  )
385
- inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
 
 
 
 
 
386
  generation_kwargs = dict(
387
  inputs,
388
  streamer=streamer,
@@ -392,23 +655,36 @@ def stream_chat(
392
  top_k=top_k,
393
  repetition_penalty=penalty,
394
  do_sample=True,
395
- stopping_criteria=stopping_criteria
 
 
396
  )
397
- thread = threading.Thread(target=global_model.generate, kwargs=generation_kwargs)
 
398
  thread.start()
 
399
  updated_history = history + [
400
- {"role": "user", "content": message},
401
  {"role": "assistant", "content": ""}
402
  ]
403
  yield updated_history
 
404
  partial_response = ""
405
  try:
406
  for new_text in streamer:
407
  partial_response += new_text
408
  updated_history[-1]["content"] = partial_response
409
  yield updated_history
410
- output_ids = global_tokenizer.encode(partial_response, return_tensors="pt")
411
- yield updated_history
 
 
 
 
 
 
 
 
412
  except GeneratorExit:
413
  stop_event.set()
414
  thread.join()
@@ -446,13 +722,13 @@ def create_demo():
446
  with gr.Column(elem_classes="chatbot-container"):
447
  chatbot = gr.Chatbot(
448
  height=500,
449
- placeholder="Chat with your documents here... Type your question below.",
450
  show_label=False,
451
  type="messages"
452
  )
453
  with gr.Row(elem_classes="input-row"):
454
  message_input = gr.Textbox(
455
- placeholder="Type your question here...",
456
  show_label=False,
457
  container=False,
458
  lines=1,
@@ -460,9 +736,28 @@ def create_demo():
460
  )
461
  submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
462
 
463
- with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  system_prompt = gr.Textbox(
465
- value="As a knowledgeable assistant, your task is to provide detailed and context-rich answers based on the relevant information from all uploaded documents. When information is sourced from multiple documents, summarize the key points from each and explain how they relate, noting any connections or contradictions. Your response should be thorough, informative, and easy to understand.",
466
  label="System Prompt",
467
  lines=3
468
  )
@@ -472,15 +767,16 @@ def create_demo():
472
  minimum=0,
473
  maximum=1,
474
  step=0.1,
475
- value=0.9,
476
  label="Temperature"
477
  )
478
  max_new_tokens = gr.Slider(
479
- minimum=128,
480
- maximum=8192,
481
- step=64,
482
- value=1024,
483
  label="Max New Tokens",
 
484
  )
485
  top_p = gr.Slider(
486
  minimum=0.0,
@@ -532,7 +828,10 @@ def create_demo():
532
  top_k,
533
  penalty,
534
  retriever_k,
535
- merge_threshold
 
 
 
536
  ],
537
  outputs=chatbot
538
  )
@@ -549,7 +848,10 @@ def create_demo():
549
  top_k,
550
  penalty,
551
  retriever_k,
552
- merge_threshold
 
 
 
553
  ],
554
  outputs=chatbot
555
  )
@@ -557,6 +859,8 @@ def create_demo():
557
  return demo
558
 
559
  if __name__ == "__main__":
560
- initialize_model_and_tokenizer()
 
 
561
  demo = create_demo()
562
- demo.launch()
 
31
  from llama_index.llms.huggingface import HuggingFaceLLM
32
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
33
  from tqdm import tqdm
34
+ from langdetect import detect, LangDetectException
35
+ from duckduckgo_search import DDGS
36
+ import requests
37
+ from bs4 import BeautifulSoup
38
 
39
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
40
  logging.basicConfig(level=logging.INFO)
41
  logger = logging.getLogger(__name__)
42
  hf_logging.set_verbosity_error()
43
 
44
+ # Model configurations
45
+ TRANSLATION_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
46
+ MEDSWIN_MODELS = {
47
+ "MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
48
+ "MedSwin KD": "MedSwin/MedSwin-7B-KD",
49
+ "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
50
+ }
51
+ DEFAULT_MEDICAL_MODEL = "MedSwin SFT"
52
  EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
53
  HF_TOKEN = os.environ.get("HF_TOKEN")
54
  if not HF_TOKEN:
55
  raise ValueError("HF_TOKEN not found in environment variables")
56
 
57
  # Custom UI
58
+ TITLE = "<h1><center>🩺 MedLLM Agent - Medical RAG & Web Search System</center></h1>"
59
  DESCRIPTION = """
60
  <center>
61
+ <p><strong>Advanced Medical AI Assistant</strong> powered by MedSwin models</p>
62
+ <p>📄 <strong>Document RAG:</strong> Answer based on uploaded medical documents</p>
63
+ <p>🌐 <strong>Web Search:</strong> Fetch knowledge from reliable online medical resources</p>
64
+ <p>🌍 <strong>Multi-language:</strong> Automatic translation for non-English queries</p>
65
  <p>Upload PDF or text files to get started!</p>
 
66
  </center>
67
  """
68
  CSS = """
 
121
  display: flex;
122
  align-items: center;
123
  }
124
+ .feature-badge {
125
+ display: inline-block;
126
+ padding: 3px 8px;
127
+ margin: 2px;
128
+ border-radius: 12px;
129
+ font-size: 11px;
130
+ font-weight: bold;
131
+ }
132
+ .badge-rag {
133
+ background: #e3f2fd;
134
+ color: #1976d2;
135
+ }
136
+ .badge-web {
137
+ background: #f3e5f5;
138
+ color: #7b1fa2;
139
+ }
140
  @media (min-width: 768px) {
141
  .main-container {
142
  display: flex;
 
154
  }
155
  """
156
 
157
+ # Global model storage
158
+ global_translation_model = None
159
+ global_translation_tokenizer = None
160
+ global_medical_models = {}
161
+ global_medical_tokenizers = {}
162
  global_file_info = {}
163
 
164
+ def initialize_translation_model():
165
+ """Initialize Llama model for translation purposes"""
166
+ global global_translation_model, global_translation_tokenizer
167
+ if global_translation_model is None or global_translation_tokenizer is None:
168
+ logger.info("Initializing translation model (Llama-8B)...")
169
+ global_translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL, token=HF_TOKEN)
170
+ global_translation_model = AutoModelForCausalLM.from_pretrained(
171
+ TRANSLATION_MODEL,
172
  device_map="auto",
173
  trust_remote_code=True,
174
  token=HF_TOKEN,
175
  torch_dtype=torch.float16
176
  )
177
+ logger.info("Translation model initialized successfully")
178
 
179
+ def initialize_medical_model(model_name: str):
180
+ """Initialize medical model (MedSwin) - download on demand"""
181
+ global global_medical_models, global_medical_tokenizers
182
+ if model_name not in global_medical_models or global_medical_models[model_name] is None:
183
+ logger.info(f"Initializing medical model: {model_name}...")
184
+ model_path = MEDSWIN_MODELS[model_name]
185
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
186
+ model = AutoModelForCausalLM.from_pretrained(
187
+ model_path,
188
+ device_map="auto",
189
+ trust_remote_code=True,
190
+ token=HF_TOKEN,
191
+ torch_dtype=torch.float16
192
+ )
193
+ global_medical_models[model_name] = model
194
+ global_medical_tokenizers[model_name] = tokenizer
195
+ logger.info(f"Medical model {model_name} initialized successfully")
196
+ return global_medical_models[model_name], global_medical_tokenizers[model_name]
197
+
198
+ def detect_language(text: str) -> str:
199
+ """Detect language of input text"""
200
+ try:
201
+ lang = detect(text)
202
+ return lang
203
+ except LangDetectException:
204
+ return "en" # Default to English if detection fails
205
+
206
+ def translate_text(text: str, target_lang: str = "en", source_lang: str = None) -> str:
207
+ """Translate text using Llama model"""
208
+ global global_translation_model, global_translation_tokenizer
209
+ if global_translation_model is None or global_translation_tokenizer is None:
210
+ initialize_translation_model()
211
+
212
+ if source_lang:
213
+ prompt = f"Translate the following {source_lang} text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
214
+ else:
215
+ prompt = f"Translate the following text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
216
+
217
+ messages = [
218
+ {"role": "system", "content": "You are a professional translator. Translate accurately and concisely."},
219
+ {"role": "user", "content": prompt}
220
+ ]
221
+
222
+ prompt_text = global_translation_tokenizer.apply_chat_template(
223
+ messages,
224
+ tokenize=False,
225
+ add_generation_prompt=True
226
+ )
227
+
228
+ inputs = global_translation_tokenizer(prompt_text, return_tensors="pt").to(global_translation_model.device)
229
+
230
+ with torch.no_grad():
231
+ outputs = global_translation_model.generate(
232
+ **inputs,
233
+ max_new_tokens=512,
234
+ temperature=0.3,
235
+ do_sample=True,
236
+ pad_token_id=global_translation_tokenizer.eos_token_id
237
+ )
238
+
239
+ response = global_translation_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
240
+ return response.strip()
241
+
242
+ def search_web(query: str, max_results: int = 5) -> list:
243
+ """Search web using DuckDuckGo and extract content"""
244
+ try:
245
+ with DDGS() as ddgs:
246
+ results = list(ddgs.text(query, max_results=max_results))
247
+ web_content = []
248
+ for result in results:
249
+ try:
250
+ url = result.get('href', '')
251
+ title = result.get('title', '')
252
+ snippet = result.get('body', '')
253
+
254
+ # Try to fetch full content
255
+ try:
256
+ response = requests.get(url, timeout=5, headers={'User-Agent': 'Mozilla/5.0'})
257
+ if response.status_code == 200:
258
+ soup = BeautifulSoup(response.content, 'html.parser')
259
+ # Extract main content
260
+ for script in soup(["script", "style"]):
261
+ script.decompose()
262
+ text = soup.get_text()
263
+ # Clean and limit text
264
+ lines = (line.strip() for line in text.splitlines())
265
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
266
+ text = ' '.join(chunk for chunk in chunks if chunk)
267
+ if len(text) > 1000:
268
+ text = text[:1000] + "..."
269
+ web_content.append({
270
+ 'title': title,
271
+ 'url': url,
272
+ 'content': snippet + "\n" + text[:500] if text else snippet
273
+ })
274
+ else:
275
+ web_content.append({
276
+ 'title': title,
277
+ 'url': url,
278
+ 'content': snippet
279
+ })
280
+ except:
281
+ web_content.append({
282
+ 'title': title,
283
+ 'url': url,
284
+ 'content': snippet
285
+ })
286
+ except Exception as e:
287
+ logger.error(f"Error processing search result: {e}")
288
+ continue
289
+ return web_content
290
+ except Exception as e:
291
+ logger.error(f"Web search error: {e}")
292
+ return []
293
+
294
+ def summarize_web_content(content_list: list, query: str) -> str:
295
+ """Summarize web search results using Llama model"""
296
+ global global_translation_model, global_translation_tokenizer
297
+ if global_translation_model is None or global_translation_tokenizer is None:
298
+ initialize_translation_model()
299
+
300
+ combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]])
301
+
302
+ prompt = f"""Summarize the following web search results related to the query: "{query}"
303
+
304
+ Extract key medical information, facts, and insights. Be concise and focus on reliable information.
305
+
306
+ Search Results:
307
+ {combined_content}
308
+
309
+ Summary:"""
310
+
311
+ messages = [
312
+ {"role": "system", "content": "You are a medical information summarizer. Extract and summarize key medical facts accurately."},
313
+ {"role": "user", "content": prompt}
314
+ ]
315
+
316
+ prompt_text = global_translation_tokenizer.apply_chat_template(
317
+ messages,
318
+ tokenize=False,
319
+ add_generation_prompt=True
320
+ )
321
+
322
+ inputs = global_translation_tokenizer(prompt_text, return_tensors="pt").to(global_translation_model.device)
323
+
324
+ with torch.no_grad():
325
+ outputs = global_translation_model.generate(
326
+ **inputs,
327
+ max_new_tokens=512,
328
+ temperature=0.5,
329
+ do_sample=True,
330
+ pad_token_id=global_translation_tokenizer.eos_token_id
331
+ )
332
+
333
+ summary = global_translation_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
334
+ return summary.strip()
335
+
336
+ def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
337
+ """Get LLM for RAG indexing (uses translation model)"""
338
+ if global_translation_model is None or global_translation_tokenizer is None:
339
+ initialize_translation_model()
340
 
341
  return HuggingFaceLLM(
342
  context_window=4096,
343
  max_new_tokens=max_new_tokens,
344
+ tokenizer=global_translation_tokenizer,
345
+ model=global_translation_model,
346
  generate_kwargs={
347
  "do_sample": True,
348
  "temperature": temperature,
 
365
  else:
366
  return None, 0, ValueError(f"Unsupported file format: {file_extension}")
367
 
368
+ @spaces.GPU(max_duration=120)
369
  def create_or_update_index(files, request: gr.Request):
370
  global global_file_info
371
 
 
376
  user_id = request.session_hash
377
  save_dir = f"./{user_id}_index"
378
  # Initialize LlamaIndex modules
379
+ llm = get_llm_for_rag()
380
  embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
381
  Settings.llm = llm
382
  Settings.embed_model = embed_model
 
425
  new_leaf_nodes = get_leaf_nodes(new_nodes)
426
  new_root_nodes = get_root_nodes(new_nodes)
427
  logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
 
 
 
 
 
 
 
428
 
429
  if os.path.exists(save_dir):
430
  logger.info(f"Loading existing index from {save_dir}")
 
472
  output_container += "</div>"
473
  return f"Successfully indexed {len(files)} files.", output_container
474
 
475
+ @spaces.GPU(max_duration=120)
476
  def stream_chat(
477
  message: str,
478
  history: list,
 
484
  penalty: float,
485
  retriever_k: int,
486
  merge_threshold: float,
487
+ use_rag: bool,
488
+ medical_model: str,
489
+ use_web_search: bool,
490
  request: gr.Request
491
  ):
492
  if not request:
493
  yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}]
494
  return
495
+
496
+ # Detect language and translate if needed
497
+ original_lang = detect_language(message)
498
+ original_message = message
499
+ needs_translation = original_lang != "en"
500
+
501
+ if needs_translation:
502
+ logger.info(f"Detected non-English language: {original_lang}, translating to English...")
503
+ message = translate_text(message, target_lang="en", source_lang=original_lang)
504
+ logger.info(f"Translated query: {message}")
505
+
506
  user_id = request.session_hash
507
  index_dir = f"./{user_id}_index"
508
+
509
+ # Initialize medical model
510
+ medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
511
+
512
+ # Adjust system prompt based on RAG setting
513
+ if use_rag:
514
+ if not os.path.exists(index_dir):
515
+ yield history + [{"role": "assistant", "content": "Please upload documents first to use RAG."}]
516
+ return
517
+
518
+ base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide detailed and accurate answers based on the provided medical documents."
519
+ else:
520
+ base_system_prompt = "As a medical specialist, provide short and concise clinical answers. Be brief and avoid lengthy explanations. Focus on key medical facts only."
521
+
522
+ # Get RAG context if enabled
523
+ rag_context = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  source_info = ""
525
+ if use_rag and os.path.exists(index_dir):
526
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
527
+ Settings.embed_model = embed_model
528
+ storage_context = StorageContext.from_defaults(persist_dir=index_dir)
529
+ index = load_index_from_storage(storage_context, settings=Settings)
530
+ base_retriever = index.as_retriever(similarity_top_k=retriever_k)
531
+ auto_merging_retriever = AutoMergingRetriever(
532
+ base_retriever,
533
+ storage_context=storage_context,
534
+ simple_ratio_thresh=merge_threshold,
535
+ verbose=True
536
+ )
537
+ logger.info(f"Query: {message}")
538
+ retrieval_start = time.time()
539
+ merged_nodes = auto_merging_retriever.retrieve(message)
540
+ logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - retrieval_start:.2f}s")
541
+ merged_file_sources = {}
542
+ for node in merged_nodes:
543
+ if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
544
+ file_name = node.node.metadata['file_name']
545
+ if file_name not in merged_file_sources:
546
+ merged_file_sources[file_name] = 0
547
+ merged_file_sources[file_name] += 1
548
+ logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
549
+ rag_context = "\n\n".join([n.node.text for n in merged_nodes])
550
+ if merged_file_sources:
551
+ source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
552
+
553
+ # Get web search context if enabled
554
+ web_context = ""
555
+ web_sources = []
556
+ if use_web_search:
557
+ logger.info("Performing web search...")
558
+ web_results = search_web(message, max_results=5)
559
+ if web_results:
560
+ web_summary = summarize_web_content(web_results, message)
561
+ web_context = f"\n\nAdditional Web Sources:\n{web_summary}"
562
+ web_sources = [r['title'] for r in web_results[:3]]
563
+ logger.info(f"Web search completed, found {len(web_results)} results")
564
+
565
+ # Build final context
566
+ context_parts = []
567
+ if rag_context:
568
+ context_parts.append(f"Document Context:\n{rag_context}")
569
+ if web_context:
570
+ context_parts.append(web_context)
571
+
572
+ full_context = "\n\n".join(context_parts) if context_parts else ""
573
+
574
+ # Build system prompt
575
+ if use_rag or use_web_search:
576
+ formatted_system_prompt = f"{base_system_prompt}\n\n{full_context}{source_info}"
577
+ else:
578
+ formatted_system_prompt = base_system_prompt
579
+
580
+ # Prepare messages
581
  messages = [{"role": "system", "content": formatted_system_prompt}]
582
  for entry in history:
583
  messages.append(entry)
584
  messages.append({"role": "user", "content": message})
585
+
586
+ # Get EOS token and adjust stopping criteria
587
+ eos_token_id = medical_tokenizer.eos_token_id
588
+ if eos_token_id is None:
589
+ eos_token_id = medical_tokenizer.pad_token_id
590
+
591
+ # Increase max tokens for medical models (prevent early stopping)
592
+ max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
593
+ max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
594
+
595
+ prompt = medical_tokenizer.apply_chat_template(
596
  messages,
597
  tokenize=False,
598
  add_generation_prompt=True
599
  )
600
+
601
+ inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
602
+ prompt_length = inputs['input_ids'].shape[1]
603
+
604
  stop_event = threading.Event()
605
+
606
  class StopOnEvent(StoppingCriteria):
607
  def __init__(self, stop_event):
608
  super().__init__()
 
610
 
611
  def __call__(self, input_ids, scores, **kwargs):
612
  return self.stop_event.is_set()
613
+
614
+ # Custom stopping criteria that doesn't stop on EOS too early
615
+ class MedicalStoppingCriteria(StoppingCriteria):
616
+ def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
617
+ super().__init__()
618
+ self.eos_token_id = eos_token_id
619
+ self.prompt_length = prompt_length
620
+ self.min_new_tokens = min_new_tokens
621
+
622
+ def __call__(self, input_ids, scores, **kwargs):
623
+ current_length = input_ids.shape[1]
624
+ new_tokens = current_length - self.prompt_length
625
+ last_token = input_ids[0, -1].item()
626
+
627
+ # Don't stop on EOS if we haven't generated enough new tokens
628
+ if new_tokens < self.min_new_tokens:
629
+ return False
630
+ # Allow EOS after minimum new tokens have been generated
631
+ return last_token == self.eos_token_id
632
+
633
+ stopping_criteria = StoppingCriteriaList([
634
+ StopOnEvent(stop_event),
635
+ MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100)
636
+ ])
637
+
638
  streamer = TextIteratorStreamer(
639
+ medical_tokenizer,
640
  skip_prompt=True,
641
  skip_special_tokens=True
642
  )
643
+
644
+ temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7
645
+ top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
646
+ top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
647
+ penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
648
+
649
  generation_kwargs = dict(
650
  inputs,
651
  streamer=streamer,
 
655
  top_k=top_k,
656
  repetition_penalty=penalty,
657
  do_sample=True,
658
+ stopping_criteria=stopping_criteria,
659
+ eos_token_id=eos_token_id,
660
+ pad_token_id=medical_tokenizer.pad_token_id or eos_token_id
661
  )
662
+
663
+ thread = threading.Thread(target=medical_model_obj.generate, kwargs=generation_kwargs)
664
  thread.start()
665
+
666
  updated_history = history + [
667
+ {"role": "user", "content": original_message},
668
  {"role": "assistant", "content": ""}
669
  ]
670
  yield updated_history
671
+
672
  partial_response = ""
673
  try:
674
  for new_text in streamer:
675
  partial_response += new_text
676
  updated_history[-1]["content"] = partial_response
677
  yield updated_history
678
+
679
+ # Translate back if needed
680
+ if needs_translation and partial_response:
681
+ logger.info(f"Translating response back to {original_lang}...")
682
+ translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
683
+ updated_history[-1]["content"] = translated_response
684
+ yield updated_history
685
+ else:
686
+ yield updated_history
687
+
688
  except GeneratorExit:
689
  stop_event.set()
690
  thread.join()
 
722
  with gr.Column(elem_classes="chatbot-container"):
723
  chatbot = gr.Chatbot(
724
  height=500,
725
+ placeholder="Chat with your medical documents here... Type your question below.",
726
  show_label=False,
727
  type="messages"
728
  )
729
  with gr.Row(elem_classes="input-row"):
730
  message_input = gr.Textbox(
731
+ placeholder="Type your medical question here...",
732
  show_label=False,
733
  container=False,
734
  lines=1,
 
736
  )
737
  submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
738
 
739
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
740
+ with gr.Row():
741
+ use_rag = gr.Checkbox(
742
+ value=True,
743
+ label="Enable Document RAG",
744
+ info="Answer based on uploaded documents"
745
+ )
746
+ use_web_search = gr.Checkbox(
747
+ value=False,
748
+ label="Enable Web Search (MCP)",
749
+ info="Fetch knowledge from online medical resources"
750
+ )
751
+
752
+ medical_model = gr.Radio(
753
+ choices=list(MEDSWIN_MODELS.keys()),
754
+ value=DEFAULT_MEDICAL_MODEL,
755
+ label="Medical Model",
756
+ info="MedSwin SFT (default), others download on first use"
757
+ )
758
+
759
  system_prompt = gr.Textbox(
760
+ value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
761
  label="System Prompt",
762
  lines=3
763
  )
 
767
  minimum=0,
768
  maximum=1,
769
  step=0.1,
770
+ value=0.7,
771
  label="Temperature"
772
  )
773
  max_new_tokens = gr.Slider(
774
+ minimum=512,
775
+ maximum=4096,
776
+ step=128,
777
+ value=2048,
778
  label="Max New Tokens",
779
+ info="Increased for medical models to prevent early stopping"
780
  )
781
  top_p = gr.Slider(
782
  minimum=0.0,
 
828
  top_k,
829
  penalty,
830
  retriever_k,
831
+ merge_threshold,
832
+ use_rag,
833
+ medical_model,
834
+ use_web_search
835
  ],
836
  outputs=chatbot
837
  )
 
848
  top_k,
849
  penalty,
850
  retriever_k,
851
+ merge_threshold,
852
+ use_rag,
853
+ medical_model,
854
+ use_web_search
855
  ],
856
  outputs=chatbot
857
  )
 
859
  return demo
860
 
861
  if __name__ == "__main__":
862
+ # Initialize default medical model
863
+ logger.info("Initializing default medical model (MedSwin SFT)...")
864
+ initialize_medical_model(DEFAULT_MEDICAL_MODEL)
865
  demo = create_demo()
866
+ demo.launch()