LiamKhoaLe commited on
Commit
b8bf5c8
·
1 Parent(s): 410be5e

Upd backend with full search service implementation. Refactor directory

Browse files
.dockerignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ api/legacy.py
2
+ *.md
3
+ .env
4
+ *yml
Dockerfile CHANGED
@@ -24,7 +24,7 @@ RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformer
24
  chown -R user:user /app/model_cache /home/user/.cache/huggingface
25
 
26
  # Pre-load model in a separate script
27
- RUN python /app/download_model.py && python /app/warmup.py
28
 
29
  # Ensure ownership and permissions remain intact
30
  RUN chown -R user:user /app/model_cache
@@ -32,5 +32,5 @@ RUN chown -R user:user /app/model_cache
32
  # Expose port
33
  EXPOSE 7860
34
 
35
- # Run the application
36
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
 
24
  chown -R user:user /app/model_cache /home/user/.cache/huggingface
25
 
26
  # Pre-load model in a separate script
27
+ RUN python /app/models/download_model.py && python /app/models/warmup.py
28
 
29
  # Ensure ownership and permissions remain intact
30
  RUN chown -R user:user /app/model_cache
 
32
  # Expose port
33
  EXPOSE 7860
34
 
35
+ # Run the application using main.py as entry point
36
+ CMD ["python", "main.py"]
README.md CHANGED
@@ -10,4 +10,123 @@ license: apache-2.0
10
  short_description: MedicalChatbot, FAISS, Gemini, MongoDB vDB, LRU
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  short_description: MedicalChatbot, FAISS, Gemini, MongoDB vDB, LRU
11
  ---
12
 
