import gradio as gr import os import base64 import logging import torch import threading import time import json from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList, ) from transformers import logging as hf_logging import spaces from llama_index.core import ( StorageContext, VectorStoreIndex, load_index_from_storage, Document as LlamaDocument, ) from llama_index.core import Settings from llama_index.core.node_parser import ( HierarchicalNodeParser, get_leaf_nodes, get_root_nodes, ) from llama_index.core.retrievers import AutoMergingRetriever from llama_index.core.storage.docstore import SimpleDocumentStore from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.embeddings.huggingface import HuggingFaceEmbedding # Import GPU-tagged model functions from model import ( get_llm_for_rag as get_llm_for_rag_gpu, get_embedding_model as get_embedding_model_gpu, generate_with_medswin, initialize_medical_model, global_medical_models, global_medical_tokenizers ) from tqdm import tqdm from langdetect import detect, LangDetectException # MCP imports try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client import asyncio try: import nest_asyncio nest_asyncio.apply() # Allow nested event loops except ImportError: pass # nest_asyncio is optional MCP_AVAILABLE = True except ImportError: MCP_AVAILABLE = False # Fallback imports if MCP is not available from ddgs import DDGS import requests from bs4 import BeautifulSoup try: from TTS.api import TTS TTS_AVAILABLE = True except ImportError: TTS_AVAILABLE = False TTS = None import numpy as np import soundfile as sf import tempfile os.environ["TOKENIZERS_PARALLELISM"] = "false" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) hf_logging.set_verbosity_error() # Model configurations MEDSWIN_MODELS = { "MedSwin SFT": "MedSwin/MedSwin-7B-SFT", "MedSwin KD": "MedSwin/MedSwin-7B-KD", "MedSwin TA": "MedSwin/MedSwin-Merged-TA-SFT-0.7" } DEFAULT_MEDICAL_MODEL = "MedSwin TA" EMBEDDING_MODEL = "abhinand/MedEmbed-large-v0.1" # Domain-tuned medical embedding model TTS_MODEL = "maya-research/maya1" HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN not found in environment variables") # Gemini MCP configuration GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash") # Default for harder tasks GEMINI_MODEL_LITE = os.environ.get("GEMINI_MODEL_LITE", "gemini-2.5-flash-lite") # For parsing and simple tasks # Custom UI TITLE = "

🩺 MedLLM Agent - Medical RAG & Web Search System

" DESCRIPTION = """

Advanced Medical AI Assistant powered by MedSwin models

šŸ“„ Document RAG: Answer based on uploaded medical documents

🌐 Web Search: Fetch knowledge from reliable online medical resources

šŸŒ Multi-language: Automatic translation for non-English queries

Upload PDF or text files to get started!

