import gradio as gr
import os
import base64
import logging
import torch
import threading
import time
import json
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList,
)
from transformers import logging as hf_logging
import spaces
from llama_index.core import (
StorageContext,
VectorStoreIndex,
load_index_from_storage,
Document as LlamaDocument,
)
from llama_index.core import Settings
from llama_index.core.node_parser import (
HierarchicalNodeParser,
get_leaf_nodes,
get_root_nodes,
)
from llama_index.core.retrievers import AutoMergingRetriever
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# Import GPU-tagged model functions
from model import (
get_llm_for_rag as get_llm_for_rag_gpu,
get_embedding_model as get_embedding_model_gpu,
generate_with_medswin,
initialize_medical_model,
global_medical_models,
global_medical_tokenizers
)
from tqdm import tqdm
from langdetect import detect, LangDetectException
# MCP imports
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
import asyncio
try:
import nest_asyncio
nest_asyncio.apply() # Allow nested event loops
except ImportError:
pass # nest_asyncio is optional
MCP_AVAILABLE = True
except ImportError:
MCP_AVAILABLE = False
# Fallback imports if MCP is not available
from ddgs import DDGS
import requests
from bs4 import BeautifulSoup
try:
from TTS.api import TTS
TTS_AVAILABLE = True
except ImportError:
TTS_AVAILABLE = False
TTS = None
import numpy as np
import soundfile as sf
import tempfile
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
hf_logging.set_verbosity_error()
# Model configurations
MEDSWIN_MODELS = {
"MedSwin SFT": "MedSwin/MedSwin-7B-SFT",
"MedSwin KD": "MedSwin/MedSwin-7B-KD",
"MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7"
}
DEFAULT_MEDICAL_MODEL = "MedSwin TA"
EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model
TTS_MODEL = "maya-research/maya1"
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
# Gemini MCP configuration
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash") # Default for harder tasks
GEMINI_MODEL_LITE = os.environ.get("GEMINI_MODEL_LITE", "gemini-2.5-flash-lite") # For parsing and simple tasks
# Custom UI
TITLE = "
𩺠MedLLM Agent - Medical RAG & Web Search System
"
DESCRIPTION = """
Advanced Medical AI Assistant powered by MedSwin models
š Document RAG: Answer based on uploaded medical documents
š Web Search: Fetch knowledge from reliable online medical resources
š Multi-language: Automatic translation for non-English queries
Upload PDF or text files to get started!
"""
CSS = """
.upload-section {
max-width: 400px;
margin: 0 auto;
padding: 10px;
border: 2px dashed #ccc;
border-radius: 10px;
}
.upload-button {
background: #34c759 !important;
color: white !important;
border-radius: 25px !important;
}
.chatbot-container {
margin-top: 20px;
}
.status-output {
margin-top: 10px;
font-size: 14px;
}
.processing-info {
margin-top: 5px;
font-size: 12px;
color: #666;
}
.info-container {
margin-top: 10px;
padding: 10px;
border-radius: 5px;
}
.file-list {
margin-top: 0;
max-height: 200px;
overflow-y: auto;
padding: 5px;
border: 1px solid #eee;
border-radius: 5px;
}
.stats-box {
margin-top: 10px;
padding: 10px;
border-radius: 5px;
font-size: 12px;
}
.submit-btn {
background: #1a73e8 !important;
color: white !important;
border-radius: 25px !important;
margin-left: 10px;
padding: 5px 10px;
font-size: 16px;
}
.input-row {
display: flex;
align-items: center;
}
.recording-timer {
font-size: 12px;
color: #666;
text-align: center;
margin-top: 5px;
}
.feature-badge {
display: inline-block;
padding: 3px 8px;
margin: 2px;
border-radius: 12px;
font-size: 11px;
font-weight: bold;
}
.badge-rag {
background: #e3f2fd;
color: #1976d2;
}
.badge-web {
background: #f3e5f5;
color: #7b1fa2;
}
@media (min-width: 768px) {
.main-container {
display: flex;
justify-content: space-between;
gap: 20px;
}
.upload-section {
flex: 1;
max-width: 300px;
}
.chatbot-container {
flex: 2;
margin-top: 0;
}
}
"""
# Global model storage - models are stored in model.py
# Import the global model storage from model.py
global_file_info = {}
global_tts_model = None
# MCP client storage
global_mcp_session = None
global_mcp_stdio_ctx = None # Store stdio context to keep it alive
global_mcp_lock = threading.Lock() # Lock for thread-safe session access
# MCP server configuration via environment variables
# Gemini MCP server: Python-based server (agent.py)
# This works on Hugging Face Spaces without requiring npm/Node.js
# Make sure GEMINI_API_KEY is set in environment variables
#
# Default configuration uses the bundled agent.py script
# To override:
# export MCP_SERVER_COMMAND="python"
# export MCP_SERVER_ARGS="/path/to/agent.py"
script_dir = os.path.dirname(os.path.abspath(__file__))
agent_path = os.path.join(script_dir, "agent.py")
MCP_SERVER_COMMAND = os.environ.get("MCP_SERVER_COMMAND", "python")
MCP_SERVER_ARGS = os.environ.get("MCP_SERVER_ARGS", agent_path).split() if os.environ.get("MCP_SERVER_ARGS") else [agent_path]
async def get_mcp_session():
"""Get or create MCP client session with proper context management"""
global global_mcp_session, global_mcp_stdio_ctx
if not MCP_AVAILABLE:
return None
# Check if session exists and is still valid
if global_mcp_session is not None:
try:
# Test if session is still alive by listing tools
await global_mcp_session.list_tools()
return global_mcp_session
except Exception as e:
logger.debug(f"Existing MCP session invalid, recreating: {e}")
# Clean up old session
try:
if global_mcp_session is not None:
await global_mcp_session.__aexit__(None, None, None)
except:
pass
try:
if global_mcp_stdio_ctx is not None:
await global_mcp_stdio_ctx.__aexit__(None, None, None)
except:
pass
global_mcp_session = None
global_mcp_stdio_ctx = None
# Create new session using correct MCP SDK pattern
try:
# Prepare environment variables for MCP server
mcp_env = os.environ.copy()
if GEMINI_API_KEY:
mcp_env["GEMINI_API_KEY"] = GEMINI_API_KEY
else:
logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.")
# Add other Gemini MCP configuration if set
if os.environ.get("GEMINI_MODEL"):
mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL")
if os.environ.get("GEMINI_TIMEOUT"):
mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT")
if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"):
mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS")
if os.environ.get("GEMINI_TEMPERATURE"):
mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE")
logger.info(f"Creating MCP client session with command: {MCP_SERVER_COMMAND} {MCP_SERVER_ARGS}")
server_params = StdioServerParameters(
command=MCP_SERVER_COMMAND,
args=MCP_SERVER_ARGS,
env=mcp_env
)
# Correct MCP SDK usage: stdio_client is an async context manager
# that yields (read, write) streams
stdio_ctx = stdio_client(server_params)
read, write = await stdio_ctx.__aenter__()
# Create ClientSession from the streams
# The __aenter__() method automatically handles the initialization handshake
session = ClientSession(read, write)
# Wait longer for the server process to fully start
# The server needs time to: start Python, import modules, initialize Gemini client, start MCP server
logger.info("ā³ Waiting for MCP server process to start...")
await asyncio.sleep(3.0) # Increased wait for server process startup
try:
# Initialize the session (this sends initialize request and waits for response)
logger.info("š Initializing MCP session...")
await session.__aenter__()
logger.info("ā
MCP session initialized, verifying tools...")
except Exception as e:
logger.warning(f"MCP session initialization had an issue (may be expected): {e}")
# Continue anyway - the session might still work
# Wait longer for the server to be fully ready after initialization
# The server needs time to process the initialization and be ready for requests
await asyncio.sleep(2.0) # Wait after initialization
# Verify the session works by listing tools with retries
# This confirms the server is ready to handle requests
max_init_retries = 15
tools_listed = False
tools = None
last_error = None
for init_attempt in range(max_init_retries):
try:
tools = await session.list_tools()
if tools and hasattr(tools, 'tools') and len(tools.tools) > 0:
logger.info(f"ā
MCP server ready with {len(tools.tools)} tools: {[t.name for t in tools.tools]}")
tools_listed = True
break
except Exception as e:
last_error = e
error_str = str(e).lower()
error_msg = str(e)
# Log the actual error for debugging
if init_attempt == 0:
logger.debug(f"First list_tools attempt failed: {error_msg}")
# Ignore initialization-related errors during the handshake phase
if "initialization" in error_str or "before initialization" in error_str or "not initialized" in error_str:
if init_attempt < max_init_retries - 1:
wait_time = 0.5 * (init_attempt + 1) # Progressive wait: 0.5s, 1s, 1.5s...
logger.debug(f"Server still initializing (attempt {init_attempt + 1}/{max_init_retries}), waiting {wait_time}s...")
await asyncio.sleep(wait_time)
continue
elif "invalid request" in error_str or "invalid request parameters" in error_str:
# This might be a timing issue - wait and retry
if init_attempt < max_init_retries - 1:
wait_time = 0.8 * (init_attempt + 1) # Longer wait for invalid request errors
logger.debug(f"Invalid request error (attempt {init_attempt + 1}/{max_init_retries}), waiting {wait_time}s...")
await asyncio.sleep(wait_time)
continue
elif init_attempt < max_init_retries - 1:
wait_time = 0.5 * (init_attempt + 1)
logger.debug(f"Tool listing attempt {init_attempt + 1}/{max_init_retries} failed: {error_msg}, waiting {wait_time}s...")
await asyncio.sleep(wait_time)
else:
logger.error(f"ā Could not list tools after {max_init_retries} attempts. Last error: {error_msg}")
# Don't continue - if we can't list tools, the session is not usable
try:
await session.__aexit__(None, None, None)
except:
pass
try:
await stdio_ctx.__aexit__(None, None, None)
except:
pass
return None
if not tools_listed:
error_msg = str(last_error) if last_error else "Unknown error"
logger.error(f"MCP server failed to initialize - tools could not be listed. Last error: {error_msg}")
try:
await session.__aexit__(None, None, None)
except:
pass
try:
await stdio_ctx.__aexit__(None, None, None)
except:
pass
return None
# Store both the session and stdio context to keep them alive
global_mcp_session = session
global_mcp_stdio_ctx = stdio_ctx
logger.info("MCP client session created successfully")
return session
except Exception as e:
logger.error(f"Failed to create MCP client session: {e}")
import traceback
logger.debug(traceback.format_exc())
global_mcp_session = None
global_mcp_stdio_ctx = None
return None
async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str:
"""Call Gemini MCP generate_content tool"""
if not MCP_AVAILABLE:
logger.warning("MCP not available for Gemini call")
return ""
try:
session = await get_mcp_session()
if session is None:
logger.warning("Failed to get MCP session for Gemini call")
return ""
# Retry listing tools if it fails the first time
# Use more retries and longer waits since MCP server might need time
max_retries = 5
tools = None
for attempt in range(max_retries):
try:
tools = await session.list_tools()
if tools and hasattr(tools, 'tools') and len(tools.tools) > 0:
break
else:
raise ValueError("Empty tools list")
except Exception as e:
if attempt < max_retries - 1:
wait_time = 1.0 * (attempt + 1) # Progressive wait
logger.debug(f"Failed to list tools (attempt {attempt + 1}/{max_retries}), waiting {wait_time}s...")
await asyncio.sleep(wait_time)
else:
logger.error(f"ā Failed to list MCP tools after {max_retries} attempts: {e}")
return ""
if not tools or not hasattr(tools, 'tools'):
logger.error("Invalid tools response from MCP server")
return ""
# Find generate_content tool
generate_tool = None
for tool in tools.tools:
if tool.name == "generate_content" or "generate_content" in tool.name.lower():
generate_tool = tool
logger.info(f"Found Gemini MCP tool: {tool.name}")
break
if not generate_tool:
logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools.tools]}")
return ""
# Prepare arguments
arguments = {
"user_prompt": user_prompt
}
if system_prompt:
arguments["system_prompt"] = system_prompt
if files:
arguments["files"] = files
if model:
arguments["model"] = model
if temperature is not None:
arguments["temperature"] = temperature
logger.info(f"š§ [MCP] Calling Gemini MCP tool '{generate_tool.name}' for: {user_prompt[:100]}...")
logger.info(f"š [MCP] Arguments: model={model}, temperature={temperature}, files={len(files) if files else 0}")
result = await session.call_tool(generate_tool.name, arguments=arguments)
# Parse result
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
response_text = item.text.strip()
logger.info(f"ā
[MCP] Gemini MCP returned response ({len(response_text)} chars)")
return response_text
logger.warning("ā ļø [MCP] Gemini MCP returned empty or invalid result")
return ""
except Exception as e:
logger.error(f"Gemini MCP call error: {e}")
import traceback
logger.debug(traceback.format_exc())
return ""
# initialize_medical_model is now imported from model.py
def initialize_tts_model():
"""Initialize TTS model for text-to-speech"""
global global_tts_model
if not TTS_AVAILABLE:
logger.warning("TTS library not installed. TTS features will be disabled.")
return None
if global_tts_model is None:
try:
logger.info("Initializing TTS model for voice generation...")
global_tts_model = TTS(model_name=TTS_MODEL, progress_bar=False)
logger.info("TTS model initialized successfully")
except Exception as e:
logger.warning(f"TTS model initialization failed: {e}")
logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts")
global_tts_model = None
return global_tts_model
async def transcribe_audio_gemini(audio_path: str) -> str:
"""Transcribe audio using Gemini MCP"""
if not MCP_AVAILABLE:
return ""
try:
# Ensure we have an absolute path
audio_path_abs = os.path.abspath(audio_path)
# Prepare file object for Gemini MCP using path (as per Gemini MCP documentation)
files = [{
"path": audio_path_abs
}]
# Use exact prompts from Gemini MCP documentation
system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts."
user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises."
result = await call_agent(
user_prompt=user_prompt,
system_prompt=system_prompt,
files=files,
model=GEMINI_MODEL_LITE, # Use lite model for transcription
temperature=0.2
)
return result.strip()
except Exception as e:
logger.error(f"Gemini transcription error: {e}")
import traceback
logger.debug(traceback.format_exc())
return ""
def transcribe_audio(audio):
"""Transcribe audio to text using Gemini MCP"""
if audio is None:
return ""
try:
# Handle file path (Gradio Audio component returns file path)
if isinstance(audio, str):
audio_path = audio
elif isinstance(audio, tuple):
# Handle tuple format (sample_rate, audio_data)
sample_rate, audio_data = audio
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, audio_data, samplerate=sample_rate)
audio_path = tmp_file.name
else:
audio_path = audio
# Use Gemini MCP for transcription
if MCP_AVAILABLE:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
try:
import nest_asyncio
transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path))
if transcribed:
logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
return transcribed
except Exception as e:
logger.error(f"Error in nested async transcription: {e}")
else:
transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path))
if transcribed:
logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...")
return transcribed
except Exception as e:
logger.error(f"Gemini MCP transcription error: {e}")
logger.warning("Gemini MCP transcription not available")
return ""
except Exception as e:
logger.error(f"Transcription error: {e}")
return ""
async def generate_speech_mcp(text: str) -> str:
"""Generate speech using MCP TTS tool"""
if not MCP_AVAILABLE:
return None
try:
# Get MCP session
session = await get_mcp_session()
if session is None:
return None
# Find TTS tool
tools = await session.list_tools()
tts_tool = None
for tool in tools.tools:
if "tts" in tool.name.lower() or "speech" in tool.name.lower() or "synthesize" in tool.name.lower():
tts_tool = tool
logger.info(f"Found MCP TTS tool: {tool.name}")
break
if tts_tool:
result = await session.call_tool(
tts_tool.name,
arguments={"text": text, "language": "en"}
)
# Parse result - MCP might return audio data or file path
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
# If it's a file path
if os.path.exists(item.text):
return item.text
elif hasattr(item, 'data') and item.data:
# If it's binary audio data, save it
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(item.data)
return tmp_file.name
return None
except Exception as e:
logger.debug(f"MCP TTS error: {e}")
return None
def generate_speech(text: str):
"""Generate speech from text using TTS model (with MCP fallback)"""
if not text or len(text.strip()) == 0:
return None
# Try MCP first if available
if MCP_AVAILABLE:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
try:
import nest_asyncio
audio_path = nest_asyncio.run(generate_speech_mcp(text))
if audio_path:
logger.info("Generated speech via MCP")
return audio_path
except:
pass
else:
audio_path = loop.run_until_complete(generate_speech_mcp(text))
if audio_path:
logger.info("Generated speech via MCP")
return audio_path
except Exception as e:
logger.debug(f"MCP TTS not available: {e}")
# Fallback to local TTS model
if not TTS_AVAILABLE:
logger.error("TTS library not installed. Please install TTS to use voice generation.")
return None
global global_tts_model
if global_tts_model is None:
initialize_tts_model()
if global_tts_model is None:
logger.error("TTS model not available. Please check dependencies.")
return None
try:
# Generate audio
wav = global_tts_model.tts(text)
# Save to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, wav, samplerate=22050)
return tmp_file.name
except Exception as e:
logger.error(f"TTS error: {e}")
return None
def format_prompt_manually(messages: list, tokenizer) -> str:
"""Manually format prompt for models without chat template"""
prompt_parts = []
# Combine system and user messages into a single instruction
system_content = ""
user_content = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_content = content
elif role == "user":
user_content = content
elif role == "assistant":
# Skip assistant messages in history for now (can be added if needed)
pass
# Format for MedAlpaca/LLaMA-based medical models
# Common format: Instruction + Input -> Response
if system_content:
prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:"
else:
prompt = f"Question: {user_content}\n\nAnswer:"
return prompt
def detect_language(text: str) -> str:
"""Detect language of input text"""
try:
lang = detect(text)
return lang
except LangDetectException:
return "en" # Default to English if detection fails
def format_url_as_domain(url: str) -> str:
"""Format URL as simple domain name (e.g., www.mayoclinic.org)"""
if not url:
return ""
try:
from urllib.parse import urlparse
parsed = urlparse(url)
domain = parsed.netloc or parsed.path
# Remove www. prefix if present, but keep it for display
if domain.startswith('www.'):
return domain
elif domain:
return domain
return url
except Exception:
# Fallback: try to extract domain manually
if '://' in url:
domain = url.split('://')[1].split('/')[0]
return domain
return url
async def translate_text_gemini(text: str, target_lang: str = "en", source_lang: str = None) -> str:
"""Translate text using Gemini MCP"""
if source_lang:
user_prompt = f"Translate the following {source_lang} text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
else:
user_prompt = f"Translate the following text to {target_lang}. Only provide the translation, no explanations:\n\n{text}"
# Use concise system prompt
system_prompt = "You are a professional translator. Translate accurately and concisely."
result = await call_agent(
user_prompt=user_prompt,
system_prompt=system_prompt,
model=GEMINI_MODEL_LITE, # Use lite model for translation
temperature=0.2
)
return result.strip()
def translate_text(text: str, target_lang: str = "en", source_lang: str = None) -> str:
"""Translate text using Gemini MCP"""
if not MCP_AVAILABLE:
logger.warning("Gemini MCP not available for translation")
return text # Return original text if translation fails
try:
loop = asyncio.get_event_loop()
if loop.is_running():
try:
import nest_asyncio
translated = nest_asyncio.run(translate_text_gemini(text, target_lang, source_lang))
if translated:
logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
return translated
except Exception as e:
logger.error(f"Error in nested async translation: {e}")
else:
translated = loop.run_until_complete(translate_text_gemini(text, target_lang, source_lang))
if translated:
logger.info(f"Translated via Gemini MCP: {translated[:50]}...")
return translated
except Exception as e:
logger.error(f"Gemini MCP translation error: {e}")
# Return original text if translation fails
return text
async def search_web_mcp_tool(query: str, max_results: int = 5) -> list:
"""Search web using MCP web search tool (e.g., DuckDuckGo MCP server)"""
if not MCP_AVAILABLE:
return []
try:
session = await get_mcp_session()
if session is None:
return []
# Retry listing tools if it fails the first time
max_retries = 3
tools = None
for attempt in range(max_retries):
try:
tools = await session.list_tools()
break
except Exception as e:
if attempt < max_retries - 1:
await asyncio.sleep(0.5 * (attempt + 1))
else:
logger.error(f"Failed to list MCP tools after {max_retries} attempts: {e}")
return []
if not tools or not hasattr(tools, 'tools'):
return []
# Look for web search tools (DuckDuckGo, search, etc.)
search_tool = None
for tool in tools.tools:
tool_name_lower = tool.name.lower()
if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]):
search_tool = tool
logger.info(f"Found web search MCP tool: {tool.name}")
break
if search_tool:
try:
logger.info(f"š [MCP] Using web search MCP tool '{search_tool.name}' for: {query[:100]}...")
# Call the search tool
result = await session.call_tool(
search_tool.name,
arguments={"query": query, "max_results": max_results}
)
# Parse result
web_content = []
if hasattr(result, 'content') and result.content:
for item in result.content:
if hasattr(item, 'text'):
try:
data = json.loads(item.text)
if isinstance(data, list):
for entry in data[:max_results]:
web_content.append({
'title': entry.get('title', ''),
'url': entry.get('url', entry.get('href', '')),
'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
})
elif isinstance(data, dict):
if 'results' in data:
for entry in data['results'][:max_results]:
web_content.append({
'title': entry.get('title', ''),
'url': entry.get('url', entry.get('href', '')),
'content': entry.get('body', entry.get('snippet', entry.get('content', '')))
})
else:
web_content.append({
'title': data.get('title', ''),
'url': data.get('url', data.get('href', '')),
'content': data.get('body', data.get('snippet', data.get('content', '')))
})
except json.JSONDecodeError:
# If not JSON, treat as plain text
web_content.append({
'title': '',
'url': '',
'content': item.text[:1000]
})
if web_content:
logger.info(f"ā
[MCP] Web search MCP tool returned {len(web_content)} results")
return web_content
except Exception as e:
logger.error(f"Error calling web search MCP tool: {e}")
return []
except Exception as e:
logger.error(f"Web search MCP tool error: {e}")
return []
async def search_web_mcp(query: str, max_results: int = 5) -> list:
"""Search web using MCP tools - tries web search MCP tool first, then falls back to direct search"""
# First try to use a dedicated web search MCP tool (like DuckDuckGo MCP server)
results = await search_web_mcp_tool(query, max_results)
if results:
logger.info(f"ā
Web search via MCP tool: found {len(results)} results")
return results
# If no web search MCP tool available, use direct search (ddgs)
# Note: Gemini MCP doesn't have web search capability, so we use direct API
# The results will then be summarized using Gemini MCP
logger.info("ā¹ļø [Direct API] No web search MCP tool found, using direct DuckDuckGo search (results will be summarized with Gemini MCP)")
return search_web_fallback(query, max_results)
def search_web_fallback(query: str, max_results: int = 5) -> list:
"""Fallback web search using DuckDuckGo directly (when MCP is not available)"""
logger.info(f"š [Direct API] Performing web search using DuckDuckGo API for: {query[:100]}...")
# Always import here to ensure availability
try:
from ddgs import DDGS
import requests
from bs4 import BeautifulSoup
except ImportError:
logger.error("Fallback dependencies (ddgs, requests, beautifulsoup4) not available")
return []
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
web_content = []
for result in results:
try:
url = result.get('href', '')
title = result.get('title', '')
snippet = result.get('body', '')
# Try to fetch full content
try:
response = requests.get(url, timeout=5, headers={'User-Agent': 'Mozilla/5.0'})
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
# Extract main content
for script in soup(["script", "style"]):
script.decompose()
text = soup.get_text()
# Clean and limit text
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = ' '.join(chunk for chunk in chunks if chunk)
if len(text) > 1000:
text = text[:1000] + "..."
web_content.append({
'title': title,
'url': url,
'content': snippet + "\n" + text[:500] if text else snippet
})
else:
web_content.append({
'title': title,
'url': url,
'content': snippet
})
except:
web_content.append({
'title': title,
'url': url,
'content': snippet
})
except Exception as e:
logger.error(f"Error processing search result: {e}")
continue
logger.info(f"ā
[Direct API] Web search completed: {len(web_content)} results")
return web_content
except Exception as e:
logger.error(f"ā [Direct API] Web search error: {e}")
return []
def search_web(query: str, max_results: int = 5) -> list:
"""Search web using MCP tools (synchronous wrapper) - prioritizes MCP over direct ddgs"""
# Always try MCP first if available
if MCP_AVAILABLE:
try:
# Run async MCP search
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if loop.is_running():
# If loop is already running, use nest_asyncio or create new thread
try:
import nest_asyncio
results = nest_asyncio.run(search_web_mcp(query, max_results))
if results: # Only return if we got results from MCP
return results
except (ImportError, AttributeError):
# Fallback: run in thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, search_web_mcp(query, max_results))
results = future.result(timeout=30)
if results: # Only return if we got results from MCP
return results
else:
results = loop.run_until_complete(search_web_mcp(query, max_results))
if results: # Only return if we got results from MCP
return results
except Exception as e:
logger.error(f"Error running async MCP search: {e}")
# Only use ddgs fallback if MCP is not available or returned no results
logger.info("ā¹ļø [Direct API] Falling back to direct DuckDuckGo search (MCP unavailable or returned no results)")
return search_web_fallback(query, max_results)
async def summarize_web_content_gemini(content_list: list, query: str) -> str:
"""Summarize web search results using Gemini MCP"""
logger.info(f"š [MCP] Summarizing {len(content_list)} web search results using Gemini MCP...")
combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]])
user_prompt = f"""Summarize the following web search results related to the query: "{query}"
Extract key medical information, facts, and insights. Be concise and focus on reliable information.
Search Results:
{combined_content}
Summary:"""
# Use concise system prompt
system_prompt = "You are a medical information summarizer. Extract and summarize key medical facts accurately."
result = await call_agent(
user_prompt=user_prompt,
system_prompt=system_prompt,
model=GEMINI_MODEL, # Use full model for summarization
temperature=0.5
)
if result:
logger.info(f"ā
[MCP] Web content summarized successfully using Gemini MCP ({len(result)} chars)")
else:
logger.warning("ā ļø [MCP] Gemini MCP summarization returned empty result")
return result.strip()
def summarize_web_content(content_list: list, query: str) -> str:
"""Summarize web search results using Gemini MCP"""
if not MCP_AVAILABLE:
logger.warning("Gemini MCP not available for summarization")
# Fallback: return first result's content
if content_list:
return content_list[0].get('content', '')[:500]
return ""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
try:
import nest_asyncio
summary = nest_asyncio.run(summarize_web_content_gemini(content_list, query))
if summary:
return summary
except Exception as e:
logger.error(f"Error in nested async summarization: {e}")
else:
summary = loop.run_until_complete(summarize_web_content_gemini(content_list, query))
if summary:
return summary
except Exception as e:
logger.error(f"Gemini MCP summarization error: {e}")
# Fallback: return first result's content
if content_list:
return content_list[0].get('content', '')[:500]
return ""
# get_llm_for_rag is now imported from model.py as get_llm_for_rag_gpu
async def autonomous_reasoning_gemini(query: str) -> dict:
"""Autonomous reasoning using Gemini MCP"""
logger.info(f"š§ [MCP] Analyzing query with Gemini MCP: {query[:100]}...")
reasoning_prompt = f"""Analyze this medical query and provide structured reasoning:
Query: "{query}"
Analyze:
1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info)
2. Complexity: (simple, moderate, complex, multi_faceted)
3. Information Needs: What specific information is required?
4. Requires RAG: (yes/no) - Does this need document context?
5. Requires Web Search: (yes/no) - Does this need current/updated information?
6. Sub-questions: Break down into key sub-questions if complex
Respond in JSON format:
{{
"query_type": "...",
"complexity": "...",
"information_needs": ["..."],
"requires_rag": true/false,
"requires_web_search": true/false,
"sub_questions": ["..."]
}}"""
# Use concise system prompt
system_prompt = "You are a medical reasoning system. Analyze queries systematically and provide structured JSON responses."
response = await call_agent(
user_prompt=reasoning_prompt,
system_prompt=system_prompt,
model=GEMINI_MODEL, # Use full model for reasoning
temperature=0.3
)
# Parse JSON response (with fallback)
try:
# Extract JSON from response
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
reasoning = json.loads(response[json_start:json_end])
else:
raise ValueError("No JSON found")
except:
# Fallback reasoning
reasoning = {
"query_type": "general_info",
"complexity": "moderate",
"information_needs": ["medical information"],
"requires_rag": True,
"requires_web_search": False,
"sub_questions": [query]
}
logger.info(f"Reasoning analysis: {reasoning}")
return reasoning
def autonomous_reasoning(query: str, history: list) -> dict:
"""
Autonomous reasoning: Analyze query complexity, intent, and information needs.
Returns reasoning analysis with query type, complexity, and required information sources.
Uses Gemini MCP for reasoning.
"""
if not MCP_AVAILABLE:
logger.warning("ā ļø Gemini MCP not available for reasoning, using fallback")
# Fallback reasoning
return {
"query_type": "general_info",
"complexity": "moderate",
"information_needs": ["medical information"],
"requires_rag": True,
"requires_web_search": False,
"sub_questions": [query]
}
try:
logger.info("š¤ [MCP] Using Gemini MCP for autonomous reasoning...")
loop = asyncio.get_event_loop()
if loop.is_running():
try:
import nest_asyncio
reasoning = nest_asyncio.run(autonomous_reasoning_gemini(query))
if reasoning and reasoning.get("query_type") != "general_info": # Check if we got real reasoning
logger.info(f"ā
[MCP] Gemini MCP reasoning successful: {reasoning.get('query_type')}, complexity: {reasoning.get('complexity')}")
return reasoning
else:
logger.warning("ā ļø [MCP] Gemini MCP returned fallback reasoning, using it anyway")
return reasoning
except Exception as e:
logger.error(f"ā Error in nested async reasoning: {e}")
import traceback
logger.debug(traceback.format_exc())
else:
reasoning = loop.run_until_complete(autonomous_reasoning_gemini(query))
if reasoning and reasoning.get("query_type") != "general_info":
logger.info(f"ā
[MCP] Gemini MCP reasoning successful: {reasoning.get('query_type')}, complexity: {reasoning.get('complexity')}")
return reasoning
else:
logger.warning("ā ļø [MCP] Gemini MCP returned fallback reasoning, using it anyway")
return reasoning
except Exception as e:
logger.error(f"ā Gemini MCP reasoning error: {e}")
import traceback
logger.debug(traceback.format_exc())
# Fallback reasoning only if all attempts failed
logger.warning("ā ļø Falling back to default reasoning")
return {
"query_type": "general_info",
"complexity": "moderate",
"information_needs": ["medical information"],
"requires_rag": True,
"requires_web_search": False,
"sub_questions": [query]
}
def create_execution_plan(reasoning: dict, query: str, has_rag_index: bool) -> dict:
"""
Planning: Create multi-step execution plan based on reasoning analysis.
Returns execution plan with steps and strategy.
"""
plan = {
"steps": [],
"strategy": "sequential",
"iterations": 1
}
# Determine execution strategy
if reasoning["complexity"] in ["complex", "multi_faceted"]:
plan["strategy"] = "iterative"
plan["iterations"] = 2
# Step 1: Language detection and translation
plan["steps"].append({
"step": 1,
"action": "detect_language",
"description": "Detect query language and translate if needed"
})
# Step 2: RAG retrieval (if needed and available)
if reasoning.get("requires_rag", True) and has_rag_index:
plan["steps"].append({
"step": 2,
"action": "rag_retrieval",
"description": "Retrieve relevant document context",
"parameters": {"top_k": 15, "merge_threshold": 0.5}
})
# Step 3: Web search (if needed)
if reasoning.get("requires_web_search", False):
plan["steps"].append({
"step": 3,
"action": "web_search",
"description": "Search web for current/updated information",
"parameters": {"max_results": 5}
})
# Step 4: Sub-question processing (if complex)
if reasoning.get("sub_questions") and len(reasoning["sub_questions"]) > 1:
plan["steps"].append({
"step": 4,
"action": "multi_step_reasoning",
"description": "Process sub-questions iteratively",
"sub_questions": reasoning["sub_questions"]
})
# Step 5: Synthesis and answer generation
plan["steps"].append({
"step": len(plan["steps"]) + 1,
"action": "synthesize_answer",
"description": "Generate comprehensive answer from all sources"
})
# Step 6: Self-reflection (for complex queries)
if reasoning["complexity"] in ["complex", "multi_faceted"]:
plan["steps"].append({
"step": len(plan["steps"]) + 1,
"action": "self_reflection",
"description": "Evaluate answer quality and completeness"
})
logger.info(f"Execution plan created: {len(plan['steps'])} steps")
return plan
def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, use_web_search: bool, has_rag_index: bool) -> dict:
"""
Autonomous execution: Make decisions on information gathering strategy.
Only suggests web search override, but respects user's RAG disable setting.
"""
strategy = {
"use_rag": use_rag, # Respect user's RAG setting
"use_web_search": use_web_search,
"reasoning_override": False,
"rationale": ""
}
# Only suggest web search override (RAG requires documents, so we respect user's choice)
if reasoning.get("requires_web_search", False) and not use_web_search:
strategy["use_web_search"] = True
strategy["reasoning_override"] = True
strategy["rationale"] += "Reasoning suggests web search for current information. "
# Note: We don't override RAG setting because:
# 1. User may have explicitly disabled it
# 2. RAG requires documents to be uploaded
# 3. We should respect user's explicit choice
if strategy["reasoning_override"]:
logger.info(f"Autonomous override: {strategy['rationale']}")
return strategy
async def self_reflection_gemini(answer: str, query: str) -> dict:
"""Self-reflection using Gemini MCP"""
reflection_prompt = f"""Evaluate this medical answer for quality and completeness:
Query: "{query}"
Answer: "{answer[:1000]}"
Evaluate:
1. Completeness: Does it address all aspects of the query?
2. Accuracy: Is the medical information accurate?
3. Clarity: Is it clear and well-structured?
4. Sources: Are sources cited appropriately?
5. Missing Information: What important information might be missing?
Respond in JSON:
{{
"completeness_score": 0-10,
"accuracy_score": 0-10,
"clarity_score": 0-10,
"overall_score": 0-10,
"missing_aspects": ["..."],
"improvement_suggestions": ["..."]
}}"""
# Use concise system prompt
system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback."
response = await call_agent(
user_prompt=reflection_prompt,
system_prompt=system_prompt,
model=GEMINI_MODEL, # Use full model for reflection
temperature=0.3
)
try:
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
reflection = json.loads(response[json_start:json_end])
else:
reflection = {"overall_score": 7, "improvement_suggestions": []}
except:
reflection = {"overall_score": 7, "improvement_suggestions": []}
logger.info(f"Self-reflection score: {reflection.get('overall_score', 'N/A')}")
return reflection
def self_reflection(answer: str, query: str, reasoning: dict) -> dict:
"""
Self-reflection: Evaluate answer quality and completeness.
Returns reflection with quality score and improvement suggestions.
"""
if not MCP_AVAILABLE:
logger.warning("Gemini MCP not available for reflection, using fallback")
return {"overall_score": 7, "improvement_suggestions": []}
try:
loop = asyncio.get_event_loop()
if loop.is_running():
try:
import nest_asyncio
return nest_asyncio.run(self_reflection_gemini(answer, query))
except Exception as e:
logger.error(f"Error in nested async reflection: {e}")
else:
return loop.run_until_complete(self_reflection_gemini(answer, query))
except Exception as e:
logger.error(f"Gemini MCP reflection error: {e}")
return {"overall_score": 7, "improvement_suggestions": []}
async def parse_document_gemini(file_path: str, file_extension: str) -> str:
"""Parse document using Gemini MCP"""
if not MCP_AVAILABLE:
return ""
try:
# Read file and encode to base64
with open(file_path, 'rb') as f:
file_content = base64.b64encode(f.read()).decode('utf-8')
# Determine MIME type from file extension
mime_type_map = {
'.pdf': 'application/pdf',
'.doc': 'application/msword',
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'.txt': 'text/plain',
'.md': 'text/markdown',
'.json': 'application/json',
'.xml': 'application/xml',
'.csv': 'text/csv'
}
mime_type = mime_type_map.get(file_extension, 'application/octet-stream')
# Prepare file object for Gemini MCP (use content for base64)
files = [{
"content": file_content,
"type": mime_type
}]
# Use concise system prompt
system_prompt = "Extract all text content from the document accurately."
user_prompt = "Extract all text content from this document. Return only the extracted text, preserving structure and formatting where possible."
result = await call_agent(
user_prompt=user_prompt,
system_prompt=system_prompt,
files=files,
model=GEMINI_MODEL_LITE, # Use lite model for parsing
temperature=0.2
)
return result.strip()
except Exception as e:
logger.error(f"Gemini document parsing error: {e}")
import traceback
logger.debug(traceback.format_exc())
return ""
def extract_text_from_document(file):
"""Extract text from document using Gemini MCP"""
file_name = file.name
file_extension = os.path.splitext(file_name)[1].lower()
# Handle text files directly
if file_extension == '.txt':
text = file.read().decode('utf-8')
return text, len(text.split()), None
# For PDF, Word, and other documents, use Gemini MCP
# Save file to temporary location for processing
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
# Write file content to temp file
file.seek(0) # Reset file pointer
tmp_file.write(file.read())
tmp_file_path = tmp_file.name
# Use Gemini MCP to parse document
if MCP_AVAILABLE:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
try:
import nest_asyncio
text = nest_asyncio.run(parse_document_gemini(tmp_file_path, file_extension))
except Exception as e:
logger.error(f"Error in nested async document parsing: {e}")
text = ""
else:
text = loop.run_until_complete(parse_document_gemini(tmp_file_path, file_extension))
# Clean up temp file
try:
os.unlink(tmp_file_path)
except:
pass
if text:
return text, len(text.split()), None
else:
return None, 0, ValueError(f"Failed to extract text from {file_extension} file using Gemini MCP")
except Exception as e:
logger.error(f"Gemini MCP document parsing error: {e}")
# Clean up temp file
try:
os.unlink(tmp_file_path)
except:
pass
return None, 0, ValueError(f"Error parsing {file_extension} file: {str(e)}")
else:
# Clean up temp file
try:
os.unlink(tmp_file_path)
except:
pass
return None, 0, ValueError(f"Gemini MCP not available. Cannot parse {file_extension} files.")
except Exception as e:
logger.error(f"Error processing document: {e}")
return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}")
def create_or_update_index(files, request: gr.Request):
global global_file_info
if not files:
return "Please provide files.", ""
start_time = time.time()
user_id = request.session_hash
save_dir = f"./{user_id}_index"
# Initialize LlamaIndex modules - use GPU functions for model inference only
llm = get_llm_for_rag_gpu()
embed_model = get_embedding_model_gpu()
Settings.llm = llm
Settings.embed_model = embed_model
file_stats = []
new_documents = []
for file in tqdm(files, desc="Processing files"):
file_basename = os.path.basename(file.name)
text, word_count, error = extract_text_from_document(file)
if error:
logger.error(f"Error processing file {file_basename}: {str(error)}")
file_stats.append({
"name": file_basename,
"words": 0,
"status": f"error: {str(error)}"
})
continue
doc = LlamaDocument(
text=text,
metadata={
"file_name": file_basename,
"word_count": word_count,
"source": "user_upload"
}
)
new_documents.append(doc)
file_stats.append({
"name": file_basename,
"words": word_count,
"status": "processed"
})
global_file_info[file_basename] = {
"word_count": word_count,
"processed_at": time.time()
}
node_parser = HierarchicalNodeParser.from_defaults(
chunk_sizes=[2048, 512, 128],
chunk_overlap=20
)
logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
new_nodes = node_parser.get_nodes_from_documents(new_documents)
new_leaf_nodes = get_leaf_nodes(new_nodes)
new_root_nodes = get_root_nodes(new_nodes)
logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
if os.path.exists(save_dir):
logger.info(f"Loading existing index from {save_dir}")
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
index = load_index_from_storage(storage_context, settings=Settings)
docstore = storage_context.docstore
docstore.add_documents(new_nodes)
for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
index.insert_nodes([node])
total_docs = len(docstore.docs)
logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
else:
logger.info("Creating new index")
docstore = SimpleDocumentStore()
storage_context = StorageContext.from_defaults(docstore=docstore)
docstore.add_documents(new_nodes)
index = VectorStoreIndex(
new_leaf_nodes,
storage_context=storage_context,
settings=Settings
)
total_docs = len(new_documents)
logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
index.storage_context.persist(persist_dir=save_dir)
# custom outputs after processing files
file_list_html = ""
for stat in file_stats:
status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
file_list_html += f"
ā {stat['name']} - {stat['words']} words
"
file_list_html += "
"
processing_time = time.time() - start_time
stats_output = f""
stats_output += f"ā Processed {len(files)} files in {processing_time:.2f} seconds
"
stats_output += f"ā Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)
"
stats_output += f"ā Total documents in index: {total_docs}
"
stats_output += f"ā Index saved to: {save_dir}
"
stats_output += "
"
output_container = f""
output_container += file_list_html
output_container += stats_output
output_container += "
"
return f"Successfully indexed {len(files)} files.", output_container
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
penalty: float,
retriever_k: int,
merge_threshold: float,
use_rag: bool,
medical_model: str,
use_web_search: bool,
request: gr.Request
):
if not request:
yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}]
return
user_id = request.session_hash
index_dir = f"./{user_id}_index"
has_rag_index = os.path.exists(index_dir)
# ===== AUTONOMOUS REASONING =====
logger.info("š¤ Starting autonomous reasoning...")
reasoning = autonomous_reasoning(message, history)
# ===== PLANNING =====
logger.info("š Creating execution plan...")
plan = create_execution_plan(reasoning, message, has_rag_index)
# ===== AUTONOMOUS EXECUTION STRATEGY =====
logger.info("šÆ Determining execution strategy...")
execution_strategy = autonomous_execution_strategy(reasoning, plan, use_rag, use_web_search, has_rag_index)
# Use autonomous strategy decisions (respect user's RAG setting)
final_use_rag = execution_strategy["use_rag"] and has_rag_index # Only use RAG if enabled AND documents exist
final_use_web_search = execution_strategy["use_web_search"]
# Show reasoning override message if applicable
reasoning_note = ""
if execution_strategy["reasoning_override"]:
reasoning_note = f"\n\nš” *Autonomous Reasoning: {execution_strategy['rationale']}*"
# Detect language and translate if needed (Step 1 of plan)
original_lang = detect_language(message)
original_message = message
needs_translation = original_lang != "en"
if needs_translation:
logger.info(f"Detected non-English language: {original_lang}, translating to English...")
message = translate_text(message, target_lang="en", source_lang=original_lang)
logger.info(f"Translated query: {message}")
# Initialize medical model
medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model)
# Adjust system prompt based on RAG setting and reasoning
if final_use_rag:
base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers based on the provided medical documents and context."
else:
base_system_prompt = "As a medical specialist, provide short and concise clinical answers. Be brief and avoid lengthy explanations. Focus on key medical facts only."
# Add reasoning context to system prompt for complex queries
if reasoning["complexity"] in ["complex", "multi_faceted"]:
base_system_prompt += f"\n\nQuery Analysis: This is a {reasoning['complexity']} {reasoning['query_type']} query. Address all sub-questions: {', '.join(reasoning.get('sub_questions', [])[:3])}"
# ===== EXECUTION: RAG Retrieval (Step 2) =====
rag_context = ""
source_info = ""
if final_use_rag and has_rag_index:
# Use GPU function for embedding model
embed_model = get_embedding_model_gpu()
Settings.embed_model = embed_model
storage_context = StorageContext.from_defaults(persist_dir=index_dir)
index = load_index_from_storage(storage_context, settings=Settings)
base_retriever = index.as_retriever(similarity_top_k=retriever_k)
auto_merging_retriever = AutoMergingRetriever(
base_retriever,
storage_context=storage_context,
simple_ratio_thresh=merge_threshold,
verbose=True
)
logger.info(f"Query: {message}")
retrieval_start = time.time()
merged_nodes = auto_merging_retriever.retrieve(message)
logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - retrieval_start:.2f}s")
merged_file_sources = {}
for node in merged_nodes:
if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
file_name = node.node.metadata['file_name']
if file_name not in merged_file_sources:
merged_file_sources[file_name] = 0
merged_file_sources[file_name] += 1
logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
rag_context = "\n\n".join([n.node.text for n in merged_nodes])
if merged_file_sources:
source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
# ===== EXECUTION: Web Search (Step 3) =====
web_context = ""
web_sources = []
web_urls = [] # Store URLs for citations
if final_use_web_search:
logger.info("š Performing web search (will use Gemini MCP for summarization)...")
web_results = search_web(message, max_results=5)
if web_results:
logger.info(f"š Found {len(web_results)} web search results, now summarizing with Gemini MCP...")
web_summary = summarize_web_content(web_results, message)
if web_summary and len(web_summary) > 50: # Check if we got a real summary
logger.info(f"ā
[MCP] Gemini MCP summarization successful ({len(web_summary)} chars)")
web_context = f"\n\nAdditional Web Sources (summarized with Gemini MCP):\n{web_summary}"
else:
logger.warning("ā ļø [MCP] Gemini MCP summarization failed or returned empty, using raw results")
# Fallback: use first result's content
web_context = f"\n\nAdditional Web Sources:\n{web_results[0].get('content', '')[:500]}"
web_sources = [r['title'] for r in web_results[:3]]
# Extract unique URLs for citations
web_urls = [r.get('url', '') for r in web_results if r.get('url')]
logger.info(f"ā
Web search completed: {len(web_results)} results, summarized with Gemini MCP")
else:
logger.warning("ā ļø Web search returned no results")
# Build final context
context_parts = []
if rag_context:
context_parts.append(f"Document Context:\n{rag_context}")
if web_context:
context_parts.append(web_context)
full_context = "\n\n".join(context_parts) if context_parts else ""
# Build system prompt
if final_use_rag or final_use_web_search:
formatted_system_prompt = f"{base_system_prompt}\n\n{full_context}{source_info}"
else:
formatted_system_prompt = base_system_prompt
# Prepare messages
messages = [{"role": "system", "content": formatted_system_prompt}]
for entry in history:
messages.append(entry)
messages.append({"role": "user", "content": message})
# Get EOS token and adjust stopping criteria
eos_token_id = medical_tokenizer.eos_token_id
if eos_token_id is None:
eos_token_id = medical_tokenizer.pad_token_id
# Increase max tokens for medical models (prevent early stopping)
max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048
max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers
# Check if tokenizer has chat template, otherwise format manually
if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None:
try:
prompt = medical_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using manual formatting: {e}")
# Fallback to manual formatting
prompt = format_prompt_manually(messages, medical_tokenizer)
else:
# Manual formatting for models without chat template
prompt = format_prompt_manually(messages, medical_tokenizer)
inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device)
prompt_length = inputs['input_ids'].shape[1]
stop_event = threading.Event()
class StopOnEvent(StoppingCriteria):
def __init__(self, stop_event):
super().__init__()
self.stop_event = stop_event
def __call__(self, input_ids, scores, **kwargs):
return self.stop_event.is_set()
# Custom stopping criteria that doesn't stop on EOS too early
class MedicalStoppingCriteria(StoppingCriteria):
def __init__(self, eos_token_id, prompt_length, min_new_tokens=100):
super().__init__()
self.eos_token_id = eos_token_id
self.prompt_length = prompt_length
self.min_new_tokens = min_new_tokens
def __call__(self, input_ids, scores, **kwargs):
current_length = input_ids.shape[1]
new_tokens = current_length - self.prompt_length
last_token = input_ids[0, -1].item()
# Don't stop on EOS if we haven't generated enough new tokens
if new_tokens < self.min_new_tokens:
return False
# Allow EOS after minimum new tokens have been generated
return last_token == self.eos_token_id
stopping_criteria = StoppingCriteriaList([
StopOnEvent(stop_event),
MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100)
])
streamer = TextIteratorStreamer(
medical_tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7
top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
# Call GPU function for model inference only
thread = threading.Thread(
target=generate_with_medswin,
kwargs={
"medical_model_obj": medical_model_obj,
"medical_tokenizer": medical_tokenizer,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"penalty": penalty,
"eos_token_id": eos_token_id,
"pad_token_id": medical_tokenizer.pad_token_id or eos_token_id,
"stop_event": stop_event,
"streamer": streamer,
"stopping_criteria": stopping_criteria
}
)
thread.start()
updated_history = history + [
{"role": "user", "content": original_message},
{"role": "assistant", "content": ""}
]
yield updated_history
partial_response = ""
try:
for new_text in streamer:
partial_response += new_text
updated_history[-1]["content"] = partial_response
yield updated_history
# ===== SELF-REFLECTION (Step 6) =====
if reasoning["complexity"] in ["complex", "multi_faceted"]:
logger.info("š Performing self-reflection on answer quality...")
reflection = self_reflection(partial_response, message, reasoning)
# Add reflection note if score is low or improvements suggested
if reflection.get("overall_score", 10) < 7 or reflection.get("improvement_suggestions"):
reflection_note = f"\n\n---\n**Self-Reflection** (Score: {reflection.get('overall_score', 'N/A')}/10)"
if reflection.get("improvement_suggestions"):
reflection_note += f"\nš” Suggestions: {', '.join(reflection['improvement_suggestions'][:2])}"
partial_response += reflection_note
updated_history[-1]["content"] = partial_response
# Add reasoning note if autonomous override occurred
if reasoning_note:
partial_response = reasoning_note + "\n\n" + partial_response
updated_history[-1]["content"] = partial_response
# Translate back if needed
if needs_translation and partial_response:
logger.info(f"Translating response back to {original_lang}...")
translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en")
partial_response = translated_response
# Add citations if web sources were used
citations_text = ""
if web_urls:
# Get unique domains
unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates
citation_links = []
for url in unique_urls[:5]: # Limit to 5 citations
domain = format_url_as_domain(url)
if domain:
# Create markdown link: [domain](url)
citation_links.append(f"[{domain}]({url})")
if citation_links:
citations_text = "\n\n**Sources:** " + ", ".join(citation_links)
# Add speaker icon and citations to assistant message
speaker_icon = ' š'
partial_response_with_speaker = partial_response + citations_text + speaker_icon
updated_history[-1]["content"] = partial_response_with_speaker
yield updated_history
except GeneratorExit:
stop_event.set()
thread.join()
raise
def generate_speech_for_message(text: str):
"""Generate speech for a message and return audio file"""
audio_path = generate_speech(text)
if audio_path:
return audio_path
return None
def create_demo():
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag and Drop Files Here",
file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with MedSwin... Type your question below.",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your medical question here...",
show_label=False,
container=False,
lines=1,
scale=10
)
mic_button = gr.Audio(
sources=["microphone"],
type="filepath",
label="",
show_label=False,
container=False,
scale=1
)
submit_button = gr.Button("ā¤", elem_classes="submit-btn", scale=1)
# Timer display for recording (shown below input row)
recording_timer = gr.Textbox(
value="",
label="",
show_label=False,
interactive=False,
visible=False,
container=False,
elem_classes="recording-timer"
)
# Handle microphone transcription
import time
recording_start_time = [None]
def handle_recording_start():
"""Called when recording starts"""
recording_start_time[0] = time.time()
return gr.update(visible=True, value="Recording... 0s")
def handle_recording_stop(audio):
"""Called when recording stops"""
recording_start_time[0] = None
if audio is None:
return gr.update(visible=False, value=""), ""
transcribed = transcribe_audio(audio)
return gr.update(visible=False, value=""), transcribed
# Use JavaScript for timer updates (simpler than Gradio Timer)
mic_button.start_recording(
fn=handle_recording_start,
outputs=[recording_timer]
)
mic_button.stop_recording(
fn=handle_recording_stop,
inputs=[mic_button],
outputs=[recording_timer, message_input]
)
# TTS component for generating speech from messages
with gr.Row(visible=False) as tts_row:
tts_text = gr.Textbox(visible=False)
tts_audio = gr.Audio(label="Generated Speech", visible=False)
# Function to generate speech when speaker icon is clicked
def generate_speech_from_chat(history):
"""Extract last assistant message and generate speech"""
if not history or len(history) == 0:
return None
last_msg = history[-1]
if last_msg.get("role") == "assistant":
text = last_msg.get("content", "").replace(" š", "").strip()
if text:
audio_path = generate_speech(text)
return audio_path
return None
# Add TTS button that appears when assistant responds
tts_button = gr.Button("š Play Response", visible=False, size="sm")
# Update TTS button visibility and generate speech
def update_tts_button(history):
if history and len(history) > 0 and history[-1].get("role") == "assistant":
return gr.update(visible=True)
return gr.update(visible=False)
chatbot.change(
fn=update_tts_button,
inputs=[chatbot],
outputs=[tts_button]
)
tts_button.click(
fn=generate_speech_from_chat,
inputs=[chatbot],
outputs=[tts_audio]
)
with gr.Accordion("āļø Advanced Settings", open=False):
with gr.Row():
use_rag = gr.Checkbox(
value=False,
label="Enable Document RAG",
info="Answer based on uploaded documents (requires document upload)"
)
use_web_search = gr.Checkbox(
value=False,
label="Enable Web Search (MCP)",
info="Fetch knowledge from online medical resources"
)
medical_model = gr.Radio(
choices=list(MEDSWIN_MODELS.keys()),
value=DEFAULT_MEDICAL_MODEL,
label="Medical Model",
info="MedSwin TA (default), others download on first use"
)
system_prompt = gr.Textbox(
value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=512,
maximum=4096,
step=128,
value=2048,
label="Max New Tokens",
info="Increased for medical models to prevent early stopping"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
submit_button.click(
fn=stream_chat,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search
],
outputs=chatbot
)
message_input.submit(
fn=stream_chat,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold,
use_rag,
medical_model,
use_web_search
],
outputs=chatbot
)
return demo
if __name__ == "__main__":
# Preload models on startup
logger.info("Preloading models on startup...")
logger.info("Initializing default medical model (MedSwin TA)...")
initialize_medical_model(DEFAULT_MEDICAL_MODEL)
logger.info("Preloading TTS model...")
try:
initialize_tts_model()
if global_tts_model is not None:
logger.info("TTS model preloaded successfully!")
else:
logger.warning("TTS model not available - will use MCP or disable voice generation")
except Exception as e:
logger.warning(f"TTS model preloading failed: {e}")
logger.warning("Text-to-speech will use MCP or be disabled")
# Check Gemini MCP availability
if MCP_AVAILABLE:
logger.info("Gemini MCP is available for translation, summarization, document parsing, and transcription")
else:
logger.warning("Gemini MCP not available - translation, summarization, document parsing, and transcription features will be limited")
logger.info("Model preloading complete!")
demo = create_demo()
demo.launch()