13
+ # Medical Chatbot Backend
14
+
15
+ ## Project Structure
16
+
17
+ The backend is organized into logical modules for better maintainability:
18
+
19
+ ### 📁 **api/**
20
+ - **app.py** - Main FastAPI application with endpoints
21
+ - **__init__.py** - API package initialization
22
+
23
+ ### 📁 **models/**
24
+ - **llama.py** - NVIDIA Llama model integration for search processing
25
+ - **summarizer.py** - Text summarization using NVIDIA Llama
26
+ - **download_model.py** - Model download utilities
27
+ - **warmup.py** - Model warmup scripts
28
+
29
+ ### 📁 **memory/**
30
+ - **memory_updated.py** - Enhanced memory management with NVIDIA Llama summarization
31
+ - **memory.py** - Legacy memory implementation
32
+
33
+ ### 📁 **search/**
34
+ - **search.py** - Web search and content extraction functionality
35
+
36
+ ### 📁 **utils/**
37
+ - **translation.py** - Multi-language translation utilities
38
+ - **vlm.py** - Vision Language Model for medical image processing
39
+ - **diagnosis.py** - Symptom-based diagnosis utilities
40
+ - **connect_mongo.py** - MongoDB connection utilities
41
+ - **clear_mongo.py** - Database cleanup utilities
42
+ - **migrate.py** - Database migration scripts
43
+
44
+ ## Key Features
45
+
46
+ ### 🔍 **Search Integration**
47
+ - Web search with up to 10 resources
48
+ - NVIDIA Llama model for keyword generation and document summarization
49
+ - Citation system with URL mapping
50
+ - Smart content filtering and validation
51
+
52
+ ### 🧠 **Enhanced Memory Management**
53
+ - NVIDIA Llama-powered summarization for all text processing
54
+ - Optimized chunking and context retrieval
55
+ - Smart deduplication and merging
56
+ - Conversation continuity with concise summaries
57
+
58
+ ### 📝 **Summarization System**
59
+ - **Text Cleaning**: Removes conversational fillers and normalizes text
60
+ - **Key Phrase Extraction**: Identifies medical terms and concepts
61
+ - **Concise Summaries**: Preserves key ideas without fluff
62
+ - **NVIDIA Llama Integration**: All summarization uses NVIDIA model instead of Gemini
63
+
64
+ ## Usage
65
+
66
+ ### Running the Application
67
+ ```bash
68
+ # Using main entry point
69
+ python main.py
70
+
71
+ # Or directly
72
+ python api/app.py
73
+ ```
74
+
75
+ ### Environment Variables
76
+ - `NVIDIA_URI` - NVIDIA API key for Llama model
77
+ - `FlashAPI` - Gemini API key
78
+ - `MONGO_URI` - MongoDB connection string
79
+ - `INDEX_URI` - FAISS index database URI
80
+
81
+ ## API Endpoints
82
+
83
+ ### POST `/chat`
84
+ Main chat endpoint with search mode support.
85
+
86
+ **Request Body:**
87
+ ```json
88
+ {
89
+ "query": "User's medical question",
90
+ "lang": "EN",
91
+ "search": true,
92
+ "user_id": "unique_user_id",
93
+ "image_base64": "optional_base64_image",
94
+ "img_desc": "image_description"
95
+ }
96
+ ```
97
+
98
+ **Response:**
99
+ ```json
100
+ {
101
+ "response": "Medical response with citations <URL>",
102
+ "response_time": "2.34s"
103
+ }
104
+ ```
105
+
106
+ ## Search Mode Features
107
+
108
+ When `search: true`:
109
+ 1. **Web Search**: Fetches up to 10 relevant medical resources
110
+ 2. **Llama Processing**: Generates keywords and summarizes content
111
+ 3. **Citation System**: Replaces `<#ID>` tags with actual URLs
112
+ 4. **UI Integration**: Frontend displays magnifier icons for source links
113
+
114
+ ## Summarization Features
115
+
116
+ All summarization tasks use NVIDIA Llama model:
117
+ - **get_contextual_chunks**: Summarizes conversation history and RAG chunks
118
+ - **chunk_response**: Chunks and summarizes bot responses
119
+ - **summarize_documents**: Summarizes web search results
120
+
121
+ ### Text Processing Pipeline
122
+ 1. **Clean Text**: Remove conversational elements and normalize
123
+ 2. **Extract Key Phrases**: Identify medical terms and concepts
124
+ 3. **Summarize**: Create concise, focused summaries
125
+ 4. **Validate**: Ensure quality and relevance
126
+
127
+ ## Dependencies
128
+
129
+ See `requirements.txt` for complete list. Key additions:
130
+ - `requests` - Web search functionality
131
+ - `beautifulsoup4` - HTML content extraction
132
+ - NVIDIA API integration for Llama model
api/README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Module Structure
2
+
3
+ ## 📁 **Module Overview**
4
+
5
+ ### **config.py** - Configuration Management
6
+ - Environment variables validation
7
+ - Logging configuration
8
+ - System resource monitoring
9
+ - Memory optimization settings
10
+ - CORS configuration
11
+
12
+ ### **database.py** - Database Management
13
+ - MongoDB connection management
14
+ - FAISS index lazy loading
15
+ - SentenceTransformer model initialization
16
+ - Symptom vectors management
17
+ - GridFS integration
18
+
19
+ ### **retrieval.py** - RAG Retrieval Engine
20
+ - Medical information retrieval from FAISS
21
+ - Symptom-based diagnosis retrieval
22
+ - Smart deduplication and similarity matching
23
+ - Vector similarity computations
24
+
25
+ ### **chatbot.py** - Core Chatbot Logic
26
+ - RAGMedicalChatbot class
27
+ - Gemini API client
28
+ - Search mode integration
29
+ - Citation processing
30
+ - Memory management integration
31
+
32
+ ### **routes.py** - API Endpoints
33
+ - `/chat` - Main chat endpoint
34
+ - `/health` - Health check
35
+ - `/` - Root endpoint
36
+ - Request/response handling
37
+
38
+ ### **app.py** - Main Application
39
+ - FastAPI app initialization
40
+ - Middleware configuration
41
+ - Database initialization
42
+ - Route registration
43
+ - Server startup
44
+
45
+ ## 🔄 **Data Flow**
46
+
47
+ ```
48
+ Request → routes.py → chatbot.py → retrieval.py → database.py
49
+
50
+ memory.py (context) + search.py (web search)
51
+
52
+ models/ (NVIDIA Llama processing)
53
+
54
+ Response with citations
55
+ ```
56
+
57
+ ## 🚀 **Benefits of Modular Structure**
58
+
59
+ 1. **Separation of Concerns**: Each module has a single responsibility
60
+ 2. **Easier Testing**: Individual modules can be tested in isolation
61
+ 3. **Better Maintainability**: Changes to one module don't affect others
62
+ 4. **Improved Readability**: Smaller files are easier to understand
63
+ 5. **Reusability**: Modules can be imported and used elsewhere
64
+ 6. **Scalability**: Easy to add new features without affecting existing code
65
+
66
+ ## 📊 **File Sizes Comparison**
67
+
68
+ | File | Lines | Purpose |
69
+ |------|-------|---------|
70
+ | **app_old.py** | 370 | Monolithic (everything) |
71
+ | **app.py** | 45 | Main app initialization |
72
+ | **config.py** | 65 | Configuration |
73
+ | **database.py** | 95 | Database management |
74
+ | **retrieval.py** | 85 | RAG retrieval |
75
+ | **chatbot.py** | 120 | Chatbot logic |
76
+ | **routes.py** | 55 | API endpoints |
77
+ | **Total** | 465 | Modular structure |
78
+
79
+ ## 🔧 **Usage**
80
+
81
+ The modular structure maintains the same API interface:
82
+
83
+ ```python
84
+ # All imports work the same way
85
+ from api.app import app
86
+ from api.chatbot import RAGMedicalChatbot
87
+ from api.retrieval import retrieval_engine
88
+ ```
89
+
90
+ ## 🛠 **Development Benefits**
91
+
92
+ - **Easier Debugging**: Issues can be isolated to specific modules
93
+ - **Parallel Development**: Multiple developers can work on different modules
94
+ - **Code Reviews**: Smaller files are easier to review
95
+ - **Documentation**: Each module can have focused documentation
96
+ - **Testing**: Unit tests can be written for each module independently
api/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # API package
2
+ # Main API endpoints and routes
api/app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/app_new.py
2
+ import uvicorn
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from api.config import setup_logging, check_system_resources, optimize_memory, CORS_ORIGINS
6
+ from api.database import db_manager
7
+ from api.routes import router
8
+
9
+ # ✅ Setup logging
10
+ logger = setup_logging()
11
+ logger.info("🚀 Starting Medical Chatbot API...")
12
+
13
+ # ✅ Monitor system resources
14
+ check_system_resources(logger)
15
+
16
+ # ✅ Optimize memory usage
17
+ optimize_memory()
18
+
19
+ # ✅ Initialize FastAPI app
20
+ app = FastAPI(
21
+ title="Medical Chatbot API",
22
+ description="AI-powered medical chatbot with RAG and search capabilities",
23
+ version="1.0.0"
24
+ )
25
+
26
+ # ✅ Add CORS middleware
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=CORS_ORIGINS,
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # ✅ Initialize database connections
36
+ try:
37
+ db_manager.initialize_embedding_model()
38
+ db_manager.initialize_mongodb()
39
+ logger.info("✅ Database connections initialized successfully")
40
+ except Exception as e:
41
+ logger.error(f"❌ Database initialization failed: {e}")
42
+ raise
43
+
44
+ # ✅ Include routes
45
+ app.include_router(router)
46
+
47
+ # ✅ Run Uvicorn
48
+ if __name__ == "__main__":
49
+ logger.info("[System] ✅ Starting FastAPI Server...")
50
+ try:
51
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
52
+ except Exception as e:
53
+ logger.error(f"❌ Server Startup Failed: {e}")
54
+ exit(1)
api/chatbot.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/chatbot.py
2
+ import re
3
+ import logging
4
+ from typing import Dict
5
+ from google import genai
6
+ from api.config import gemini_flash_api_key
7
+ from api.retrieval import retrieval_engine
8
+ from memory import MemoryManager
9
+ from utils import translate_query, process_medical_image
10
+ from search import search_web
11
+ from models import process_search_query
12
+
13
+ logger = logging.getLogger("medical-chatbot")
14
+
15
+ class GeminiClient:
16
+ """Gemini API client for generating responses"""
17
+
18
+ def __init__(self):
19
+ self.client = genai.Client(api_key=gemini_flash_api_key)
20
+
21
+ def generate_content(self, prompt: str, model: str = "gemini-2.5-flash", temperature: float = 0.7) -> str:
22
+ """Generate content using Gemini API"""
23
+ try:
24
+ response = self.client.models.generate_content(model=model, contents=prompt)
25
+ return response.text
26
+ except Exception as e:
27
+ logger.error(f"[LLM] ❌ Error calling Gemini API: {e}")
28
+ return "Error generating response from Gemini."
29
+
30
+ class RAGMedicalChatbot:
31
+ """Main chatbot class with RAG capabilities"""
32
+
33
+ def __init__(self, model_name: str, retrieve_function):
34
+ self.model_name = model_name
35
+ self.retrieve = retrieve_function
36
+ self.gemini_client = GeminiClient()
37
+ self.memory = MemoryManager()
38
+
39
+ def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "", search_mode: bool = False) -> str:
40
+ """Main chat method with RAG and search capabilities"""
41
+
42
+ # 0. Translate query if not EN, this help our RAG system
43
+ if lang.upper() in {"VI", "ZH"}:
44
+ user_query = translate_query(user_query, lang.lower())
45
+
46
+ # 1. Fetch knowledge
47
+ ## a. KB for generic QA retrieval
48
+ retrieved_info = self.retrieve(user_query)
49
+ knowledge_base = "\n".join(retrieved_info)
50
+ ## b. Diagnosis RAG from symptom query
51
+ diagnosis_guides = retrieval_engine.retrieve_diagnosis_from_symptoms(user_query)
52
+
53
+ # 1.5. Search mode - web search and Llama processing
54
+ search_context = ""
55
+ url_mapping = {}
56
+ if search_mode:
57
+ logger.info(f"[SEARCH] Starting web search mode for query: {user_query}")
58
+ try:
59
+ # Search the web with max 10 resources
60
+ search_results = search_web(user_query, num_results=10)
61
+ if search_results:
62
+ logger.info(f"[SEARCH] Retrieved {len(search_results)} web resources")
63
+ # Process with Llama
64
+ search_context, url_mapping = process_search_query(user_query, search_results)
65
+ logger.info(f"[SEARCH] Processed with Llama, generated {len(url_mapping)} URL mappings")
66
+ else:
67
+ logger.warning("[SEARCH] No search results found")
68
+ except Exception as e:
69
+ logger.error(f"[SEARCH] Search failed: {e}")
70
+ search_context = ""
71
+
72
+ # 2. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection
73
+ contextual_chunks = self.memory.get_contextual_chunks(user_id, user_query, lang)
74
+
75
+ # 3. Build prompt parts
76
+ parts = ["You are a medical chatbot, designed to answer medical questions."]
77
+ parts.append("Please format your answer using MarkDown.")
78
+ parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
79
+
80
+ # 4. Append image diagnosis from VLM
81
+ if image_diagnosis:
82
+ parts.append(
83
+ "A user medical image is diagnosed by our VLM agent:\n"
84
+ f"{image_diagnosis}\n\n"
85
+ "Please incorporate the above findings in your response if medically relevant.\n\n"
86
+ )
87
+
88
+ # Append contextual chunks from hybrid approach
89
+ if contextual_chunks:
90
+ parts.append("Relevant context from conversation history:\n" + contextual_chunks)
91
+ # Load up guideline (RAG over medical knowledge base)
92
+ if knowledge_base:
93
+ parts.append(f"Example Q&A medical scenario knowledge-base: {knowledge_base}")
94
+ # Symptom-Diagnosis prediction RAG
95
+ if diagnosis_guides:
96
+ parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides))
97
+
98
+ # 5. Search context with citation instructions
99
+ if search_context:
100
+ parts.append("Additional information from web search:\n" + search_context)
101
+ parts.append("IMPORTANT: When you use information from the web search results above, you MUST add a citation tag <#ID> immediately after the relevant content, where ID is the document number (1, 2, 3, etc.). For example: 'According to recent studies <#1>, this condition affects...'")
102
+
103
+ parts.append(f"User's question: {user_query}")
104
+ parts.append(f"Language to generate answer: {lang}")
105
+ prompt = "\n\n".join(parts)
106
+ logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history
107
+ response = self.gemini_client.generate_content(prompt, model=self.model_name, temperature=0.7)
108
+
109
+ # 6. Process citations and replace with URLs
110
+ if search_mode and url_mapping:
111
+ response = self._process_citations(response, url_mapping)
112
+
113
+ # Store exchange + chunking
114
+ if user_id:
115
+ self.memory.add_exchange(user_id, user_query, response, lang=lang)
116
+ logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response
117
+ return response.strip()
118
+
119
+ def _process_citations(self, response: str, url_mapping: Dict[int, str]) -> str:
120
+ """Replace citation tags with actual URLs"""
121
+
122
+ # Find all citation tags like <#1>, <#2>, etc.
123
+ citation_pattern = r'<#(\d+)>'
124
+ citations_found = re.findall(citation_pattern, response)
125
+
126
+ def replace_citation(match):
127
+ doc_id = int(match.group(1))
128
+ if doc_id in url_mapping:
129
+ url = url_mapping[doc_id]
130
+ logger.info(f"[CITATION] Replacing <#{doc_id}> with {url}")
131
+ return f'<{url}>'
132
+ else:
133
+ logger.warning(f"[CITATION] No URL mapping found for document ID {doc_id}")
134
+ return match.group(0) # Keep original if URL not found
135
+
136
+ # Replace citations with URLs
137
+ processed_response = re.sub(citation_pattern, replace_citation, response)
138
+
139
+ logger.info(f"[CITATION] Processed {len(citations_found)} citations, {len(url_mapping)} URL mappings available")
140
+ return processed_response
api/config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/config.py
2
+ import os
3
+ import logging
4
+ import psutil
5
+ from typing import List
6
+
7
+ # ✅ Environment Variables
8
+ mongo_uri = os.getenv("MONGO_URI")
9
+ index_uri = os.getenv("INDEX_URI")
10
+ gemini_flash_api_key = os.getenv("FlashAPI")
11
+
12
+ # Validate environment endpoint
13
+ if not all([gemini_flash_api_key, mongo_uri, index_uri]):
14
+ raise ValueError("❌ Missing API keys! Set them in Hugging Face Secrets.")
15
+
16
+ # ✅ Logging Configuration
17
+ def setup_logging():
18
+ """Configure logging for the application"""
19
+ # Silence noisy loggers
20
+ for name in [
21
+ "uvicorn.error", "uvicorn.access",
22
+ "fastapi", "starlette",
23
+ "pymongo", "gridfs",
24
+ "sentence_transformers", "faiss",
25
+ "google", "google.auth",
26
+ ]:
27
+ logging.getLogger(name).setLevel(logging.WARNING)
28
+
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format="%(asctime)s — %(name)s — %(levelname)s — %(message)s",
32
+ force=True
33
+ )
34
+
35
+ logger = logging.getLogger("medical-chatbot")
36
+ logger.setLevel(logging.DEBUG)
37
+ return logger
38
+
39
+ # ✅ System Resource Monitoring
40
+ def check_system_resources(logger):
41
+ """Monitor system resources and log warnings"""
42
+ memory = psutil.virtual_memory()
43
+ cpu = psutil.cpu_percent(interval=1)
44
+ disk = psutil.disk_usage("/")
45
+
46
+ logger.info(f"[System] 🔍 System Resources - RAM: {memory.percent}%, CPU: {cpu}%, Disk: {disk.percent}%")
47
+
48
+ if memory.percent > 85:
49
+ logger.warning("⚠️ High RAM usage detected!")
50
+ if cpu > 90:
51
+ logger.warning("⚠️ High CPU usage detected!")
52
+ if disk.percent > 90:
53
+ logger.warning("⚠️ High Disk usage detected!")
54
+
55
+ # ✅ Memory Optimization
56
+ def optimize_memory():
57
+ """Set environment variables for memory optimization"""
58
+ os.environ["OMP_NUM_THREADS"] = "1"
59
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
60
+
61
+ # ✅ CORS Configuration
62
+ CORS_ORIGINS = [
63
+ "http://localhost:5173", # Vite dev server
64
+ "http://localhost:3000", # Another vercel local dev
65
+ "https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
66
+ ]
67
+
68
+ # ✅ Model Configuration
69
+ MODEL_CACHE_DIR = "/app/model_cache"
70
+ EMBEDDING_MODEL_DEVICE = "cpu"
api/database.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/database.py
2
+ import faiss
3
+ import numpy as np
4
+ import gridfs
5
+ from pymongo import MongoClient
6
+ from sentence_transformers import SentenceTransformer
7
+ from api.config import mongo_uri, index_uri, MODEL_CACHE_DIR, EMBEDDING_MODEL_DEVICE
8
+ import logging
9
+
10
+ logger = logging.getLogger("medical-chatbot")
11
+
12
+ class DatabaseManager:
13
+ def __init__(self):
14
+ self.embedding_model = None
15
+ self.index = None
16
+ self.symptom_vectors = None
17
+ self.symptom_docs = None
18
+
19
+ # MongoDB connections
20
+ self.client = None
21
+ self.iclient = None
22
+ self.symptom_client = None
23
+
24
+ # Collections
25
+ self.qa_collection = None
26
+ self.index_collection = None
27
+ self.symptom_col = None
28
+ self.fs = None
29
+
30
+ def initialize_embedding_model(self):
31
+ """Initialize the SentenceTransformer model"""
32
+ logger.info("[Embedder] 📥 Loading SentenceTransformer Model...")
33
+ try:
34
+ self.embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device=EMBEDDING_MODEL_DEVICE)
35
+ self.embedding_model = self.embedding_model.half() # Reduce memory
36
+ logger.info("✅ Model Loaded Successfully.")
37
+ except Exception as e:
38
+ logger.error(f"❌ Model Loading Failed: {e}")
39
+ raise
40
+
41
+ def initialize_mongodb(self):
42
+ """Initialize MongoDB connections and collections"""
43
+ # QA data
44
+ self.client = MongoClient(mongo_uri)
45
+ db = self.client["MedicalChatbotDB"]
46
+ self.qa_collection = db["qa_data"]
47
+
48
+ # FAISS Index data
49
+ self.iclient = MongoClient(index_uri)
50
+ idb = self.iclient["MedicalChatbotDB"]
51
+ self.index_collection = idb["faiss_index_files"]
52
+
53
+ # Symptom Diagnosis data
54
+ self.symptom_client = MongoClient(mongo_uri)
55
+ self.symptom_col = self.symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
56
+
57
+ # GridFS for FAISS index
58
+ self.fs = gridfs.GridFS(idb, collection="faiss_index_files")
59
+
60
+ def load_faiss_index(self):
61
+ """Lazy load FAISS index from GridFS"""
62
+ if self.index is None:
63
+ logger.info("[KB] ⏳ Loading FAISS index from GridFS...")
64
+ existing_file = self.fs.find_one({"filename": "faiss_index.bin"})
65
+ if existing_file:
66
+ stored_index_bytes = existing_file.read()
67
+ index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
68
+ self.index = faiss.deserialize_index(index_bytes_np)
69
+ logger.info("[KB] ✅ FAISS Index Loaded")
70
+ else:
71
+ logger.error("[KB] ❌ FAISS index not found in GridFS.")
72
+ return self.index
73
+
74
+ def load_symptom_vectors(self):
75
+ """Lazy load symptom vectors for diagnosis"""
76
+ if self.symptom_vectors is None:
77
+ all_docs = list(self.symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1}))
78
+ self.symptom_docs = all_docs
79
+ self.symptom_vectors = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
80
+
81
+ def get_embedding_model(self):
82
+ """Get the embedding model"""
83
+ if self.embedding_model is None:
84
+ self.initialize_embedding_model()
85
+ return self.embedding_model
86
+
87
+ def get_qa_collection(self):
88
+ """Get QA collection"""
89
+ if self.qa_collection is None:
90
+ self.initialize_mongodb()
91
+ return self.qa_collection
92
+
93
+ def get_symptom_collection(self):
94
+ """Get symptom collection"""
95
+ if self.symptom_col is None:
96
+ self.initialize_mongodb()
97
+ return self.symptom_col
98
+
99
+ # Global database manager instance
100
+ db_manager = DatabaseManager()
app.py → api/legacy.py RENAMED
@@ -1,5 +1,6 @@
1
  # app.py
