# api/chatbot.py import re import logging from typing import Dict from google import genai from .config import gemini_flash_api_key from .retrieval import retrieval_engine from memory import MemoryManager from utils import translate_query, process_medical_image from search import search_comprehensive from models import summarizer from models import process_search_query from models.guard import safety_guard logger = logging.getLogger("medical-chatbot") class GeminiClient: """Gemini API client for generating responses""" def __init__(self): self.client = genai.Client(api_key=gemini_flash_api_key) def generate_content(self, prompt: str, model: str = "gemini-2.5-flash", temperature: float = 0.7) -> str: """Generate content using Gemini API""" try: response = self.client.models.generate_content(model=model, contents=prompt) return response.text except Exception as e: logger.error(f"[LLM] ❌ Error calling Gemini API: {e}") return "Error generating response from Gemini." class RAGMedicalChatbot: """Main chatbot class with RAG capabilities""" def __init__(self, model_name: str, retrieve_function): self.model_name = model_name self.retrieve = retrieve_function self.gemini_client = GeminiClient() self.memory = MemoryManager() def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "", search_mode: bool = False, video_mode: bool = False) -> str: """Main chat method with RAG and search capabilities""" # 0. Translate query if not EN, this help our RAG system if lang.upper() in {"VI", "ZH"}: user_query = translate_query(user_query, lang.lower()) # 0.1 Safety check on user query is_safe_user, reason_user = safety_guard.check_user_query(user_query or "") if not is_safe_user: logger.warning(f"[SAFETY] Blocked unsafe user query: {reason_user}") return "⚠️ Unable to process this request safely. Please rephrase your question." # 1. Fetch knowledge ## a. KB for generic QA retrieval retrieved_info = self.retrieve(user_query) knowledge_base = "\n".join(retrieved_info) ## b. Diagnosis RAG from symptom query diagnosis_guides = retrieval_engine.retrieve_diagnosis_from_symptoms(user_query) # c. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection contextual_chunks = self.memory.get_contextual_chunks(user_id, user_query, lang) # 2. Retrieval modes (search and/or video) — handled independently search_context = "" url_mapping = {} video_results = [] # Text/web search mode (no videos unless video_mode=True) if search_mode: logger.info(f"[SEARCH] Starting web search mode for query: {user_query}") try: recent_memory_chunk = self.memory.get_context(user_id, num_turns=3) or "" recent_memory_ctx = contextual_chunks if contextual_chunks else recent_memory_chunk[:600] memory_focus = summarizer.summarize_for_query(recent_memory_ctx, user_query, max_length=180) if recent_memory_ctx else "" final_search_query = user_query if not memory_focus else f"{user_query}. {memory_focus}" logger.info(f"[SEARCH] Final search query: {final_search_query}") # Run comprehensive search; include videos only if UI requested search_context, url_mapping, source_aggregation = search_comprehensive( final_search_query, num_results=15, target_language=lang, include_videos=bool(video_mode) ) if search_context and url_mapping: logger.info(f"[SEARCH] Retrieved and processed {len(url_mapping)} comprehensive web resources") else: logger.warning("[SEARCH] No search results found") search_context = "" url_mapping = {} source_aggregation = {} # If videos were requested and provided by comprehensive search, capture them if video_mode and source_aggregation: video_results = source_aggregation.get('sources', []) or [] except Exception as e: logger.error(f"[SEARCH] Search failed: {e}") search_context = "" url_mapping = {} video_results = [] # Standalone video mode (when search_mode is False but videos requested) if (not search_mode) and video_mode: try: video_results = self.video_search(user_query, num_results=5, target_language=lang) except Exception as e: logger.warning(f"[VIDEO] Standalone video search failed: {e}") video_results = [] # 3. Build prompt parts parts = ["You are a medical chatbot, designed to answer medical questions."] parts.append("Please format your answer using MarkDown.") parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.") # 4. Append image diagnosis from VLM if image_diagnosis: parts.append( "A user medical image is diagnosed by our VLM agent:\n" f"{image_diagnosis}\n\n" "Please incorporate the above findings in your response if medically relevant.\n\n" ) # Append contextual chunks from hybrid approach if contextual_chunks: parts.append("Relevant context from conversation history:\n" + contextual_chunks) # Load up guideline (RAG over medical knowledge base) if knowledge_base: parts.append(f"Example Q&A medical scenario knowledge-base: {knowledge_base}") # Symptom-Diagnosis prediction RAG if diagnosis_guides: parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides)) # 5. Search context with comprehensive information if search_context: parts.append("Comprehensive medical information from multiple sources:\n" + search_context) parts.append("IMPORTANT: The above information includes comprehensive details from multiple medical sources with inline citations. Use the citation tags <#ID> that are already included in the text to reference specific sources. The information is organized by medical categories (symptoms, causes, treatments, diagnosis, prevention, prognosis) and includes both text and video sources.") parts.append(f"User's question: {user_query}") parts.append(f"Language to generate answer: {lang}") prompt = "\n\n".join(parts) logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history response = self.gemini_client.generate_content(prompt, model=self.model_name, temperature=0.7) # 6. Process citations and replace with URLs if search_mode and url_mapping: response = self._process_citations(response, url_mapping) # 7. Safety check on model answer is_safe_ans, reason_ans = safety_guard.check_model_answer(user_query, response or "") if not is_safe_ans: logger.warning(f"[SAFETY] Withholding unsafe model answer: {reason_ans}") response = "⚠️ I cannot share that information. Let's discuss this topic at a high level or try a different question. Tips: Ensure correct language preferences." # Store exchange + chunking if user_id: self.memory.add_exchange(user_id, user_query, response, lang=lang) logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response # Return response with video data if available and requested if video_mode and video_results: return { 'text': response.strip(), 'videos': video_results } else: return response.strip() def _process_citations(self, response: str, url_mapping: Dict[int, str]) -> str: """Replace citation tags with actual URLs, handling both single and multiple references""" # Pattern to match both single citations <#1> and multiple citations <#1, #2, #5, #7, #9> citation_pattern = r'<#([^>]+)>' def replace_citation(match): citation_content = match.group(1) # Split by comma and clean up each citation ID citation_ids = [id_str.strip() for id_str in citation_content.split(',')] urls = [] for citation_id in citation_ids: try: doc_id = int(citation_id) if doc_id in url_mapping: url = url_mapping[doc_id] urls.append(f'<{url}>') logger.info(f"[CITATION] Replacing <#{doc_id}> with {url}") else: logger.warning(f"[CITATION] No URL mapping found for document ID {doc_id}") urls.append(f'<#{doc_id}>') # Keep original if URL not found except ValueError: logger.warning(f"[CITATION] Invalid citation ID: {citation_id}") urls.append(f'<#{citation_id}>') # Keep original if invalid # Join multiple URLs with spaces return ' '.join(urls) # Replace citations with URLs processed_response = re.sub(citation_pattern, replace_citation, response) # Count total citations processed citations_found = re.findall(citation_pattern, response) total_citations = sum(len([id_str.strip() for id_str in citation_content.split(',')]) for citation_content in citations_found) logger.info(f"[CITATION] Processed {total_citations} citations from {len(citations_found)} citation groups, {len(url_mapping)} URL mappings available") return processed_response