""" CSS = """ .upload-section { max-width: 400px; margin: 0 auto; padding: 10px; border: 2px dashed #ccc; border-radius: 10px; } .upload-button { background: #34c759 !important; color: white !important; border-radius: 25px !important; } .chatbot-container { margin-top: 20px; } .status-output { margin-top: 10px; font-size: 14px; } .processing-info { margin-top: 5px; font-size: 12px; color: #666; } .info-container { margin-top: 10px; padding: 10px; border-radius: 5px; } .file-list { margin-top: 0; max-height: 200px; overflow-y: auto; padding: 5px; border: 1px solid #eee; border-radius: 5px; } .stats-box { margin-top: 10px; padding: 10px; border-radius: 5px; font-size: 12px; } .submit-btn { background: #1a73e8 !important; color: white !important; border-radius: 25px !important; margin-left: 10px; padding: 5px 10px; font-size: 16px; } .input-row { display: flex; align-items: center; } .recording-timer { font-size: 12px; color: #666; text-align: center; margin-top: 5px; } .feature-badge { display: inline-block; padding: 3px 8px; margin: 2px; border-radius: 12px; font-size: 11px; font-weight: bold; } .badge-rag { background: #e3f2fd; color: #1976d2; } .badge-web { background: #f3e5f5; color: #7b1fa2; } @media (min-width: 768px) { .main-container { display: flex; justify-content: space-between; gap: 20px; } .upload-section { flex: 1; max-width: 300px; } .chatbot-container { flex: 2; margin-top: 0; } } """ # Global model storage - models are stored in model.py # Import the global model storage from model.py global_file_info = {} global_tts_model = None # MCP client storage global_mcp_session = None global_mcp_stdio_ctx = None # Store stdio context to keep it alive global_mcp_lock = threading.Lock() # Lock for thread-safe session access # MCP server configuration via environment variables # Gemini MCP server: Python-based server (agent.py) # This works on Hugging Face Spaces without requiring npm/Node.js # Make sure GEMINI_API_KEY is set in environment variables # # Default configuration uses the bundled agent.py script # To override: # export MCP_SERVER_COMMAND="python" # export MCP_SERVER_ARGS="/path/to/agent.py" script_dir = os.path.dirname(os.path.abspath(__file__)) agent_path = os.path.join(script_dir, "agent.py") MCP_SERVER_COMMAND = os.environ.get("MCP_SERVER_COMMAND", "python") MCP_SERVER_ARGS = os.environ.get("MCP_SERVER_ARGS", agent_path).split() if os.environ.get("MCP_SERVER_ARGS") else [agent_path] async def get_mcp_session(): """Get or create MCP client session with proper context management""" global global_mcp_session, global_mcp_stdio_ctx if not MCP_AVAILABLE: return None # Check if session exists and is still valid if global_mcp_session is not None: try: # Test if session is still alive by listing tools await global_mcp_session.list_tools() return global_mcp_session except Exception as e: logger.debug(f"Existing MCP session invalid, recreating: {e}") # Clean up old session try: if global_mcp_session is not None: await global_mcp_session.__aexit__(None, None, None) except: pass try: if global_mcp_stdio_ctx is not None: await global_mcp_stdio_ctx.__aexit__(None, None, None) except: pass global_mcp_session = None global_mcp_stdio_ctx = None # Create new session using correct MCP SDK pattern try: # Prepare environment variables for MCP server mcp_env = os.environ.copy() if GEMINI_API_KEY: mcp_env["GEMINI_API_KEY"] = GEMINI_API_KEY else: logger.warning("GEMINI_API_KEY not set in environment. Gemini MCP features may not work.") # Add other Gemini MCP configuration if set if os.environ.get("GEMINI_MODEL"): mcp_env["GEMINI_MODEL"] = os.environ.get("GEMINI_MODEL") if os.environ.get("GEMINI_TIMEOUT"): mcp_env["GEMINI_TIMEOUT"] = os.environ.get("GEMINI_TIMEOUT") if os.environ.get("GEMINI_MAX_OUTPUT_TOKENS"): mcp_env["GEMINI_MAX_OUTPUT_TOKENS"] = os.environ.get("GEMINI_MAX_OUTPUT_TOKENS") if os.environ.get("GEMINI_TEMPERATURE"): mcp_env["GEMINI_TEMPERATURE"] = os.environ.get("GEMINI_TEMPERATURE") logger.info(f"Creating MCP client session with command: {MCP_SERVER_COMMAND} {MCP_SERVER_ARGS}") server_params = StdioServerParameters( command=MCP_SERVER_COMMAND, args=MCP_SERVER_ARGS, env=mcp_env ) # Correct MCP SDK usage: stdio_client is an async context manager # that yields (read, write) streams stdio_ctx = stdio_client(server_params) read, write = await stdio_ctx.__aenter__() # Create ClientSession from the streams # The __aenter__() method automatically handles the initialization handshake session = ClientSession(read, write) # Wait longer for the server process to fully start # The server needs time to: start Python, import modules, initialize Gemini client, start MCP server logger.info("ā³ Waiting for MCP server process to start...") await asyncio.sleep(3.0) # Increased wait for server process startup try: # Initialize the session (this sends initialize request and waits for response) logger.info("šŸ”„ Initializing MCP session...") await session.__aenter__() logger.info("āœ… MCP session initialized, verifying tools...") except Exception as e: logger.warning(f"MCP session initialization had an issue (may be expected): {e}") # Continue anyway - the session might still work # Wait longer for the server to be fully ready after initialization # The server needs time to process the initialization and be ready for requests await asyncio.sleep(2.0) # Wait after initialization # Verify the session works by listing tools with retries # This confirms the server is ready to handle requests max_init_retries = 15 tools_listed = False tools = None last_error = None for init_attempt in range(max_init_retries): try: tools = await session.list_tools() if tools and hasattr(tools, 'tools') and len(tools.tools) > 0: logger.info(f"āœ… MCP server ready with {len(tools.tools)} tools: {[t.name for t in tools.tools]}") tools_listed = True break except Exception as e: last_error = e error_str = str(e).lower() error_msg = str(e) # Log the actual error for debugging if init_attempt == 0: logger.debug(f"First list_tools attempt failed: {error_msg}") # Ignore initialization-related errors during the handshake phase if "initialization" in error_str or "before initialization" in error_str or "not initialized" in error_str: if init_attempt < max_init_retries - 1: wait_time = 0.5 * (init_attempt + 1) # Progressive wait: 0.5s, 1s, 1.5s... logger.debug(f"Server still initializing (attempt {init_attempt + 1}/{max_init_retries}), waiting {wait_time}s...") await asyncio.sleep(wait_time) continue elif "invalid request" in error_str or "invalid request parameters" in error_str: # This might be a timing issue - wait and retry if init_attempt < max_init_retries - 1: wait_time = 0.8 * (init_attempt + 1) # Longer wait for invalid request errors logger.debug(f"Invalid request error (attempt {init_attempt + 1}/{max_init_retries}), waiting {wait_time}s...") await asyncio.sleep(wait_time) continue elif init_attempt < max_init_retries - 1: wait_time = 0.5 * (init_attempt + 1) logger.debug(f"Tool listing attempt {init_attempt + 1}/{max_init_retries} failed: {error_msg}, waiting {wait_time}s...") await asyncio.sleep(wait_time) else: logger.error(f"āŒ Could not list tools after {max_init_retries} attempts. Last error: {error_msg}") # Don't continue - if we can't list tools, the session is not usable try: await session.__aexit__(None, None, None) except: pass try: await stdio_ctx.__aexit__(None, None, None) except: pass return None if not tools_listed: error_msg = str(last_error) if last_error else "Unknown error" logger.error(f"MCP server failed to initialize - tools could not be listed. Last error: {error_msg}") try: await session.__aexit__(None, None, None) except: pass try: await stdio_ctx.__aexit__(None, None, None) except: pass return None # Store both the session and stdio context to keep them alive global_mcp_session = session global_mcp_stdio_ctx = stdio_ctx logger.info("MCP client session created successfully") return session except Exception as e: logger.error(f"Failed to create MCP client session: {e}") import traceback logger.debug(traceback.format_exc()) global_mcp_session = None global_mcp_stdio_ctx = None return None async def call_agent(user_prompt: str, system_prompt: str = None, files: list = None, model: str = None, temperature: float = 0.2) -> str: """Call Gemini MCP generate_content tool""" if not MCP_AVAILABLE: logger.warning("MCP not available for Gemini call") return "" try: session = await get_mcp_session() if session is None: logger.warning("Failed to get MCP session for Gemini call") return "" # Retry listing tools if it fails the first time # Use more retries and longer waits since MCP server might need time max_retries = 5 tools = None for attempt in range(max_retries): try: tools = await session.list_tools() if tools and hasattr(tools, 'tools') and len(tools.tools) > 0: break else: raise ValueError("Empty tools list") except Exception as e: if attempt < max_retries - 1: wait_time = 1.0 * (attempt + 1) # Progressive wait logger.debug(f"Failed to list tools (attempt {attempt + 1}/{max_retries}), waiting {wait_time}s...") await asyncio.sleep(wait_time) else: logger.error(f"āŒ Failed to list MCP tools after {max_retries} attempts: {e}") return "" if not tools or not hasattr(tools, 'tools'): logger.error("Invalid tools response from MCP server") return "" # Find generate_content tool generate_tool = None for tool in tools.tools: if tool.name == "generate_content" or "generate_content" in tool.name.lower(): generate_tool = tool logger.info(f"Found Gemini MCP tool: {tool.name}") break if not generate_tool: logger.warning(f"Gemini MCP generate_content tool not found. Available tools: {[t.name for t in tools.tools]}") return "" # Prepare arguments arguments = { "user_prompt": user_prompt } if system_prompt: arguments["system_prompt"] = system_prompt if files: arguments["files"] = files if model: arguments["model"] = model if temperature is not None: arguments["temperature"] = temperature logger.info(f"šŸ”§ [MCP] Calling Gemini MCP tool '{generate_tool.name}' for: {user_prompt[:100]}...") logger.info(f"šŸ“‹ [MCP] Arguments: model={model}, temperature={temperature}, files={len(files) if files else 0}") result = await session.call_tool(generate_tool.name, arguments=arguments) # Parse result if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): response_text = item.text.strip() logger.info(f"āœ… [MCP] Gemini MCP returned response ({len(response_text)} chars)") return response_text logger.warning("āš ļø [MCP] Gemini MCP returned empty or invalid result") return "" except Exception as e: logger.error(f"Gemini MCP call error: {e}") import traceback logger.debug(traceback.format_exc()) return "" # initialize_medical_model is now imported from model.py def initialize_tts_model(): """Initialize TTS model for text-to-speech""" global global_tts_model if not TTS_AVAILABLE: logger.warning("TTS library not installed. TTS features will be disabled.") return None if global_tts_model is None: try: logger.info("Initializing TTS model for voice generation...") global_tts_model = TTS(model_name=TTS_MODEL, progress_bar=False) logger.info("TTS model initialized successfully") except Exception as e: logger.warning(f"TTS model initialization failed: {e}") logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts") global_tts_model = None return global_tts_model async def transcribe_audio_gemini(audio_path: str) -> str: """Transcribe audio using Gemini MCP""" if not MCP_AVAILABLE: return "" try: # Ensure we have an absolute path audio_path_abs = os.path.abspath(audio_path) # Prepare file object for Gemini MCP using path (as per Gemini MCP documentation) files = [{ "path": audio_path_abs }] # Use exact prompts from Gemini MCP documentation system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts." user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises." result = await call_agent( user_prompt=user_prompt, system_prompt=system_prompt, files=files, model=GEMINI_MODEL_LITE, # Use lite model for transcription temperature=0.2 ) return result.strip() except Exception as e: logger.error(f"Gemini transcription error: {e}") import traceback logger.debug(traceback.format_exc()) return "" def transcribe_audio(audio): """Transcribe audio to text using Gemini MCP""" if audio is None: return "" try: # Handle file path (Gradio Audio component returns file path) if isinstance(audio, str): audio_path = audio elif isinstance(audio, tuple): # Handle tuple format (sample_rate, audio_data) sample_rate, audio_data = audio # Save to temp file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: sf.write(tmp_file.name, audio_data, samplerate=sample_rate) audio_path = tmp_file.name else: audio_path = audio # Use Gemini MCP for transcription if MCP_AVAILABLE: try: loop = asyncio.get_event_loop() if loop.is_running(): try: import nest_asyncio transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path)) if transcribed: logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...") return transcribed except Exception as e: logger.error(f"Error in nested async transcription: {e}") else: transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path)) if transcribed: logger.info(f"Transcribed via Gemini MCP: {transcribed[:50]}...") return transcribed except Exception as e: logger.error(f"Gemini MCP transcription error: {e}") logger.warning("Gemini MCP transcription not available") return "" except Exception as e: logger.error(f"Transcription error: {e}") return "" async def generate_speech_mcp(text: str) -> str: """Generate speech using MCP TTS tool""" if not MCP_AVAILABLE: return None try: # Get MCP session session = await get_mcp_session() if session is None: return None # Find TTS tool tools = await session.list_tools() tts_tool = None for tool in tools.tools: if "tts" in tool.name.lower() or "speech" in tool.name.lower() or "synthesize" in tool.name.lower(): tts_tool = tool logger.info(f"Found MCP TTS tool: {tool.name}") break if tts_tool: result = await session.call_tool( tts_tool.name, arguments={"text": text, "language": "en"} ) # Parse result - MCP might return audio data or file path if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): # If it's a file path if os.path.exists(item.text): return item.text elif hasattr(item, 'data') and item.data: # If it's binary audio data, save it with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: tmp_file.write(item.data) return tmp_file.name return None except Exception as e: logger.debug(f"MCP TTS error: {e}") return None def generate_speech(text: str): """Generate speech from text using TTS model (with MCP fallback)""" if not text or len(text.strip()) == 0: return None # Try MCP first if available if MCP_AVAILABLE: try: loop = asyncio.get_event_loop() if loop.is_running(): try: import nest_asyncio audio_path = nest_asyncio.run(generate_speech_mcp(text)) if audio_path: logger.info("Generated speech via MCP") return audio_path except: pass else: audio_path = loop.run_until_complete(generate_speech_mcp(text)) if audio_path: logger.info("Generated speech via MCP") return audio_path except Exception as e: logger.debug(f"MCP TTS not available: {e}") # Fallback to local TTS model if not TTS_AVAILABLE: logger.error("TTS library not installed. Please install TTS to use voice generation.") return None global global_tts_model if global_tts_model is None: initialize_tts_model() if global_tts_model is None: logger.error("TTS model not available. Please check dependencies.") return None try: # Generate audio wav = global_tts_model.tts(text) # Save to temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: sf.write(tmp_file.name, wav, samplerate=22050) return tmp_file.name except Exception as e: logger.error(f"TTS error: {e}") return None def format_prompt_manually(messages: list, tokenizer) -> str: """Manually format prompt for models without chat template""" prompt_parts = [] # Combine system and user messages into a single instruction system_content = "" user_content = "" for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": system_content = content elif role == "user": user_content = content elif role == "assistant": # Skip assistant messages in history for now (can be added if needed) pass # Format for MedAlpaca/LLaMA-based medical models # Common format: Instruction + Input -> Response if system_content: prompt = f"{system_content}\n\nQuestion: {user_content}\n\nAnswer:" else: prompt = f"Question: {user_content}\n\nAnswer:" return prompt def detect_language(text: str) -> str: """Detect language of input text""" try: lang = detect(text) return lang except LangDetectException: return "en" # Default to English if detection fails def format_url_as_domain(url: str) -> str: """Format URL as simple domain name (e.g., www.mayoclinic.org)""" if not url: return "" try: from urllib.parse import urlparse parsed = urlparse(url) domain = parsed.netloc or parsed.path # Remove www. prefix if present, but keep it for display if domain.startswith('www.'): return domain elif domain: return domain return url except Exception: # Fallback: try to extract domain manually if '://' in url: domain = url.split('://')[1].split('/')[0] return domain return url async def translate_text_gemini(text: str, target_lang: str = "en", source_lang: str = None) -> str: """Translate text using Gemini MCP""" if source_lang: user_prompt = f"Translate the following {source_lang} text to {target_lang}. Only provide the translation, no explanations:\n\n{text}" else: user_prompt = f"Translate the following text to {target_lang}. Only provide the translation, no explanations:\n\n{text}" # Use concise system prompt system_prompt = "You are a professional translator. Translate accurately and concisely." result = await call_agent( user_prompt=user_prompt, system_prompt=system_prompt, model=GEMINI_MODEL_LITE, # Use lite model for translation temperature=0.2 ) return result.strip() def translate_text(text: str, target_lang: str = "en", source_lang: str = None) -> str: """Translate text using Gemini MCP""" if not MCP_AVAILABLE: logger.warning("Gemini MCP not available for translation") return text # Return original text if translation fails try: loop = asyncio.get_event_loop() if loop.is_running(): try: import nest_asyncio translated = nest_asyncio.run(translate_text_gemini(text, target_lang, source_lang)) if translated: logger.info(f"Translated via Gemini MCP: {translated[:50]}...") return translated except Exception as e: logger.error(f"Error in nested async translation: {e}") else: translated = loop.run_until_complete(translate_text_gemini(text, target_lang, source_lang)) if translated: logger.info(f"Translated via Gemini MCP: {translated[:50]}...") return translated except Exception as e: logger.error(f"Gemini MCP translation error: {e}") # Return original text if translation fails return text async def search_web_mcp_tool(query: str, max_results: int = 5) -> list: """Search web using MCP web search tool (e.g., DuckDuckGo MCP server)""" if not MCP_AVAILABLE: return [] try: session = await get_mcp_session() if session is None: return [] # Retry listing tools if it fails the first time max_retries = 3 tools = None for attempt in range(max_retries): try: tools = await session.list_tools() break except Exception as e: if attempt < max_retries - 1: await asyncio.sleep(0.5 * (attempt + 1)) else: logger.error(f"Failed to list MCP tools after {max_retries} attempts: {e}") return [] if not tools or not hasattr(tools, 'tools'): return [] # Look for web search tools (DuckDuckGo, search, etc.) search_tool = None for tool in tools.tools: tool_name_lower = tool.name.lower() if any(keyword in tool_name_lower for keyword in ["search", "duckduckgo", "ddg", "web"]): search_tool = tool logger.info(f"Found web search MCP tool: {tool.name}") break if search_tool: try: logger.info(f"šŸ” [MCP] Using web search MCP tool '{search_tool.name}' for: {query[:100]}...") # Call the search tool result = await session.call_tool( search_tool.name, arguments={"query": query, "max_results": max_results} ) # Parse result web_content = [] if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): try: data = json.loads(item.text) if isinstance(data, list): for entry in data[:max_results]: web_content.append({ 'title': entry.get('title', ''), 'url': entry.get('url', entry.get('href', '')), 'content': entry.get('body', entry.get('snippet', entry.get('content', ''))) }) elif isinstance(data, dict): if 'results' in data: for entry in data['results'][:max_results]: web_content.append({ 'title': entry.get('title', ''), 'url': entry.get('url', entry.get('href', '')), 'content': entry.get('body', entry.get('snippet', entry.get('content', ''))) }) else: web_content.append({ 'title': data.get('title', ''), 'url': data.get('url', data.get('href', '')), 'content': data.get('body', data.get('snippet', data.get('content', ''))) }) except json.JSONDecodeError: # If not JSON, treat as plain text web_content.append({ 'title': '', 'url': '', 'content': item.text[:1000] }) if web_content: logger.info(f"āœ… [MCP] Web search MCP tool returned {len(web_content)} results") return web_content except Exception as e: logger.error(f"Error calling web search MCP tool: {e}") return [] except Exception as e: logger.error(f"Web search MCP tool error: {e}") return [] async def search_web_mcp(query: str, max_results: int = 5) -> list: """Search web using MCP tools - tries web search MCP tool first, then falls back to direct search""" # First try to use a dedicated web search MCP tool (like DuckDuckGo MCP server) results = await search_web_mcp_tool(query, max_results) if results: logger.info(f"āœ… Web search via MCP tool: found {len(results)} results") return results # If no web search MCP tool available, use direct search (ddgs) # Note: Gemini MCP doesn't have web search capability, so we use direct API # The results will then be summarized using Gemini MCP logger.info("ā„¹ļø [Direct API] No web search MCP tool found, using direct DuckDuckGo search (results will be summarized with Gemini MCP)") return search_web_fallback(query, max_results) def search_web_fallback(query: str, max_results: int = 5) -> list: """Fallback web search using DuckDuckGo directly (when MCP is not available)""" logger.info(f"šŸ” [Direct API] Performing web search using DuckDuckGo API for: {query[:100]}...") # Always import here to ensure availability try: from ddgs import DDGS import requests from bs4 import BeautifulSoup except ImportError: logger.error("Fallback dependencies (ddgs, requests, beautifulsoup4) not available") return [] try: with DDGS() as ddgs: results = list(ddgs.text(query, max_results=max_results)) web_content = [] for result in results: try: url = result.get('href', '') title = result.get('title', '') snippet = result.get('body', '') # Try to fetch full content try: response = requests.get(url, timeout=5, headers={'User-Agent': 'Mozilla/5.0'}) if response.status_code == 200: soup = BeautifulSoup(response.content, 'html.parser') # Extract main content for script in soup(["script", "style"]): script.decompose() text = soup.get_text() # Clean and limit text lines = (line.strip() for line in text.splitlines()) chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) text = ' '.join(chunk for chunk in chunks if chunk) if len(text) > 1000: text = text[:1000] + "..." web_content.append({ 'title': title, 'url': url, 'content': snippet + "\n" + text[:500] if text else snippet }) else: web_content.append({ 'title': title, 'url': url, 'content': snippet }) except: web_content.append({ 'title': title, 'url': url, 'content': snippet }) except Exception as e: logger.error(f"Error processing search result: {e}") continue logger.info(f"āœ… [Direct API] Web search completed: {len(web_content)} results") return web_content except Exception as e: logger.error(f"āŒ [Direct API] Web search error: {e}") return [] def search_web(query: str, max_results: int = 5) -> list: """Search web using MCP tools (synchronous wrapper) - prioritizes MCP over direct ddgs""" # Always try MCP first if available if MCP_AVAILABLE: try: # Run async MCP search try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) if loop.is_running(): # If loop is already running, use nest_asyncio or create new thread try: import nest_asyncio results = nest_asyncio.run(search_web_mcp(query, max_results)) if results: # Only return if we got results from MCP return results except (ImportError, AttributeError): # Fallback: run in thread import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, search_web_mcp(query, max_results)) results = future.result(timeout=30) if results: # Only return if we got results from MCP return results else: results = loop.run_until_complete(search_web_mcp(query, max_results)) if results: # Only return if we got results from MCP return results except Exception as e: logger.error(f"Error running async MCP search: {e}") # Only use ddgs fallback if MCP is not available or returned no results logger.info("ā„¹ļø [Direct API] Falling back to direct DuckDuckGo search (MCP unavailable or returned no results)") return search_web_fallback(query, max_results) async def summarize_web_content_gemini(content_list: list, query: str) -> str: """Summarize web search results using Gemini MCP""" logger.info(f"šŸ“ [MCP] Summarizing {len(content_list)} web search results using Gemini MCP...") combined_content = "\n\n".join([f"Source: {item['title']}\n{item['content']}" for item in content_list[:3]]) user_prompt = f"""Summarize the following web search results related to the query: "{query}" Extract key medical information, facts, and insights. Be concise and focus on reliable information. Search Results: {combined_content} Summary:""" # Use concise system prompt system_prompt = "You are a medical information summarizer. Extract and summarize key medical facts accurately." result = await call_agent( user_prompt=user_prompt, system_prompt=system_prompt, model=GEMINI_MODEL, # Use full model for summarization temperature=0.5 ) if result: logger.info(f"āœ… [MCP] Web content summarized successfully using Gemini MCP ({len(result)} chars)") else: logger.warning("āš ļø [MCP] Gemini MCP summarization returned empty result") return result.strip() def summarize_web_content(content_list: list, query: str) -> str: """Summarize web search results using Gemini MCP""" if not MCP_AVAILABLE: logger.warning("Gemini MCP not available for summarization") # Fallback: return first result's content if content_list: return content_list[0].get('content', '')[:500] return "" try: loop = asyncio.get_event_loop() if loop.is_running(): try: import nest_asyncio summary = nest_asyncio.run(summarize_web_content_gemini(content_list, query)) if summary: return summary except Exception as e: logger.error(f"Error in nested async summarization: {e}") else: summary = loop.run_until_complete(summarize_web_content_gemini(content_list, query)) if summary: return summary except Exception as e: logger.error(f"Gemini MCP summarization error: {e}") # Fallback: return first result's content if content_list: return content_list[0].get('content', '')[:500] return "" # get_llm_for_rag is now imported from model.py as get_llm_for_rag_gpu async def autonomous_reasoning_gemini(query: str) -> dict: """Autonomous reasoning using Gemini MCP""" logger.info(f"🧠 [MCP] Analyzing query with Gemini MCP: {query[:100]}...") reasoning_prompt = f"""Analyze this medical query and provide structured reasoning: Query: "{query}" Analyze: 1. Query Type: (diagnosis, treatment, drug_info, symptom_analysis, research, general_info) 2. Complexity: (simple, moderate, complex, multi_faceted) 3. Information Needs: What specific information is required? 4. Requires RAG: (yes/no) - Does this need document context? 5. Requires Web Search: (yes/no) - Does this need current/updated information? 6. Sub-questions: Break down into key sub-questions if complex Respond in JSON format: {{ "query_type": "...", "complexity": "...", "information_needs": ["..."], "requires_rag": true/false, "requires_web_search": true/false, "sub_questions": ["..."] }}""" # Use concise system prompt system_prompt = "You are a medical reasoning system. Analyze queries systematically and provide structured JSON responses." response = await call_agent( user_prompt=reasoning_prompt, system_prompt=system_prompt, model=GEMINI_MODEL, # Use full model for reasoning temperature=0.3 ) # Parse JSON response (with fallback) try: # Extract JSON from response json_start = response.find('{') json_end = response.rfind('}') + 1 if json_start >= 0 and json_end > json_start: reasoning = json.loads(response[json_start:json_end]) else: raise ValueError("No JSON found") except: # Fallback reasoning reasoning = { "query_type": "general_info", "complexity": "moderate", "information_needs": ["medical information"], "requires_rag": True, "requires_web_search": False, "sub_questions": [query] } logger.info(f"Reasoning analysis: {reasoning}") return reasoning def autonomous_reasoning(query: str, history: list) -> dict: """ Autonomous reasoning: Analyze query complexity, intent, and information needs. Returns reasoning analysis with query type, complexity, and required information sources. Uses Gemini MCP for reasoning. """ if not MCP_AVAILABLE: logger.warning("āš ļø Gemini MCP not available for reasoning, using fallback") # Fallback reasoning return { "query_type": "general_info", "complexity": "moderate", "information_needs": ["medical information"], "requires_rag": True, "requires_web_search": False, "sub_questions": [query] } try: logger.info("šŸ¤” [MCP] Using Gemini MCP for autonomous reasoning...") loop = asyncio.get_event_loop() if loop.is_running(): try: import nest_asyncio reasoning = nest_asyncio.run(autonomous_reasoning_gemini(query)) if reasoning and reasoning.get("query_type") != "general_info": # Check if we got real reasoning logger.info(f"āœ… [MCP] Gemini MCP reasoning successful: {reasoning.get('query_type')}, complexity: {reasoning.get('complexity')}") return reasoning else: logger.warning("āš ļø [MCP] Gemini MCP returned fallback reasoning, using it anyway") return reasoning except Exception as e: logger.error(f"āŒ Error in nested async reasoning: {e}") import traceback logger.debug(traceback.format_exc()) else: reasoning = loop.run_until_complete(autonomous_reasoning_gemini(query)) if reasoning and reasoning.get("query_type") != "general_info": logger.info(f"āœ… [MCP] Gemini MCP reasoning successful: {reasoning.get('query_type')}, complexity: {reasoning.get('complexity')}") return reasoning else: logger.warning("āš ļø [MCP] Gemini MCP returned fallback reasoning, using it anyway") return reasoning except Exception as e: logger.error(f"āŒ Gemini MCP reasoning error: {e}") import traceback logger.debug(traceback.format_exc()) # Fallback reasoning only if all attempts failed logger.warning("āš ļø Falling back to default reasoning") return { "query_type": "general_info", "complexity": "moderate", "information_needs": ["medical information"], "requires_rag": True, "requires_web_search": False, "sub_questions": [query] } def create_execution_plan(reasoning: dict, query: str, has_rag_index: bool) -> dict: """ Planning: Create multi-step execution plan based on reasoning analysis. Returns execution plan with steps and strategy. """ plan = { "steps": [], "strategy": "sequential", "iterations": 1 } # Determine execution strategy if reasoning["complexity"] in ["complex", "multi_faceted"]: plan["strategy"] = "iterative" plan["iterations"] = 2 # Step 1: Language detection and translation plan["steps"].append({ "step": 1, "action": "detect_language", "description": "Detect query language and translate if needed" }) # Step 2: RAG retrieval (if needed and available) if reasoning.get("requires_rag", True) and has_rag_index: plan["steps"].append({ "step": 2, "action": "rag_retrieval", "description": "Retrieve relevant document context", "parameters": {"top_k": 15, "merge_threshold": 0.5} }) # Step 3: Web search (if needed) if reasoning.get("requires_web_search", False): plan["steps"].append({ "step": 3, "action": "web_search", "description": "Search web for current/updated information", "parameters": {"max_results": 5} }) # Step 4: Sub-question processing (if complex) if reasoning.get("sub_questions") and len(reasoning["sub_questions"]) > 1: plan["steps"].append({ "step": 4, "action": "multi_step_reasoning", "description": "Process sub-questions iteratively", "sub_questions": reasoning["sub_questions"] }) # Step 5: Synthesis and answer generation plan["steps"].append({ "step": len(plan["steps"]) + 1, "action": "synthesize_answer", "description": "Generate comprehensive answer from all sources" }) # Step 6: Self-reflection (for complex queries) if reasoning["complexity"] in ["complex", "multi_faceted"]: plan["steps"].append({ "step": len(plan["steps"]) + 1, "action": "self_reflection", "description": "Evaluate answer quality and completeness" }) logger.info(f"Execution plan created: {len(plan['steps'])} steps") return plan def autonomous_execution_strategy(reasoning: dict, plan: dict, use_rag: bool, use_web_search: bool, has_rag_index: bool) -> dict: """ Autonomous execution: Make decisions on information gathering strategy. Only suggests web search override, but respects user's RAG disable setting. """ strategy = { "use_rag": use_rag, # Respect user's RAG setting "use_web_search": use_web_search, "reasoning_override": False, "rationale": "" } # Only suggest web search override (RAG requires documents, so we respect user's choice) if reasoning.get("requires_web_search", False) and not use_web_search: strategy["use_web_search"] = True strategy["reasoning_override"] = True strategy["rationale"] += "Reasoning suggests web search for current information. " # Note: We don't override RAG setting because: # 1. User may have explicitly disabled it # 2. RAG requires documents to be uploaded # 3. We should respect user's explicit choice if strategy["reasoning_override"]: logger.info(f"Autonomous override: {strategy['rationale']}") return strategy async def self_reflection_gemini(answer: str, query: str) -> dict: """Self-reflection using Gemini MCP""" reflection_prompt = f"""Evaluate this medical answer for quality and completeness: Query: "{query}" Answer: "{answer[:1000]}" Evaluate: 1. Completeness: Does it address all aspects of the query? 2. Accuracy: Is the medical information accurate? 3. Clarity: Is it clear and well-structured? 4. Sources: Are sources cited appropriately? 5. Missing Information: What important information might be missing? Respond in JSON: {{ "completeness_score": 0-10, "accuracy_score": 0-10, "clarity_score": 0-10, "overall_score": 0-10, "missing_aspects": ["..."], "improvement_suggestions": ["..."] }}""" # Use concise system prompt system_prompt = "You are a medical answer quality evaluator. Provide honest, constructive feedback." response = await call_agent( user_prompt=reflection_prompt, system_prompt=system_prompt, model=GEMINI_MODEL, # Use full model for reflection temperature=0.3 ) try: json_start = response.find('{') json_end = response.rfind('}') + 1 if json_start >= 0 and json_end > json_start: reflection = json.loads(response[json_start:json_end]) else: reflection = {"overall_score": 7, "improvement_suggestions": []} except: reflection = {"overall_score": 7, "improvement_suggestions": []} logger.info(f"Self-reflection score: {reflection.get('overall_score', 'N/A')}") return reflection def self_reflection(answer: str, query: str, reasoning: dict) -> dict: """ Self-reflection: Evaluate answer quality and completeness. Returns reflection with quality score and improvement suggestions. """ if not MCP_AVAILABLE: logger.warning("Gemini MCP not available for reflection, using fallback") return {"overall_score": 7, "improvement_suggestions": []} try: loop = asyncio.get_event_loop() if loop.is_running(): try: import nest_asyncio return nest_asyncio.run(self_reflection_gemini(answer, query)) except Exception as e: logger.error(f"Error in nested async reflection: {e}") else: return loop.run_until_complete(self_reflection_gemini(answer, query)) except Exception as e: logger.error(f"Gemini MCP reflection error: {e}") return {"overall_score": 7, "improvement_suggestions": []} async def parse_document_gemini(file_path: str, file_extension: str) -> str: """Parse document using Gemini MCP""" if not MCP_AVAILABLE: return "" try: # Read file and encode to base64 with open(file_path, 'rb') as f: file_content = base64.b64encode(f.read()).decode('utf-8') # Determine MIME type from file extension mime_type_map = { '.pdf': 'application/pdf', '.doc': 'application/msword', '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', '.txt': 'text/plain', '.md': 'text/markdown', '.json': 'application/json', '.xml': 'application/xml', '.csv': 'text/csv' } mime_type = mime_type_map.get(file_extension, 'application/octet-stream') # Prepare file object for Gemini MCP (use content for base64) files = [{ "content": file_content, "type": mime_type }] # Use concise system prompt system_prompt = "Extract all text content from the document accurately." user_prompt = "Extract all text content from this document. Return only the extracted text, preserving structure and formatting where possible." result = await call_agent( user_prompt=user_prompt, system_prompt=system_prompt, files=files, model=GEMINI_MODEL_LITE, # Use lite model for parsing temperature=0.2 ) return result.strip() except Exception as e: logger.error(f"Gemini document parsing error: {e}") import traceback logger.debug(traceback.format_exc()) return "" def extract_text_from_document(file): """Extract text from document using Gemini MCP""" file_name = file.name file_extension = os.path.splitext(file_name)[1].lower() # Handle text files directly if file_extension == '.txt': text = file.read().decode('utf-8') return text, len(text.split()), None # For PDF, Word, and other documents, use Gemini MCP # Save file to temporary location for processing try: with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: # Write file content to temp file file.seek(0) # Reset file pointer tmp_file.write(file.read()) tmp_file_path = tmp_file.name # Use Gemini MCP to parse document if MCP_AVAILABLE: try: loop = asyncio.get_event_loop() if loop.is_running(): try: import nest_asyncio text = nest_asyncio.run(parse_document_gemini(tmp_file_path, file_extension)) except Exception as e: logger.error(f"Error in nested async document parsing: {e}") text = "" else: text = loop.run_until_complete(parse_document_gemini(tmp_file_path, file_extension)) # Clean up temp file try: os.unlink(tmp_file_path) except: pass if text: return text, len(text.split()), None else: return None, 0, ValueError(f"Failed to extract text from {file_extension} file using Gemini MCP") except Exception as e: logger.error(f"Gemini MCP document parsing error: {e}") # Clean up temp file try: os.unlink(tmp_file_path) except: pass return None, 0, ValueError(f"Error parsing {file_extension} file: {str(e)}") else: # Clean up temp file try: os.unlink(tmp_file_path) except: pass return None, 0, ValueError(f"Gemini MCP not available. Cannot parse {file_extension} files.") except Exception as e: logger.error(f"Error processing document: {e}") return None, 0, ValueError(f"Error processing {file_extension} file: {str(e)}") def create_or_update_index(files, request: gr.Request): global global_file_info if not files: return "Please provide files.", "" start_time = time.time() user_id = request.session_hash save_dir = f"./{user_id}_index" # Initialize LlamaIndex modules - use GPU functions for model inference only llm = get_llm_for_rag_gpu() embed_model = get_embedding_model_gpu() Settings.llm = llm Settings.embed_model = embed_model file_stats = [] new_documents = [] for file in tqdm(files, desc="Processing files"): file_basename = os.path.basename(file.name) text, word_count, error = extract_text_from_document(file) if error: logger.error(f"Error processing file {file_basename}: {str(error)}") file_stats.append({ "name": file_basename, "words": 0, "status": f"error: {str(error)}" }) continue doc = LlamaDocument( text=text, metadata={ "file_name": file_basename, "word_count": word_count, "source": "user_upload" } ) new_documents.append(doc) file_stats.append({ "name": file_basename, "words": word_count, "status": "processed" }) global_file_info[file_basename] = { "word_count": word_count, "processed_at": time.time() } node_parser = HierarchicalNodeParser.from_defaults( chunk_sizes=[2048, 512, 128], chunk_overlap=20 ) logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes") new_nodes = node_parser.get_nodes_from_documents(new_documents) new_leaf_nodes = get_leaf_nodes(new_nodes) new_root_nodes = get_root_nodes(new_nodes) logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)") if os.path.exists(save_dir): logger.info(f"Loading existing index from {save_dir}") storage_context = StorageContext.from_defaults(persist_dir=save_dir) index = load_index_from_storage(storage_context, settings=Settings) docstore = storage_context.docstore docstore.add_documents(new_nodes) for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"): index.insert_nodes([node]) total_docs = len(docstore.docs) logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files") else: logger.info("Creating new index") docstore = SimpleDocumentStore() storage_context = StorageContext.from_defaults(docstore=docstore) docstore.add_documents(new_nodes) index = VectorStoreIndex( new_leaf_nodes, storage_context=storage_context, settings=Settings ) total_docs = len(new_documents) logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files") index.storage_context.persist(persist_dir=save_dir) # custom outputs after processing files file_list_html = "
" for stat in file_stats: status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336" file_list_html += f"
ā— {stat['name']} - {stat['words']} words
" file_list_html += "
" processing_time = time.time() - start_time stats_output = f"
" stats_output += f"āœ“ Processed {len(files)} files in {processing_time:.2f} seconds
" stats_output += f"āœ“ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)
" stats_output += f"āœ“ Total documents in index: {total_docs}
" stats_output += f"āœ“ Index saved to: {save_dir}
" stats_output += "
" output_container = f"
" output_container += file_list_html output_container += stats_output output_container += "
" return f"Successfully indexed {len(files)} files.", output_container def stream_chat( message: str, history: list, system_prompt: str, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float, retriever_k: int, merge_threshold: float, use_rag: bool, medical_model: str, use_web_search: bool, request: gr.Request ): if not request: yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}] return user_id = request.session_hash index_dir = f"./{user_id}_index" has_rag_index = os.path.exists(index_dir) # ===== AUTONOMOUS REASONING ===== logger.info("šŸ¤” Starting autonomous reasoning...") reasoning = autonomous_reasoning(message, history) # ===== PLANNING ===== logger.info("šŸ“‹ Creating execution plan...") plan = create_execution_plan(reasoning, message, has_rag_index) # ===== AUTONOMOUS EXECUTION STRATEGY ===== logger.info("šŸŽÆ Determining execution strategy...") execution_strategy = autonomous_execution_strategy(reasoning, plan, use_rag, use_web_search, has_rag_index) # Use autonomous strategy decisions (respect user's RAG setting) final_use_rag = execution_strategy["use_rag"] and has_rag_index # Only use RAG if enabled AND documents exist final_use_web_search = execution_strategy["use_web_search"] # Show reasoning override message if applicable reasoning_note = "" if execution_strategy["reasoning_override"]: reasoning_note = f"\n\nšŸ’” *Autonomous Reasoning: {execution_strategy['rationale']}*" # Detect language and translate if needed (Step 1 of plan) original_lang = detect_language(message) original_message = message needs_translation = original_lang != "en" if needs_translation: logger.info(f"Detected non-English language: {original_lang}, translating to English...") message = translate_text(message, target_lang="en", source_lang=original_lang) logger.info(f"Translated query: {message}") # Initialize medical model medical_model_obj, medical_tokenizer = initialize_medical_model(medical_model) # Adjust system prompt based on RAG setting and reasoning if final_use_rag: base_system_prompt = system_prompt if system_prompt else "As a medical specialist, provide clinical and concise answers based on the provided medical documents and context." else: base_system_prompt = "As a medical specialist, provide short and concise clinical answers. Be brief and avoid lengthy explanations. Focus on key medical facts only." # Add reasoning context to system prompt for complex queries if reasoning["complexity"] in ["complex", "multi_faceted"]: base_system_prompt += f"\n\nQuery Analysis: This is a {reasoning['complexity']} {reasoning['query_type']} query. Address all sub-questions: {', '.join(reasoning.get('sub_questions', [])[:3])}" # ===== EXECUTION: RAG Retrieval (Step 2) ===== rag_context = "" source_info = "" if final_use_rag and has_rag_index: # Use GPU function for embedding model embed_model = get_embedding_model_gpu() Settings.embed_model = embed_model storage_context = StorageContext.from_defaults(persist_dir=index_dir) index = load_index_from_storage(storage_context, settings=Settings) base_retriever = index.as_retriever(similarity_top_k=retriever_k) auto_merging_retriever = AutoMergingRetriever( base_retriever, storage_context=storage_context, simple_ratio_thresh=merge_threshold, verbose=True ) logger.info(f"Query: {message}") retrieval_start = time.time() merged_nodes = auto_merging_retriever.retrieve(message) logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - retrieval_start:.2f}s") merged_file_sources = {} for node in merged_nodes: if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata: file_name = node.node.metadata['file_name'] if file_name not in merged_file_sources: merged_file_sources[file_name] = 0 merged_file_sources[file_name] += 1 logger.info(f"Merged retrieval file distribution: {merged_file_sources}") rag_context = "\n\n".join([n.node.text for n in merged_nodes]) if merged_file_sources: source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys()) # ===== EXECUTION: Web Search (Step 3) ===== web_context = "" web_sources = [] web_urls = [] # Store URLs for citations if final_use_web_search: logger.info("🌐 Performing web search (will use Gemini MCP for summarization)...") web_results = search_web(message, max_results=5) if web_results: logger.info(f"šŸ“Š Found {len(web_results)} web search results, now summarizing with Gemini MCP...") web_summary = summarize_web_content(web_results, message) if web_summary and len(web_summary) > 50: # Check if we got a real summary logger.info(f"āœ… [MCP] Gemini MCP summarization successful ({len(web_summary)} chars)") web_context = f"\n\nAdditional Web Sources (summarized with Gemini MCP):\n{web_summary}" else: logger.warning("āš ļø [MCP] Gemini MCP summarization failed or returned empty, using raw results") # Fallback: use first result's content web_context = f"\n\nAdditional Web Sources:\n{web_results[0].get('content', '')[:500]}" web_sources = [r['title'] for r in web_results[:3]] # Extract unique URLs for citations web_urls = [r.get('url', '') for r in web_results if r.get('url')] logger.info(f"āœ… Web search completed: {len(web_results)} results, summarized with Gemini MCP") else: logger.warning("āš ļø Web search returned no results") # Build final context context_parts = [] if rag_context: context_parts.append(f"Document Context:\n{rag_context}") if web_context: context_parts.append(web_context) full_context = "\n\n".join(context_parts) if context_parts else "" # Build system prompt if final_use_rag or final_use_web_search: formatted_system_prompt = f"{base_system_prompt}\n\n{full_context}{source_info}" else: formatted_system_prompt = base_system_prompt # Prepare messages messages = [{"role": "system", "content": formatted_system_prompt}] for entry in history: messages.append(entry) messages.append({"role": "user", "content": message}) # Get EOS token and adjust stopping criteria eos_token_id = medical_tokenizer.eos_token_id if eos_token_id is None: eos_token_id = medical_tokenizer.pad_token_id # Increase max tokens for medical models (prevent early stopping) max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 2048 max_new_tokens = max(max_new_tokens, 1024) # Minimum 1024 tokens for medical answers # Check if tokenizer has chat template, otherwise format manually if hasattr(medical_tokenizer, 'chat_template') and medical_tokenizer.chat_template is not None: try: prompt = medical_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception as e: logger.warning(f"Chat template failed, using manual formatting: {e}") # Fallback to manual formatting prompt = format_prompt_manually(messages, medical_tokenizer) else: # Manual formatting for models without chat template prompt = format_prompt_manually(messages, medical_tokenizer) inputs = medical_tokenizer(prompt, return_tensors="pt").to(medical_model_obj.device) prompt_length = inputs['input_ids'].shape[1] stop_event = threading.Event() class StopOnEvent(StoppingCriteria): def __init__(self, stop_event): super().__init__() self.stop_event = stop_event def __call__(self, input_ids, scores, **kwargs): return self.stop_event.is_set() # Custom stopping criteria that doesn't stop on EOS too early class MedicalStoppingCriteria(StoppingCriteria): def __init__(self, eos_token_id, prompt_length, min_new_tokens=100): super().__init__() self.eos_token_id = eos_token_id self.prompt_length = prompt_length self.min_new_tokens = min_new_tokens def __call__(self, input_ids, scores, **kwargs): current_length = input_ids.shape[1] new_tokens = current_length - self.prompt_length last_token = input_ids[0, -1].item() # Don't stop on EOS if we haven't generated enough new tokens if new_tokens < self.min_new_tokens: return False # Allow EOS after minimum new tokens have been generated return last_token == self.eos_token_id stopping_criteria = StoppingCriteriaList([ StopOnEvent(stop_event), MedicalStoppingCriteria(eos_token_id, prompt_length, min_new_tokens=100) ]) streamer = TextIteratorStreamer( medical_tokenizer, skip_prompt=True, skip_special_tokens=True ) temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.7 top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95 top_k = int(top_k) if isinstance(top_k, (int, float)) else 50 penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2 # Call GPU function for model inference only thread = threading.Thread( target=generate_with_medswin, kwargs={ "medical_model_obj": medical_model_obj, "medical_tokenizer": medical_tokenizer, "prompt": prompt, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "penalty": penalty, "eos_token_id": eos_token_id, "pad_token_id": medical_tokenizer.pad_token_id or eos_token_id, "stop_event": stop_event, "streamer": streamer, "stopping_criteria": stopping_criteria } ) thread.start() updated_history = history + [ {"role": "user", "content": original_message}, {"role": "assistant", "content": ""} ] yield updated_history partial_response = "" try: for new_text in streamer: partial_response += new_text updated_history[-1]["content"] = partial_response yield updated_history # ===== SELF-REFLECTION (Step 6) ===== if reasoning["complexity"] in ["complex", "multi_faceted"]: logger.info("šŸ” Performing self-reflection on answer quality...") reflection = self_reflection(partial_response, message, reasoning) # Add reflection note if score is low or improvements suggested if reflection.get("overall_score", 10) < 7 or reflection.get("improvement_suggestions"): reflection_note = f"\n\n---\n**Self-Reflection** (Score: {reflection.get('overall_score', 'N/A')}/10)" if reflection.get("improvement_suggestions"): reflection_note += f"\nšŸ’” Suggestions: {', '.join(reflection['improvement_suggestions'][:2])}" partial_response += reflection_note updated_history[-1]["content"] = partial_response # Add reasoning note if autonomous override occurred if reasoning_note: partial_response = reasoning_note + "\n\n" + partial_response updated_history[-1]["content"] = partial_response # Translate back if needed if needs_translation and partial_response: logger.info(f"Translating response back to {original_lang}...") translated_response = translate_text(partial_response, target_lang=original_lang, source_lang="en") partial_response = translated_response # Add citations if web sources were used citations_text = "" if web_urls: # Get unique domains unique_urls = list(dict.fromkeys(web_urls)) # Preserve order, remove duplicates citation_links = [] for url in unique_urls[:5]: # Limit to 5 citations domain = format_url_as_domain(url) if domain: # Create markdown link: [domain](url) citation_links.append(f"[{domain}]({url})") if citation_links: citations_text = "\n\n**Sources:** " + ", ".join(citation_links) # Add speaker icon and citations to assistant message speaker_icon = ' šŸ”Š' partial_response_with_speaker = partial_response + citations_text + speaker_icon updated_history[-1]["content"] = partial_response_with_speaker yield updated_history except GeneratorExit: stop_event.set() thread.join() raise def generate_speech_for_message(text: str): """Generate speech for a message and return audio file""" audio_path = generate_speech(text) if audio_path: return audio_path return None def create_demo(): with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: gr.HTML(TITLE) gr.HTML(DESCRIPTION) with gr.Row(elem_classes="main-container"): with gr.Column(elem_classes="upload-section"): file_upload = gr.File( file_count="multiple", label="Drag and Drop Files Here", file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"], elem_id="file-upload" ) upload_button = gr.Button("Upload & Index", elem_classes="upload-button") status_output = gr.Textbox( label="Status", placeholder="Upload files to start...", interactive=False ) file_info_output = gr.HTML( label="File Information", elem_classes="processing-info" ) upload_button.click( fn=create_or_update_index, inputs=[file_upload], outputs=[status_output, file_info_output] ) with gr.Column(elem_classes="chatbot-container"): chatbot = gr.Chatbot( height=500, placeholder="Chat with MedSwin... Type your question below.", show_label=False, type="messages" ) with gr.Row(elem_classes="input-row"): message_input = gr.Textbox( placeholder="Type your medical question here...", show_label=False, container=False, lines=1, scale=10 ) mic_button = gr.Audio( sources=["microphone"], type="filepath", label="", show_label=False, container=False, scale=1 ) submit_button = gr.Button("āž¤", elem_classes="submit-btn", scale=1) # Timer display for recording (shown below input row) recording_timer = gr.Textbox( value="", label="", show_label=False, interactive=False, visible=False, container=False, elem_classes="recording-timer" ) # Handle microphone transcription import time recording_start_time = [None] def handle_recording_start(): """Called when recording starts""" recording_start_time[0] = time.time() return gr.update(visible=True, value="Recording... 0s") def handle_recording_stop(audio): """Called when recording stops""" recording_start_time[0] = None if audio is None: return gr.update(visible=False, value=""), "" transcribed = transcribe_audio(audio) return gr.update(visible=False, value=""), transcribed # Use JavaScript for timer updates (simpler than Gradio Timer) mic_button.start_recording( fn=handle_recording_start, outputs=[recording_timer] ) mic_button.stop_recording( fn=handle_recording_stop, inputs=[mic_button], outputs=[recording_timer, message_input] ) # TTS component for generating speech from messages with gr.Row(visible=False) as tts_row: tts_text = gr.Textbox(visible=False) tts_audio = gr.Audio(label="Generated Speech", visible=False) # Function to generate speech when speaker icon is clicked def generate_speech_from_chat(history): """Extract last assistant message and generate speech""" if not history or len(history) == 0: return None last_msg = history[-1] if last_msg.get("role") == "assistant": text = last_msg.get("content", "").replace(" šŸ”Š", "").strip() if text: audio_path = generate_speech(text) return audio_path return None # Add TTS button that appears when assistant responds tts_button = gr.Button("šŸ”Š Play Response", visible=False, size="sm") # Update TTS button visibility and generate speech def update_tts_button(history): if history and len(history) > 0 and history[-1].get("role") == "assistant": return gr.update(visible=True) return gr.update(visible=False) chatbot.change( fn=update_tts_button, inputs=[chatbot], outputs=[tts_button] ) tts_button.click( fn=generate_speech_from_chat, inputs=[chatbot], outputs=[tts_audio] ) with gr.Accordion("āš™ļø Advanced Settings", open=False): with gr.Row(): use_rag = gr.Checkbox( value=False, label="Enable Document RAG", info="Answer based on uploaded documents (requires document upload)" ) use_web_search = gr.Checkbox( value=False, label="Enable Web Search (MCP)", info="Fetch knowledge from online medical resources" ) medical_model = gr.Radio( choices=list(MEDSWIN_MODELS.keys()), value=DEFAULT_MEDICAL_MODEL, label="Medical Model", info="MedSwin TA (default), others download on first use" ) system_prompt = gr.Textbox( value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available.", label="System Prompt", lines=3 ) with gr.Tab("Generation Parameters"): temperature = gr.Slider( minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature" ) max_new_tokens = gr.Slider( minimum=512, maximum=4096, step=128, value=2048, label="Max New Tokens", info="Increased for medical models to prevent early stopping" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Top P" ) top_k = gr.Slider( minimum=1, maximum=100, step=1, value=50, label="Top K" ) penalty = gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.2, label="Repetition Penalty" ) with gr.Tab("Retrieval Parameters"): retriever_k = gr.Slider( minimum=5, maximum=30, step=1, value=15, label="Initial Retrieval Size (Top K)" ) merge_threshold = gr.Slider( minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Merge Threshold (lower = more merging)" ) submit_button.click( fn=stream_chat, inputs=[ message_input, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model, use_web_search ], outputs=chatbot ) message_input.submit( fn=stream_chat, inputs=[ message_input, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, penalty, retriever_k, merge_threshold, use_rag, medical_model, use_web_search ], outputs=chatbot ) return demo if __name__ == "__main__": # Preload models on startup logger.info("Preloading models on startup...") logger.info("Initializing default medical model (MedSwin TA)...") initialize_medical_model(DEFAULT_MEDICAL_MODEL) logger.info("Preloading TTS model...") try: initialize_tts_model() if global_tts_model is not None: logger.info("TTS model preloaded successfully!") else: logger.warning("TTS model not available - will use MCP or disable voice generation") except Exception as e: logger.warning(f"TTS model preloading failed: {e}") logger.warning("Text-to-speech will use MCP or be disabled") # Check Gemini MCP availability if MCP_AVAILABLE: logger.info("Gemini MCP is available for translation, summarization, document parsing, and transcription") else: logger.warning("Gemini MCP not available - translation, summarization, document parsing, and transcription features will be limited") logger.info("Model preloading complete!") demo = create_demo() demo.launch()