2
- import os
 
3
  import faiss
4
  import numpy as np
5
  import time
@@ -11,10 +12,9 @@ from google import genai
11
  from sentence_transformers import SentenceTransformer
12
  from sentence_transformers.util import cos_sim
13
  from memory import MemoryManager
14
- from translation import translate_query
15
- from vlm import process_medical_image
16
  from search import search_web
17
- from llama_integration import process_search_query
18
 
19
  # ✅ Enable Logging for Debugging
20
  import logging
@@ -239,14 +239,15 @@ class RAGMedicalChatbot:
239
  search_context = ""
240
  url_mapping = {}
241
  if search_mode:
242
- logger.info("[SEARCH] Starting web search mode")
243
  try:
244
- # Search the web
245
- search_results = search_web(user_query, num_results=5)
246
  if search_results:
 
247
  # Process with Llama
248
  search_context, url_mapping = process_search_query(user_query, search_results)
249
- logger.info(f"[SEARCH] Found {len(search_results)} results, processed with Llama")
250
  else:
251
  logger.warning("[SEARCH] No search results found")
252
  except Exception as e:
@@ -306,17 +307,22 @@ class RAGMedicalChatbot:
306
 
307
  # Find all citation tags like <#1>, <#2>, etc.
308
  citation_pattern = r'<#(\d+)>'
 
309
 
310
  def replace_citation(match):
311
  doc_id = int(match.group(1))
312
  if doc_id in url_mapping:
313
- return f'<{url_mapping[doc_id]}>'
314
- return match.group(0) # Keep original if URL not found
 
 
 
 
315
 
316
  # Replace citations with URLs
317
  processed_response = re.sub(citation_pattern, replace_citation, response)
318
 
319
- logger.info(f"[CITATION] Processed citations, found {len(re.findall(citation_pattern, response))} citations")
320
  return processed_response
321
 
322
  # ✅ Initialize Chatbot
 
1
  # app.py
2
+ import os, json, re
3
+ from typing import Dict
4
  import faiss
5
  import numpy as np
6
  import time
 
12
  from sentence_transformers import SentenceTransformer
13
  from sentence_transformers.util import cos_sim
14
  from memory import MemoryManager
15
+ from utils import translate_query, process_medical_image, retrieve_diagnosis_from_symptoms
 
16
  from search import search_web
17
+ from models import process_search_query
18
 
19
  # ✅ Enable Logging for Debugging
20
  import logging
 
239
  search_context = ""
240
  url_mapping = {}
241
  if search_mode:
242
+ logger.info(f"[SEARCH] Starting web search mode for query: {user_query}")
243
  try:
244
+ # Search the web with max 10 resources
245
+ search_results = search_web(user_query, num_results=10)
246
  if search_results:
247
+ logger.info(f"[SEARCH] Retrieved {len(search_results)} web resources")
248
  # Process with Llama
249
  search_context, url_mapping = process_search_query(user_query, search_results)
250
+ logger.info(f"[SEARCH] Processed with Llama, generated {len(url_mapping)} URL mappings")
251
  else:
252
  logger.warning("[SEARCH] No search results found")
253
  except Exception as e:
 
307
 
308
  # Find all citation tags like <#1>, <#2>, etc.
309
  citation_pattern = r'<#(\d+)>'
310
+ citations_found = re.findall(citation_pattern, response)
311
 
312
  def replace_citation(match):
