Spaces:
Sleeping
Sleeping
Commit
·
b8bf5c8
1
Parent(s):
410be5e
Upd backend with full search service implementation. Refactor directory
Browse files- .dockerignore +4 -0
- Dockerfile +3 -3
- README.md +120 -1
- api/README.md +96 -0
- api/__init__.py +2 -0
- api/app.py +54 -0
- api/chatbot.py +140 -0
- api/config.py +70 -0
- api/database.py +100 -0
- app.py → api/legacy.py +17 -11
- api/retrieval.py +100 -0
- api/routes.py +63 -0
- main.py +6 -0
- memory/__init__.py +2 -0
- memory.py → memory/memory.py +180 -275
- models/__init__.py +3 -0
- download_model.py → models/download_model.py +0 -0
- llama_integration.py → models/llama.py +57 -55
- models/summarizer.py +185 -0
- warmup.py → models/warmup.py +0 -0
- search/__init__.py +2 -0
- search.py → search/search.py +24 -9
- utils/__init__.py +4 -0
- clear_mongo.py → utils/clear_mongo.py +0 -0
- connect_mongo.py → utils/connect_mongo.py +0 -0
- diagnosis.py → utils/diagnosis.py +0 -0
- migrate.py → utils/migrate.py +0 -0
- translation.py → utils/translation.py +0 -0
- vlm.py → utils/vlm.py +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
api/legacy.py
|
| 2 |
+
*.md
|
| 3 |
+
.env
|
| 4 |
+
*yml
|
Dockerfile
CHANGED
|
@@ -24,7 +24,7 @@ RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformer
|
|
| 24 |
chown -R user:user /app/model_cache /home/user/.cache/huggingface
|
| 25 |
|
| 26 |
# Pre-load model in a separate script
|
| 27 |
-
RUN python /app/download_model.py && python /app/warmup.py
|
| 28 |
|
| 29 |
# Ensure ownership and permissions remain intact
|
| 30 |
RUN chown -R user:user /app/model_cache
|
|
@@ -32,5 +32,5 @@ RUN chown -R user:user /app/model_cache
|
|
| 32 |
# Expose port
|
| 33 |
EXPOSE 7860
|
| 34 |
|
| 35 |
-
# Run the application
|
| 36 |
-
CMD ["
|
|
|
|
| 24 |
chown -R user:user /app/model_cache /home/user/.cache/huggingface
|
| 25 |
|
| 26 |
# Pre-load model in a separate script
|
| 27 |
+
RUN python /app/models/download_model.py && python /app/models/warmup.py
|
| 28 |
|
| 29 |
# Ensure ownership and permissions remain intact
|
| 30 |
RUN chown -R user:user /app/model_cache
|
|
|
|
| 32 |
# Expose port
|
| 33 |
EXPOSE 7860
|
| 34 |
|
| 35 |
+
# Run the application using main.py as entry point
|
| 36 |
+
CMD ["python", "main.py"]
|
README.md
CHANGED
|
@@ -10,4 +10,123 @@ license: apache-2.0
|
|
| 10 |
short_description: MedicalChatbot, FAISS, Gemini, MongoDB vDB, LRU
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
short_description: MedicalChatbot, FAISS, Gemini, MongoDB vDB, LRU
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Medical Chatbot Backend
|
| 14 |
+
|
| 15 |
+
## Project Structure
|
| 16 |
+
|
| 17 |
+
The backend is organized into logical modules for better maintainability:
|
| 18 |
+
|
| 19 |
+
### 📁 **api/**
|
| 20 |
+
- **app.py** - Main FastAPI application with endpoints
|
| 21 |
+
- **__init__.py** - API package initialization
|
| 22 |
+
|
| 23 |
+
### 📁 **models/**
|
| 24 |
+
- **llama.py** - NVIDIA Llama model integration for search processing
|
| 25 |
+
- **summarizer.py** - Text summarization using NVIDIA Llama
|
| 26 |
+
- **download_model.py** - Model download utilities
|
| 27 |
+
- **warmup.py** - Model warmup scripts
|
| 28 |
+
|
| 29 |
+
### 📁 **memory/**
|
| 30 |
+
- **memory_updated.py** - Enhanced memory management with NVIDIA Llama summarization
|
| 31 |
+
- **memory.py** - Legacy memory implementation
|
| 32 |
+
|
| 33 |
+
### 📁 **search/**
|
| 34 |
+
- **search.py** - Web search and content extraction functionality
|
| 35 |
+
|
| 36 |
+
### 📁 **utils/**
|
| 37 |
+
- **translation.py** - Multi-language translation utilities
|
| 38 |
+
- **vlm.py** - Vision Language Model for medical image processing
|
| 39 |
+
- **diagnosis.py** - Symptom-based diagnosis utilities
|
| 40 |
+
- **connect_mongo.py** - MongoDB connection utilities
|
| 41 |
+
- **clear_mongo.py** - Database cleanup utilities
|
| 42 |
+
- **migrate.py** - Database migration scripts
|
| 43 |
+
|
| 44 |
+
## Key Features
|
| 45 |
+
|
| 46 |
+
### 🔍 **Search Integration**
|
| 47 |
+
- Web search with up to 10 resources
|
| 48 |
+
- NVIDIA Llama model for keyword generation and document summarization
|
| 49 |
+
- Citation system with URL mapping
|
| 50 |
+
- Smart content filtering and validation
|
| 51 |
+
|
| 52 |
+
### 🧠 **Enhanced Memory Management**
|
| 53 |
+
- NVIDIA Llama-powered summarization for all text processing
|
| 54 |
+
- Optimized chunking and context retrieval
|
| 55 |
+
- Smart deduplication and merging
|
| 56 |
+
- Conversation continuity with concise summaries
|
| 57 |
+
|
| 58 |
+
### 📝 **Summarization System**
|
| 59 |
+
- **Text Cleaning**: Removes conversational fillers and normalizes text
|
| 60 |
+
- **Key Phrase Extraction**: Identifies medical terms and concepts
|
| 61 |
+
- **Concise Summaries**: Preserves key ideas without fluff
|
| 62 |
+
- **NVIDIA Llama Integration**: All summarization uses NVIDIA model instead of Gemini
|
| 63 |
+
|
| 64 |
+
## Usage
|
| 65 |
+
|
| 66 |
+
### Running the Application
|
| 67 |
+
```bash
|
| 68 |
+
# Using main entry point
|
| 69 |
+
python main.py
|
| 70 |
+
|
| 71 |
+
# Or directly
|
| 72 |
+
python api/app.py
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Environment Variables
|
| 76 |
+
- `NVIDIA_URI` - NVIDIA API key for Llama model
|
| 77 |
+
- `FlashAPI` - Gemini API key
|
| 78 |
+
- `MONGO_URI` - MongoDB connection string
|
| 79 |
+
- `INDEX_URI` - FAISS index database URI
|
| 80 |
+
|
| 81 |
+
## API Endpoints
|
| 82 |
+
|
| 83 |
+
### POST `/chat`
|
| 84 |
+
Main chat endpoint with search mode support.
|
| 85 |
+
|
| 86 |
+
**Request Body:**
|
| 87 |
+
```json
|
| 88 |
+
{
|
| 89 |
+
"query": "User's medical question",
|
| 90 |
+
"lang": "EN",
|
| 91 |
+
"search": true,
|
| 92 |
+
"user_id": "unique_user_id",
|
| 93 |
+
"image_base64": "optional_base64_image",
|
| 94 |
+
"img_desc": "image_description"
|
| 95 |
+
}
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
**Response:**
|
| 99 |
+
```json
|
| 100 |
+
{
|
| 101 |
+
"response": "Medical response with citations <URL>",
|
| 102 |
+
"response_time": "2.34s"
|
| 103 |
+
}
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## Search Mode Features
|
| 107 |
+
|
| 108 |
+
When `search: true`:
|
| 109 |
+
1. **Web Search**: Fetches up to 10 relevant medical resources
|
| 110 |
+
2. **Llama Processing**: Generates keywords and summarizes content
|
| 111 |
+
3. **Citation System**: Replaces `<#ID>` tags with actual URLs
|
| 112 |
+
4. **UI Integration**: Frontend displays magnifier icons for source links
|
| 113 |
+
|
| 114 |
+
## Summarization Features
|
| 115 |
+
|
| 116 |
+
All summarization tasks use NVIDIA Llama model:
|
| 117 |
+
- **get_contextual_chunks**: Summarizes conversation history and RAG chunks
|
| 118 |
+
- **chunk_response**: Chunks and summarizes bot responses
|
| 119 |
+
- **summarize_documents**: Summarizes web search results
|
| 120 |
+
|
| 121 |
+
### Text Processing Pipeline
|
| 122 |
+
1. **Clean Text**: Remove conversational elements and normalize
|
| 123 |
+
2. **Extract Key Phrases**: Identify medical terms and concepts
|
| 124 |
+
3. **Summarize**: Create concise, focused summaries
|
| 125 |
+
4. **Validate**: Ensure quality and relevance
|
| 126 |
+
|
| 127 |
+
## Dependencies
|
| 128 |
+
|
| 129 |
+
See `requirements.txt` for complete list. Key additions:
|
| 130 |
+
- `requests` - Web search functionality
|
| 131 |
+
- `beautifulsoup4` - HTML content extraction
|
| 132 |
+
- NVIDIA API integration for Llama model
|
api/README.md
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API Module Structure
|
| 2 |
+
|
| 3 |
+
## 📁 **Module Overview**
|
| 4 |
+
|
| 5 |
+
### **config.py** - Configuration Management
|
| 6 |
+
- Environment variables validation
|
| 7 |
+
- Logging configuration
|
| 8 |
+
- System resource monitoring
|
| 9 |
+
- Memory optimization settings
|
| 10 |
+
- CORS configuration
|
| 11 |
+
|
| 12 |
+
### **database.py** - Database Management
|
| 13 |
+
- MongoDB connection management
|
| 14 |
+
- FAISS index lazy loading
|
| 15 |
+
- SentenceTransformer model initialization
|
| 16 |
+
- Symptom vectors management
|
| 17 |
+
- GridFS integration
|
| 18 |
+
|
| 19 |
+
### **retrieval.py** - RAG Retrieval Engine
|
| 20 |
+
- Medical information retrieval from FAISS
|
| 21 |
+
- Symptom-based diagnosis retrieval
|
| 22 |
+
- Smart deduplication and similarity matching
|
| 23 |
+
- Vector similarity computations
|
| 24 |
+
|
| 25 |
+
### **chatbot.py** - Core Chatbot Logic
|
| 26 |
+
- RAGMedicalChatbot class
|
| 27 |
+
- Gemini API client
|
| 28 |
+
- Search mode integration
|
| 29 |
+
- Citation processing
|
| 30 |
+
- Memory management integration
|
| 31 |
+
|
| 32 |
+
### **routes.py** - API Endpoints
|
| 33 |
+
- `/chat` - Main chat endpoint
|
| 34 |
+
- `/health` - Health check
|
| 35 |
+
- `/` - Root endpoint
|
| 36 |
+
- Request/response handling
|
| 37 |
+
|
| 38 |
+
### **app.py** - Main Application
|
| 39 |
+
- FastAPI app initialization
|
| 40 |
+
- Middleware configuration
|
| 41 |
+
- Database initialization
|
| 42 |
+
- Route registration
|
| 43 |
+
- Server startup
|
| 44 |
+
|
| 45 |
+
## 🔄 **Data Flow**
|
| 46 |
+
|
| 47 |
+
```
|
| 48 |
+
Request → routes.py → chatbot.py → retrieval.py → database.py
|
| 49 |
+
↓
|
| 50 |
+
memory.py (context) + search.py (web search)
|
| 51 |
+
↓
|
| 52 |
+
models/ (NVIDIA Llama processing)
|
| 53 |
+
↓
|
| 54 |
+
Response with citations
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## 🚀 **Benefits of Modular Structure**
|
| 58 |
+
|
| 59 |
+
1. **Separation of Concerns**: Each module has a single responsibility
|
| 60 |
+
2. **Easier Testing**: Individual modules can be tested in isolation
|
| 61 |
+
3. **Better Maintainability**: Changes to one module don't affect others
|
| 62 |
+
4. **Improved Readability**: Smaller files are easier to understand
|
| 63 |
+
5. **Reusability**: Modules can be imported and used elsewhere
|
| 64 |
+
6. **Scalability**: Easy to add new features without affecting existing code
|
| 65 |
+
|
| 66 |
+
## 📊 **File Sizes Comparison**
|
| 67 |
+
|
| 68 |
+
| File | Lines | Purpose |
|
| 69 |
+
|------|-------|---------|
|
| 70 |
+
| **app_old.py** | 370 | Monolithic (everything) |
|
| 71 |
+
| **app.py** | 45 | Main app initialization |
|
| 72 |
+
| **config.py** | 65 | Configuration |
|
| 73 |
+
| **database.py** | 95 | Database management |
|
| 74 |
+
| **retrieval.py** | 85 | RAG retrieval |
|
| 75 |
+
| **chatbot.py** | 120 | Chatbot logic |
|
| 76 |
+
| **routes.py** | 55 | API endpoints |
|
| 77 |
+
| **Total** | 465 | Modular structure |
|
| 78 |
+
|
| 79 |
+
## 🔧 **Usage**
|
| 80 |
+
|
| 81 |
+
The modular structure maintains the same API interface:
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
# All imports work the same way
|
| 85 |
+
from api.app import app
|
| 86 |
+
from api.chatbot import RAGMedicalChatbot
|
| 87 |
+
from api.retrieval import retrieval_engine
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## 🛠 **Development Benefits**
|
| 91 |
+
|
| 92 |
+
- **Easier Debugging**: Issues can be isolated to specific modules
|
| 93 |
+
- **Parallel Development**: Multiple developers can work on different modules
|
| 94 |
+
- **Code Reviews**: Smaller files are easier to review
|
| 95 |
+
- **Documentation**: Each module can have focused documentation
|
| 96 |
+
- **Testing**: Unit tests can be written for each module independently
|
api/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# API package
|
| 2 |
+
# Main API endpoints and routes
|
api/app.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/app_new.py
|
| 2 |
+
import uvicorn
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from api.config import setup_logging, check_system_resources, optimize_memory, CORS_ORIGINS
|
| 6 |
+
from api.database import db_manager
|
| 7 |
+
from api.routes import router
|
| 8 |
+
|
| 9 |
+
# ✅ Setup logging
|
| 10 |
+
logger = setup_logging()
|
| 11 |
+
logger.info("🚀 Starting Medical Chatbot API...")
|
| 12 |
+
|
| 13 |
+
# ✅ Monitor system resources
|
| 14 |
+
check_system_resources(logger)
|
| 15 |
+
|
| 16 |
+
# ✅ Optimize memory usage
|
| 17 |
+
optimize_memory()
|
| 18 |
+
|
| 19 |
+
# ✅ Initialize FastAPI app
|
| 20 |
+
app = FastAPI(
|
| 21 |
+
title="Medical Chatbot API",
|
| 22 |
+
description="AI-powered medical chatbot with RAG and search capabilities",
|
| 23 |
+
version="1.0.0"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# ✅ Add CORS middleware
|
| 27 |
+
app.add_middleware(
|
| 28 |
+
CORSMiddleware,
|
| 29 |
+
allow_origins=CORS_ORIGINS,
|
| 30 |
+
allow_credentials=True,
|
| 31 |
+
allow_methods=["*"],
|
| 32 |
+
allow_headers=["*"],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# ✅ Initialize database connections
|
| 36 |
+
try:
|
| 37 |
+
db_manager.initialize_embedding_model()
|
| 38 |
+
db_manager.initialize_mongodb()
|
| 39 |
+
logger.info("✅ Database connections initialized successfully")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
logger.error(f"❌ Database initialization failed: {e}")
|
| 42 |
+
raise
|
| 43 |
+
|
| 44 |
+
# ✅ Include routes
|
| 45 |
+
app.include_router(router)
|
| 46 |
+
|
| 47 |
+
# ✅ Run Uvicorn
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
logger.info("[System] ✅ Starting FastAPI Server...")
|
| 50 |
+
try:
|
| 51 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"❌ Server Startup Failed: {e}")
|
| 54 |
+
exit(1)
|
api/chatbot.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/chatbot.py
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from google import genai
|
| 6 |
+
from api.config import gemini_flash_api_key
|
| 7 |
+
from api.retrieval import retrieval_engine
|
| 8 |
+
from memory import MemoryManager
|
| 9 |
+
from utils import translate_query, process_medical_image
|
| 10 |
+
from search import search_web
|
| 11 |
+
from models import process_search_query
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger("medical-chatbot")
|
| 14 |
+
|
| 15 |
+
class GeminiClient:
|
| 16 |
+
"""Gemini API client for generating responses"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.client = genai.Client(api_key=gemini_flash_api_key)
|
| 20 |
+
|
| 21 |
+
def generate_content(self, prompt: str, model: str = "gemini-2.5-flash", temperature: float = 0.7) -> str:
|
| 22 |
+
"""Generate content using Gemini API"""
|
| 23 |
+
try:
|
| 24 |
+
response = self.client.models.generate_content(model=model, contents=prompt)
|
| 25 |
+
return response.text
|
| 26 |
+
except Exception as e:
|
| 27 |
+
logger.error(f"[LLM] ❌ Error calling Gemini API: {e}")
|
| 28 |
+
return "Error generating response from Gemini."
|
| 29 |
+
|
| 30 |
+
class RAGMedicalChatbot:
|
| 31 |
+
"""Main chatbot class with RAG capabilities"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, model_name: str, retrieve_function):
|
| 34 |
+
self.model_name = model_name
|
| 35 |
+
self.retrieve = retrieve_function
|
| 36 |
+
self.gemini_client = GeminiClient()
|
| 37 |
+
self.memory = MemoryManager()
|
| 38 |
+
|
| 39 |
+
def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "", search_mode: bool = False) -> str:
|
| 40 |
+
"""Main chat method with RAG and search capabilities"""
|
| 41 |
+
|
| 42 |
+
# 0. Translate query if not EN, this help our RAG system
|
| 43 |
+
if lang.upper() in {"VI", "ZH"}:
|
| 44 |
+
user_query = translate_query(user_query, lang.lower())
|
| 45 |
+
|
| 46 |
+
# 1. Fetch knowledge
|
| 47 |
+
## a. KB for generic QA retrieval
|
| 48 |
+
retrieved_info = self.retrieve(user_query)
|
| 49 |
+
knowledge_base = "\n".join(retrieved_info)
|
| 50 |
+
## b. Diagnosis RAG from symptom query
|
| 51 |
+
diagnosis_guides = retrieval_engine.retrieve_diagnosis_from_symptoms(user_query)
|
| 52 |
+
|
| 53 |
+
# 1.5. Search mode - web search and Llama processing
|
| 54 |
+
search_context = ""
|
| 55 |
+
url_mapping = {}
|
| 56 |
+
if search_mode:
|
| 57 |
+
logger.info(f"[SEARCH] Starting web search mode for query: {user_query}")
|
| 58 |
+
try:
|
| 59 |
+
# Search the web with max 10 resources
|
| 60 |
+
search_results = search_web(user_query, num_results=10)
|
| 61 |
+
if search_results:
|
| 62 |
+
logger.info(f"[SEARCH] Retrieved {len(search_results)} web resources")
|
| 63 |
+
# Process with Llama
|
| 64 |
+
search_context, url_mapping = process_search_query(user_query, search_results)
|
| 65 |
+
logger.info(f"[SEARCH] Processed with Llama, generated {len(url_mapping)} URL mappings")
|
| 66 |
+
else:
|
| 67 |
+
logger.warning("[SEARCH] No search results found")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"[SEARCH] Search failed: {e}")
|
| 70 |
+
search_context = ""
|
| 71 |
+
|
| 72 |
+
# 2. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection
|
| 73 |
+
contextual_chunks = self.memory.get_contextual_chunks(user_id, user_query, lang)
|
| 74 |
+
|
| 75 |
+
# 3. Build prompt parts
|
| 76 |
+
parts = ["You are a medical chatbot, designed to answer medical questions."]
|
| 77 |
+
parts.append("Please format your answer using MarkDown.")
|
| 78 |
+
parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
|
| 79 |
+
|
| 80 |
+
# 4. Append image diagnosis from VLM
|
| 81 |
+
if image_diagnosis:
|
| 82 |
+
parts.append(
|
| 83 |
+
"A user medical image is diagnosed by our VLM agent:\n"
|
| 84 |
+
f"{image_diagnosis}\n\n"
|
| 85 |
+
"Please incorporate the above findings in your response if medically relevant.\n\n"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Append contextual chunks from hybrid approach
|
| 89 |
+
if contextual_chunks:
|
| 90 |
+
parts.append("Relevant context from conversation history:\n" + contextual_chunks)
|
| 91 |
+
# Load up guideline (RAG over medical knowledge base)
|
| 92 |
+
if knowledge_base:
|
| 93 |
+
parts.append(f"Example Q&A medical scenario knowledge-base: {knowledge_base}")
|
| 94 |
+
# Symptom-Diagnosis prediction RAG
|
| 95 |
+
if diagnosis_guides:
|
| 96 |
+
parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides))
|
| 97 |
+
|
| 98 |
+
# 5. Search context with citation instructions
|
| 99 |
+
if search_context:
|
| 100 |
+
parts.append("Additional information from web search:\n" + search_context)
|
| 101 |
+
parts.append("IMPORTANT: When you use information from the web search results above, you MUST add a citation tag <#ID> immediately after the relevant content, where ID is the document number (1, 2, 3, etc.). For example: 'According to recent studies <#1>, this condition affects...'")
|
| 102 |
+
|
| 103 |
+
parts.append(f"User's question: {user_query}")
|
| 104 |
+
parts.append(f"Language to generate answer: {lang}")
|
| 105 |
+
prompt = "\n\n".join(parts)
|
| 106 |
+
logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history
|
| 107 |
+
response = self.gemini_client.generate_content(prompt, model=self.model_name, temperature=0.7)
|
| 108 |
+
|
| 109 |
+
# 6. Process citations and replace with URLs
|
| 110 |
+
if search_mode and url_mapping:
|
| 111 |
+
response = self._process_citations(response, url_mapping)
|
| 112 |
+
|
| 113 |
+
# Store exchange + chunking
|
| 114 |
+
if user_id:
|
| 115 |
+
self.memory.add_exchange(user_id, user_query, response, lang=lang)
|
| 116 |
+
logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response
|
| 117 |
+
return response.strip()
|
| 118 |
+
|
| 119 |
+
def _process_citations(self, response: str, url_mapping: Dict[int, str]) -> str:
|
| 120 |
+
"""Replace citation tags with actual URLs"""
|
| 121 |
+
|
| 122 |
+
# Find all citation tags like <#1>, <#2>, etc.
|
| 123 |
+
citation_pattern = r'<#(\d+)>'
|
| 124 |
+
citations_found = re.findall(citation_pattern, response)
|
| 125 |
+
|
| 126 |
+
def replace_citation(match):
|
| 127 |
+
doc_id = int(match.group(1))
|
| 128 |
+
if doc_id in url_mapping:
|
| 129 |
+
url = url_mapping[doc_id]
|
| 130 |
+
logger.info(f"[CITATION] Replacing <#{doc_id}> with {url}")
|
| 131 |
+
return f'<{url}>'
|
| 132 |
+
else:
|
| 133 |
+
logger.warning(f"[CITATION] No URL mapping found for document ID {doc_id}")
|
| 134 |
+
return match.group(0) # Keep original if URL not found
|
| 135 |
+
|
| 136 |
+
# Replace citations with URLs
|
| 137 |
+
processed_response = re.sub(citation_pattern, replace_citation, response)
|
| 138 |
+
|
| 139 |
+
logger.info(f"[CITATION] Processed {len(citations_found)} citations, {len(url_mapping)} URL mappings available")
|
| 140 |
+
return processed_response
|
api/config.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/config.py
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import psutil
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
# ✅ Environment Variables
|
| 8 |
+
mongo_uri = os.getenv("MONGO_URI")
|
| 9 |
+
index_uri = os.getenv("INDEX_URI")
|
| 10 |
+
gemini_flash_api_key = os.getenv("FlashAPI")
|
| 11 |
+
|
| 12 |
+
# Validate environment endpoint
|
| 13 |
+
if not all([gemini_flash_api_key, mongo_uri, index_uri]):
|
| 14 |
+
raise ValueError("❌ Missing API keys! Set them in Hugging Face Secrets.")
|
| 15 |
+
|
| 16 |
+
# ✅ Logging Configuration
|
| 17 |
+
def setup_logging():
|
| 18 |
+
"""Configure logging for the application"""
|
| 19 |
+
# Silence noisy loggers
|
| 20 |
+
for name in [
|
| 21 |
+
"uvicorn.error", "uvicorn.access",
|
| 22 |
+
"fastapi", "starlette",
|
| 23 |
+
"pymongo", "gridfs",
|
| 24 |
+
"sentence_transformers", "faiss",
|
| 25 |
+
"google", "google.auth",
|
| 26 |
+
]:
|
| 27 |
+
logging.getLogger(name).setLevel(logging.WARNING)
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(
|
| 30 |
+
level=logging.INFO,
|
| 31 |
+
format="%(asctime)s — %(name)s — %(levelname)s — %(message)s",
|
| 32 |
+
force=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger("medical-chatbot")
|
| 36 |
+
logger.setLevel(logging.DEBUG)
|
| 37 |
+
return logger
|
| 38 |
+
|
| 39 |
+
# ✅ System Resource Monitoring
|
| 40 |
+
def check_system_resources(logger):
|
| 41 |
+
"""Monitor system resources and log warnings"""
|
| 42 |
+
memory = psutil.virtual_memory()
|
| 43 |
+
cpu = psutil.cpu_percent(interval=1)
|
| 44 |
+
disk = psutil.disk_usage("/")
|
| 45 |
+
|
| 46 |
+
logger.info(f"[System] 🔍 System Resources - RAM: {memory.percent}%, CPU: {cpu}%, Disk: {disk.percent}%")
|
| 47 |
+
|
| 48 |
+
if memory.percent > 85:
|
| 49 |
+
logger.warning("⚠️ High RAM usage detected!")
|
| 50 |
+
if cpu > 90:
|
| 51 |
+
logger.warning("⚠️ High CPU usage detected!")
|
| 52 |
+
if disk.percent > 90:
|
| 53 |
+
logger.warning("⚠️ High Disk usage detected!")
|
| 54 |
+
|
| 55 |
+
# ✅ Memory Optimization
|
| 56 |
+
def optimize_memory():
|
| 57 |
+
"""Set environment variables for memory optimization"""
|
| 58 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 59 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 60 |
+
|
| 61 |
+
# ✅ CORS Configuration
|
| 62 |
+
CORS_ORIGINS = [
|
| 63 |
+
"http://localhost:5173", # Vite dev server
|
| 64 |
+
"http://localhost:3000", # Another vercel local dev
|
| 65 |
+
"https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
# ✅ Model Configuration
|
| 69 |
+
MODEL_CACHE_DIR = "/app/model_cache"
|
| 70 |
+
EMBEDDING_MODEL_DEVICE = "cpu"
|
api/database.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/database.py
|
| 2 |
+
import faiss
|
| 3 |
+
import numpy as np
|
| 4 |
+
import gridfs
|
| 5 |
+
from pymongo import MongoClient
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from api.config import mongo_uri, index_uri, MODEL_CACHE_DIR, EMBEDDING_MODEL_DEVICE
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("medical-chatbot")
|
| 11 |
+
|
| 12 |
+
class DatabaseManager:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.embedding_model = None
|
| 15 |
+
self.index = None
|
| 16 |
+
self.symptom_vectors = None
|
| 17 |
+
self.symptom_docs = None
|
| 18 |
+
|
| 19 |
+
# MongoDB connections
|
| 20 |
+
self.client = None
|
| 21 |
+
self.iclient = None
|
| 22 |
+
self.symptom_client = None
|
| 23 |
+
|
| 24 |
+
# Collections
|
| 25 |
+
self.qa_collection = None
|
| 26 |
+
self.index_collection = None
|
| 27 |
+
self.symptom_col = None
|
| 28 |
+
self.fs = None
|
| 29 |
+
|
| 30 |
+
def initialize_embedding_model(self):
|
| 31 |
+
"""Initialize the SentenceTransformer model"""
|
| 32 |
+
logger.info("[Embedder] 📥 Loading SentenceTransformer Model...")
|
| 33 |
+
try:
|
| 34 |
+
self.embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device=EMBEDDING_MODEL_DEVICE)
|
| 35 |
+
self.embedding_model = self.embedding_model.half() # Reduce memory
|
| 36 |
+
logger.info("✅ Model Loaded Successfully.")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"❌ Model Loading Failed: {e}")
|
| 39 |
+
raise
|
| 40 |
+
|
| 41 |
+
def initialize_mongodb(self):
|
| 42 |
+
"""Initialize MongoDB connections and collections"""
|
| 43 |
+
# QA data
|
| 44 |
+
self.client = MongoClient(mongo_uri)
|
| 45 |
+
db = self.client["MedicalChatbotDB"]
|
| 46 |
+
self.qa_collection = db["qa_data"]
|
| 47 |
+
|
| 48 |
+
# FAISS Index data
|
| 49 |
+
self.iclient = MongoClient(index_uri)
|
| 50 |
+
idb = self.iclient["MedicalChatbotDB"]
|
| 51 |
+
self.index_collection = idb["faiss_index_files"]
|
| 52 |
+
|
| 53 |
+
# Symptom Diagnosis data
|
| 54 |
+
self.symptom_client = MongoClient(mongo_uri)
|
| 55 |
+
self.symptom_col = self.symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
|
| 56 |
+
|
| 57 |
+
# GridFS for FAISS index
|
| 58 |
+
self.fs = gridfs.GridFS(idb, collection="faiss_index_files")
|
| 59 |
+
|
| 60 |
+
def load_faiss_index(self):
|
| 61 |
+
"""Lazy load FAISS index from GridFS"""
|
| 62 |
+
if self.index is None:
|
| 63 |
+
logger.info("[KB] ⏳ Loading FAISS index from GridFS...")
|
| 64 |
+
existing_file = self.fs.find_one({"filename": "faiss_index.bin"})
|
| 65 |
+
if existing_file:
|
| 66 |
+
stored_index_bytes = existing_file.read()
|
| 67 |
+
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
|
| 68 |
+
self.index = faiss.deserialize_index(index_bytes_np)
|
| 69 |
+
logger.info("[KB] ✅ FAISS Index Loaded")
|
| 70 |
+
else:
|
| 71 |
+
logger.error("[KB] ❌ FAISS index not found in GridFS.")
|
| 72 |
+
return self.index
|
| 73 |
+
|
| 74 |
+
def load_symptom_vectors(self):
|
| 75 |
+
"""Lazy load symptom vectors for diagnosis"""
|
| 76 |
+
if self.symptom_vectors is None:
|
| 77 |
+
all_docs = list(self.symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1}))
|
| 78 |
+
self.symptom_docs = all_docs
|
| 79 |
+
self.symptom_vectors = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
|
| 80 |
+
|
| 81 |
+
def get_embedding_model(self):
|
| 82 |
+
"""Get the embedding model"""
|
| 83 |
+
if self.embedding_model is None:
|
| 84 |
+
self.initialize_embedding_model()
|
| 85 |
+
return self.embedding_model
|
| 86 |
+
|
| 87 |
+
def get_qa_collection(self):
|
| 88 |
+
"""Get QA collection"""
|
| 89 |
+
if self.qa_collection is None:
|
| 90 |
+
self.initialize_mongodb()
|
| 91 |
+
return self.qa_collection
|
| 92 |
+
|
| 93 |
+
def get_symptom_collection(self):
|
| 94 |
+
"""Get symptom collection"""
|
| 95 |
+
if self.symptom_col is None:
|
| 96 |
+
self.initialize_mongodb()
|
| 97 |
+
return self.symptom_col
|
| 98 |
+
|
| 99 |
+
# Global database manager instance
|
| 100 |
+
db_manager = DatabaseManager()
|
app.py → api/legacy.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# app.py
|
| 2 |
-
import os
|
|
|
|
| 3 |
import faiss
|
| 4 |
import numpy as np
|
| 5 |
import time
|
|
@@ -11,10 +12,9 @@ from google import genai
|
|
| 11 |
from sentence_transformers import SentenceTransformer
|
| 12 |
from sentence_transformers.util import cos_sim
|
| 13 |
from memory import MemoryManager
|
| 14 |
-
from
|
| 15 |
-
from vlm import process_medical_image
|
| 16 |
from search import search_web
|
| 17 |
-
from
|
| 18 |
|
| 19 |
# ✅ Enable Logging for Debugging
|
| 20 |
import logging
|
|
@@ -239,14 +239,15 @@ class RAGMedicalChatbot:
|
|
| 239 |
search_context = ""
|
| 240 |
url_mapping = {}
|
| 241 |
if search_mode:
|
| 242 |
-
logger.info("[SEARCH] Starting web search mode")
|
| 243 |
try:
|
| 244 |
-
# Search the web
|
| 245 |
-
search_results = search_web(user_query, num_results=
|
| 246 |
if search_results:
|
|
|
|
| 247 |
# Process with Llama
|
| 248 |
search_context, url_mapping = process_search_query(user_query, search_results)
|
| 249 |
-
logger.info(f"[SEARCH]
|
| 250 |
else:
|
| 251 |
logger.warning("[SEARCH] No search results found")
|
| 252 |
except Exception as e:
|
|
@@ -306,17 +307,22 @@ class RAGMedicalChatbot:
|
|
| 306 |
|
| 307 |
# Find all citation tags like <#1>, <#2>, etc.
|
| 308 |
citation_pattern = r'<#(\d+)>'
|
|
|
|
| 309 |
|
| 310 |
def replace_citation(match):
|
| 311 |
doc_id = int(match.group(1))
|
| 312 |
if doc_id in url_mapping:
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
# Replace citations with URLs
|
| 317 |
processed_response = re.sub(citation_pattern, replace_citation, response)
|
| 318 |
|
| 319 |
-
logger.info(f"[CITATION] Processed citations,
|
| 320 |
return processed_response
|
| 321 |
|
| 322 |
# ✅ Initialize Chatbot
|
|
|
|
| 1 |
# app.py
|
| 2 |
+
import os, json, re
|
| 3 |
+
from typing import Dict
|
| 4 |
import faiss
|
| 5 |
import numpy as np
|
| 6 |
import time
|
|
|
|
| 12 |
from sentence_transformers import SentenceTransformer
|
| 13 |
from sentence_transformers.util import cos_sim
|
| 14 |
from memory import MemoryManager
|
| 15 |
+
from utils import translate_query, process_medical_image, retrieve_diagnosis_from_symptoms
|
|
|
|
| 16 |
from search import search_web
|
| 17 |
+
from models import process_search_query
|
| 18 |
|
| 19 |
# ✅ Enable Logging for Debugging
|
| 20 |
import logging
|
|
|
|
| 239 |
search_context = ""
|
| 240 |
url_mapping = {}
|
| 241 |
if search_mode:
|
| 242 |
+
logger.info(f"[SEARCH] Starting web search mode for query: {user_query}")
|
| 243 |
try:
|
| 244 |
+
# Search the web with max 10 resources
|
| 245 |
+
search_results = search_web(user_query, num_results=10)
|
| 246 |
if search_results:
|
| 247 |
+
logger.info(f"[SEARCH] Retrieved {len(search_results)} web resources")
|
| 248 |
# Process with Llama
|
| 249 |
search_context, url_mapping = process_search_query(user_query, search_results)
|
| 250 |
+
logger.info(f"[SEARCH] Processed with Llama, generated {len(url_mapping)} URL mappings")
|
| 251 |
else:
|
| 252 |
logger.warning("[SEARCH] No search results found")
|
| 253 |
except Exception as e:
|
|
|
|
| 307 |
|
| 308 |
# Find all citation tags like <#1>, <#2>, etc.
|
| 309 |
citation_pattern = r'<#(\d+)>'
|
| 310 |
+
citations_found = re.findall(citation_pattern, response)
|
| 311 |
|
| 312 |
def replace_citation(match):
|
| 313 |
doc_id = int(match.group(1))
|
| 314 |
if doc_id in url_mapping:
|
| 315 |
+
url = url_mapping[doc_id]
|
| 316 |
+
logger.info(f"[CITATION] Replacing <#{doc_id}> with {url}")
|
| 317 |
+
return f'<{url}>'
|
| 318 |
+
else:
|
| 319 |
+
logger.warning(f"[CITATION] No URL mapping found for document ID {doc_id}")
|
| 320 |
+
return match.group(0) # Keep original if URL not found
|
| 321 |
|
| 322 |
# Replace citations with URLs
|
| 323 |
processed_response = re.sub(citation_pattern, replace_citation, response)
|
| 324 |
|
| 325 |
+
logger.info(f"[CITATION] Processed {len(citations_found)} citations, {len(url_mapping)} URL mappings available")
|
| 326 |
return processed_response
|
| 327 |
|
| 328 |
# ✅ Initialize Chatbot
|
api/retrieval.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/retrieval.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
from api.database import db_manager
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger("medical-chatbot")
|
| 7 |
+
|
| 8 |
+
class RetrievalEngine:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.db_manager = db_manager
|
| 11 |
+
|
| 12 |
+
def retrieve_medical_info(self, query: str, k: int = 5, min_sim: float = 0.9) -> list:
|
| 13 |
+
"""
|
| 14 |
+
Retrieve medical information from FAISS index
|
| 15 |
+
Min similarity between query and kb is to be 80%
|
| 16 |
+
"""
|
| 17 |
+
index = self.db_manager.load_faiss_index()
|
| 18 |
+
if index is None:
|
| 19 |
+
return [""]
|
| 20 |
+
|
| 21 |
+
embedding_model = self.db_manager.get_embedding_model()
|
| 22 |
+
qa_collection = self.db_manager.get_qa_collection()
|
| 23 |
+
|
| 24 |
+
# Embed query
|
| 25 |
+
query_vec = embedding_model.encode([query], convert_to_numpy=True)
|
| 26 |
+
D, I = index.search(query_vec, k=k)
|
| 27 |
+
|
| 28 |
+
# Filter by cosine threshold
|
| 29 |
+
results = []
|
| 30 |
+
kept = []
|
| 31 |
+
kept_vecs = []
|
| 32 |
+
|
| 33 |
+
# Smart dedup on cosine threshold between similar candidates
|
| 34 |
+
for score, idx in zip(D[0], I[0]):
|
| 35 |
+
if score < min_sim:
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
# List sim docs
|
| 39 |
+
doc = qa_collection.find_one({"i": int(idx)})
|
| 40 |
+
if not doc:
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
# Only compare answers
|
| 44 |
+
answer = doc.get("Doctor", "").strip()
|
| 45 |
+
if not answer:
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
# Check semantic redundancy among previously kept results
|
| 49 |
+
new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0]
|
| 50 |
+
is_similar = False
|
| 51 |
+
|
| 52 |
+
for i, vec in enumerate(kept_vecs):
|
| 53 |
+
sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9)
|
| 54 |
+
if sim >= 0.9: # High semantic similarity
|
| 55 |
+
is_similar = True
|
| 56 |
+
# Keep only better match to original query
|
| 57 |
+
cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9)
|
| 58 |
+
new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9)
|
| 59 |
+
if new_sim_to_query > cur_sim_to_query:
|
| 60 |
+
kept[i] = answer
|
| 61 |
+
kept_vecs[i] = new_vec
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
# Non-similar candidates
|
| 65 |
+
if not is_similar:
|
| 66 |
+
kept.append(answer)
|
| 67 |
+
kept_vecs.append(new_vec)
|
| 68 |
+
|
| 69 |
+
return kept if kept else [""]
|
| 70 |
+
|
| 71 |
+
def retrieve_diagnosis_from_symptoms(self, symptom_text: str, top_k: int = 5, min_sim: float = 0.5) -> list:
|
| 72 |
+
"""
|
| 73 |
+
Retrieve diagnosis information from symptom vectors
|
| 74 |
+
"""
|
| 75 |
+
self.db_manager.load_symptom_vectors()
|
| 76 |
+
embedding_model = self.db_manager.get_embedding_model()
|
| 77 |
+
|
| 78 |
+
# Embed input
|
| 79 |
+
qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
|
| 80 |
+
qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
|
| 81 |
+
|
| 82 |
+
# Similarity compute
|
| 83 |
+
sims = self.db_manager.symptom_vectors @ qvec # cosine
|
| 84 |
+
sorted_idx = np.argsort(sims)[-top_k:][::-1]
|
| 85 |
+
seen_diag = set()
|
| 86 |
+
final = [] # Dedup
|
| 87 |
+
|
| 88 |
+
for i in sorted_idx:
|
| 89 |
+
sim = sims[i]
|
| 90 |
+
if sim < min_sim:
|
| 91 |
+
continue
|
| 92 |
+
label = self.db_manager.symptom_docs[i]["prognosis"]
|
| 93 |
+
if label not in seen_diag:
|
| 94 |
+
final.append(self.db_manager.symptom_docs[i]["answer"])
|
| 95 |
+
seen_diag.add(label)
|
| 96 |
+
|
| 97 |
+
return final
|
| 98 |
+
|
| 99 |
+
# Global retrieval engine instance
|
| 100 |
+
retrieval_engine = RetrievalEngine()
|
api/routes.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api/routes.py
|
| 2 |
+
import time
|
| 3 |
+
import logging
|
| 4 |
+
from fastapi import APIRouter, Request
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
from api.chatbot import RAGMedicalChatbot
|
| 7 |
+
from api.retrieval import retrieval_engine
|
| 8 |
+
from utils import process_medical_image
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("medical-chatbot")
|
| 11 |
+
|
| 12 |
+
# Create router
|
| 13 |
+
router = APIRouter()
|
| 14 |
+
|
| 15 |
+
# Initialize chatbot
|
| 16 |
+
chatbot = RAGMedicalChatbot(
|
| 17 |
+
model_name="gemini-2.5-flash",
|
| 18 |
+
retrieve_function=retrieval_engine.retrieve_medical_info
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
@router.post("/chat")
|
| 22 |
+
async def chat_endpoint(req: Request):
|
| 23 |
+
"""Main chat endpoint with search mode support"""
|
| 24 |
+
body = await req.json()
|
| 25 |
+
user_id = body.get("user_id", "anonymous")
|
| 26 |
+
query_raw = body.get("query")
|
| 27 |
+
query = query_raw.strip() if isinstance(query_raw, str) else ""
|
| 28 |
+
lang = body.get("lang", "EN")
|
| 29 |
+
search_mode = body.get("search", False)
|
| 30 |
+
image_base64 = body.get("image_base64", None)
|
| 31 |
+
img_desc = body.get("img_desc", "Describe and investigate any clinical findings from this medical image.")
|
| 32 |
+
|
| 33 |
+
start = time.time()
|
| 34 |
+
image_diagnosis = ""
|
| 35 |
+
|
| 36 |
+
# LLM Only
|
| 37 |
+
if not image_base64:
|
| 38 |
+
logger.info(f"[BOT] LLM scenario. Search mode: {search_mode}")
|
| 39 |
+
# LLM+VLM
|
| 40 |
+
else:
|
| 41 |
+
# If image is present → diagnose first
|
| 42 |
+
safe_load = len(image_base64.encode("utf-8"))
|
| 43 |
+
if safe_load > 5_000_000: # Img size safe processor
|
| 44 |
+
return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."})
|
| 45 |
+
logger.info(f"[BOT] VLM+LLM scenario. Search mode: {search_mode}")
|
| 46 |
+
logger.info(f"[VLM] Process medical image size: {safe_load}, desc: {img_desc}, {lang}.")
|
| 47 |
+
image_diagnosis = process_medical_image(image_base64, img_desc, lang)
|
| 48 |
+
|
| 49 |
+
answer = chatbot.chat(user_id, query, lang, image_diagnosis, search_mode)
|
| 50 |
+
elapsed = time.time() - start
|
| 51 |
+
|
| 52 |
+
# Final
|
| 53 |
+
return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
|
| 54 |
+
|
| 55 |
+
@router.get("/health")
|
| 56 |
+
async def health_check():
|
| 57 |
+
"""Health check endpoint"""
|
| 58 |
+
return {"status": "healthy", "service": "medical-chatbot"}
|
| 59 |
+
|
| 60 |
+
@router.get("/")
|
| 61 |
+
async def root():
|
| 62 |
+
"""Root endpoint"""
|
| 63 |
+
return {"message": "Medical Chatbot API", "version": "1.0.0"}
|
main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py - Entry point for the Medical Chatbot API
|
| 2 |
+
from api.app import app
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
import uvicorn
|
| 6 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|
memory/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Memory package
|
| 2 |
+
from .memory import MemoryManager
|
memory.py → memory/memory.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
#
|
| 2 |
import re, time, hashlib, asyncio, os
|
| 3 |
from collections import defaultdict, deque
|
| 4 |
from typing import List, Dict
|
|
@@ -7,6 +7,7 @@ import faiss
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from google import genai # must be configured in app.py and imported globally
|
| 9 |
import logging
|
|
|
|
| 10 |
|
| 11 |
_LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17"
|
| 12 |
# Load embedding model
|
|
@@ -98,7 +99,7 @@ class MemoryManager:
|
|
| 98 |
|
| 99 |
def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
|
| 100 |
"""
|
| 101 |
-
Use
|
| 102 |
This ensures conversational continuity while providing a concise summary for the main LLM.
|
| 103 |
"""
|
| 104 |
# Get both types of context
|
|
@@ -112,7 +113,8 @@ class MemoryManager:
|
|
| 112 |
if not recent_history and not rag_chunks:
|
| 113 |
logger.info(f"[Contextual] No context found, returning empty string")
|
| 114 |
return ""
|
| 115 |
-
|
|
|
|
| 116 |
context_parts = []
|
| 117 |
# Add recent chat history
|
| 118 |
if recent_history:
|
|
@@ -126,301 +128,204 @@ class MemoryManager:
|
|
| 126 |
rag_text = "\n".join(rag_chunks)
|
| 127 |
context_parts.append(f"Semantically relevant historical medical information:\n{rag_text}")
|
| 128 |
|
| 129 |
-
#
|
| 130 |
-
|
| 131 |
-
You are a medical assistant creating a concise summary of conversation context for continuity.
|
| 132 |
-
|
| 133 |
-
Current user query: "{current_query}"
|
| 134 |
-
|
| 135 |
-
Available context information:
|
| 136 |
-
{chr(10).join(context_parts)}
|
| 137 |
-
|
| 138 |
-
Task: Create a brief, coherent summary that captures the key points from the conversation history and relevant medical information that are important for understanding the current query.
|
| 139 |
-
|
| 140 |
-
Guidelines:
|
| 141 |
-
1. Focus on medical symptoms, diagnoses, treatments, or recommendations mentioned
|
| 142 |
-
2. Include any patient concerns or questions that are still relevant
|
| 143 |
-
3. Highlight any follow-up needs or pending clarifications
|
| 144 |
-
4. Keep the summary concise but comprehensive enough for context
|
| 145 |
-
5. Maintain conversational flow and continuity
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
"""
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
logger.debug(f"[Contextual] Full prompt: {summarization_prompt}")
|
| 154 |
-
# Loop through the prompt and log the length of each part
|
| 155 |
try:
|
| 156 |
-
# Use
|
| 157 |
-
|
| 158 |
-
result = client.models.generate_content(
|
| 159 |
-
model=_LLM_SMALL,
|
| 160 |
-
contents=summarization_prompt
|
| 161 |
-
)
|
| 162 |
-
summary = result.text.strip()
|
| 163 |
-
if "No relevant context found" in summary:
|
| 164 |
-
logger.info(f"[Contextual] Gemini indicated no relevant context found")
|
| 165 |
-
return ""
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
except Exception as e:
|
| 171 |
-
logger.
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
if
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
final_fallback = " ".join(fallback_summary) if fallback_summary else ""
|
| 186 |
-
return final_fallback
|
| 187 |
|
| 188 |
-
def
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
# ----------
|
| 192 |
def _touch_user(self, user_id: str):
|
| 193 |
-
|
| 194 |
-
self._drop_user(self.user_queue.popleft())
|
| 195 |
if user_id in self.user_queue:
|
| 196 |
self.user_queue.remove(user_id)
|
| 197 |
self.user_queue.append(user_id)
|
| 198 |
|
| 199 |
-
def
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
self.chunk_meta.pop(user_id, None)
|
| 203 |
-
if user_id in self.user_queue:
|
| 204 |
-
self.user_queue.remove(user_id)
|
| 205 |
-
|
| 206 |
-
def _rebuild_index(self, user_id: str, keep_last: int):
|
| 207 |
-
"""Trim chunk list + rebuild FAISS index for user."""
|
| 208 |
-
self.chunk_meta[user_id] = self.chunk_meta[user_id][-keep_last:]
|
| 209 |
-
index = self._new_index()
|
| 210 |
-
# Store each chunk's vector once and reuse it.
|
| 211 |
-
for chunk in self.chunk_meta[user_id]:
|
| 212 |
-
index.add(np.array([chunk["vec"]]))
|
| 213 |
-
self.chunk_index[user_id] = index
|
| 214 |
-
|
| 215 |
-
@staticmethod
|
| 216 |
-
def _new_index():
|
| 217 |
-
# Use cosine similarity (vectors must be L2-normalised)
|
| 218 |
-
return faiss.IndexFlatIP(384)
|
| 219 |
-
|
| 220 |
-
@staticmethod
|
| 221 |
-
def _embed(text: str):
|
| 222 |
-
vec = EMBED.encode(text, convert_to_numpy=True)
|
| 223 |
-
# L2 normalise for cosine on IndexFlatIP
|
| 224 |
-
return vec / (np.linalg.norm(vec) + 1e-9)
|
| 225 |
-
|
| 226 |
-
def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
|
| 227 |
-
"""
|
| 228 |
-
Calls Gemini to:
|
| 229 |
-
- Translate (if needed)
|
| 230 |
-
- Chunk by context/topic (exclude disclaimer section)
|
| 231 |
-
- Summarise
|
| 232 |
-
Returns: [{"tag": ..., "text": ...}, ...]
|
| 233 |
-
"""
|
| 234 |
-
if not response: return []
|
| 235 |
-
# Gemini instruction
|
| 236 |
-
instructions = []
|
| 237 |
-
# if lang.upper() != "EN":
|
| 238 |
-
# instructions.append("- Translate the response to English.")
|
| 239 |
-
instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic, symptom, assessment, plan, or instruction (exclude disclaimer section).")
|
| 240 |
-
instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions and exact medication names/doses if present.")
|
| 241 |
-
instructions.append("- At the start of each part, write `Topic: <concise but specific sentence (10-20 words) capturing patient context, condition, and action>`.")
|
| 242 |
-
instructions.append("- Separate each part using three dashes `---` on a new line.")
|
| 243 |
-
# if lang.upper() != "EN":
|
| 244 |
-
# instructions.append(f"Below is the user-provided medical response written in `{lang}`")
|
| 245 |
-
# Gemini prompt
|
| 246 |
-
prompt = f"""
|
| 247 |
-
You are a medical assistant helping organize and condense a clinical response.
|
| 248 |
-
If helpful, use the user's latest question for context to craft specific topics.
|
| 249 |
-
User's latest question (context): {question}
|
| 250 |
-
------------------------
|
| 251 |
-
{response}
|
| 252 |
-
------------------------
|
| 253 |
-
Please perform the following tasks:
|
| 254 |
-
{chr(10).join(instructions)}
|
| 255 |
|
| 256 |
-
|
| 257 |
-
"""
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
| 266 |
)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
return [
|
| 270 |
-
{"tag": self._quick_extract_topic(chunk), "text": chunk.strip()}
|
| 271 |
-
for chunk in output.split('---') if chunk.strip()
|
| 272 |
-
]
|
| 273 |
-
except Exception as e:
|
| 274 |
-
logger.warning(f"[Memory] ❌ Gemini chunking failed: {e}")
|
| 275 |
-
retries += 1
|
| 276 |
-
time.sleep(0.5)
|
| 277 |
-
return [{"tag": "general", "text": response.strip()}] # fallback
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
for line in lines:
|
| 288 |
-
if len(line.split()) <= 8 and line.strip().endswith(":"):
|
| 289 |
-
return line.strip().rstrip(":")
|
| 290 |
-
return " ".join(chunk.split()[:3]).rstrip(":.,")
|
| 291 |
-
|
| 292 |
-
# ---------- New merging/dedup logic ----------
|
| 293 |
-
def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
|
| 294 |
-
"""Insert or merge a summarized chunk into STM with semantic dedup/merge.
|
| 295 |
-
Identical: replace the older with new. Partially similar: merge extra details from older into newer.
|
| 296 |
-
"""
|
| 297 |
-
topic = self._enrich_topic(chunk.get("tag", ""), chunk.get("text", ""))
|
| 298 |
-
text = chunk.get("text", "").strip()
|
| 299 |
-
vec = self._embed(text)
|
| 300 |
-
now = time.time()
|
| 301 |
-
entry = {"topic": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
|
| 302 |
-
stm = self.stm_summaries[user_id]
|
| 303 |
-
if not stm:
|
| 304 |
-
stm.append(entry)
|
| 305 |
-
return
|
| 306 |
-
# find best match
|
| 307 |
-
best_idx = -1
|
| 308 |
-
best_sim = -1.0
|
| 309 |
-
for i, e in enumerate(stm):
|
| 310 |
-
sim = float(np.dot(vec, e["vec"]))
|
| 311 |
-
if sim > best_sim:
|
| 312 |
-
best_sim = sim
|
| 313 |
-
best_idx = i
|
| 314 |
-
if best_sim >= 0.92: # nearly identical
|
| 315 |
-
# replace older with current
|
| 316 |
-
stm.rotate(-best_idx)
|
| 317 |
-
stm.popleft()
|
| 318 |
-
stm.rotate(best_idx)
|
| 319 |
-
stm.append(entry)
|
| 320 |
-
elif best_sim >= 0.75: # partially similar → merge
|
| 321 |
-
base = stm[best_idx]
|
| 322 |
-
merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add bits from old not in new
|
| 323 |
-
merged_topic = base["topic"] if len(base["topic"]) > len(topic) else topic
|
| 324 |
-
merged_vec = self._embed(merged_text)
|
| 325 |
-
merged_entry = {"topic": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
|
| 326 |
-
stm.rotate(-best_idx)
|
| 327 |
-
stm.popleft()
|
| 328 |
-
stm.rotate(best_idx)
|
| 329 |
-
stm.append(merged_entry)
|
| 330 |
-
else:
|
| 331 |
-
stm.append(entry)
|
| 332 |
|
| 333 |
def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
|
| 334 |
-
"""
|
| 335 |
-
Keeps only the most recent self.max_chunks entries.
|
| 336 |
-
"""
|
| 337 |
-
current_list = self.chunk_meta[user_id]
|
| 338 |
for chunk in chunks:
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
best_sim = -1.0
|
| 352 |
-
for i, e in enumerate(current_list):
|
| 353 |
-
sim = float(np.dot(vec, e["vec"]))
|
| 354 |
-
if sim > best_sim:
|
| 355 |
-
best_sim = sim
|
| 356 |
-
best_idx = i
|
| 357 |
-
if best_sim >= 0.92:
|
| 358 |
-
# replace older with new
|
| 359 |
-
current_list[best_idx] = new_entry
|
| 360 |
-
elif best_sim >= 0.75:
|
| 361 |
-
# merge details
|
| 362 |
-
base = current_list[best_idx]
|
| 363 |
-
merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add unique sentences from old
|
| 364 |
-
merged_topic = base["tag"] if len(base["tag"]) > len(topic) else topic
|
| 365 |
-
merged_vec = self._embed(merged_text)
|
| 366 |
-
current_list[best_idx] = {"tag": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
|
| 367 |
else:
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
-
def
|
| 381 |
-
"""
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
s_norm = s.lower()
|
| 388 |
-
# consider present if significant overlap with any existing sentence
|
| 389 |
-
if s_norm in new_set:
|
| 390 |
-
continue
|
| 391 |
-
# simple containment check
|
| 392 |
-
if any(self._overlap_ratio(s_norm, t.lower()) > 0.8 for t in merged):
|
| 393 |
-
continue
|
| 394 |
-
merged.append(s)
|
| 395 |
-
return " ".join(merged)
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
first = " ".join(words[:16])
|
| 422 |
-
# ensure capitalized
|
| 423 |
-
return first.strip().rstrip(':')
|
| 424 |
-
return topic
|
| 425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# memory_updated.py
|
| 2 |
import re, time, hashlib, asyncio, os
|
| 3 |
from collections import defaultdict, deque
|
| 4 |
from typing import List, Dict
|
|
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from google import genai # must be configured in app.py and imported globally
|
| 9 |
import logging
|
| 10 |
+
from summarizer import summarizer
|
| 11 |
|
| 12 |
_LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17"
|
| 13 |
# Load embedding model
|
|
|
|
| 99 |
|
| 100 |
def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
|
| 101 |
"""
|
| 102 |
+
Use NVIDIA Llama to create a summarization of relevant context from both recent history and RAG chunks.
|
| 103 |
This ensures conversational continuity while providing a concise summary for the main LLM.
|
| 104 |
"""
|
| 105 |
# Get both types of context
|
|
|
|
| 113 |
if not recent_history and not rag_chunks:
|
| 114 |
logger.info(f"[Contextual] No context found, returning empty string")
|
| 115 |
return ""
|
| 116 |
+
|
| 117 |
+
# Prepare context for summarization
|
| 118 |
context_parts = []
|
| 119 |
# Add recent chat history
|
| 120 |
if recent_history:
|
|
|
|
| 128 |
rag_text = "\n".join(rag_chunks)
|
| 129 |
context_parts.append(f"Semantically relevant historical medical information:\n{rag_text}")
|
| 130 |
|
| 131 |
+
# Combine all context
|
| 132 |
+
full_context = "\n\n".join(context_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
# Use summarizer to create concise summary
|
| 135 |
+
try:
|
| 136 |
+
summary = summarizer.summarize_text(full_context, max_length=300)
|
| 137 |
+
logger.info(f"[Contextual] Generated summary using NVIDIA Llama: {len(summary)} characters")
|
| 138 |
+
return summary
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"[Contextual] Summarization failed: {e}")
|
| 141 |
+
return full_context[:500] + "..." if len(full_context) > 500 else full_context
|
| 142 |
+
|
| 143 |
+
def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
|
| 144 |
+
"""
|
| 145 |
+
Use NVIDIA Llama to chunk and summarize response by medical topics.
|
| 146 |
+
Returns: [{"tag": ..., "text": ...}, ...]
|
| 147 |
"""
|
| 148 |
+
if not response:
|
| 149 |
+
return []
|
| 150 |
|
|
|
|
|
|
|
| 151 |
try:
|
| 152 |
+
# Use summarizer to chunk and summarize
|
| 153 |
+
chunks = summarizer.chunk_response(response, max_chunk_size=500)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
+
# Convert to the expected format
|
| 156 |
+
result_chunks = []
|
| 157 |
+
for i, chunk in enumerate(chunks):
|
| 158 |
+
# Extract topic from chunk (first sentence or key medical terms)
|
| 159 |
+
topic = self._extract_topic_from_chunk(chunk)
|
| 160 |
+
|
| 161 |
+
result_chunks.append({
|
| 162 |
+
"tag": topic,
|
| 163 |
+
"text": chunk
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
logger.info(f"[Memory] 📦 NVIDIA Llama summarized {len(result_chunks)} chunks")
|
| 167 |
+
return result_chunks
|
| 168 |
|
| 169 |
except Exception as e:
|
| 170 |
+
logger.error(f"[Memory] NVIDIA Llama chunking failed: {e}")
|
| 171 |
+
# Fallback to simple chunking
|
| 172 |
+
return self._fallback_chunking(response)
|
| 173 |
+
|
| 174 |
+
def _extract_topic_from_chunk(self, chunk: str) -> str:
|
| 175 |
+
"""Extract a concise topic from a chunk"""
|
| 176 |
+
# Look for medical terms or first sentence
|
| 177 |
+
sentences = chunk.split('.')
|
| 178 |
+
if sentences:
|
| 179 |
+
first_sentence = sentences[0].strip()
|
| 180 |
+
if len(first_sentence) > 50:
|
| 181 |
+
first_sentence = first_sentence[:50] + "..."
|
| 182 |
+
return first_sentence
|
| 183 |
+
return "Medical Information"
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
def _fallback_chunking(self, response: str) -> List[Dict]:
|
| 186 |
+
"""Fallback chunking when NVIDIA Llama fails"""
|
| 187 |
+
# Simple sentence-based chunking
|
| 188 |
+
sentences = re.split(r'[.!?]+', response)
|
| 189 |
+
chunks = []
|
| 190 |
+
current_chunk = ""
|
| 191 |
+
|
| 192 |
+
for sentence in sentences:
|
| 193 |
+
sentence = sentence.strip()
|
| 194 |
+
if not sentence:
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
if len(current_chunk) + len(sentence) > 300:
|
| 198 |
+
if current_chunk:
|
| 199 |
+
chunks.append({
|
| 200 |
+
"tag": "Medical Information",
|
| 201 |
+
"text": current_chunk.strip()
|
| 202 |
+
})
|
| 203 |
+
current_chunk = sentence
|
| 204 |
+
else:
|
| 205 |
+
current_chunk += sentence + ". "
|
| 206 |
+
|
| 207 |
+
if current_chunk:
|
| 208 |
+
chunks.append({
|
| 209 |
+
"tag": "Medical Information",
|
| 210 |
+
"text": current_chunk.strip()
|
| 211 |
+
})
|
| 212 |
+
|
| 213 |
+
return chunks
|
| 214 |
|
| 215 |
+
# ---------- Private Methods ----------
|
| 216 |
def _touch_user(self, user_id: str):
|
| 217 |
+
"""Update LRU queue"""
|
|
|
|
| 218 |
if user_id in self.user_queue:
|
| 219 |
self.user_queue.remove(user_id)
|
| 220 |
self.user_queue.append(user_id)
|
| 221 |
|
| 222 |
+
def _new_index(self):
|
| 223 |
+
"""Create new FAISS index"""
|
| 224 |
+
return faiss.IndexFlatIP(384) # 384-dim embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
|
| 227 |
+
"""Update short-term memory with merging/deduplication"""
|
| 228 |
+
topic = chunk["tag"]
|
| 229 |
+
text = chunk["text"]
|
| 230 |
+
|
| 231 |
+
# Check for similar topics in STM
|
| 232 |
+
for entry in self.stm_summaries[user_id]:
|
| 233 |
+
if self._topics_similar(topic, entry["topic"]):
|
| 234 |
+
# Merge with existing entry
|
| 235 |
+
entry["text"] = summarizer.summarize_text(
|
| 236 |
+
f"{entry['text']}\n{text}",
|
| 237 |
+
max_length=200
|
| 238 |
)
|
| 239 |
+
entry["timestamp"] = time.time()
|
| 240 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
+
# Add new entry
|
| 243 |
+
self.stm_summaries[user_id].append({
|
| 244 |
+
"topic": topic,
|
| 245 |
+
"text": text,
|
| 246 |
+
"vec": self._embed(f"{topic} {text}"),
|
| 247 |
+
"timestamp": time.time(),
|
| 248 |
+
"used": 0
|
| 249 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
|
| 252 |
+
"""Update long-term memory with merging/deduplication"""
|
|
|
|
|
|
|
|
|
|
| 253 |
for chunk in chunks:
|
| 254 |
+
# Check for similar chunks in LTM
|
| 255 |
+
similar_idx = self._find_similar_chunk(user_id, chunk["text"])
|
| 256 |
+
|
| 257 |
+
if similar_idx is not None:
|
| 258 |
+
# Merge with existing chunk
|
| 259 |
+
existing = self.chunk_meta[user_id][similar_idx]
|
| 260 |
+
merged_text = summarizer.summarize_text(
|
| 261 |
+
f"{existing['text']}\n{chunk['text']}",
|
| 262 |
+
max_length=300
|
| 263 |
+
)
|
| 264 |
+
existing["text"] = merged_text
|
| 265 |
+
existing["timestamp"] = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
else:
|
| 267 |
+
# Add new chunk
|
| 268 |
+
if len(self.chunk_meta[user_id]) >= self.max_chunks:
|
| 269 |
+
# Remove oldest chunk
|
| 270 |
+
self._remove_oldest_chunk(user_id)
|
| 271 |
+
|
| 272 |
+
vec = self._embed(chunk["text"])
|
| 273 |
+
self.chunk_index[user_id].add(np.array([vec]))
|
| 274 |
+
self.chunk_meta[user_id].append({
|
| 275 |
+
"text": chunk["text"],
|
| 276 |
+
"tag": chunk["tag"],
|
| 277 |
+
"vec": vec,
|
| 278 |
+
"timestamp": time.time(),
|
| 279 |
+
"used": 0
|
| 280 |
+
})
|
| 281 |
|
| 282 |
+
def _topics_similar(self, topic1: str, topic2: str) -> bool:
|
| 283 |
+
"""Check if two topics are similar"""
|
| 284 |
+
# Simple similarity check based on common words
|
| 285 |
+
words1 = set(topic1.lower().split())
|
| 286 |
+
words2 = set(topic2.lower().split())
|
| 287 |
+
intersection = words1.intersection(words2)
|
| 288 |
+
return len(intersection) >= 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
+
def _find_similar_chunk(self, user_id: str, text: str) -> int:
|
| 291 |
+
"""Find similar chunk in LTM"""
|
| 292 |
+
if not self.chunk_meta[user_id]:
|
| 293 |
+
return None
|
| 294 |
+
|
| 295 |
+
text_vec = self._embed(text)
|
| 296 |
+
sims, idxs = self.chunk_index[user_id].search(np.array([text_vec]), k=3)
|
| 297 |
+
|
| 298 |
+
for sim, idx in zip(sims[0], idxs[0]):
|
| 299 |
+
if sim > 0.8: # High similarity threshold
|
| 300 |
+
return int(idx)
|
| 301 |
+
return None
|
| 302 |
|
| 303 |
+
def _remove_oldest_chunk(self, user_id: str):
|
| 304 |
+
"""Remove the oldest chunk from LTM"""
|
| 305 |
+
if not self.chunk_meta[user_id]:
|
| 306 |
+
return
|
| 307 |
+
|
| 308 |
+
# Find oldest chunk
|
| 309 |
+
oldest_idx = min(range(len(self.chunk_meta[user_id])),
|
| 310 |
+
key=lambda i: self.chunk_meta[user_id][i]["timestamp"])
|
| 311 |
+
|
| 312 |
+
# Remove from index and metadata
|
| 313 |
+
self.chunk_meta[user_id].pop(oldest_idx)
|
| 314 |
+
# Note: FAISS doesn't support direct removal, so we rebuild the index
|
| 315 |
+
self._rebuild_index(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
+
def _rebuild_index(self, user_id: str):
|
| 318 |
+
"""Rebuild FAISS index after removal"""
|
| 319 |
+
if not self.chunk_meta[user_id]:
|
| 320 |
+
self.chunk_index[user_id] = self._new_index()
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
vectors = [chunk["vec"] for chunk in self.chunk_meta[user_id]]
|
| 324 |
+
self.chunk_index[user_id] = self._new_index()
|
| 325 |
+
self.chunk_index[user_id].add(np.array(vectors))
|
| 326 |
|
| 327 |
+
@staticmethod
|
| 328 |
+
def _embed(text: str):
|
| 329 |
+
vec = EMBED.encode(text, convert_to_numpy=True)
|
| 330 |
+
# L2 normalise for cosine on IndexFlatIP
|
| 331 |
+
return vec / (np.linalg.norm(vec) + 1e-9)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Models package
|
| 2 |
+
from .llama import NVIDIALLamaClient, process_search_query
|
| 3 |
+
from .summarizer import TextSummarizer, summarizer
|
download_model.py → models/download_model.py
RENAMED
|
File without changes
|
llama_integration.py → models/llama.py
RENAMED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import requests
|
| 3 |
import json
|
| 4 |
import logging
|
|
|
|
| 5 |
from typing import List, Dict, Tuple
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
|
@@ -40,27 +41,11 @@ Keywords:"""
|
|
| 40 |
def summarize_documents(self, documents: List[Dict], user_query: str) -> Tuple[str, Dict[int, str]]:
|
| 41 |
"""Use Llama to summarize documents and return summary with URL mapping"""
|
| 42 |
try:
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
url_mapping = {}
|
| 46 |
|
| 47 |
-
for
|
| 48 |
-
|
| 49 |
-
url_mapping[doc_id] = doc['url']
|
| 50 |
-
|
| 51 |
-
# Create a summary prompt for each document
|
| 52 |
-
summary_prompt = f"""Summarize this medical information in 2-3 sentences, focusing on details relevant to: "{user_query}"
|
| 53 |
-
|
| 54 |
-
Document: {doc['title']}
|
| 55 |
-
Content: {doc['content'][:1000]}...
|
| 56 |
-
|
| 57 |
-
Summary:"""
|
| 58 |
-
|
| 59 |
-
summary = self._call_llama(summary_prompt)
|
| 60 |
-
doc_summaries.append(f"Document {doc_id}: {summary}")
|
| 61 |
-
|
| 62 |
-
# Combine all summaries
|
| 63 |
-
combined_summary = "\n\n".join(doc_summaries)
|
| 64 |
|
| 65 |
return combined_summary, url_mapping
|
| 66 |
|
|
@@ -68,41 +53,58 @@ Summary:"""
|
|
| 68 |
logger.error(f"Failed to summarize documents: {e}")
|
| 69 |
return "", {}
|
| 70 |
|
| 71 |
-
def _call_llama(self, prompt: str) -> str:
|
| 72 |
-
"""Make API call to NVIDIA Llama model"""
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
def process_search_query(user_query: str, search_results: List[Dict]) -> Tuple[str, Dict[int, str]]:
|
| 108 |
"""Process search results using Llama model"""
|
|
|
|
| 2 |
import requests
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
+
import time
|
| 6 |
from typing import List, Dict, Tuple
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
|
|
|
| 41 |
def summarize_documents(self, documents: List[Dict], user_query: str) -> Tuple[str, Dict[int, str]]:
|
| 42 |
"""Use Llama to summarize documents and return summary with URL mapping"""
|
| 43 |
try:
|
| 44 |
+
# Import summarizer here to avoid circular imports
|
| 45 |
+
from summarizer import summarizer
|
|
|
|
| 46 |
|
| 47 |
+
# Use the summarizer for document summarization
|
| 48 |
+
combined_summary, url_mapping = summarizer.summarize_documents(documents, user_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
return combined_summary, url_mapping
|
| 51 |
|
|
|
|
| 53 |
logger.error(f"Failed to summarize documents: {e}")
|
| 54 |
return "", {}
|
| 55 |
|
| 56 |
+
def _call_llama(self, prompt: str, max_retries: int = 3) -> str:
|
| 57 |
+
"""Make API call to NVIDIA Llama model with retry logic"""
|
| 58 |
+
for attempt in range(max_retries):
|
| 59 |
+
try:
|
| 60 |
+
headers = {
|
| 61 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 62 |
+
"Content-Type": "application/json"
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
payload = {
|
| 66 |
+
"model": self.model,
|
| 67 |
+
"messages": [
|
| 68 |
+
{
|
| 69 |
+
"role": "user",
|
| 70 |
+
"content": prompt
|
| 71 |
+
}
|
| 72 |
+
],
|
| 73 |
+
"temperature": 0.7,
|
| 74 |
+
"max_tokens": 1000
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
response = requests.post(
|
| 78 |
+
self.base_url,
|
| 79 |
+
headers=headers,
|
| 80 |
+
json=payload,
|
| 81 |
+
timeout=30
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
response.raise_for_status()
|
| 85 |
+
result = response.json()
|
| 86 |
+
|
| 87 |
+
content = result['choices'][0]['message']['content'].strip()
|
| 88 |
+
if not content:
|
| 89 |
+
raise ValueError("Empty response from Llama API")
|
| 90 |
+
|
| 91 |
+
return content
|
| 92 |
+
|
| 93 |
+
except requests.exceptions.Timeout:
|
| 94 |
+
logger.warning(f"Llama API timeout (attempt {attempt + 1}/{max_retries})")
|
| 95 |
+
if attempt == max_retries - 1:
|
| 96 |
+
raise
|
| 97 |
+
time.sleep(2 ** attempt) # Exponential backoff
|
| 98 |
+
|
| 99 |
+
except requests.exceptions.RequestException as e:
|
| 100 |
+
logger.warning(f"Llama API request failed (attempt {attempt + 1}/{max_retries}): {e}")
|
| 101 |
+
if attempt == max_retries - 1:
|
| 102 |
+
raise
|
| 103 |
+
time.sleep(2 ** attempt)
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"Llama API call failed: {e}")
|
| 107 |
+
raise
|
| 108 |
|
| 109 |
def process_search_query(user_query: str, search_results: List[Dict]) -> Tuple[str, Dict[int, str]]:
|
| 110 |
"""Process search results using Llama model"""
|
models/summarizer.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
from llama import NVIDIALLamaClient
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class TextSummarizer:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.llama_client = NVIDIALLamaClient()
|
| 11 |
+
|
| 12 |
+
def clean_text(self, text: str) -> str:
|
| 13 |
+
"""Clean and normalize text for summarization"""
|
| 14 |
+
if not text:
|
| 15 |
+
return ""
|
| 16 |
+
|
| 17 |
+
# Remove common conversation starters and fillers
|
| 18 |
+
conversation_patterns = [
|
| 19 |
+
r'\b(hi|hello|hey|sure|okay|yes|no|thanks|thank you)\b',
|
| 20 |
+
r'\b(here is|this is|let me|i will|i can|i would)\b',
|
| 21 |
+
r'\b(summarize|summary|here\'s|here is)\b',
|
| 22 |
+
r'\b(please|kindly|would you|could you)\b',
|
| 23 |
+
r'\b(um|uh|er|ah|well|so|like|you know)\b'
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# Remove excessive whitespace and normalize
|
| 27 |
+
text = re.sub(r'\s+', ' ', text)
|
| 28 |
+
text = re.sub(r'\n+', ' ', text)
|
| 29 |
+
|
| 30 |
+
# Remove conversation patterns
|
| 31 |
+
for pattern in conversation_patterns:
|
| 32 |
+
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
|
| 33 |
+
|
| 34 |
+
# Remove extra punctuation and normalize
|
| 35 |
+
text = re.sub(r'[.]{2,}', '.', text)
|
| 36 |
+
text = re.sub(r'[!]{2,}', '!', text)
|
| 37 |
+
text = re.sub(r'[?]{2,}', '?', text)
|
| 38 |
+
|
| 39 |
+
return text.strip()
|
| 40 |
+
|
| 41 |
+
def extract_key_phrases(self, text: str) -> List[str]:
|
| 42 |
+
"""Extract key medical phrases and terms"""
|
| 43 |
+
if not text:
|
| 44 |
+
return []
|
| 45 |
+
|
| 46 |
+
# Medical term patterns
|
| 47 |
+
medical_patterns = [
|
| 48 |
+
r'\b(?:symptoms?|diagnosis|treatment|therapy|medication|drug|disease|condition|syndrome)\b',
|
| 49 |
+
r'\b(?:patient|doctor|physician|medical|clinical|healthcare)\b',
|
| 50 |
+
r'\b(?:blood pressure|heart rate|temperature|pulse|respiration)\b',
|
| 51 |
+
r'\b(?:acute|chronic|severe|mild|moderate|serious|critical)\b',
|
| 52 |
+
r'\b(?:pain|ache|discomfort|swelling|inflammation|infection)\b'
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
key_phrases = []
|
| 56 |
+
for pattern in medical_patterns:
|
| 57 |
+
matches = re.findall(pattern, text, re.IGNORECASE)
|
| 58 |
+
key_phrases.extend(matches)
|
| 59 |
+
|
| 60 |
+
return list(set(key_phrases)) # Remove duplicates
|
| 61 |
+
|
| 62 |
+
def summarize_text(self, text: str, max_length: int = 200) -> str:
|
| 63 |
+
"""Summarize text using NVIDIA Llama model"""
|
| 64 |
+
try:
|
| 65 |
+
if not text or len(text.strip()) < 50:
|
| 66 |
+
return text
|
| 67 |
+
|
| 68 |
+
# Clean the text first
|
| 69 |
+
cleaned_text = self.clean_text(text)
|
| 70 |
+
|
| 71 |
+
# Extract key phrases for context
|
| 72 |
+
key_phrases = self.extract_key_phrases(cleaned_text)
|
| 73 |
+
key_phrases_str = ", ".join(key_phrases[:5]) if key_phrases else "medical information"
|
| 74 |
+
|
| 75 |
+
# Create optimized prompt
|
| 76 |
+
prompt = f"""Summarize this medical text in {max_length} characters or less. Focus only on key medical facts, symptoms, treatments, and diagnoses. Do not include greetings, confirmations, or conversational elements.
|
| 77 |
+
|
| 78 |
+
Key terms: {key_phrases_str}
|
| 79 |
+
|
| 80 |
+
Text: {cleaned_text[:1500]}
|
| 81 |
+
|
| 82 |
+
Summary:"""
|
| 83 |
+
|
| 84 |
+
summary = self.llama_client._call_llama(prompt)
|
| 85 |
+
|
| 86 |
+
# Post-process summary
|
| 87 |
+
summary = self.clean_text(summary)
|
| 88 |
+
|
| 89 |
+
# Ensure it's within length limit
|
| 90 |
+
if len(summary) > max_length:
|
| 91 |
+
summary = summary[:max_length-3] + "..."
|
| 92 |
+
|
| 93 |
+
return summary
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Summarization failed: {e}")
|
| 97 |
+
# Fallback to simple truncation
|
| 98 |
+
return self.clean_text(text)[:max_length]
|
| 99 |
+
|
| 100 |
+
def summarize_documents(self, documents: List[Dict], user_query: str) -> Tuple[str, Dict[int, str]]:
|
| 101 |
+
"""Summarize multiple documents with URL mapping"""
|
| 102 |
+
try:
|
| 103 |
+
doc_summaries = []
|
| 104 |
+
url_mapping = {}
|
| 105 |
+
|
| 106 |
+
for doc in documents:
|
| 107 |
+
doc_id = doc['id']
|
| 108 |
+
url_mapping[doc_id] = doc['url']
|
| 109 |
+
|
| 110 |
+
# Create focused summary for each document
|
| 111 |
+
summary_prompt = f"""Summarize this medical document in 2-3 sentences, focusing on information relevant to: "{user_query}"
|
| 112 |
+
|
| 113 |
+
Document: {doc['title']}
|
| 114 |
+
Content: {doc['content'][:800]}
|
| 115 |
+
|
| 116 |
+
Key medical information:"""
|
| 117 |
+
|
| 118 |
+
summary = self.llama_client._call_llama(summary_prompt)
|
| 119 |
+
summary = self.clean_text(summary)
|
| 120 |
+
|
| 121 |
+
doc_summaries.append(f"Document {doc_id}: {summary}")
|
| 122 |
+
|
| 123 |
+
combined_summary = "\n\n".join(doc_summaries)
|
| 124 |
+
return combined_summary, url_mapping
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"Document summarization failed: {e}")
|
| 128 |
+
return "", {}
|
| 129 |
+
|
| 130 |
+
def summarize_conversation_chunk(self, chunk: str) -> str:
|
| 131 |
+
"""Summarize a conversation chunk for memory"""
|
| 132 |
+
try:
|
| 133 |
+
if not chunk or len(chunk.strip()) < 30:
|
| 134 |
+
return chunk
|
| 135 |
+
|
| 136 |
+
cleaned_chunk = self.clean_text(chunk)
|
| 137 |
+
|
| 138 |
+
prompt = f"""Summarize this medical conversation in 1-2 sentences. Focus only on medical facts, symptoms, treatments, or diagnoses discussed. Remove greetings and conversational elements.
|
| 139 |
+
|
| 140 |
+
Conversation: {cleaned_chunk[:1000]}
|
| 141 |
+
|
| 142 |
+
Medical summary:"""
|
| 143 |
+
|
| 144 |
+
summary = self.llama_client._call_llama(prompt)
|
| 145 |
+
return self.clean_text(summary)
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Conversation summarization failed: {e}")
|
| 149 |
+
return self.clean_text(chunk)[:150]
|
| 150 |
+
|
| 151 |
+
def chunk_response(self, response: str, max_chunk_size: int = 500) -> List[str]:
|
| 152 |
+
"""Split response into chunks and summarize each"""
|
| 153 |
+
try:
|
| 154 |
+
if not response or len(response) <= max_chunk_size:
|
| 155 |
+
return [response]
|
| 156 |
+
|
| 157 |
+
# Split by sentences first
|
| 158 |
+
sentences = re.split(r'[.!?]+', response)
|
| 159 |
+
chunks = []
|
| 160 |
+
current_chunk = ""
|
| 161 |
+
|
| 162 |
+
for sentence in sentences:
|
| 163 |
+
sentence = sentence.strip()
|
| 164 |
+
if not sentence:
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
# Check if adding this sentence would exceed limit
|
| 168 |
+
if len(current_chunk) + len(sentence) > max_chunk_size and current_chunk:
|
| 169 |
+
chunks.append(self.summarize_conversation_chunk(current_chunk))
|
| 170 |
+
current_chunk = sentence
|
| 171 |
+
else:
|
| 172 |
+
current_chunk += sentence + ". "
|
| 173 |
+
|
| 174 |
+
# Add the last chunk
|
| 175 |
+
if current_chunk:
|
| 176 |
+
chunks.append(self.summarize_conversation_chunk(current_chunk))
|
| 177 |
+
|
| 178 |
+
return chunks
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error(f"Response chunking failed: {e}")
|
| 182 |
+
return [response]
|
| 183 |
+
|
| 184 |
+
# Global summarizer instance
|
| 185 |
+
summarizer = TextSummarizer()
|
warmup.py → models/warmup.py
RENAMED
|
File without changes
|
search/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Search package
|
| 2 |
+
from .search import WebSearcher, search_web
|
search.py → search/search.py
RENAMED
|
@@ -96,39 +96,54 @@ class WebSearcher:
|
|
| 96 |
logger.warning(f"Failed to extract content from {url}: {e}")
|
| 97 |
return ""
|
| 98 |
|
| 99 |
-
def search_and_extract(self, query: str, num_results: int =
|
| 100 |
"""Search for query and extract content from top results"""
|
| 101 |
logger.info(f"Searching for: {query}")
|
| 102 |
|
| 103 |
-
# Get search results
|
| 104 |
-
search_results = self.search_duckduckgo(query, num_results)
|
| 105 |
|
| 106 |
-
# Extract content from each result
|
| 107 |
enriched_results = []
|
|
|
|
|
|
|
|
|
|
| 108 |
for i, result in enumerate(search_results):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
try:
|
| 110 |
logger.info(f"Extracting content from {result['url']}")
|
| 111 |
content = self.extract_content(result['url'])
|
| 112 |
|
| 113 |
-
if content:
|
| 114 |
enriched_results.append({
|
| 115 |
-
'id':
|
| 116 |
'url': result['url'],
|
| 117 |
'title': result['title'],
|
| 118 |
'content': content
|
| 119 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
# Add delay to be respectful
|
| 122 |
-
time.sleep(
|
| 123 |
|
| 124 |
except Exception as e:
|
|
|
|
| 125 |
logger.warning(f"Failed to process {result['url']}: {e}")
|
| 126 |
continue
|
| 127 |
|
| 128 |
-
logger.info(f"Successfully processed {len(enriched_results)} results")
|
| 129 |
return enriched_results
|
| 130 |
|
| 131 |
-
def search_web(query: str, num_results: int =
|
| 132 |
"""Main function to search the web and return enriched results"""
|
| 133 |
searcher = WebSearcher()
|
| 134 |
return searcher.search_and_extract(query, num_results)
|
|
|
|
| 96 |
logger.warning(f"Failed to extract content from {url}: {e}")
|
| 97 |
return ""
|
| 98 |
|
| 99 |
+
def search_and_extract(self, query: str, num_results: int = 10) -> List[Dict]:
|
| 100 |
"""Search for query and extract content from top results"""
|
| 101 |
logger.info(f"Searching for: {query}")
|
| 102 |
|
| 103 |
+
# Get search results (fetch more than needed for filtering)
|
| 104 |
+
search_results = self.search_duckduckgo(query, min(num_results * 2, 20))
|
| 105 |
|
| 106 |
+
# Extract content from each result with parallel processing
|
| 107 |
enriched_results = []
|
| 108 |
+
failed_count = 0
|
| 109 |
+
max_failures = 5 # Stop after 5 consecutive failures
|
| 110 |
+
|
| 111 |
for i, result in enumerate(search_results):
|
| 112 |
+
if len(enriched_results) >= num_results:
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
if failed_count >= max_failures:
|
| 116 |
+
logger.warning(f"Too many failures ({failed_count}), stopping extraction")
|
| 117 |
+
break
|
| 118 |
+
|
| 119 |
try:
|
| 120 |
logger.info(f"Extracting content from {result['url']}")
|
| 121 |
content = self.extract_content(result['url'])
|
| 122 |
|
| 123 |
+
if content and len(content.strip()) > 50: # Only include substantial content
|
| 124 |
enriched_results.append({
|
| 125 |
+
'id': len(enriched_results) + 1, # Sequential ID
|
| 126 |
'url': result['url'],
|
| 127 |
'title': result['title'],
|
| 128 |
'content': content
|
| 129 |
})
|
| 130 |
+
failed_count = 0 # Reset failure counter
|
| 131 |
+
else:
|
| 132 |
+
failed_count += 1
|
| 133 |
+
logger.warning(f"Insufficient content from {result['url']}")
|
| 134 |
|
| 135 |
# Add delay to be respectful
|
| 136 |
+
time.sleep(0.5) # Reduced delay for better performance
|
| 137 |
|
| 138 |
except Exception as e:
|
| 139 |
+
failed_count += 1
|
| 140 |
logger.warning(f"Failed to process {result['url']}: {e}")
|
| 141 |
continue
|
| 142 |
|
| 143 |
+
logger.info(f"Successfully processed {len(enriched_results)} results out of {len(search_results)} attempted")
|
| 144 |
return enriched_results
|
| 145 |
|
| 146 |
+
def search_web(query: str, num_results: int = 10) -> List[Dict]:
|
| 147 |
"""Main function to search the web and return enriched results"""
|
| 148 |
searcher = WebSearcher()
|
| 149 |
return searcher.search_and_extract(query, num_results)
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Utils package
|
| 2 |
+
from .translation import translate_query
|
| 3 |
+
from .vlm import process_medical_image
|
| 4 |
+
from .diagnosis import retrieve_diagnosis_from_symptoms
|
clear_mongo.py → utils/clear_mongo.py
RENAMED
|
File without changes
|
connect_mongo.py → utils/connect_mongo.py
RENAMED
|
File without changes
|
diagnosis.py → utils/diagnosis.py
RENAMED
|
File without changes
|
migrate.py → utils/migrate.py
RENAMED
|
File without changes
|
translation.py → utils/translation.py
RENAMED
|
File without changes
|
vlm.py → utils/vlm.py
RENAMED
|
File without changes
|