313
  doc_id = int(match.group(1))
314
  if doc_id in url_mapping:
315
+ url = url_mapping[doc_id]
316
+ logger.info(f"[CITATION] Replacing <#{doc_id}> with {url}")
317
+ return f'<{url}>'
318
+ else:
319
+ logger.warning(f"[CITATION] No URL mapping found for document ID {doc_id}")
320
+ return match.group(0) # Keep original if URL not found
321
 
322
  # Replace citations with URLs
323
  processed_response = re.sub(citation_pattern, replace_citation, response)
324
 
325
+ logger.info(f"[CITATION] Processed {len(citations_found)} citations, {len(url_mapping)} URL mappings available")
326
  return processed_response
327
 
328
  # ✅ Initialize Chatbot
api/retrieval.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/retrieval.py
2
+ import numpy as np
3
+ import logging
4
+ from api.database import db_manager
5
+
6
+ logger = logging.getLogger("medical-chatbot")
7
+
8
+ class RetrievalEngine:
9
+ def __init__(self):
10
+ self.db_manager = db_manager
11
+
12
+ def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.9) -> list:
13
+ """
14
+ Retrieve medical information from FAISS index
15
+ Min similarity between query and kb is to be 80%
16
+ """
17
+ index = self.db_manager.load_faiss_index()
18
+ if index is None:
19
+ return [""]
20
+
21
+ embedding_model = self.db_manager.get_embedding_model()
22
+ qa_collection = self.db_manager.get_qa_collection()
23
+
24
+ # Embed query
25
+ query_vec = embedding_model.encode([query], convert_to_numpy=True)
26
+ D, I = index.search(query_vec, k=k)
27
+
28
+ # Filter by cosine threshold
29
+ results = []
30
+ kept = []
31
+ kept_vecs = []
32
+
33
+ # Smart dedup on cosine threshold between similar candidates
34
+ for score, idx in zip(D[0], I[0]):
35
+ if score < min_sim:
36
+ continue
37
+
38
+ # List sim docs
39
+ doc = qa_collection.find_one({"i": int(idx)})
40
+ if not doc:
41
+ continue
42
+
43
+ # Only compare answers
44
+ answer = doc.get("Doctor", "").strip()
45
+ if not answer:
46
+ continue
47
+
48
+ # Check semantic redundancy among previously kept results
49
+ new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0]
50
+ is_similar = False
51
+
52
+ for i, vec in enumerate(kept_vecs):
53
+ sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9)
54
+ if sim >= 0.9: # High semantic similarity
55
+ is_similar = True
56
+ # Keep only better match to original query
57
+ cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9)
58
+ new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9)
59
+ if new_sim_to_query > cur_sim_to_query:
60
+ kept[i] = answer
61
+ kept_vecs[i] = new_vec
62
+ break
63
+
64
+ # Non-similar candidates
65
+ if not is_similar:
66
+ kept.append(answer)
67
+ kept_vecs.append(new_vec)
68
+
69
+ return kept if kept else [""]
70
+
71
+ def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
72
+ """
73
+ Retrieve diagnosis information from symptom vectors
74
+ """
75
+ self.db_manager.load_symptom_vectors()
76
+ embedding_model = self.db_manager.get_embedding_model()
77
+
78
+ # Embed input
79
+ qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
80
+ qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
81
+
82
+ # Similarity compute
83
+ sims = self.db_manager.symptom_vectors @ qvec # cosine
84
+ sorted_idx = np.argsort(sims)[-top_k:][::-1]
85
+ seen_diag = set()
86
+ final = [] # Dedup
87
+
88
+ for i in sorted_idx:
89
+ sim = sims[i]
90
+ if sim < min_sim:
91
+ continue
92
+ label = self.db_manager.symptom_docs[i]["prognosis"]
93
+ if label not in seen_diag:
94
+ final.append(self.db_manager.symptom_docs[i]["answer"])
95
+ seen_diag.add(label)
96
+
97
+ return final
98
+
99
+ # Global retrieval engine instance
100
+ retrieval_engine = RetrievalEngine()
api/routes.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/routes.py
2
+ import time
3
+ import logging
4
+ from fastapi import APIRouter, Request
5
+ from fastapi.responses import JSONResponse
6
+ from api.chatbot import RAGMedicalChatbot
7
+ from api.retrieval import retrieval_engine
8
+ from utils import process_medical_image
9
+
10
+ logger = logging.getLogger("medical-chatbot")
11
+
12
+ # Create router
13
+ router = APIRouter()
14
+
15
+ # Initialize chatbot
16
+ chatbot = RAGMedicalChatbot(
17
+ model_name="gemini-2.5-flash",
18
+ retrieve_function=retrieval_engine.retrieve_medical_info
19
+ )
20
+
21
+ @router.post("/chat")
22
+ async def chat_endpoint(req: Request):
23
+ """Main chat endpoint with search mode support"""
24
+ body = await req.json()
25
+ user_id = body.get("user_id", "anonymous")
26
+ query_raw = body.get("query")
27
+ query = query_raw.strip() if isinstance(query_raw, str) else ""
28
+ lang = body.get("lang", "EN")
29
+ search_mode = body.get("search", False)
30
+ image_base64 = body.get("image_base64", None)
31
+ img_desc = body.get("img_desc", "Describe and investigate any clinical findings from this medical image.")
32
+
33
+ start = time.time()
34
+ image_diagnosis = ""
35
+
36
+ # LLM Only
37
+ if not image_base64:
38
+ logger.info(f"[BOT] LLM scenario. Search mode: {search_mode}")
39
+ # LLM+VLM
40
+ else:
41
+ # If image is present → diagnose first
42
+ safe_load = len(image_base64.encode("utf-8"))
43
+ if safe_load > 5_000_000: # Img size safe processor
44
+ return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."})
45
+ logger.info(f"[BOT] VLM+LLM scenario. Search mode: {search_mode}")
46
+ logger.info(f"[VLM] Process medical image size: {safe_load}, desc: {img_desc}, {lang}.")
47
+ image_diagnosis = process_medical_image(image_base64, img_desc, lang)
48
+
49
+ answer = chatbot.chat(user_id, query, lang, image_diagnosis, search_mode)
50
+ elapsed = time.time() - start
51
+
52
+ # Final
53
+ return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
54
+
55
+ @router.get("/health")
56
+ async def health_check():
57
+ """Health check endpoint"""
58
+ return {"status": "healthy", "service": "medical-chatbot"}
59
+
60
+ @router.get("/")
61
+ async def root():
62
+ """Root endpoint"""
63
+ return {"message": "Medical Chatbot API", "version": "1.0.0"}
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # main.py - Entry point for the Medical Chatbot API
2
+ from api.app import app
3
+
4
+ if __name__ == "__main__":
5
+ import uvicorn
6
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
memory/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Memory package
2
+ from .memory import MemoryManager
memory.py → memory/memory.py RENAMED
@@ -1,4 +1,4 @@
1
- # memory.py
2
  import re, time, hashlib, asyncio, os
3
  from collections import defaultdict, deque
4
  from typing import List, Dict
@@ -7,6 +7,7 @@ import faiss
7
  from sentence_transformers import SentenceTransformer
8
  from google import genai # must be configured in app.py and imported globally
9
  import logging
 
10
 
11
  _LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17"
12
  # Load embedding model
@@ -98,7 +99,7 @@ class MemoryManager:
98
 
99
  def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
100
  """
101
- Use Gemini Flash Lite to create a summarization of relevant context from both recent history and RAG chunks.
102
  This ensures conversational continuity while providing a concise summary for the main LLM.
103
  """
104
  # Get both types of context
@@ -112,7 +113,8 @@ class MemoryManager:
112
  if not recent_history and not rag_chunks:
113
  logger.info(f"[Contextual] No context found, returning empty string")
114
  return ""
115
- # Prepare context for Gemini to summarize
 
116
  context_parts = []
117
  # Add recent chat history
118
  if recent_history:
@@ -126,301 +128,204 @@ class MemoryManager:
126
  rag_text = "\n".join(rag_chunks)
127
  context_parts.append(f"Semantically relevant historical medical information:\n{rag_text}")
128
 
129
- # Build summarization prompt
130
- summarization_prompt = f"""
131
- You are a medical assistant creating a concise summary of conversation context for continuity.
132
-
133
- Current user query: "{current_query}"
134
-
135
- Available context information:
136
- {chr(10).join(context_parts)}
137
-
138
- Task: Create a brief, coherent summary that captures the key points from the conversation history and relevant medical information that are important for understanding the current query.
139
-
140
- Guidelines:
141
- 1. Focus on medical symptoms, diagnoses, treatments, or recommendations mentioned
142
- 2. Include any patient concerns or questions that are still relevant
143
- 3. Highlight any follow-up needs or pending clarifications
144
- 4. Keep the summary concise but comprehensive enough for context
145
- 5. Maintain conversational flow and continuity
146
 
147
- Output: Provide a single, well-structured summary paragraph that can be used as context for the main LLM to provide a coherent response.
148
- If no relevant context exists, return "No relevant context found."
149
-
150
- Language context: {lang}
 
 
 
 
 
 
 
 
 
151
  """
 
 
152
 
153
- logger.debug(f"[Contextual] Full prompt: {summarization_prompt}")
154
- # Loop through the prompt and log the length of each part
155
  try:
156
- # Use Gemini Flash Lite for summarization
157
- client = genai.Client(api_key=os.getenv("FlashAPI"))
158
- result = client.models.generate_content(
159
- model=_LLM_SMALL,
160
- contents=summarization_prompt
161
- )
162
- summary = result.text.strip()
163
- if "No relevant context found" in summary:
164
- logger.info(f"[Contextual] Gemini indicated no relevant context found")
165
- return ""
166
 
167
- logger.info(f"[Contextual] Gemini created summary: {summary[:100]}...")
168
- return summary
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  except Exception as e:
171
- logger.warning(f"[Contextual] Gemini summarization failed: {e}")
172
- logger.info(f"[Contextual] Using fallback summarization method")
173
- # Fallback: create a simple summary
174
- fallback_summary = []
175
- # Fallback: add recent history
176
- if recent_history:
177
- recent_summary = f"Recent conversation: User asked about {recent_history[-1]['user'][:50]}... and received a response about {recent_history[-1]['bot'][:50]}..."
178
- fallback_summary.append(recent_summary)
179
- logger.info(f"[Contextual] Fallback: Added recent history summary")
180
- # Fallback: add RAG chunks
181
- if rag_chunks:
182
- rag_summary = f"Relevant medical information: {len(rag_chunks)} chunks found covering various medical topics."
183
- fallback_summary.append(rag_summary)
184
- logger.info(f"[Contextual] Fallback: Added RAG chunks summary")
185
- final_fallback = " ".join(fallback_summary) if fallback_summary else ""
186
- return final_fallback
187
 
188
- def reset(self, user_id: str):
189
- self._drop_user(user_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- # ---------- Internal helpers ----------
192
  def _touch_user(self, user_id: str):
193
- if user_id not in self.text_cache and len(self.user_queue) >= self.user_queue.maxlen:
194
- self._drop_user(self.user_queue.popleft())
195
  if user_id in self.user_queue:
196
  self.user_queue.remove(user_id)
197
  self.user_queue.append(user_id)
198
 
199
- def _drop_user(self, user_id: str):
200
- self.text_cache.pop(user_id, None)
201
- self.chunk_index.pop(user_id, None)
202
- self.chunk_meta.pop(user_id, None)
203
- if user_id in self.user_queue:
204
- self.user_queue.remove(user_id)
205
-
206
- def _rebuild_index(self, user_id: str, keep_last: int):
207
- """Trim chunk list + rebuild FAISS index for user."""
208
- self.chunk_meta[user_id] = self.chunk_meta[user_id][-keep_last:]
209
- index = self._new_index()
210
- # Store each chunk's vector once and reuse it.
211
- for chunk in self.chunk_meta[user_id]:
212
- index.add(np.array([chunk["vec"]]))
213
- self.chunk_index[user_id] = index
214
-
215
- @staticmethod
216
- def _new_index():
217
- # Use cosine similarity (vectors must be L2-normalised)
218
- return faiss.IndexFlatIP(384)
219
-
220
- @staticmethod
221
- def _embed(text: str):
222
- vec = EMBED.encode(text, convert_to_numpy=True)
223
- # L2 normalise for cosine on IndexFlatIP
224
- return vec / (np.linalg.norm(vec) + 1e-9)
225
-
226
- def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
227
- """
228
- Calls Gemini to:
229
- - Translate (if needed)
230
- - Chunk by context/topic (exclude disclaimer section)
231
- - Summarise
232
- Returns: [{"tag": ..., "text": ...}, ...]
233
- """
234
- if not response: return []
235
- # Gemini instruction
236
- instructions = []
237
- # if lang.upper() != "EN":
238
- # instructions.append("- Translate the response to English.")
239
- instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic, symptom, assessment, plan, or instruction (exclude disclaimer section).")
240
- instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions and exact medication names/doses if present.")
241
- instructions.append("- At the start of each part, write `Topic: <concise but specific sentence (10-20 words) capturing patient context, condition, and action>`.")
242
- instructions.append("- Separate each part using three dashes `---` on a new line.")
243
- # if lang.upper() != "EN":
244
- # instructions.append(f"Below is the user-provided medical response written in `{lang}`")
245
- # Gemini prompt
246
- prompt = f"""
247
- You are a medical assistant helping organize and condense a clinical response.
248
- If helpful, use the user's latest question for context to craft specific topics.
249
- User's latest question (context): {question}
250
- ------------------------
251
- {response}
252
- ------------------------
253
- Please perform the following tasks:
254
- {chr(10).join(instructions)}
255
 
256
- Output only the structured summaries, separated by dashes.
257
- """
258
- retries = 0
259
- while retries < 5:
260
- try:
261
- client = genai.Client(api_key=os.getenv("FlashAPI"))
262
- result = client.models.generate_content(
263
- model=_LLM_SMALL,
264
- contents=prompt
265
- # ,generation_config={"temperature": 0.4} # Skip temp configs for gem-flash
 
 
266
  )
267
- output = result.text.strip()
268
- logger.info(f"[Memory] 📦 Gemini summarized chunk output: {output}")
269
- return [
270
- {"tag": self._quick_extract_topic(chunk), "text": chunk.strip()}
271
- for chunk in output.split('---') if chunk.strip()
272
- ]
273
- except Exception as e:
274
- logger.warning(f"[Memory] ❌ Gemini chunking failed: {e}")
275
- retries += 1
276
- time.sleep(0.5)
277
- return [{"tag": "general", "text": response.strip()}] # fallback
278
 
279
- @staticmethod
280
- def _quick_extract_topic(chunk: str) -> str:
281
- """Heuristically extract the topic from a chunk (title line or first 3 words)."""
282
- # Expecting 'Topic: <something>'
283
- match = re.search(r'^Topic:\s*(.+)', chunk, re.IGNORECASE | re.MULTILINE)
284
- if match:
285
- return match.group(1).strip()
286
- lines = chunk.strip().splitlines()
287
- for line in lines:
288
- if len(line.split()) <= 8 and line.strip().endswith(":"):
289
- return line.strip().rstrip(":")
290
- return " ".join(chunk.split()[:3]).rstrip(":.,")
291
-
292
- # ---------- New merging/dedup logic ----------
293
- def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
294
- """Insert or merge a summarized chunk into STM with semantic dedup/merge.
295
- Identical: replace the older with new. Partially similar: merge extra details from older into newer.
296
- """
297
- topic = self._enrich_topic(chunk.get("tag", ""), chunk.get("text", ""))
298
- text = chunk.get("text", "").strip()
299
- vec = self._embed(text)
300
- now = time.time()
301
- entry = {"topic": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
302
- stm = self.stm_summaries[user_id]
303
- if not stm:
304
- stm.append(entry)
305
- return
306
- # find best match
307
- best_idx = -1
308
- best_sim = -1.0
309
- for i, e in enumerate(stm):
310
- sim = float(np.dot(vec, e["vec"]))
311
- if sim > best_sim:
312
- best_sim = sim
313
- best_idx = i
314
- if best_sim >= 0.92: # nearly identical
315
- # replace older with current
316
- stm.rotate(-best_idx)
317
- stm.popleft()
318
- stm.rotate(best_idx)
319
- stm.append(entry)
320
- elif best_sim >= 0.75: # partially similar → merge
321
- base = stm[best_idx]
322
- merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add bits from old not in new
323
- merged_topic = base["topic"] if len(base["topic"]) > len(topic) else topic
324
- merged_vec = self._embed(merged_text)
325
- merged_entry = {"topic": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
326
- stm.rotate(-best_idx)
327
- stm.popleft()
328
- stm.rotate(best_idx)
329
- stm.append(merged_entry)
330
- else:
331
- stm.append(entry)
332
 
333
  def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
334
- """Insert or merge chunks into LTM with semantic dedup/merge, then rebuild index.
335
- Keeps only the most recent self.max_chunks entries.
336
- """
337
- current_list = self.chunk_meta[user_id]
338
  for chunk in chunks:
339
- text = chunk.get("text", "").strip()
340
- if not text:
341
- continue
342
- vec = self._embed(text)
343
- topic = self._enrich_topic(chunk.get("tag", ""), text)
344
- now = time.time()
345
- new_entry = {"tag": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
346
- if not current_list:
347
- current_list.append(new_entry)
348
- continue
349
- # find best similar entry
350
- best_idx = -1
351
- best_sim = -1.0
352
- for i, e in enumerate(current_list):
353
- sim = float(np.dot(vec, e["vec"]))
354
- if sim > best_sim:
355
- best_sim = sim
356
- best_idx = i
357
- if best_sim >= 0.92:
358
- # replace older with new
359
- current_list[best_idx] = new_entry
360
- elif best_sim >= 0.75:
361
- # merge details
362
- base = current_list[best_idx]
363
- merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add unique sentences from old
364
- merged_topic = base["tag"] if len(base["tag"]) > len(topic) else topic
365
- merged_vec = self._embed(merged_text)
366
- current_list[best_idx] = {"tag": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
367
  else:
368
- current_list.append(new_entry)
369
- # Trim and rebuild index
370
- if len(current_list) > self.max_chunks:
371
- current_list[:] = current_list[-self.max_chunks:]
372
- self._rebuild_index(user_id, keep_last=self.max_chunks)
373
-
374
- @staticmethod
375
- def _split_sentences(text: str) -> List[str]:
376
- # naive sentence splitter by ., !, ?
377
- parts = re.split(r"(?<=[\.!?])\s+", text.strip())
378
- return [p.strip() for p in parts if p.strip()]
 
 
 
379
 
380
- def _merge_texts(self, new_text: str, old_text: str) -> str:
381
- """Append sentences from old_text that are not already contained in new_text (by fuzzy match)."""
382
- new_sents = self._split_sentences(new_text)
383
- old_sents = self._split_sentences(old_text)
384
- new_set = set(s.lower() for s in new_sents)
385
- merged = list(new_sents)
386
- for s in old_sents:
387
- s_norm = s.lower()
388
- # consider present if significant overlap with any existing sentence
389
- if s_norm in new_set:
390
- continue
391
- # simple containment check
392
- if any(self._overlap_ratio(s_norm, t.lower()) > 0.8 for t in merged):
393
- continue
394
- merged.append(s)
395
- return " ".join(merged)
396
 
397
- @staticmethod
398
- def _overlap_ratio(a: str, b: str) -> float:
399
- """Compute token overlap ratio between two sentences."""
400
- ta = set(re.findall(r"\w+", a))
401
- tb = set(re.findall(r"\w+", b))
402
- if not ta or not tb:
403
- return 0.0
404
- inter = len(ta & tb)
405
- union = len(ta | tb)
406
- return inter / union
 
 
407
 
408
- @staticmethod
409
- def _enrich_topic(topic: str, text: str) -> str:
410
- """Make topic more descriptive if it's too short by using the first sentence of the text.
411
- Does not call LLM to keep latency low.
412
- """
413
- topic = (topic or "").strip()
414
- if len(topic.split()) < 5 or len(topic) < 20:
415
- sents = re.split(r"(?<=[\.!?])\s+", text.strip())
416
- if sents:
417
- first = sents[0]
418
- # cap to ~16 words
419
- words = first.split()
420
- if len(words) > 16:
421
- first = " ".join(words[:16])
422
- # ensure capitalized
423
- return first.strip().rstrip(':')
424
- return topic
425
 
 
 
 
 
 
 
 
 
 
426
 
 
 
 
 
 
 
1
+ # memory_updated.py
2
  import re, time, hashlib, asyncio, os
3
  from collections import defaultdict, deque
4
  from typing import List, Dict
 
7
  from sentence_transformers import SentenceTransformer
8
  from google import genai # must be configured in app.py and imported globally
9
  import logging
10
+ from summarizer import summarizer
11
 
12
  _LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17"
13
  # Load embedding model
 
99
 
100
  def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
101
  """
102
+ Use NVIDIA Llama to create a summarization of relevant context from both recent history and RAG chunks.
103
  This ensures conversational continuity while providing a concise summary for the main LLM.
104
  """
105
  # Get both types of context
 
113
  if not recent_history and not rag_chunks:
114
  logger.info(f"[Contextual] No context found, returning empty string")
115
  return ""
116
+
117
+ # Prepare context for summarization
118
  context_parts = []
119
  # Add recent chat history
120
  if recent_history:
 
128
  rag_text = "\n".join(rag_chunks)
129
  context_parts.append(f"Semantically relevant historical medical information:\n{rag_text}")
130
 
131
+ # Combine all context
132
+ full_context = "\n\n".join(context_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Use summarizer to create concise summary
135
+ try:
136
+ summary = summarizer.summarize_text(full_context, max_length=300)
137
+ logger.info(f"[Contextual] Generated summary using NVIDIA Llama: {len(summary)} characters")
138
+ return summary
139
+ except Exception as e:
140
+ logger.error(f"[Contextual] Summarization failed: {e}")
141
+ return full_context[:500] + "..." if len(full_context) > 500 else full_context
142
+
143
+ def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
144
+ """
145
+ Use NVIDIA Llama to chunk and summarize response by medical topics.
146
+ Returns: [{"tag": ..., "text": ...}, ...]
147
  """
148
+ if not response:
149
+ return []
150
 
 
 
151
  try:
152
+ # Use summarizer to chunk and summarize
153
+ chunks = summarizer.chunk_response(response, max_chunk_size=500)
 
 
 
 
 
 
 
 
154
 
155
+ # Convert to the expected format
156
+ result_chunks = []
157
+ for i, chunk in enumerate(chunks):
158
+ # Extract topic from chunk (first sentence or key medical terms)
159
+ topic = self._extract_topic_from_chunk(chunk)
160
+
161
+ result_chunks.append({
162
+ "tag": topic,
163
+ "text": chunk
164
+ })
165
+
166
+ logger.info(f"[Memory] 📦 NVIDIA Llama summarized {len(result_chunks)} chunks")
167
+ return result_chunks
168
 
169
  except Exception as e:
170
+ logger.error(f"[Memory] NVIDIA Llama chunking failed: {e}")
171
+ # Fallback to simple chunking
172
+ return self._fallback_chunking(response)
173
+
174
+ def _extract_topic_from_chunk(self, chunk: str) -> str:
175
+ """Extract a concise topic from a chunk"""
176
+ # Look for medical terms or first sentence
177
+ sentences = chunk.split('.')
178
+ if sentences:
179
+ first_sentence = sentences[0].strip()
180
+ if len(first_sentence) > 50:
181
+ first_sentence = first_sentence[:50] + "..."
182
+ return first_sentence
183
+ return "Medical Information"
 
 
184
 
185
+ def _fallback_chunking(self, response: str) -> List[Dict]:
186
+ """Fallback chunking when NVIDIA Llama fails"""
187
+ # Simple sentence-based chunking
188
+ sentences = re.split(r'[.!?]+', response)
189
+ chunks = []
190
+ current_chunk = ""
191
+
192
+ for sentence in sentences:
193
+ sentence = sentence.strip()
194
+ if not sentence:
195
+ continue
196
+
197
+ if len(current_chunk) + len(sentence) > 300:
198
+ if current_chunk:
199
+ chunks.append({
200
+ "tag": "Medical Information",
201
+ "text": current_chunk.strip()
202
+ })
203
+ current_chunk = sentence
204
+ else:
205
+ current_chunk += sentence + ". "
206
+
207
+ if current_chunk:
208
+ chunks.append({
209
+ "tag": "Medical Information",
210
+ "text": current_chunk.strip()
211
+ })
212
+
213
+ return chunks
214
 
215
+ # ---------- Private Methods ----------
216
  def _touch_user(self, user_id: str):
217
+ """Update LRU queue"""
 
218
  if user_id in self.user_queue:
219
  self.user_queue.remove(user_id)
220
  self.user_queue.append(user_id)
221
 
222
+ def _new_index(self):
223
+ """Create new FAISS index"""
224
+ return faiss.IndexFlatIP(384) # 384-dim embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
227
+ """Update short-term memory with merging/deduplication"""
228
+ topic = chunk["tag"]
229
+ text = chunk["text"]
230
+
231
+ # Check for similar topics in STM
232
+ for entry in self.stm_summaries[user_id]:
233
+ if self._topics_similar(topic, entry["topic"]):
234
+ # Merge with existing entry
235
+ entry["text"] = summarizer.summarize_text(
236
+ f"{entry['text']}\n{text}",
237
+ max_length=200
238
  )
239
+ entry["timestamp"] = time.time()
240
+ return
 
 
 
 
 
 
 
 
 
241
 
242
+ # Add new entry
243
+ self.stm_summaries[user_id].append({
244
+ "topic": topic,
245
+ "text": text,
246
+ "vec": self._embed(f"{topic} {text}"),
247
+ "timestamp": time.time(),
248
+ "used": 0
249
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
252
+ """Update long-term memory with merging/deduplication"""
 
 
 
253
  for chunk in chunks:
254
+ # Check for similar chunks in LTM
255
+ similar_idx = self._find_similar_chunk(user_id, chunk["text"])
256
+
257
+ if similar_idx is not None:
258
+ # Merge with existing chunk
259
+ existing = self.chunk_meta[user_id][similar_idx]
260
+ merged_text = summarizer.summarize_text(
261
+ f"{existing['text']}\n{chunk['text']}",
262
+ max_length=300
263
+ )
264
+ existing["text"] = merged_text
265
+ existing["timestamp"] = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  else:
267
+ # Add new chunk
268
+ if len(self.chunk_meta[user_id]) >= self.max_chunks:
269
+ # Remove oldest chunk
270
+ self._remove_oldest_chunk(user_id)
271
+
272
+ vec = self._embed(chunk["text"])
273
+ self.chunk_index[user_id].add(np.array([vec]))
274
+ self.chunk_meta[user_id].append({
275
+ "text": chunk["text"],
276
+ "tag": chunk["tag"],
277
+ "vec": vec,
278
+ "timestamp": time.time(),
279
+ "used": 0
280
+ })
281
 
282
+ def _topics_similar(self, topic1: str, topic2: str) -> bool:
283
+ """Check if two topics are similar"""
284
+ # Simple similarity check based on common words
285
+ words1 = set(topic1.lower().split())
286
+ words2 = set(topic2.lower().split())
287
+ intersection = words1.intersection(words2)
288
+ return len(intersection) >= 2
 
 
 
 
 
 
 
 
 
289
 
290
+ def _find_similar_chunk(self, user_id: str, text: str) -> int:
291
+ """Find similar chunk in LTM"""
292
+ if not self.chunk_meta[user_id]:
293
+ return None
294
+
295
+ text_vec = self._embed(text)
296
+ sims, idxs = self.chunk_index[user_id].search(np.array([text_vec]), k=3)
297
+
298
+ for sim, idx in zip(sims[0], idxs[0]):
299
+ if sim > 0.8: # High similarity threshold
300
+ return int(idx)
301
+ return None
302
 
303
+ def _remove_oldest_chunk(self, user_id: str):
304
+ """Remove the oldest chunk from LTM"""
305
+ if not self.chunk_meta[user_id]:
306
+ return
307
+
308
+ # Find oldest chunk
309
+ oldest_idx = min(range(len(self.chunk_meta[user_id])),
310
+ key=lambda i: self.chunk_meta[user_id][i]["timestamp"])
311
+
312
+ # Remove from index and metadata
313
+ self.chunk_meta[user_id].pop(oldest_idx)
314
+ # Note: FAISS doesn't support direct removal, so we rebuild the index
315
+ self._rebuild_index(user_id)
 
 
 
 
316
 
317
+ def _rebuild_index(self, user_id: str):
318
+ """Rebuild FAISS index after removal"""
319
+ if not self.chunk_meta[user_id]:
320
+ self.chunk_index[user_id] = self._new_index()
321
+ return
322
+
323
+ vectors = [chunk["vec"] for chunk in self.chunk_meta[user_id]]
324
+ self.chunk_index[user_id] = self._new_index()
325
+ self.chunk_index[user_id].add(np.array(vectors))
326
 
327
+ @staticmethod
328
+ def _embed(text: str):
329
+ vec = EMBED.encode(text, convert_to_numpy=True)
330
+ # L2 normalise for cosine on IndexFlatIP
331
+ return vec / (np.linalg.norm(vec) + 1e-9)
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Models package
2
+ from .llama import NVIDIALLamaClient, process_search_query
3
+ from .summarizer import TextSummarizer, summarizer
download_model.py → models/download_model.py RENAMED
File without changes
llama_integration.py → models/llama.py RENAMED
@@ -2,6 +2,7 @@ import os
2
  import requests
3
  import json
4
  import logging
 
5
  from typing import List, Dict, Tuple
6
 
7
  logger = logging.getLogger(__name__)
@@ -40,27 +41,11 @@ Keywords:"""
40
  def summarize_documents(self, documents: List[Dict], user_query: str) -> Tuple[str, Dict[int, str]]:
41
  """Use Llama to summarize documents and return summary with URL mapping"""
42
  try:
43
- # Create document summaries
44
- doc_summaries = []
45
- url_mapping = {}
46
 
47
- for doc in documents:
48
- doc_id = doc['id']
49
- url_mapping[doc_id] = doc['url']
50
-
51
- # Create a summary prompt for each document
52
- summary_prompt = f"""Summarize this medical information in 2-3 sentences, focusing on details relevant to: "{user_query}"
53
-
54
- Document: {doc['title']}
55
- Content: {doc['content'][:1000]}...
56
-
57
- Summary:"""
58
-
59
- summary = self._call_llama(summary_prompt)
60
- doc_summaries.append(f"Document {doc_id}: {summary}")
61
-
62
- # Combine all summaries
63
- combined_summary = "\n\n".join(doc_summaries)
64
 
65
  return combined_summary, url_mapping
66
 
@@ -68,41 +53,58 @@ Summary:"""
68
  logger.error(f"Failed to summarize documents: {e}")
69
  return "", {}
70
 
71
- def _call_llama(self, prompt: str) -> str:
72
- """Make API call to NVIDIA Llama model"""
73
- try:
74
- headers = {
75
- "Authorization": f"Bearer {self.api_key}",
76
- "Content-Type": "application/json"
77
- }
78
-
79
- payload = {
80
- "model": self.model,
81
- "messages": [
82
- {
83
- "role": "user",
84
- "content": prompt
85
- }
86
- ],
87
- "temperature": 0.7,
88
- "max_tokens": 1000
89
- }
90
-
91
- response = requests.post(
92
- self.base_url,
93
- headers=headers,
94
- json=payload,
95
- timeout=30
96
- )
97
-
98
- response.raise_for_status()
99
- result = response.json()
100
-
101
- return result['choices'][0]['message']['content'].strip()
102
-
103
- except Exception as e:
104
- logger.error(f"Llama API call failed: {e}")
105
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def process_search_query(user_query: str, search_results: List[Dict]) -> Tuple[str, Dict[int, str]]:
108
  """Process search results using Llama model"""
 
2
  import requests
3
  import json
4
  import logging
5
+ import time
6
  from typing import List, Dict, Tuple
7
 
8
  logger = logging.getLogger(__name__)
 
41
  def summarize_documents(self, documents: List[Dict], user_query: str) -> Tuple[str, Dict[int, str]]:
42
  """Use Llama to summarize documents and return summary with URL mapping"""
43
  try:
44
+ # Import summarizer here to avoid circular imports
45
+ from summarizer import summarizer
 
46
 
47
+ # Use the summarizer for document summarization
48
+ combined_summary, url_mapping = summarizer.summarize_documents(documents, user_query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  return combined_summary, url_mapping
51
 
 
53
  logger.error(f"Failed to summarize documents: {e}")
54
  return "", {}
55
 
56
+ def _call_llama(self, prompt: str, max_retries: int = 3) -> str:
57
+ """Make API call to NVIDIA Llama model with retry logic"""
58
+ for attempt in range(max_retries):
59
+ try:
60
+ headers = {
61
+ "Authorization": f"Bearer {self.api_key}",
62
+ "Content-Type": "application/json"
63
+ }
64
+
65
+ payload = {
66
+ "model": self.model,
67
+ "messages": [
68
+ {
69
+ "role": "user",
70
+ "content": prompt
71
+ }
72
+ ],
73
+ "temperature": 0.7,
74
+ "max_tokens": 1000
75
+ }
76
+
77
+ response = requests.post(
78
+ self.base_url,
79
+ headers=headers,
80
+ json=payload,
81
+ timeout=30
82
+ )
83
+
84
+ response.raise_for_status()
85
+ result = response.json()
86
+
87
+ content = result['choices'][0]['message']['content'].strip()
88
+ if not content:
89
+ raise ValueError("Empty response from Llama API")
90
+
91
+ return content
92
+
93
+ except requests.exceptions.Timeout:
94
+ logger.warning(f"Llama API timeout (attempt {attempt + 1}/{max_retries})")
95
+ if attempt == max_retries - 1:
96
+ raise
97
+ time.sleep(2 ** attempt) # Exponential backoff
98
+
99
+ except requests.exceptions.RequestException as e:
100
+ logger.warning(f"Llama API request failed (attempt {attempt + 1}/{max_retries}): {e}")
101
+ if attempt == max_retries - 1:
102
+ raise
103
+ time.sleep(2 ** attempt)
104
+
105
+ except Exception as e:
106
+ logger.error(f"Llama API call failed: {e}")
107
+ raise
108
 
109
  def process_search_query(user_query: str, search_results: List[Dict]) -> Tuple[str, Dict[int, str]]:
110
  """Process search results using Llama model"""
models/summarizer.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import logging
3
+ from typing import List, Dict, Tuple
4
+ from llama import NVIDIALLamaClient
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class TextSummarizer:
9
+ def __init__(self):
10
+ self.llama_client = NVIDIALLamaClient()
11
+
12
+ def clean_text(self, text: str) -> str:
13
+ """Clean and normalize text for summarization"""
14
+ if not text:
15
+ return ""
16
+
17
+ # Remove common conversation starters and fillers
18
+ conversation_patterns = [
19
+ r'\b(hi|hello|hey|sure|okay|yes|no|thanks|thank you)\b',
20
+ r'\b(here is|this is|let me|i will|i can|i would)\b',
21
+ r'\b(summarize|summary|here\'s|here is)\b',
22
+ r'\b(please|kindly|would you|could you)\b',
23
+ r'\b(um|uh|er|ah|well|so|like|you know)\b'
24
+ ]
25
+
26
+ # Remove excessive whitespace and normalize
27
+ text = re.sub(r'\s+', ' ', text)
28
+ text = re.sub(r'\n+', ' ', text)
29
+
30
+ # Remove conversation patterns
31
+ for pattern in conversation_patterns:
32
+ text = re.sub(pattern, '', text, flags=re.IGNORECASE)
33
+
34
+ # Remove extra punctuation and normalize
35
+ text = re.sub(r'[.]{2,}', '.', text)
36
+ text = re.sub(r'[!]{2,}', '!', text)
37
+ text = re.sub(r'[?]{2,}', '?', text)
38
+
39
+ return text.strip()
40
+
41
+ def extract_key_phrases(self, text: str) -> List[str]:
42
+ """Extract key medical phrases and terms"""
43
+ if not text:
44
+ return []
45
+
46
+ # Medical term patterns
47
+ medical_patterns = [
48
+ r'\b(?:symptoms?|diagnosis|treatment|therapy|medication|drug|disease|condition|syndrome)\b',
49
+ r'\b(?:patient|doctor|physician|medical|clinical|healthcare)\b',
50
+ r'\b(?:blood pressure|heart rate|temperature|pulse|respiration)\b',
51
+ r'\b(?:acute|chronic|severe|mild|moderate|serious|critical)\b',
52
+ r'\b(?:pain|ache|discomfort|swelling|inflammation|infection)\b'
53
+ ]
54
+
55
+ key_phrases = []
56
+ for pattern in medical_patterns:
57
+ matches = re.findall(pattern, text, re.IGNORECASE)
58
+ key_phrases.extend(matches)
59
+
60
+ return list(set(key_phrases)) # Remove duplicates
61
+
62
+ def summarize_text(self, text: str, max_length: int = 200) -> str:
63
+ """Summarize text using NVIDIA Llama model"""
64
+ try:
65
+ if not text or len(text.strip()) < 50:
66
+ return text
67
+
68
+ # Clean the text first
69
+ cleaned_text = self.clean_text(text)
70
+
71
+ # Extract key phrases for context
72
+ key_phrases = self.extract_key_phrases(cleaned_text)
73
+ key_phrases_str = ", ".join(key_phrases[:5]) if key_phrases else "medical information"
74
+
75
+ # Create optimized prompt
76
+ prompt = f"""Summarize this medical text in {max_length} characters or less. Focus only on key medical facts, symptoms, treatments, and diagnoses. Do not include greetings, confirmations, or conversational elements.
77
+
78
+ Key terms: {key_phrases_str}
79
+
80
+ Text: {cleaned_text[:1500]}
81
+
82
+ Summary:"""
83
+
84
+ summary = self.llama_client._call_llama(prompt)
85
+
86
+ # Post-process summary
87
+ summary = self.clean_text(summary)
88
+
89
+ # Ensure it's within length limit
90
+ if len(summary) > max_length:
91
+ summary = summary[:max_length-3] + "..."
92
+
93
+ return summary
94
+
95
+ except Exception as e:
96
+ logger.error(f"Summarization failed: {e}")
97
+ # Fallback to simple truncation
98
+ return self.clean_text(text)[:max_length]
99
+
100
+ def summarize_documents(self, documents: List[Dict], user_query: str) -> Tuple[str, Dict[int, str]]:
101
+ """Summarize multiple documents with URL mapping"""
102
+ try:
103
+ doc_summaries = []
104
+ url_mapping = {}
105
+
106
+ for doc in documents:
107
+ doc_id = doc['id']
108
+ url_mapping[doc_id] = doc['url']
109
+
110
+ # Create focused summary for each document
111
+ summary_prompt = f"""Summarize this medical document in 2-3 sentences, focusing on information relevant to: "{user_query}"
112
+
113
+ Document: {doc['title']}
114
+ Content: {doc['content'][:800]}
115
+
116
+ Key medical information:"""
117
+
118
+ summary = self.llama_client._call_llama(summary_prompt)
119
+ summary = self.clean_text(summary)
120
+
121
+ doc_summaries.append(f"Document {doc_id}: {summary}")
122
+
123
+ combined_summary = "\n\n".join(doc_summaries)
124
+ return combined_summary, url_mapping
125
+
126
+ except Exception as e:
127
+ logger.error(f"Document summarization failed: {e}")
128
+ return "", {}
129
+
130
+ def summarize_conversation_chunk(self, chunk: str) -> str:
131
+ """Summarize a conversation chunk for memory"""
132
+ try:
133
+ if not chunk or len(chunk.strip()) < 30:
134
+ return chunk
135
+
136
+ cleaned_chunk = self.clean_text(chunk)
137
+
138
+ prompt = f"""Summarize this medical conversation in 1-2 sentences. Focus only on medical facts, symptoms, treatments, or diagnoses discussed. Remove greetings and conversational elements.
139
+
140
+ Conversation: {cleaned_chunk[:1000]}
141
+
142
+ Medical summary:"""
143
+
144
+ summary = self.llama_client._call_llama(prompt)
145
+ return self.clean_text(summary)
146
+
147
+ except Exception as e:
148
+ logger.error(f"Conversation summarization failed: {e}")
149
+ return self.clean_text(chunk)[:150]
150
+
151
+ def chunk_response(self, response: str, max_chunk_size: int = 500) -> List[str]:
152
+ """Split response into chunks and summarize each"""
153
+ try:
154
+ if not response or len(response) <= max_chunk_size:
155
+ return [response]
156
+
157
+ # Split by sentences first
158
+ sentences = re.split(r'[.!?]+', response)
159
+ chunks = []
160
+ current_chunk = ""
161
+
162
+ for sentence in sentences:
163
+ sentence = sentence.strip()
164
+ if not sentence:
165
+ continue
166
+
167
+ # Check if adding this sentence would exceed limit
168
+ if len(current_chunk) + len(sentence) > max_chunk_size and current_chunk:
169
+ chunks.append(self.summarize_conversation_chunk(current_chunk))
170
+ current_chunk = sentence
171
+ else:
172
+ current_chunk += sentence + ". "
173
+
174
+ # Add the last chunk
175
+ if current_chunk:
176
+ chunks.append(self.summarize_conversation_chunk(current_chunk))
177
+
178
+ return chunks
179
+
180
+ except Exception as e:
181
+ logger.error(f"Response chunking failed: {e}")
182
+ return [response]
183
+
184
+ # Global summarizer instance
185
+ summarizer = TextSummarizer()
warmup.py → models/warmup.py RENAMED
File without changes
search/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Search package
2
+ from .search import WebSearcher, search_web
search.py → search/search.py RENAMED
@@ -96,39 +96,54 @@ class WebSearcher:
96
  logger.warning(f"Failed to extract content from {url}: {e}")
97
  return ""
98
 
99
- def search_and_extract(self, query: str, num_results: int = 5) -> List[Dict]:
100
  """Search for query and extract content from top results"""
101
  logger.info(f"Searching for: {query}")
102
 
103
- # Get search results
104
- search_results = self.search_duckduckgo(query, num_results)
105
 
106
- # Extract content from each result
107
  enriched_results = []
 
 
 
108
  for i, result in enumerate(search_results):
 
 
 
 
 
 
 
109
  try:
110
  logger.info(f"Extracting content from {result['url']}")
111
  content = self.extract_content(result['url'])
112
 
113
- if content:
114
  enriched_results.append({
115
- 'id': i + 1,
116
  'url': result['url'],
117
  'title': result['title'],
118
  'content': content
119
  })
 
 
 
 
120
 
121
  # Add delay to be respectful
122
- time.sleep(1)
123
 
124
  except Exception as e:
 
125
  logger.warning(f"Failed to process {result['url']}: {e}")
126
  continue
127
 
128
- logger.info(f"Successfully processed {len(enriched_results)} results")
129
  return enriched_results
130
 
131
- def search_web(query: str, num_results: int = 5) -> List[Dict]:
132
  """Main function to search the web and return enriched results"""
133
  searcher = WebSearcher()
134
  return searcher.search_and_extract(query, num_results)
 
96
  logger.warning(f"Failed to extract content from {url}: {e}")
97
  return ""
98
 
99
+ def search_and_extract(self, query: str, num_results: int = 10) -> List[Dict]:
100
  """Search for query and extract content from top results"""
101
  logger.info(f"Searching for: {query}")
102
 
103
+ # Get search results (fetch more than needed for filtering)
104
+ search_results = self.search_duckduckgo(query, min(num_results * 2, 20))
105
 
106
+ # Extract content from each result with parallel processing
107
  enriched_results = []
108
+ failed_count = 0
109
+ max_failures = 5 # Stop after 5 consecutive failures
110
+
111
  for i, result in enumerate(search_results):
112
+ if len(enriched_results) >= num_results:
113
+ break
114
+
115
+ if failed_count >= max_failures:
116
+ logger.warning(f"Too many failures ({failed_count}), stopping extraction")
117
+ break
118
+
119
  try:
120
  logger.info(f"Extracting content from {result['url']}")
121
  content = self.extract_content(result['url'])
122
 
123
+ if content and len(content.strip()) > 50: # Only include substantial content
124
  enriched_results.append({
125
+ 'id': len(enriched_results) + 1, # Sequential ID
126
  'url': result['url'],
127
  'title': result['title'],
128
  'content': content
129
  })
130
+ failed_count = 0 # Reset failure counter
131
+ else:
132
+ failed_count += 1
133
+ logger.warning(f"Insufficient content from {result['url']}")
134
 
135
  # Add delay to be respectful
136
+ time.sleep(0.5) # Reduced delay for better performance
137
 
138
  except Exception as e:
139
+ failed_count += 1
140
  logger.warning(f"Failed to process {result['url']}: {e}")
141
  continue
142
 
143
+ logger.info(f"Successfully processed {len(enriched_results)} results out of {len(search_results)} attempted")
144
  return enriched_results
145
 
146
+ def search_web(query: str, num_results: int = 10) -> List[Dict]:
147
  """Main function to search the web and return enriched results"""
148
  searcher = WebSearcher()
149
  return searcher.search_and_extract(query, num_results)
utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Utils package
2
+ from .translation import translate_query
3
+ from .vlm import process_medical_image
4
+ from .diagnosis import retrieve_diagnosis_from_symptoms
clear_mongo.py → utils/clear_mongo.py RENAMED
File without changes
connect_mongo.py → utils/connect_mongo.py RENAMED
File without changes
diagnosis.py → utils/diagnosis.py RENAMED
File without changes
migrate.py → utils/migrate.py RENAMED
File without changes
translation.py → utils/translation.py RENAMED
File without changes
vlm.py → utils/vlm.py RENAMED
File without changes