Spaces:
Sleeping
Sleeping
Commit
·
2a31cee
0
Parent(s):
Enhance frontend with search icon
Browse files- .gitattributes +35 -0
- .gitignore +2 -0
- .huggingface.yml +4 -0
- Dockerfile +36 -0
- README.md +13 -0
- app.py +312 -0
- chat-history.md +382 -0
- clear_mongo.py +48 -0
- connect_mongo.py +24 -0
- diagnosis.py +76 -0
- download_model.py +51 -0
- memory.py +426 -0
- migrate.py +48 -0
- requirements.txt +23 -0
- translation.py +26 -0
- vlm.py +54 -0
- warmup.py +8 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
secrets.toml
|
.huggingface.yml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sdk: docker
|
| 2 |
+
app_file: app.py
|
| 3 |
+
port: 7860
|
| 4 |
+
hardware: cpu-basic
|
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11
|
| 2 |
+
|
| 3 |
+
# Create and use a non-root user (optional)
|
| 4 |
+
RUN useradd -m -u 1000 user
|
| 5 |
+
USER user
|
| 6 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 7 |
+
|
| 8 |
+
# Set working directory
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Copy all project files to the container
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
# Install dependencies
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Set Hugging Face cache directory to persist model downloads
|
| 18 |
+
ENV HF_HOME="/home/user/.cache/huggingface"
|
| 19 |
+
ENV SENTENCE_TRANSFORMERS_HOME="/home/user/.cache/huggingface/sentence-transformers"
|
| 20 |
+
ENV MEDGEMMA_HOME="/home/user/.cache/huggingface/sentence-transformers"
|
| 21 |
+
|
| 22 |
+
# Create cache directories and ensure permissions
|
| 23 |
+
RUN mkdir -p /app/model_cache /home/user/.cache/huggingface/sentence-transformers && \
|
| 24 |
+
chown -R user:user /app/model_cache /home/user/.cache/huggingface
|
| 25 |
+
|
| 26 |
+
# Pre-load model in a separate script
|
| 27 |
+
RUN python /app/download_model.py && python /app/warmup.py
|
| 28 |
+
|
| 29 |
+
# Ensure ownership and permissions remain intact
|
| 30 |
+
RUN chown -R user:user /app/model_cache
|
| 31 |
+
|
| 32 |
+
# Expose port
|
| 33 |
+
EXPOSE 7860
|
| 34 |
+
|
| 35 |
+
# Run the application
|
| 36 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Medical Chatbot
|
| 3 |
+
emoji: 🤖🩺
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: latest
|
| 8 |
+
pinned: false
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
short_description: MedicalChatbot, FAISS, Gemini, MongoDB vDB, LRU
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os
|
| 3 |
+
import faiss
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import uvicorn
|
| 7 |
+
from fastapi import FastAPI, Request
|
| 8 |
+
from fastapi.responses import JSONResponse
|
| 9 |
+
from pymongo import MongoClient
|
| 10 |
+
from google import genai
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
from sentence_transformers.util import cos_sim
|
| 13 |
+
from memory import MemoryManager
|
| 14 |
+
from translation import translate_query
|
| 15 |
+
from vlm import process_medical_image
|
| 16 |
+
|
| 17 |
+
# ✅ Enable Logging for Debugging
|
| 18 |
+
import logging
|
| 19 |
+
# —————— Silence Noisy Loggers ——————
|
| 20 |
+
for name in [
|
| 21 |
+
"uvicorn.error", "uvicorn.access",
|
| 22 |
+
"fastapi", "starlette",
|
| 23 |
+
"pymongo", "gridfs",
|
| 24 |
+
"sentence_transformers", "faiss",
|
| 25 |
+
"google", "google.auth",
|
| 26 |
+
]:
|
| 27 |
+
logging.getLogger(name).setLevel(logging.WARNING)
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
|
| 29 |
+
logger = logging.getLogger("medical-chatbot")
|
| 30 |
+
logger.setLevel(logging.DEBUG)
|
| 31 |
+
|
| 32 |
+
# Debug Start
|
| 33 |
+
logger.info("🚀 Starting Medical Chatbot API...")
|
| 34 |
+
|
| 35 |
+
# ✅ Environment Variables
|
| 36 |
+
mongo_uri = os.getenv("MONGO_URI")
|
| 37 |
+
index_uri = os.getenv("INDEX_URI")
|
| 38 |
+
gemini_flash_api_key = os.getenv("FlashAPI")
|
| 39 |
+
# Validate environment endpoint
|
| 40 |
+
if not all([gemini_flash_api_key, mongo_uri, index_uri]):
|
| 41 |
+
raise ValueError("❌ Missing API keys! Set them in Hugging Face Secrets.")
|
| 42 |
+
# logger.info(f"🔎 MongoDB URI: {mongo_uri}")
|
| 43 |
+
# logger.info(f"🔎 FAISS Index URI: {index_uri}")
|
| 44 |
+
|
| 45 |
+
# ✅ Monitor Resources Before Startup
|
| 46 |
+
import psutil
|
| 47 |
+
def check_system_resources():
|
| 48 |
+
memory = psutil.virtual_memory()
|
| 49 |
+
cpu = psutil.cpu_percent(interval=1)
|
| 50 |
+
disk = psutil.disk_usage("/")
|
| 51 |
+
# Defines log info messages
|
| 52 |
+
logger.info(f"[System] 🔍 System Resources - RAM: {memory.percent}%, CPU: {cpu}%, Disk: {disk.percent}%")
|
| 53 |
+
if memory.percent > 85:
|
| 54 |
+
logger.warning("⚠️ High RAM usage detected!")
|
| 55 |
+
if cpu > 90:
|
| 56 |
+
logger.warning("⚠️ High CPU usage detected!")
|
| 57 |
+
if disk.percent > 90:
|
| 58 |
+
logger.warning("⚠️ High Disk usage detected!")
|
| 59 |
+
check_system_resources()
|
| 60 |
+
|
| 61 |
+
# ✅ Reduce Memory usage with optimizers
|
| 62 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 63 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 64 |
+
|
| 65 |
+
# ✅ Initialize FastAPI app
|
| 66 |
+
app = FastAPI(title="Medical Chatbot API")
|
| 67 |
+
memory = MemoryManager()
|
| 68 |
+
|
| 69 |
+
from fastapi.middleware.cors import CORSMiddleware # Bypassing CORS origin
|
| 70 |
+
# Define the origins
|
| 71 |
+
origins = [
|
| 72 |
+
"http://localhost:5173", # Vite dev server
|
| 73 |
+
"http://localhost:3000", # Another vercel local dev
|
| 74 |
+
"https://medical-chatbot-henna.vercel.app", # ✅ Vercel frontend production URL
|
| 75 |
+
|
| 76 |
+
]
|
| 77 |
+
# Add the CORS middleware:
|
| 78 |
+
app.add_middleware(
|
| 79 |
+
CORSMiddleware,
|
| 80 |
+
allow_origins=origins, # or ["*"] to allow all
|
| 81 |
+
allow_credentials=True,
|
| 82 |
+
allow_methods=["*"],
|
| 83 |
+
allow_headers=["*"],
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# ✅ Use Lazy Loading for FAISS Index
|
| 87 |
+
index = None # Delay FAISS Index loading until first query
|
| 88 |
+
|
| 89 |
+
# ✅ Load SentenceTransformer Model (Quantized/Halved)
|
| 90 |
+
logger.info("[Embedder] 📥 Loading SentenceTransformer Model...")
|
| 91 |
+
MODEL_CACHE_DIR = "/app/model_cache"
|
| 92 |
+
try:
|
| 93 |
+
embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device="cpu")
|
| 94 |
+
embedding_model = embedding_model.half() # Reduce memory
|
| 95 |
+
logger.info("✅ Model Loaded Successfully.")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"❌ Model Loading Failed: {e}")
|
| 98 |
+
exit(1)
|
| 99 |
+
|
| 100 |
+
# Cache in-memory vectors (optional — useful for <10k rows)
|
| 101 |
+
SYMPTOM_VECTORS = None
|
| 102 |
+
SYMPTOM_DOCS = None
|
| 103 |
+
|
| 104 |
+
# ✅ Setup MongoDB Connection
|
| 105 |
+
# QA data
|
| 106 |
+
client = MongoClient(mongo_uri)
|
| 107 |
+
db = client["MedicalChatbotDB"]
|
| 108 |
+
qa_collection = db["qa_data"]
|
| 109 |
+
# FAISS Index data
|
| 110 |
+
iclient = MongoClient(index_uri)
|
| 111 |
+
idb = iclient["MedicalChatbotDB"]
|
| 112 |
+
index_collection = idb["faiss_index_files"]
|
| 113 |
+
# Symptom Diagnosis data
|
| 114 |
+
symptom_client = MongoClient(mongo_uri)
|
| 115 |
+
symptom_col = symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
|
| 116 |
+
|
| 117 |
+
# ✅ Load FAISS Index (Lazy Load)
|
| 118 |
+
import gridfs
|
| 119 |
+
fs = gridfs.GridFS(idb, collection="faiss_index_files")
|
| 120 |
+
|
| 121 |
+
def load_faiss_index():
|
| 122 |
+
global index
|
| 123 |
+
if index is None:
|
| 124 |
+
logger.info("[KB] ⏳ Loading FAISS index from GridFS...")
|
| 125 |
+
existing_file = fs.find_one({"filename": "faiss_index.bin"})
|
| 126 |
+
if existing_file:
|
| 127 |
+
stored_index_bytes = existing_file.read()
|
| 128 |
+
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
|
| 129 |
+
index = faiss.deserialize_index(index_bytes_np)
|
| 130 |
+
logger.info("[KB] ✅ FAISS Index Loaded")
|
| 131 |
+
else:
|
| 132 |
+
logger.error("[KB] ❌ FAISS index not found in GridFS.")
|
| 133 |
+
return index
|
| 134 |
+
|
| 135 |
+
# ✅ Retrieve Medical Info (256,916 scenario)
|
| 136 |
+
def retrieve_medical_info(query, k=5, min_sim=0.9): # Min similarity between query and kb is to be 80%
|
| 137 |
+
global index
|
| 138 |
+
index = load_faiss_index()
|
| 139 |
+
if index is None:
|
| 140 |
+
return [""]
|
| 141 |
+
# Embed query
|
| 142 |
+
query_vec = embedding_model.encode([query], convert_to_numpy=True)
|
| 143 |
+
D, I = index.search(query_vec, k=k)
|
| 144 |
+
# Filter by cosine threshold
|
| 145 |
+
results = []
|
| 146 |
+
kept = []
|
| 147 |
+
kept_vecs = []
|
| 148 |
+
# Smart dedup on cosine threshold between similar candidates
|
| 149 |
+
for score, idx in zip(D[0], I[0]):
|
| 150 |
+
if score < min_sim:
|
| 151 |
+
continue
|
| 152 |
+
# List sim docs
|
| 153 |
+
doc = qa_collection.find_one({"i": int(idx)})
|
| 154 |
+
if not doc:
|
| 155 |
+
continue
|
| 156 |
+
# Only compare answers
|
| 157 |
+
answer = doc.get("Doctor", "").strip()
|
| 158 |
+
if not answer:
|
| 159 |
+
continue
|
| 160 |
+
# Check semantic redundancy among previously kept results
|
| 161 |
+
new_vec = embedding_model.encode([answer], convert_to_numpy=True)[0]
|
| 162 |
+
is_similar = False
|
| 163 |
+
for i, vec in enumerate(kept_vecs):
|
| 164 |
+
sim = np.dot(vec, new_vec) / (np.linalg.norm(vec) * np.linalg.norm(new_vec) + 1e-9)
|
| 165 |
+
if sim >= 0.9: # High semantic similarity
|
| 166 |
+
is_similar = True
|
| 167 |
+
# Keep only better match to original query
|
| 168 |
+
cur_sim_to_query = np.dot(vec, query_vec[0]) / (np.linalg.norm(vec) * np.linalg.norm(query_vec[0]) + 1e-9)
|
| 169 |
+
new_sim_to_query = np.dot(new_vec, query_vec[0]) / (np.linalg.norm(new_vec) * np.linalg.norm(query_vec[0]) + 1e-9)
|
| 170 |
+
if new_sim_to_query > cur_sim_to_query:
|
| 171 |
+
kept[i] = answer
|
| 172 |
+
kept_vecs[i] = new_vec
|
| 173 |
+
break
|
| 174 |
+
# Non-similar candidates
|
| 175 |
+
if not is_similar:
|
| 176 |
+
kept.append(answer)
|
| 177 |
+
kept_vecs.append(new_vec)
|
| 178 |
+
# Final
|
| 179 |
+
return kept if kept else [""]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ✅ Retrieve Sym-Dia Info (4,962 scenario)
|
| 183 |
+
def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.5):
|
| 184 |
+
global SYMPTOM_VECTORS, SYMPTOM_DOCS
|
| 185 |
+
# Lazy load
|
| 186 |
+
if SYMPTOM_VECTORS is None:
|
| 187 |
+
all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1}))
|
| 188 |
+
SYMPTOM_DOCS = all_docs
|
| 189 |
+
SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
|
| 190 |
+
# Embed input
|
| 191 |
+
qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
|
| 192 |
+
qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
|
| 193 |
+
# Similarity compute
|
| 194 |
+
sims = SYMPTOM_VECTORS @ qvec # cosine
|
| 195 |
+
sorted_idx = np.argsort(sims)[-top_k:][::-1]
|
| 196 |
+
seen_diag = set()
|
| 197 |
+
final = [] # Dedup
|
| 198 |
+
for i in sorted_idx:
|
| 199 |
+
sim = sims[i]
|
| 200 |
+
if sim < min_sim:
|
| 201 |
+
continue
|
| 202 |
+
label = SYMPTOM_DOCS[i]["prognosis"]
|
| 203 |
+
if label not in seen_diag:
|
| 204 |
+
final.append(SYMPTOM_DOCS[i]["answer"])
|
| 205 |
+
seen_diag.add(label)
|
| 206 |
+
return final
|
| 207 |
+
|
| 208 |
+
# ✅ Gemini Flash API Call
|
| 209 |
+
def gemini_flash_completion(prompt, model, temperature=0.7):
|
| 210 |
+
client_genai = genai.Client(api_key=gemini_flash_api_key)
|
| 211 |
+
try:
|
| 212 |
+
response = client_genai.models.generate_content(model=model, contents=prompt)
|
| 213 |
+
return response.text
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"[LLM] ❌ Error calling Gemini API: {e}")
|
| 216 |
+
return "Error generating response from Gemini."
|
| 217 |
+
|
| 218 |
+
# ✅ Chatbot Class
|
| 219 |
+
class RAGMedicalChatbot:
|
| 220 |
+
def __init__(self, model_name, retrieve_function):
|
| 221 |
+
self.model_name = model_name
|
| 222 |
+
self.retrieve = retrieve_function
|
| 223 |
+
|
| 224 |
+
def chat(self, user_id: str, user_query: str, lang: str = "EN", image_diagnosis: str = "") -> str:
|
| 225 |
+
# 0. Translate query if not EN, this help our RAG system
|
| 226 |
+
if lang.upper() in {"VI", "ZH"}:
|
| 227 |
+
user_query = translate_query(user_query, lang.lower())
|
| 228 |
+
|
| 229 |
+
# 1. Fetch knowledge
|
| 230 |
+
## a. KB for generic QA retrieval
|
| 231 |
+
retrieved_info = self.retrieve(user_query)
|
| 232 |
+
knowledge_base = "\n".join(retrieved_info)
|
| 233 |
+
## b. Diagnosis RAG from symptom query
|
| 234 |
+
diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher
|
| 235 |
+
|
| 236 |
+
# 2. Hybrid Context Retrieval: RAG + Recent History + Intelligent Selection
|
| 237 |
+
contextual_chunks = memory.get_contextual_chunks(user_id, user_query, lang)
|
| 238 |
+
|
| 239 |
+
# 3. Build prompt parts
|
| 240 |
+
parts = ["You are a medical chatbot, designed to answer medical questions."]
|
| 241 |
+
parts.append("Please format your answer using MarkDown.")
|
| 242 |
+
parts.append("**Bold for titles**, *italic for emphasis*, and clear headings.")
|
| 243 |
+
|
| 244 |
+
# 4. Append image diagnosis from VLM
|
| 245 |
+
if image_diagnosis:
|
| 246 |
+
parts.append(
|
| 247 |
+
"A user medical image is diagnosed by our VLM agent:\n"
|
| 248 |
+
f"{image_diagnosis}\n\n"
|
| 249 |
+
"Please incorporate the above findings in your response if medically relevant.\n\n"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Append contextual chunks from hybrid approach
|
| 253 |
+
if contextual_chunks:
|
| 254 |
+
parts.append("Relevant context from conversation history:\n" + contextual_chunks)
|
| 255 |
+
# Load up guideline (RAG over medical knowledge base)
|
| 256 |
+
if knowledge_base:
|
| 257 |
+
parts.append(f"Example Q&A medical scenario knowledge-base: {knowledge_base}")
|
| 258 |
+
# Symptom-Diagnosis prediction RAG
|
| 259 |
+
if diagnosis_guides:
|
| 260 |
+
parts.append("Symptom-based diagnosis guidance (if applicable):\n" + "\n".join(diagnosis_guides))
|
| 261 |
+
parts.append(f"User's question: {user_query}")
|
| 262 |
+
parts.append(f"Language to generate answer: {lang}")
|
| 263 |
+
prompt = "\n\n".join(parts)
|
| 264 |
+
logger.info(f"[LLM] Question query in `prompt`: {prompt}") # Debug out checking RAG on kb and history
|
| 265 |
+
response = gemini_flash_completion(prompt, model=self.model_name, temperature=0.7)
|
| 266 |
+
# Store exchange + chunking
|
| 267 |
+
if user_id:
|
| 268 |
+
memory.add_exchange(user_id, user_query, response, lang=lang)
|
| 269 |
+
logger.info(f"[LLM] Response on `prompt`: {response.strip()}") # Debug out base response
|
| 270 |
+
return response.strip()
|
| 271 |
+
|
| 272 |
+
# ✅ Initialize Chatbot
|
| 273 |
+
chatbot = RAGMedicalChatbot(model_name="gemini-2.5-flash", retrieve_function=retrieve_medical_info)
|
| 274 |
+
|
| 275 |
+
# ✅ Chat Endpoint
|
| 276 |
+
@app.post("/chat")
|
| 277 |
+
async def chat_endpoint(req: Request):
|
| 278 |
+
body = await req.json()
|
| 279 |
+
user_id = body.get("user_id", "anonymous")
|
| 280 |
+
query_raw = body.get("query")
|
| 281 |
+
query = query_raw.strip() if isinstance(query_raw, str) else ""
|
| 282 |
+
lang = body.get("lang", "EN")
|
| 283 |
+
image_base64 = body.get("image_base64", None)
|
| 284 |
+
img_desc = body.get("img_desc", "Describe and investigate any clinical findings from this medical image.")
|
| 285 |
+
start = time.time()
|
| 286 |
+
image_diagnosis = ""
|
| 287 |
+
# LLM Only
|
| 288 |
+
if not image_base64:
|
| 289 |
+
logger.info("[BOT] LLM scenario.")
|
| 290 |
+
# LLM+VLM
|
| 291 |
+
else:
|
| 292 |
+
# If image is present → diagnose first
|
| 293 |
+
safe_load = len(image_base64.encode("utf-8"))
|
| 294 |
+
if safe_load > 5_000_000: # Img size safe processor
|
| 295 |
+
return JSONResponse({"response": "⚠️ Image too large. Please upload smaller images (<5MB)."})
|
| 296 |
+
logger.info("[BOT] VLM+LLM scenario.")
|
| 297 |
+
logger.info(f"[VLM] Process medical image size: {safe_load}, desc: {img_desc}, {lang}.")
|
| 298 |
+
image_diagnosis = process_medical_image(image_base64, img_desc, lang)
|
| 299 |
+
answer = chatbot.chat(user_id, query, lang, image_diagnosis)
|
| 300 |
+
elapsed = time.time() - start
|
| 301 |
+
# Final
|
| 302 |
+
return JSONResponse({"response": f"{answer}\n\n(Response time: {elapsed:.2f}s)"})
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ✅ Run Uvicorn
|
| 306 |
+
if __name__ == "__main__":
|
| 307 |
+
logger.info("[System] ✅ Starting FastAPI Server...")
|
| 308 |
+
try:
|
| 309 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="debug")
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.error(f"❌ Server Startup Failed: {e}")
|
| 312 |
+
exit(1)
|
chat-history.md
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔄 Enhanced Memory System: STM + LTM + Hybrid Context Retrieval
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
The Medical Chatbot now implements an **advanced memory system** with **Short-Term Memory (STM)** and **Long-Term Memory (LTM)** that intelligently manages conversation context, semantic knowledge, and conversational continuity. This system goes beyond simple RAG to provide truly intelligent, contextually aware responses that remember and build upon previous interactions.
|
| 6 |
+
|
| 7 |
+
## 🏗️ Architecture
|
| 8 |
+
|
| 9 |
+
### Memory Hierarchy
|
| 10 |
+
```
|
| 11 |
+
User Query → Enhanced Memory System → Intelligent Context Selection → LLM Response
|
| 12 |
+
↓
|
| 13 |
+
┌─────────────────┬─────────────────┬─────────────────┐
|
| 14 |
+
│ STM (5 items) │ LTM (60 items)│ RAG Search │
|
| 15 |
+
│ (Recent Summaries)│ (Semantic Store)│ (Knowledge Base)│
|
| 16 |
+
└─────────────────┴─────────────────┴─────────────────┘
|
| 17 |
+
↓
|
| 18 |
+
Gemini Flash Lite Contextual Analysis
|
| 19 |
+
↓
|
| 20 |
+
Summarized Context + Semantic Knowledge
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### Memory Types
|
| 24 |
+
|
| 25 |
+
#### 1. **Short-Term Memory (STM)**
|
| 26 |
+
- **Capacity:** 5 recent conversation summaries
|
| 27 |
+
- **Content:** Chunked and summarized LLM responses with enriched topics
|
| 28 |
+
- **Features:** Semantic deduplication, intelligent merging, topic enrichment
|
| 29 |
+
- **Purpose:** Maintain conversational continuity and immediate context
|
| 30 |
+
|
| 31 |
+
#### 2. **Long-Term Memory (LTM)**
|
| 32 |
+
- **Capacity:** 60 semantic chunks (~20 conversational rounds)
|
| 33 |
+
- **Content:** FAISS-indexed medical knowledge chunks
|
| 34 |
+
- **Features:** Semantic similarity search, usage tracking, smart eviction
|
| 35 |
+
- **Purpose:** Provide deep medical knowledge and historical context
|
| 36 |
+
|
| 37 |
+
#### 3. **RAG Knowledge Base**
|
| 38 |
+
- **Content:** External medical knowledge and guidelines
|
| 39 |
+
- **Features:** Real-time retrieval, semantic matching
|
| 40 |
+
- **Purpose:** Supplement with current medical information
|
| 41 |
+
|
| 42 |
+
## 🔧 Key Components
|
| 43 |
+
|
| 44 |
+
### 1. Enhanced Memory Manager (`memory.py`)
|
| 45 |
+
|
| 46 |
+
#### STM Management
|
| 47 |
+
```python
|
| 48 |
+
def get_recent_chat_history(self, user_id: str, num_turns: int = 5) -> List[Dict]:
|
| 49 |
+
"""
|
| 50 |
+
Get the most recent STM summaries (not raw Q/A).
|
| 51 |
+
Returns: [{"user": "", "bot": "Topic: ...\n<summary>", "timestamp": time}, ...]
|
| 52 |
+
"""
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
**STM Features:**
|
| 56 |
+
- **Capacity:** 5 recent conversation summaries
|
| 57 |
+
- **Content:** Chunked and summarized LLM responses with enriched topics
|
| 58 |
+
- **Deduplication:** Semantic similarity-based merging (≥0.92 identical, ≥0.75 merge)
|
| 59 |
+
- **Topic Enrichment:** Uses user question context to generate detailed topics
|
| 60 |
+
|
| 61 |
+
#### LTM Management
|
| 62 |
+
```python
|
| 63 |
+
def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
|
| 64 |
+
"""Return texts of chunks whose cosine similarity ≥ min_sim."""
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
**LTM Features:**
|
| 68 |
+
- **Capacity:** 60 semantic chunks (~20 conversational rounds)
|
| 69 |
+
- **Indexing:** FAISS-based semantic search
|
| 70 |
+
- **Smart Eviction:** Usage-based decay and recency scoring
|
| 71 |
+
- **Merging:** Intelligent deduplication and content fusion
|
| 72 |
+
|
| 73 |
+
#### Enhanced Chunking
|
| 74 |
+
```python
|
| 75 |
+
def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
|
| 76 |
+
"""
|
| 77 |
+
Enhanced chunking with question context for richer topics.
|
| 78 |
+
Returns: [{"tag": "detailed_topic", "text": "summary"}, ...]
|
| 79 |
+
"""
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
**Chunking Features:**
|
| 83 |
+
- **Question Context:** Incorporates user's latest question for topic generation
|
| 84 |
+
- **Rich Topics:** Detailed topics (10-20 words) capturing context, condition, and action
|
| 85 |
+
- **Medical Focus:** Excludes disclaimers, includes exact medication names/doses
|
| 86 |
+
- **Semantic Grouping:** Groups by medical topic, symptom, assessment, plan, or instruction
|
| 87 |
+
|
| 88 |
+
### 2. Intelligent Context Retrieval
|
| 89 |
+
|
| 90 |
+
#### Contextual Summarization
|
| 91 |
+
```python
|
| 92 |
+
def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
|
| 93 |
+
"""
|
| 94 |
+
Creates a single, coherent summary from STM + LTM + RAG.
|
| 95 |
+
Returns: A single summary string for the main LLM.
|
| 96 |
+
"""
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
**Features:**
|
| 100 |
+
- **Unified Summary:** Combines STM (5 turns) + LTM (semantic) + RAG (knowledge)
|
| 101 |
+
- **Gemini Analysis:** Uses Gemini Flash Lite for intelligent context selection
|
| 102 |
+
- **Conversational Flow:** Maintains continuity while providing medical relevance
|
| 103 |
+
- **Fallback Strategy:** Graceful degradation if analysis fails
|
| 104 |
+
|
| 105 |
+
## 🚀 How It Works
|
| 106 |
+
|
| 107 |
+
### Step 1: Enhanced Memory Processing
|
| 108 |
+
```python
|
| 109 |
+
# Process new exchange through STM and LTM
|
| 110 |
+
chunks = memory.chunk_response(response, lang, question=query)
|
| 111 |
+
for chunk in chunks:
|
| 112 |
+
memory._upsert_stm(user_id, chunk, lang) # STM with dedupe/merge
|
| 113 |
+
memory._upsert_ltm(user_id, chunks, lang) # LTM with semantic storage
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Step 2: Context Retrieval
|
| 117 |
+
```python
|
| 118 |
+
# Get STM summaries (5 recent turns)
|
| 119 |
+
recent_history = memory.get_recent_chat_history(user_id, num_turns=5)
|
| 120 |
+
|
| 121 |
+
# Get LTM semantic chunks
|
| 122 |
+
rag_chunks = memory.get_relevant_chunks(user_id, current_query, top_k=3)
|
| 123 |
+
|
| 124 |
+
# Get external RAG knowledge
|
| 125 |
+
external_rag = retrieve_medical_info(current_query)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Step 3: Intelligent Context Summarization
|
| 129 |
+
The system sends all context sources to Gemini Flash Lite for unified summarization:
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
You are a medical assistant creating a concise summary of conversation context for continuity.
|
| 133 |
+
|
| 134 |
+
Current user query: "{current_query}"
|
| 135 |
+
|
| 136 |
+
Available context information:
|
| 137 |
+
Recent conversation history:
|
| 138 |
+
{recent_history}
|
| 139 |
+
|
| 140 |
+
Semantically relevant historical medical information:
|
| 141 |
+
{rag_chunks}
|
| 142 |
+
|
| 143 |
+
Task: Create a brief, coherent summary that captures the key points from the conversation history and relevant medical information that are important for understanding the current query.
|
| 144 |
+
|
| 145 |
+
Guidelines:
|
| 146 |
+
1. Focus on medical symptoms, diagnoses, treatments, or recommendations mentioned
|
| 147 |
+
2. Include any patient concerns or questions that are still relevant
|
| 148 |
+
3. Highlight any follow-up needs or pending clarifications
|
| 149 |
+
4. Keep the summary concise but comprehensive enough for context
|
| 150 |
+
5. Maintain conversational flow and continuity
|
| 151 |
+
|
| 152 |
+
Output: Provide a single, well-structured summary paragraph that can be used as context for the main LLM to provide a coherent response.
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
### Step 4: Unified Context Integration
|
| 156 |
+
The single, coherent summary is integrated into the main LLM prompt, providing:
|
| 157 |
+
- **Conversational continuity** (from STM summaries)
|
| 158 |
+
- **Medical knowledge** (from LTM semantic chunks)
|
| 159 |
+
- **Current information** (from external RAG)
|
| 160 |
+
- **Unified narrative** (single summary instead of multiple chunks)
|
| 161 |
+
|
| 162 |
+
## 📊 Benefits
|
| 163 |
+
|
| 164 |
+
### 1. **Advanced Memory Management**
|
| 165 |
+
- **STM:** Maintains 5 recent conversation summaries with intelligent deduplication
|
| 166 |
+
- **LTM:** Stores 60 semantic chunks (~20 rounds) with FAISS indexing
|
| 167 |
+
- **Smart Merging:** Combines similar content while preserving unique details
|
| 168 |
+
- **Topic Enrichment:** Detailed topics using user question context
|
| 169 |
+
|
| 170 |
+
### 2. **Intelligent Context Summarization**
|
| 171 |
+
- **Unified Summary:** Single coherent narrative instead of multiple chunks
|
| 172 |
+
- **Gemini Analysis:** AI-powered context selection and summarization
|
| 173 |
+
- **Medical Focus:** Prioritizes symptoms, diagnoses, treatments, and recommendations
|
| 174 |
+
- **Conversational Flow:** Maintains natural dialogue continuity
|
| 175 |
+
|
| 176 |
+
### 3. **Enhanced Chunking & Topics**
|
| 177 |
+
- **Question Context:** Incorporates user's latest question for richer topics
|
| 178 |
+
- **Detailed Topics:** 10-20 word descriptions capturing context, condition, and action
|
| 179 |
+
- **Medical Precision:** Includes exact medication names, doses, and clinical instructions
|
| 180 |
+
- **Semantic Grouping:** Organizes by medical topic, symptom, assessment, plan, or instruction
|
| 181 |
+
|
| 182 |
+
### 4. **Robust Fallback Strategy**
|
| 183 |
+
- **Primary:** Gemini Flash Lite contextual summarization
|
| 184 |
+
- **Secondary:** LTM semantic search with usage-based scoring
|
| 185 |
+
- **Tertiary:** STM recent summaries
|
| 186 |
+
- **Final:** External RAG knowledge base
|
| 187 |
+
|
| 188 |
+
### 5. **Performance & Scalability**
|
| 189 |
+
- **Efficient Storage:** Semantic deduplication reduces memory footprint
|
| 190 |
+
- **Fast Retrieval:** FAISS indexing for sub-millisecond LTM search
|
| 191 |
+
- **Smart Eviction:** Usage-based decay and recency scoring
|
| 192 |
+
- **Minimal Latency:** Optimized for real-time medical consultations
|
| 193 |
+
|
| 194 |
+
## 🧪 Example Scenarios
|
| 195 |
+
|
| 196 |
+
### Scenario 1: STM Deduplication & Merging
|
| 197 |
+
```
|
| 198 |
+
User: "I have chest pain"
|
| 199 |
+
Bot: "This could be angina. Symptoms include pressure, tightness, and shortness of breath."
|
| 200 |
+
|
| 201 |
+
User: "What about chest pain with shortness of breath?"
|
| 202 |
+
Bot: "Chest pain with shortness of breath is concerning for angina or heart attack..."
|
| 203 |
+
|
| 204 |
+
User: "Tell me more about the symptoms"
|
| 205 |
+
Bot: "Angina symptoms include chest pressure, tightness, shortness of breath, and may radiate to arms..."
|
| 206 |
+
```
|
| 207 |
+
**Result:** STM merges similar responses, creating a comprehensive summary: "Patient has chest pain symptoms consistent with angina, including pressure, tightness, shortness of breath, and potential radiation to arms. This represents a concerning cardiac presentation requiring immediate evaluation."
|
| 208 |
+
|
| 209 |
+
### Scenario 2: LTM Semantic Retrieval
|
| 210 |
+
```
|
| 211 |
+
User: "What medications should I avoid with my condition?"
|
| 212 |
+
Bot: "Based on your previous discussion about hypertension and the medications mentioned..."
|
| 213 |
+
```
|
| 214 |
+
**Result:** LTM retrieves relevant medical information about hypertension medications and contraindications from previous conversations, even if not in recent STM.
|
| 215 |
+
|
| 216 |
+
### Scenario 3: Enhanced Topic Generation
|
| 217 |
+
```
|
| 218 |
+
User: "I'm having trouble sleeping"
|
| 219 |
+
Bot: "Topic: Sleep disturbance evaluation and management for adult patient with insomnia symptoms"
|
| 220 |
+
```
|
| 221 |
+
**Result:** The topic incorporates the user's question context to create a detailed, medical-specific description instead of just "Sleep problems."
|
| 222 |
+
|
| 223 |
+
### Scenario 4: Unified Context Summarization
|
| 224 |
+
```
|
| 225 |
+
User: "Can you repeat the treatment plan?"
|
| 226 |
+
Bot: "Based on our conversation about your hypertension and sleep issues, your treatment plan includes..."
|
| 227 |
+
```
|
| 228 |
+
**Result:** The system creates a unified summary combining STM (recent sleep discussion), LTM (hypertension history), and RAG (current treatment guidelines) into a single coherent narrative.
|
| 229 |
+
|
| 230 |
+
## ⚙️ Configuration
|
| 231 |
+
|
| 232 |
+
### Environment Variables
|
| 233 |
+
```bash
|
| 234 |
+
FlashAPI=your_gemini_api_key # For both main LLM and contextual analysis
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
### Enhanced Memory Settings
|
| 238 |
+
```python
|
| 239 |
+
memory = MemoryManager(
|
| 240 |
+
max_users=1000, # Maximum users in memory
|
| 241 |
+
history_per_user=5, # STM capacity (5 recent summaries)
|
| 242 |
+
max_chunks=60 # LTM capacity (~20 conversational rounds)
|
| 243 |
+
)
|
| 244 |
+
```
|
| 245 |
+
|
| 246 |
+
### Memory Parameters
|
| 247 |
+
```python
|
| 248 |
+
# STM retrieval (5 recent turns)
|
| 249 |
+
recent_history = memory.get_recent_chat_history(user_id, num_turns=5)
|
| 250 |
+
|
| 251 |
+
# LTM semantic search
|
| 252 |
+
rag_chunks = memory.get_relevant_chunks(user_id, query, top_k=3, min_sim=0.30)
|
| 253 |
+
|
| 254 |
+
# Unified context summarization
|
| 255 |
+
contextual_summary = memory.get_contextual_chunks(user_id, current_query, lang)
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
### Similarity Thresholds
|
| 259 |
+
```python
|
| 260 |
+
# STM deduplication thresholds
|
| 261 |
+
IDENTICAL_THRESHOLD = 0.92 # Replace older with newer
|
| 262 |
+
MERGE_THRESHOLD = 0.75 # Merge similar content
|
| 263 |
+
|
| 264 |
+
# LTM semantic search
|
| 265 |
+
MIN_SIMILARITY = 0.30 # Minimum similarity for retrieval
|
| 266 |
+
TOP_K = 3 # Number of chunks to retrieve
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
## 🔍 Monitoring & Debugging
|
| 270 |
+
|
| 271 |
+
### Enhanced Logging
|
| 272 |
+
The system provides comprehensive logging for all memory operations:
|
| 273 |
+
```python
|
| 274 |
+
# STM operations
|
| 275 |
+
logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items")
|
| 276 |
+
logger.info(f"[Contextual] Retrieved {len(rag_chunks)} RAG chunks")
|
| 277 |
+
|
| 278 |
+
# Chunking operations
|
| 279 |
+
logger.info(f"[Memory] 📦 Gemini summarized chunk output: {output}")
|
| 280 |
+
logger.warning(f"[Memory] ❌ Gemini chunking failed: {e}")
|
| 281 |
+
|
| 282 |
+
# Contextual summarization
|
| 283 |
+
logger.info(f"[Contextual] Gemini created summary: {summary[:100]}...")
|
| 284 |
+
logger.warning(f"[Contextual] Gemini summarization failed: {e}")
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
### Performance Metrics
|
| 288 |
+
- **STM Operations:** Deduplication rate, merge frequency, topic enrichment quality
|
| 289 |
+
- **LTM Operations:** FAISS search latency, semantic similarity scores, eviction patterns
|
| 290 |
+
- **Context Summarization:** Gemini response time, summary quality, fallback usage
|
| 291 |
+
- **Memory Usage:** Storage efficiency, retrieval hit rates, cache performance
|
| 292 |
+
|
| 293 |
+
## 🚨 Error Handling
|
| 294 |
+
|
| 295 |
+
### Enhanced Fallback Strategy
|
| 296 |
+
1. **Primary:** Gemini Flash Lite contextual summarization
|
| 297 |
+
2. **Secondary:** LTM semantic search with usage-based scoring
|
| 298 |
+
3. **Tertiary:** STM recent summaries
|
| 299 |
+
4. **Final:** External RAG knowledge base
|
| 300 |
+
5. **Emergency:** No context (minimal response)
|
| 301 |
+
|
| 302 |
+
### Error Scenarios & Recovery
|
| 303 |
+
- **Gemini API failure** → Fall back to LTM semantic search
|
| 304 |
+
- **LTM corruption** → Rebuild FAISS index from remaining chunks
|
| 305 |
+
- **STM corruption** → Reset to empty STM, continue with LTM
|
| 306 |
+
- **Memory corruption** → Reset user session, clear all memory
|
| 307 |
+
- **Chunking failure** → Store raw response as fallback chunk
|
| 308 |
+
|
| 309 |
+
## 🔮 Future Enhancements
|
| 310 |
+
|
| 311 |
+
### 1. **Persistent Memory Storage**
|
| 312 |
+
- **Database Integration:** Store LTM in PostgreSQL/SQLite with FAISS index persistence
|
| 313 |
+
- **Session Recovery:** Resume conversations after system restarts
|
| 314 |
+
- **Memory Export:** Allow users to export their conversation history
|
| 315 |
+
- **Cross-device Sync:** Synchronize memory across different devices
|
| 316 |
+
|
| 317 |
+
### 2. **Advanced Memory Features**
|
| 318 |
+
- **Fact Store:** Dedicated storage for critical medical facts (allergies, chronic conditions, medications)
|
| 319 |
+
- **Memory Compression:** Summarize older STM entries into LTM when STM overflows
|
| 320 |
+
- **Contextual Tags:** Add metadata tags (encounter type, modality, urgency) to bias retrieval
|
| 321 |
+
- **Memory Analytics:** Track memory usage patterns and optimize storage strategies
|
| 322 |
+
|
| 323 |
+
### 3. **Intelligent Memory Management**
|
| 324 |
+
- **Adaptive Thresholds:** Dynamically adjust similarity thresholds based on conversation context
|
| 325 |
+
- **Memory Prioritization:** Protect critical medical information from eviction
|
| 326 |
+
- **Usage-based Retention:** Keep frequently accessed information longer
|
| 327 |
+
- **Semantic Clustering:** Group related memories for better organization
|
| 328 |
+
|
| 329 |
+
### 4. **Enhanced Medical Context**
|
| 330 |
+
- **Clinical Decision Support:** Integrate with medical guidelines and protocols
|
| 331 |
+
- **Risk Assessment:** Track and alert on potential medical risks across conversations
|
| 332 |
+
- **Medication Reconciliation:** Maintain accurate medication lists across sessions
|
| 333 |
+
- **Follow-up Scheduling:** Track recommended follow-ups and reminders
|
| 334 |
+
|
| 335 |
+
### 5. **Multi-modal Memory**
|
| 336 |
+
- **Image Memory:** Store and retrieve medical images with descriptions
|
| 337 |
+
- **Voice Memory:** Convert voice interactions to text for memory storage
|
| 338 |
+
- **Document Memory:** Process and store medical documents and reports
|
| 339 |
+
- **Temporal Memory:** Track changes in symptoms and conditions over time
|
| 340 |
+
|
| 341 |
+
## 📝 Testing
|
| 342 |
+
|
| 343 |
+
### Memory System Testing
|
| 344 |
+
```bash
|
| 345 |
+
cd Medical-Chatbot
|
| 346 |
+
python test_memory_system.py
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
### Test Scenarios
|
| 350 |
+
1. **STM Deduplication Test:** Verify similar responses are merged correctly
|
| 351 |
+
2. **LTM Semantic Search Test:** Test FAISS retrieval with various queries
|
| 352 |
+
3. **Context Summarization Test:** Validate unified summary generation
|
| 353 |
+
4. **Topic Enrichment Test:** Check detailed topic generation with question context
|
| 354 |
+
5. **Memory Capacity Test:** Verify STM (5 items) and LTM (60 items) limits
|
| 355 |
+
6. **Fallback Strategy Test:** Test system behavior when Gemini API fails
|
| 356 |
+
|
| 357 |
+
### Expected Behaviors
|
| 358 |
+
- **STM:** Similar responses merge, unique details preserved
|
| 359 |
+
- **LTM:** Semantic search returns relevant chunks with usage tracking
|
| 360 |
+
- **Topics:** Detailed, medical-specific descriptions (10-20 words)
|
| 361 |
+
- **Summaries:** Coherent narratives combining STM + LTM + RAG
|
| 362 |
+
- **Performance:** Sub-second retrieval times for all operations
|
| 363 |
+
|
| 364 |
+
## 🎯 Summary
|
| 365 |
+
|
| 366 |
+
The enhanced memory system transforms the Medical Chatbot into a sophisticated, memory-aware medical assistant that:
|
| 367 |
+
|
| 368 |
+
✅ **Maintains Short-Term Memory (STM)** with 5 recent conversation summaries and intelligent deduplication
|
| 369 |
+
✅ **Provides Long-Term Memory (LTM)** with 60 semantic chunks and FAISS-based retrieval
|
| 370 |
+
✅ **Generates Enhanced Topics** using question context for detailed, medical-specific descriptions
|
| 371 |
+
✅ **Creates Unified Summaries** combining STM + LTM + RAG into coherent narratives
|
| 372 |
+
✅ **Implements Smart Merging** that preserves unique details while eliminating redundancy
|
| 373 |
+
✅ **Ensures Conversational Continuity** across extended medical consultations
|
| 374 |
+
✅ **Optimizes Performance** with sub-second retrieval and efficient memory management
|
| 375 |
+
|
| 376 |
+
This advanced memory system addresses the limitations of simple RAG systems by providing:
|
| 377 |
+
- **Intelligent context management** that remembers and builds upon previous interactions
|
| 378 |
+
- **Medical precision** with detailed topics and exact clinical information
|
| 379 |
+
- **Scalable architecture** that can handle extended conversations without performance degradation
|
| 380 |
+
- **Robust fallback strategies** ensuring system reliability in all scenarios
|
| 381 |
+
|
| 382 |
+
The result is a medical chatbot that truly understands conversation context, remembers patient history, and provides increasingly relevant and personalized medical guidance over time.
|
clear_mongo.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pymongo import MongoClient
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# # Load environment variables from .env
|
| 6 |
+
# load_dotenv()
|
| 7 |
+
|
| 8 |
+
##-------------##
|
| 9 |
+
# FOR QA CLUSTER
|
| 10 |
+
##-------------##
|
| 11 |
+
|
| 12 |
+
# mongo_uri = os.getenv("MONGO_URI")
|
| 13 |
+
# if not mongo_uri:
|
| 14 |
+
# raise ValueError("❌ MongoDB URI (MongoURI) is missing!")
|
| 15 |
+
|
| 16 |
+
# client = MongoClient(mongo_uri)
|
| 17 |
+
# db = client["MedicalChatbotDB"] # Use the same database name as in your main script
|
| 18 |
+
|
| 19 |
+
# # To drop just the collection storing the FAISS index:
|
| 20 |
+
# db.drop_collection("qa_data")
|
| 21 |
+
# print("Dropped collection 'qa_data' from MedicalChatbotDB.")
|
| 22 |
+
|
| 23 |
+
# # Alternatively, to drop the entire database:
|
| 24 |
+
# client.drop_database("MedicalChatbotDB")
|
| 25 |
+
# print("Dropped database 'MedicalChatbotDB'.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
##-------------##
|
| 29 |
+
# FOR INDEX CLUSTER
|
| 30 |
+
##-------------##
|
| 31 |
+
|
| 32 |
+
# Load environment variables from .env
|
| 33 |
+
# load_dotenv()
|
| 34 |
+
# index_uri = os.getenv("INDEX_URI")
|
| 35 |
+
# if not index_uri:
|
| 36 |
+
# raise ValueError("❌ MongoDB URI (IndexURI) is missing!")
|
| 37 |
+
|
| 38 |
+
# iclient = MongoClient(index_uri)
|
| 39 |
+
# idb = iclient["MedicalChatbotDB"] # Use the same database name as in your main script
|
| 40 |
+
|
| 41 |
+
# # To drop just the collection storing the FAISS index:
|
| 42 |
+
# idb.drop_collection("faiss_index_files.files")
|
| 43 |
+
# idb.drop_collection("faiss_index_files.chunks")
|
| 44 |
+
# print("Dropped collection 'faiss_index_files' and chunks from MedicalChatbotDB.")
|
| 45 |
+
|
| 46 |
+
# # Alternatively, to drop the entire database:
|
| 47 |
+
# iclient.drop_database("MedicalChatbotDB")
|
| 48 |
+
# print("Dropped database 'MedicalChatbotDB'.")
|
connect_mongo.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pymongo import MongoClient
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Test MongoDB connection, and list out all collection.
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
# QA Cluster
|
| 9 |
+
mongo_uri = os.getenv("MONGO_URI")
|
| 10 |
+
client = MongoClient(mongo_uri)
|
| 11 |
+
db = client["MedicalChatbotDB"]
|
| 12 |
+
# List all collection
|
| 13 |
+
print("QA Collection: ",db.list_collection_names())
|
| 14 |
+
# Count document QA related
|
| 15 |
+
print("QA count: ", db.qa_data.count_documents({}))
|
| 16 |
+
|
| 17 |
+
# Index Cluster
|
| 18 |
+
index_uri = os.getenv("INDEX_URI")
|
| 19 |
+
iclient = MongoClient(index_uri)
|
| 20 |
+
idb = iclient["MedicalChatbotDB"]
|
| 21 |
+
# List all collection
|
| 22 |
+
print("FAISS Collection: ",idb.list_collection_names())
|
| 23 |
+
# Count document QA related
|
| 24 |
+
print("Index count: ", idb.faiss_index_files.files.count_documents({}))
|
diagnosis.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ✅ Google Colab: SymbiPredict Embedding + Chunking + MongoDB Upload
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from pymongo import MongoClient
|
| 7 |
+
from pymongo.errors import BulkWriteError
|
| 8 |
+
import hashlib, os
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
# ✅ Load model
|
| 12 |
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 13 |
+
|
| 14 |
+
# ✅ Load SymbiPredict
|
| 15 |
+
df = pd.read_csv("symbipredict_2022.csv")
|
| 16 |
+
|
| 17 |
+
# ✅ Connect to MongoDB
|
| 18 |
+
mongo_uri = "..."
|
| 19 |
+
client = MongoClient(mongo_uri)
|
| 20 |
+
db = client["MedicalChatbotDB"]
|
| 21 |
+
collection = db["symptom_diagnosis"]
|
| 22 |
+
|
| 23 |
+
# ✅ Clear old symptom-diagnosis records
|
| 24 |
+
print("🧹 Dropping old 'symptom_diagnosis' collection...")
|
| 25 |
+
collection.drop()
|
| 26 |
+
# Reconfirm collection is empty
|
| 27 |
+
if collection.count_documents({}) != 0:
|
| 28 |
+
raise RuntimeError("❌ Collection not empty after drop — aborting!")
|
| 29 |
+
|
| 30 |
+
# ✅ Convert CSV rows into QA-style records with embeddings
|
| 31 |
+
records = []
|
| 32 |
+
for i, row in tqdm(df.iterrows(), total=len(df)):
|
| 33 |
+
symptom_cols = df.columns[:-1]
|
| 34 |
+
label_col = df.columns[-1]
|
| 35 |
+
|
| 36 |
+
# Extract symptoms present (value==1)
|
| 37 |
+
symptoms = [col.replace("_", " ").strip() for col in symptom_cols if row[col] == 1]
|
| 38 |
+
if not symptoms:
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
label = row[label_col].strip()
|
| 42 |
+
question = f"What disease is likely given these symptoms: {', '.join(symptoms)}?"
|
| 43 |
+
answer = f"The patient is likely suffering from: {label}."
|
| 44 |
+
|
| 45 |
+
# Embed question only
|
| 46 |
+
embed = model.encode(question, convert_to_numpy=True)
|
| 47 |
+
hashkey = hashlib.md5((question + answer).encode()).hexdigest()
|
| 48 |
+
|
| 49 |
+
records.append({
|
| 50 |
+
"_id": hashkey,
|
| 51 |
+
"i": int(i),
|
| 52 |
+
"symptoms": symptoms,
|
| 53 |
+
"prognosis": label,
|
| 54 |
+
"question": question,
|
| 55 |
+
"answer": answer,
|
| 56 |
+
"embedding": embed.tolist()
|
| 57 |
+
})
|
| 58 |
+
|
| 59 |
+
# ✅ Save to MongoDB
|
| 60 |
+
if records:
|
| 61 |
+
print(f"⬆️ Uploading {len(records)} records to MongoDB...")
|
| 62 |
+
unique_ids = set()
|
| 63 |
+
deduped = []
|
| 64 |
+
for r in records:
|
| 65 |
+
if r["_id"] not in unique_ids:
|
| 66 |
+
unique_ids.add(r["_id"])
|
| 67 |
+
deduped.append(r)
|
| 68 |
+
try:
|
| 69 |
+
collection.insert_many(deduped, ordered=False)
|
| 70 |
+
print(f"✅ Inserted {len(deduped)} records without duplicates.")
|
| 71 |
+
except BulkWriteError as bwe:
|
| 72 |
+
inserted = bwe.details.get('nInserted', 0)
|
| 73 |
+
print(f"⚠️ Inserted with some duplicate skips. Records inserted: {inserted}")
|
| 74 |
+
print("✅ Upload complete.")
|
| 75 |
+
else:
|
| 76 |
+
print("⚠️ No records to upload.")
|
download_model.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# download_model.py
|
| 2 |
+
### --- A. transformer and embedder ---
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
|
| 7 |
+
# Set up paths
|
| 8 |
+
MODEL_REPO = "sentence-transformers/all-MiniLM-L6-v2"
|
| 9 |
+
MODEL_CACHE_DIR = "/app/model_cache"
|
| 10 |
+
|
| 11 |
+
print("⏳ Downloading the SentenceTransformer model...")
|
| 12 |
+
model_path = snapshot_download(repo_id=MODEL_REPO, cache_dir=MODEL_CACHE_DIR)
|
| 13 |
+
|
| 14 |
+
print("Model path: ", model_path)
|
| 15 |
+
|
| 16 |
+
# Ensure the directory exists
|
| 17 |
+
if not os.path.exists(MODEL_CACHE_DIR):
|
| 18 |
+
os.makedirs(MODEL_CACHE_DIR)
|
| 19 |
+
|
| 20 |
+
# Move all contents from the snapshot folder
|
| 21 |
+
if os.path.exists(model_path):
|
| 22 |
+
print(f"📂 Moving model files from {model_path} to {MODEL_CACHE_DIR}...")
|
| 23 |
+
|
| 24 |
+
for item in os.listdir(model_path):
|
| 25 |
+
source = os.path.join(model_path, item)
|
| 26 |
+
destination = os.path.join(MODEL_CACHE_DIR, item)
|
| 27 |
+
|
| 28 |
+
if os.path.isdir(source):
|
| 29 |
+
shutil.copytree(source, destination, dirs_exist_ok=True)
|
| 30 |
+
else:
|
| 31 |
+
shutil.copy2(source, destination)
|
| 32 |
+
|
| 33 |
+
print(f"✅ Model extracted and flattened in {MODEL_CACHE_DIR}")
|
| 34 |
+
else:
|
| 35 |
+
print("❌ No snapshot directory found!")
|
| 36 |
+
exit(1)
|
| 37 |
+
|
| 38 |
+
# Verify structure after moving
|
| 39 |
+
print("\n📂 LLM Model Structure (Build Level):")
|
| 40 |
+
for root, dirs, files in os.walk(MODEL_CACHE_DIR):
|
| 41 |
+
print(f"📁 {root}/")
|
| 42 |
+
for file in files:
|
| 43 |
+
print(f" 📄 {file}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
### --- B. translation modules ---
|
| 47 |
+
from transformers import pipeline
|
| 48 |
+
print("⏬ Downloading Vietnamese–English translator...")
|
| 49 |
+
_ = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en")
|
| 50 |
+
print("⏬ Downloading Chinese–English translator...")
|
| 51 |
+
_ = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
|
memory.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# memory.py
|
| 2 |
+
import re, time, hashlib, asyncio, os
|
| 3 |
+
from collections import defaultdict, deque
|
| 4 |
+
from typing import List, Dict
|
| 5 |
+
import numpy as np
|
| 6 |
+
import faiss
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from google import genai # must be configured in app.py and imported globally
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
_LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17"
|
| 12 |
+
# Load embedding model
|
| 13 |
+
EMBED = SentenceTransformer("/app/model_cache", device="cpu").half()
|
| 14 |
+
logger = logging.getLogger("rag-agent")
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
|
| 16 |
+
|
| 17 |
+
api_key = os.getenv("FlashAPI")
|
| 18 |
+
client = genai.Client(api_key=api_key)
|
| 19 |
+
|
| 20 |
+
class MemoryManager:
|
| 21 |
+
def __init__(self, max_users=1000, history_per_user=20, max_chunks=60):
|
| 22 |
+
# STM: recent conversation summaries (topic + summary), up to 5 entries
|
| 23 |
+
self.stm_summaries = defaultdict(lambda: deque(maxlen=history_per_user)) # deque of {topic,text,vec,timestamp,used}
|
| 24 |
+
# Legacy raw cache (kept for compatibility if needed)
|
| 25 |
+
self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
|
| 26 |
+
# LTM: semantic chunk store (approx 3 chunks x 20 rounds)
|
| 27 |
+
self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index
|
| 28 |
+
self.chunk_meta = defaultdict(list) # '' -> list[{text,tag,vec,timestamp,used}]
|
| 29 |
+
self.user_queue = deque(maxlen=max_users) # LRU of users
|
| 30 |
+
self.max_chunks = max_chunks # hard cap per user
|
| 31 |
+
self.chunk_cache = {} # hash(query+resp) -> [chunks]
|
| 32 |
+
|
| 33 |
+
# ---------- Public API ----------
|
| 34 |
+
def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
|
| 35 |
+
self._touch_user(user_id)
|
| 36 |
+
# Keep raw record (optional)
|
| 37 |
+
self.text_cache[user_id].append(((query or "").strip(), (response or "").strip()))
|
| 38 |
+
if not response: return []
|
| 39 |
+
# Avoid re-chunking identical response
|
| 40 |
+
cache_key = hashlib.md5((query + response).encode()).hexdigest()
|
| 41 |
+
if cache_key in self.chunk_cache:
|
| 42 |
+
chunks = self.chunk_cache[cache_key]
|
| 43 |
+
else:
|
| 44 |
+
chunks = self.chunk_response(response, lang, question=query)
|
| 45 |
+
self.chunk_cache[cache_key] = chunks
|
| 46 |
+
# Update STM with merging/deduplication
|
| 47 |
+
for chunk in chunks:
|
| 48 |
+
self._upsert_stm(user_id, chunk, lang)
|
| 49 |
+
# Update LTM with merging/deduplication
|
| 50 |
+
self._upsert_ltm(user_id, chunks, lang)
|
| 51 |
+
return chunks
|
| 52 |
+
|
| 53 |
+
def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
|
| 54 |
+
"""Return texts of chunks whose cosine similarity ≥ min_sim."""
|
| 55 |
+
if self.chunk_index[user_id].ntotal == 0:
|
| 56 |
+
return []
|
| 57 |
+
# Encode chunk
|
| 58 |
+
qvec = self._embed(query)
|
| 59 |
+
sims, idxs = self.chunk_index[user_id].search(np.array([qvec]), k=top_k)
|
| 60 |
+
results = []
|
| 61 |
+
# Append related result with smart-decay to optimize storage and prioritize most-recent chat
|
| 62 |
+
for sim, idx in zip(sims[0], idxs[0]):
|
| 63 |
+
if idx < len(self.chunk_meta[user_id]) and sim >= min_sim:
|
| 64 |
+
chunk = self.chunk_meta[user_id][idx]
|
| 65 |
+
chunk["used"] += 1 # increment usage
|
| 66 |
+
# Decay function
|
| 67 |
+
age_sec = time.time() - chunk["timestamp"]
|
| 68 |
+
decay = 1.0 / (1.0 + age_sec / 300) # 5-min half-life
|
| 69 |
+
score = sim * decay * (1 + 0.1 * chunk["used"])
|
| 70 |
+
# Append chunk with score
|
| 71 |
+
results.append((score, chunk))
|
| 72 |
+
# Sort result on best scored
|
| 73 |
+
results.sort(key=lambda x: x[0], reverse=True)
|
| 74 |
+
# logger.info(f"[Memory] RAG Retrieved Topic: {results}") # Inspect vector data
|
| 75 |
+
return [f"### Topic: {c['tag']}\n{c['text']}" for _, c in results]
|
| 76 |
+
|
| 77 |
+
def get_recent_chat_history(self, user_id: str, num_turns: int = 5) -> List[Dict]:
|
| 78 |
+
"""
|
| 79 |
+
Get the most recent short-term memory summaries.
|
| 80 |
+
Returns: a list of entries containing only the summarized bot context.
|
| 81 |
+
"""
|
| 82 |
+
if user_id not in self.stm_summaries:
|
| 83 |
+
return []
|
| 84 |
+
recent = list(self.stm_summaries[user_id])[-num_turns:]
|
| 85 |
+
formatted = []
|
| 86 |
+
for entry in recent:
|
| 87 |
+
formatted.append({
|
| 88 |
+
"user": "",
|
| 89 |
+
"bot": f"Topic: {entry['topic']}\n{entry['text']}",
|
| 90 |
+
"timestamp": entry.get("timestamp", time.time())
|
| 91 |
+
})
|
| 92 |
+
return formatted
|
| 93 |
+
|
| 94 |
+
def get_context(self, user_id: str, num_turns: int = 5) -> str:
|
| 95 |
+
# Prefer STM summaries
|
| 96 |
+
history = self.get_recent_chat_history(user_id, num_turns=num_turns)
|
| 97 |
+
return "\n".join(h["bot"] for h in history)
|
| 98 |
+
|
| 99 |
+
def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
|
| 100 |
+
"""
|
| 101 |
+
Use Gemini Flash Lite to create a summarization of relevant context from both recent history and RAG chunks.
|
| 102 |
+
This ensures conversational continuity while providing a concise summary for the main LLM.
|
| 103 |
+
"""
|
| 104 |
+
# Get both types of context
|
| 105 |
+
recent_history = self.get_recent_chat_history(user_id, num_turns=5)
|
| 106 |
+
rag_chunks = self.get_relevant_chunks(user_id, current_query, top_k=3)
|
| 107 |
+
|
| 108 |
+
logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items")
|
| 109 |
+
logger.info(f"[Contextual] Retrieved {len(rag_chunks)} RAG chunks")
|
| 110 |
+
|
| 111 |
+
# Return empty string if no context is found
|
| 112 |
+
if not recent_history and not rag_chunks:
|
| 113 |
+
logger.info(f"[Contextual] No context found, returning empty string")
|
| 114 |
+
return ""
|
| 115 |
+
# Prepare context for Gemini to summarize
|
| 116 |
+
context_parts = []
|
| 117 |
+
# Add recent chat history
|
| 118 |
+
if recent_history:
|
| 119 |
+
history_text = "\n".join([
|
| 120 |
+
f"User: {item['user']}\nBot: {item['bot']}"
|
| 121 |
+
for item in recent_history
|
| 122 |
+
])
|
| 123 |
+
context_parts.append(f"Recent conversation history:\n{history_text}")
|
| 124 |
+
# Add RAG chunks
|
| 125 |
+
if rag_chunks:
|
| 126 |
+
rag_text = "\n".join(rag_chunks)
|
| 127 |
+
context_parts.append(f"Semantically relevant historical medical information:\n{rag_text}")
|
| 128 |
+
|
| 129 |
+
# Build summarization prompt
|
| 130 |
+
summarization_prompt = f"""
|
| 131 |
+
You are a medical assistant creating a concise summary of conversation context for continuity.
|
| 132 |
+
|
| 133 |
+
Current user query: "{current_query}"
|
| 134 |
+
|
| 135 |
+
Available context information:
|
| 136 |
+
{chr(10).join(context_parts)}
|
| 137 |
+
|
| 138 |
+
Task: Create a brief, coherent summary that captures the key points from the conversation history and relevant medical information that are important for understanding the current query.
|
| 139 |
+
|
| 140 |
+
Guidelines:
|
| 141 |
+
1. Focus on medical symptoms, diagnoses, treatments, or recommendations mentioned
|
| 142 |
+
2. Include any patient concerns or questions that are still relevant
|
| 143 |
+
3. Highlight any follow-up needs or pending clarifications
|
| 144 |
+
4. Keep the summary concise but comprehensive enough for context
|
| 145 |
+
5. Maintain conversational flow and continuity
|
| 146 |
+
|
| 147 |
+
Output: Provide a single, well-structured summary paragraph that can be used as context for the main LLM to provide a coherent response.
|
| 148 |
+
If no relevant context exists, return "No relevant context found."
|
| 149 |
+
|
| 150 |
+
Language context: {lang}
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
logger.debug(f"[Contextual] Full prompt: {summarization_prompt}")
|
| 154 |
+
# Loop through the prompt and log the length of each part
|
| 155 |
+
try:
|
| 156 |
+
# Use Gemini Flash Lite for summarization
|
| 157 |
+
client = genai.Client(api_key=os.getenv("FlashAPI"))
|
| 158 |
+
result = client.models.generate_content(
|
| 159 |
+
model=_LLM_SMALL,
|
| 160 |
+
contents=summarization_prompt
|
| 161 |
+
)
|
| 162 |
+
summary = result.text.strip()
|
| 163 |
+
if "No relevant context found" in summary:
|
| 164 |
+
logger.info(f"[Contextual] Gemini indicated no relevant context found")
|
| 165 |
+
return ""
|
| 166 |
+
|
| 167 |
+
logger.info(f"[Contextual] Gemini created summary: {summary[:100]}...")
|
| 168 |
+
return summary
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.warning(f"[Contextual] Gemini summarization failed: {e}")
|
| 172 |
+
logger.info(f"[Contextual] Using fallback summarization method")
|
| 173 |
+
# Fallback: create a simple summary
|
| 174 |
+
fallback_summary = []
|
| 175 |
+
# Fallback: add recent history
|
| 176 |
+
if recent_history:
|
| 177 |
+
recent_summary = f"Recent conversation: User asked about {recent_history[-1]['user'][:50]}... and received a response about {recent_history[-1]['bot'][:50]}..."
|
| 178 |
+
fallback_summary.append(recent_summary)
|
| 179 |
+
logger.info(f"[Contextual] Fallback: Added recent history summary")
|
| 180 |
+
# Fallback: add RAG chunks
|
| 181 |
+
if rag_chunks:
|
| 182 |
+
rag_summary = f"Relevant medical information: {len(rag_chunks)} chunks found covering various medical topics."
|
| 183 |
+
fallback_summary.append(rag_summary)
|
| 184 |
+
logger.info(f"[Contextual] Fallback: Added RAG chunks summary")
|
| 185 |
+
final_fallback = " ".join(fallback_summary) if fallback_summary else ""
|
| 186 |
+
return final_fallback
|
| 187 |
+
|
| 188 |
+
def reset(self, user_id: str):
|
| 189 |
+
self._drop_user(user_id)
|
| 190 |
+
|
| 191 |
+
# ---------- Internal helpers ----------
|
| 192 |
+
def _touch_user(self, user_id: str):
|
| 193 |
+
if user_id not in self.text_cache and len(self.user_queue) >= self.user_queue.maxlen:
|
| 194 |
+
self._drop_user(self.user_queue.popleft())
|
| 195 |
+
if user_id in self.user_queue:
|
| 196 |
+
self.user_queue.remove(user_id)
|
| 197 |
+
self.user_queue.append(user_id)
|
| 198 |
+
|
| 199 |
+
def _drop_user(self, user_id: str):
|
| 200 |
+
self.text_cache.pop(user_id, None)
|
| 201 |
+
self.chunk_index.pop(user_id, None)
|
| 202 |
+
self.chunk_meta.pop(user_id, None)
|
| 203 |
+
if user_id in self.user_queue:
|
| 204 |
+
self.user_queue.remove(user_id)
|
| 205 |
+
|
| 206 |
+
def _rebuild_index(self, user_id: str, keep_last: int):
|
| 207 |
+
"""Trim chunk list + rebuild FAISS index for user."""
|
| 208 |
+
self.chunk_meta[user_id] = self.chunk_meta[user_id][-keep_last:]
|
| 209 |
+
index = self._new_index()
|
| 210 |
+
# Store each chunk's vector once and reuse it.
|
| 211 |
+
for chunk in self.chunk_meta[user_id]:
|
| 212 |
+
index.add(np.array([chunk["vec"]]))
|
| 213 |
+
self.chunk_index[user_id] = index
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def _new_index():
|
| 217 |
+
# Use cosine similarity (vectors must be L2-normalised)
|
| 218 |
+
return faiss.IndexFlatIP(384)
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def _embed(text: str):
|
| 222 |
+
vec = EMBED.encode(text, convert_to_numpy=True)
|
| 223 |
+
# L2 normalise for cosine on IndexFlatIP
|
| 224 |
+
return vec / (np.linalg.norm(vec) + 1e-9)
|
| 225 |
+
|
| 226 |
+
def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
|
| 227 |
+
"""
|
| 228 |
+
Calls Gemini to:
|
| 229 |
+
- Translate (if needed)
|
| 230 |
+
- Chunk by context/topic (exclude disclaimer section)
|
| 231 |
+
- Summarise
|
| 232 |
+
Returns: [{"tag": ..., "text": ...}, ...]
|
| 233 |
+
"""
|
| 234 |
+
if not response: return []
|
| 235 |
+
# Gemini instruction
|
| 236 |
+
instructions = []
|
| 237 |
+
# if lang.upper() != "EN":
|
| 238 |
+
# instructions.append("- Translate the response to English.")
|
| 239 |
+
instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic, symptom, assessment, plan, or instruction (exclude disclaimer section).")
|
| 240 |
+
instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions and exact medication names/doses if present.")
|
| 241 |
+
instructions.append("- At the start of each part, write `Topic: <concise but specific sentence (10-20 words) capturing patient context, condition, and action>`.")
|
| 242 |
+
instructions.append("- Separate each part using three dashes `---` on a new line.")
|
| 243 |
+
# if lang.upper() != "EN":
|
| 244 |
+
# instructions.append(f"Below is the user-provided medical response written in `{lang}`")
|
| 245 |
+
# Gemini prompt
|
| 246 |
+
prompt = f"""
|
| 247 |
+
You are a medical assistant helping organize and condense a clinical response.
|
| 248 |
+
If helpful, use the user's latest question for context to craft specific topics.
|
| 249 |
+
User's latest question (context): {question}
|
| 250 |
+
------------------------
|
| 251 |
+
{response}
|
| 252 |
+
------------------------
|
| 253 |
+
Please perform the following tasks:
|
| 254 |
+
{chr(10).join(instructions)}
|
| 255 |
+
|
| 256 |
+
Output only the structured summaries, separated by dashes.
|
| 257 |
+
"""
|
| 258 |
+
retries = 0
|
| 259 |
+
while retries < 5:
|
| 260 |
+
try:
|
| 261 |
+
client = genai.Client(api_key=os.getenv("FlashAPI"))
|
| 262 |
+
result = client.models.generate_content(
|
| 263 |
+
model=_LLM_SMALL,
|
| 264 |
+
contents=prompt
|
| 265 |
+
# ,generation_config={"temperature": 0.4} # Skip temp configs for gem-flash
|
| 266 |
+
)
|
| 267 |
+
output = result.text.strip()
|
| 268 |
+
logger.info(f"[Memory] 📦 Gemini summarized chunk output: {output}")
|
| 269 |
+
return [
|
| 270 |
+
{"tag": self._quick_extract_topic(chunk), "text": chunk.strip()}
|
| 271 |
+
for chunk in output.split('---') if chunk.strip()
|
| 272 |
+
]
|
| 273 |
+
except Exception as e:
|
| 274 |
+
logger.warning(f"[Memory] ❌ Gemini chunking failed: {e}")
|
| 275 |
+
retries += 1
|
| 276 |
+
time.sleep(0.5)
|
| 277 |
+
return [{"tag": "general", "text": response.strip()}] # fallback
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def _quick_extract_topic(chunk: str) -> str:
|
| 281 |
+
"""Heuristically extract the topic from a chunk (title line or first 3 words)."""
|
| 282 |
+
# Expecting 'Topic: <something>'
|
| 283 |
+
match = re.search(r'^Topic:\s*(.+)', chunk, re.IGNORECASE | re.MULTILINE)
|
| 284 |
+
if match:
|
| 285 |
+
return match.group(1).strip()
|
| 286 |
+
lines = chunk.strip().splitlines()
|
| 287 |
+
for line in lines:
|
| 288 |
+
if len(line.split()) <= 8 and line.strip().endswith(":"):
|
| 289 |
+
return line.strip().rstrip(":")
|
| 290 |
+
return " ".join(chunk.split()[:3]).rstrip(":.,")
|
| 291 |
+
|
| 292 |
+
# ---------- New merging/dedup logic ----------
|
| 293 |
+
def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
|
| 294 |
+
"""Insert or merge a summarized chunk into STM with semantic dedup/merge.
|
| 295 |
+
Identical: replace the older with new. Partially similar: merge extra details from older into newer.
|
| 296 |
+
"""
|
| 297 |
+
topic = self._enrich_topic(chunk.get("tag", ""), chunk.get("text", ""))
|
| 298 |
+
text = chunk.get("text", "").strip()
|
| 299 |
+
vec = self._embed(text)
|
| 300 |
+
now = time.time()
|
| 301 |
+
entry = {"topic": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
|
| 302 |
+
stm = self.stm_summaries[user_id]
|
| 303 |
+
if not stm:
|
| 304 |
+
stm.append(entry)
|
| 305 |
+
return
|
| 306 |
+
# find best match
|
| 307 |
+
best_idx = -1
|
| 308 |
+
best_sim = -1.0
|
| 309 |
+
for i, e in enumerate(stm):
|
| 310 |
+
sim = float(np.dot(vec, e["vec"]))
|
| 311 |
+
if sim > best_sim:
|
| 312 |
+
best_sim = sim
|
| 313 |
+
best_idx = i
|
| 314 |
+
if best_sim >= 0.92: # nearly identical
|
| 315 |
+
# replace older with current
|
| 316 |
+
stm.rotate(-best_idx)
|
| 317 |
+
stm.popleft()
|
| 318 |
+
stm.rotate(best_idx)
|
| 319 |
+
stm.append(entry)
|
| 320 |
+
elif best_sim >= 0.75: # partially similar → merge
|
| 321 |
+
base = stm[best_idx]
|
| 322 |
+
merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add bits from old not in new
|
| 323 |
+
merged_topic = base["topic"] if len(base["topic"]) > len(topic) else topic
|
| 324 |
+
merged_vec = self._embed(merged_text)
|
| 325 |
+
merged_entry = {"topic": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
|
| 326 |
+
stm.rotate(-best_idx)
|
| 327 |
+
stm.popleft()
|
| 328 |
+
stm.rotate(best_idx)
|
| 329 |
+
stm.append(merged_entry)
|
| 330 |
+
else:
|
| 331 |
+
stm.append(entry)
|
| 332 |
+
|
| 333 |
+
def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
|
| 334 |
+
"""Insert or merge chunks into LTM with semantic dedup/merge, then rebuild index.
|
| 335 |
+
Keeps only the most recent self.max_chunks entries.
|
| 336 |
+
"""
|
| 337 |
+
current_list = self.chunk_meta[user_id]
|
| 338 |
+
for chunk in chunks:
|
| 339 |
+
text = chunk.get("text", "").strip()
|
| 340 |
+
if not text:
|
| 341 |
+
continue
|
| 342 |
+
vec = self._embed(text)
|
| 343 |
+
topic = self._enrich_topic(chunk.get("tag", ""), text)
|
| 344 |
+
now = time.time()
|
| 345 |
+
new_entry = {"tag": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
|
| 346 |
+
if not current_list:
|
| 347 |
+
current_list.append(new_entry)
|
| 348 |
+
continue
|
| 349 |
+
# find best similar entry
|
| 350 |
+
best_idx = -1
|
| 351 |
+
best_sim = -1.0
|
| 352 |
+
for i, e in enumerate(current_list):
|
| 353 |
+
sim = float(np.dot(vec, e["vec"]))
|
| 354 |
+
if sim > best_sim:
|
| 355 |
+
best_sim = sim
|
| 356 |
+
best_idx = i
|
| 357 |
+
if best_sim >= 0.92:
|
| 358 |
+
# replace older with new
|
| 359 |
+
current_list[best_idx] = new_entry
|
| 360 |
+
elif best_sim >= 0.75:
|
| 361 |
+
# merge details
|
| 362 |
+
base = current_list[best_idx]
|
| 363 |
+
merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add unique sentences from old
|
| 364 |
+
merged_topic = base["tag"] if len(base["tag"]) > len(topic) else topic
|
| 365 |
+
merged_vec = self._embed(merged_text)
|
| 366 |
+
current_list[best_idx] = {"tag": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
|
| 367 |
+
else:
|
| 368 |
+
current_list.append(new_entry)
|
| 369 |
+
# Trim and rebuild index
|
| 370 |
+
if len(current_list) > self.max_chunks:
|
| 371 |
+
current_list[:] = current_list[-self.max_chunks:]
|
| 372 |
+
self._rebuild_index(user_id, keep_last=self.max_chunks)
|
| 373 |
+
|
| 374 |
+
@staticmethod
|
| 375 |
+
def _split_sentences(text: str) -> List[str]:
|
| 376 |
+
# naive sentence splitter by ., !, ?
|
| 377 |
+
parts = re.split(r"(?<=[\.!?])\s+", text.strip())
|
| 378 |
+
return [p.strip() for p in parts if p.strip()]
|
| 379 |
+
|
| 380 |
+
def _merge_texts(self, new_text: str, old_text: str) -> str:
|
| 381 |
+
"""Append sentences from old_text that are not already contained in new_text (by fuzzy match)."""
|
| 382 |
+
new_sents = self._split_sentences(new_text)
|
| 383 |
+
old_sents = self._split_sentences(old_text)
|
| 384 |
+
new_set = set(s.lower() for s in new_sents)
|
| 385 |
+
merged = list(new_sents)
|
| 386 |
+
for s in old_sents:
|
| 387 |
+
s_norm = s.lower()
|
| 388 |
+
# consider present if significant overlap with any existing sentence
|
| 389 |
+
if s_norm in new_set:
|
| 390 |
+
continue
|
| 391 |
+
# simple containment check
|
| 392 |
+
if any(self._overlap_ratio(s_norm, t.lower()) > 0.8 for t in merged):
|
| 393 |
+
continue
|
| 394 |
+
merged.append(s)
|
| 395 |
+
return " ".join(merged)
|
| 396 |
+
|
| 397 |
+
@staticmethod
|
| 398 |
+
def _overlap_ratio(a: str, b: str) -> float:
|
| 399 |
+
"""Compute token overlap ratio between two sentences."""
|
| 400 |
+
ta = set(re.findall(r"\w+", a))
|
| 401 |
+
tb = set(re.findall(r"\w+", b))
|
| 402 |
+
if not ta or not tb:
|
| 403 |
+
return 0.0
|
| 404 |
+
inter = len(ta & tb)
|
| 405 |
+
union = len(ta | tb)
|
| 406 |
+
return inter / union
|
| 407 |
+
|
| 408 |
+
@staticmethod
|
| 409 |
+
def _enrich_topic(topic: str, text: str) -> str:
|
| 410 |
+
"""Make topic more descriptive if it's too short by using the first sentence of the text.
|
| 411 |
+
Does not call LLM to keep latency low.
|
| 412 |
+
"""
|
| 413 |
+
topic = (topic or "").strip()
|
| 414 |
+
if len(topic.split()) < 5 or len(topic) < 20:
|
| 415 |
+
sents = re.split(r"(?<=[\.!?])\s+", text.strip())
|
| 416 |
+
if sents:
|
| 417 |
+
first = sents[0]
|
| 418 |
+
# cap to ~16 words
|
| 419 |
+
words = first.split()
|
| 420 |
+
if len(words) > 16:
|
| 421 |
+
first = " ".join(words[:16])
|
| 422 |
+
# ensure capitalized
|
| 423 |
+
return first.strip().rstrip(':')
|
| 424 |
+
return topic
|
| 425 |
+
|
| 426 |
+
|
migrate.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running this script to split FAISS index collection to the second/different cluster.
|
| 2 |
+
from pymongo import MongoClient
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Load environment variables from .env
|
| 7 |
+
load_dotenv()
|
| 8 |
+
# Connection strings (update as needed)
|
| 9 |
+
mongo_uri = os.getenv("MONGO_URI") # QA cluster connection string
|
| 10 |
+
index_uri = os.getenv("INDEX_URI") # FAISS index cluster connection string
|
| 11 |
+
|
| 12 |
+
if not mongo_uri:
|
| 13 |
+
raise ValueError("MONGO_URI is missing!")
|
| 14 |
+
if not index_uri:
|
| 15 |
+
raise ValueError("INDEX_URI is missing!")
|
| 16 |
+
|
| 17 |
+
# Connect to the QA cluster (where FAISS data was accidentally stored)
|
| 18 |
+
qa_client = MongoClient(mongo_uri)
|
| 19 |
+
qa_db = qa_client["MedicalChatbotDB"]
|
| 20 |
+
|
| 21 |
+
# Connect to the FAISS index cluster
|
| 22 |
+
faiss_client = MongoClient(index_uri)
|
| 23 |
+
faiss_db = faiss_client["MedicalChatbotDB"] # Use the same database name if desired
|
| 24 |
+
|
| 25 |
+
# Define the GridFS collections to move.
|
| 26 |
+
# In GridFS, files are stored in two collections: "<bucket>.files" and "<bucket>.chunks".
|
| 27 |
+
source_files = qa_db["faiss_index_files.files"]
|
| 28 |
+
source_chunks = qa_db["faiss_index_files.chunks"]
|
| 29 |
+
|
| 30 |
+
dest_files = faiss_db["faiss_index_files.files"]
|
| 31 |
+
dest_chunks = faiss_db["faiss_index_files.chunks"]
|
| 32 |
+
|
| 33 |
+
print("Moving FAISS index GridFS files...")
|
| 34 |
+
|
| 35 |
+
# Copy documents from the source 'files' collection
|
| 36 |
+
for doc in source_files.find():
|
| 37 |
+
dest_files.insert_one(doc)
|
| 38 |
+
|
| 39 |
+
# Copy documents from the source 'chunks' collection
|
| 40 |
+
for doc in source_chunks.find():
|
| 41 |
+
dest_chunks.insert_one(doc)
|
| 42 |
+
|
| 43 |
+
print("✅ FAISS GridFS collections moved successfully.")
|
| 44 |
+
|
| 45 |
+
# Optionally, drop the old collections from the QA cluster to free up space:
|
| 46 |
+
qa_db.drop_collection("faiss_index_files.files")
|
| 47 |
+
qa_db.drop_collection("faiss_index_files.chunks")
|
| 48 |
+
print("Old FAISS GridFS collections dropped from the QA cluster.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
# **LLMs**
|
| 3 |
+
google-genai
|
| 4 |
+
huggingface_hub
|
| 5 |
+
# **RAG**
|
| 6 |
+
faiss-cpu
|
| 7 |
+
sentence-transformers
|
| 8 |
+
# **NLPs**
|
| 9 |
+
transformers
|
| 10 |
+
accelerate
|
| 11 |
+
sentencepiece
|
| 12 |
+
# **Environment**
|
| 13 |
+
python-dotenv # Not used in Streamlit deployment
|
| 14 |
+
pymongo
|
| 15 |
+
# **VLMs**
|
| 16 |
+
# transformers
|
| 17 |
+
gradio_client
|
| 18 |
+
pillow
|
| 19 |
+
# **Deployment**
|
| 20 |
+
uvicorn
|
| 21 |
+
fastapi
|
| 22 |
+
torch # Reduce model load with half-precision (float16) to reduce RAM usage
|
| 23 |
+
psutil # CPU/RAM logger
|
translation.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# translation.py
|
| 2 |
+
from transformers import pipeline
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger("translation-agent")
|
| 6 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
|
| 7 |
+
|
| 8 |
+
# To use lazy model loader
|
| 9 |
+
vi_en = None
|
| 10 |
+
zh_en = None
|
| 11 |
+
|
| 12 |
+
def translate_query(text: str, lang_code: str) -> str:
|
| 13 |
+
global vi_en, zh_en
|
| 14 |
+
if lang_code == "vi":
|
| 15 |
+
if vi_en is None:
|
| 16 |
+
vi_en = pipeline("translation", model="VietAI/envit5-translation", src_lang="vi", tgt_lang="en", device=-1)
|
| 17 |
+
result = vi_en(text, max_length=512)[0]["translation_text"]
|
| 18 |
+
logger.info(f"[En-Vi] Query in `{lang_code}` translated to: {result}")
|
| 19 |
+
return result
|
| 20 |
+
elif lang_code == "zh":
|
| 21 |
+
if zh_en is None:
|
| 22 |
+
zh_en = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en", device=-1)
|
| 23 |
+
result = zh_en(text, max_length=512)[0]["translation_text"]
|
| 24 |
+
logger.info(f"[En-Zh] Query in `{lang_code}` translated to: {result}")
|
| 25 |
+
return result
|
| 26 |
+
return text
|
vlm.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, logging, traceback, json, base64
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from translation import translate_query
|
| 5 |
+
from gradio_client import Client, handle_file
|
| 6 |
+
import tempfile
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger("vlm-agent")
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True)
|
| 10 |
+
|
| 11 |
+
# ✅ Load Gradio client once
|
| 12 |
+
gr_client = None
|
| 13 |
+
def load_gradio_client():
|
| 14 |
+
global gr_client
|
| 15 |
+
if gr_client is None:
|
| 16 |
+
logger.info("[VLM] ⏳ Connecting to MedGEMMA Gradio Space...")
|
| 17 |
+
gr_client = Client("warshanks/medgemma-4b-it")
|
| 18 |
+
logger.info("[VLM] Gradio MedGEMMA client ready.")
|
| 19 |
+
return gr_client
|
| 20 |
+
|
| 21 |
+
def process_medical_image(base64_image: str, prompt: str = None, lang: str = "EN") -> str:
|
| 22 |
+
if not prompt:
|
| 23 |
+
prompt = "Describe and investigate any clinical findings from this medical image."
|
| 24 |
+
elif lang.upper() in {"VI", "ZH"}:
|
| 25 |
+
prompt = translate_query(prompt, lang.lower())
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
# 1️⃣ Decode base64 image to temp file
|
| 29 |
+
image_data = base64.b64decode(base64_image)
|
| 30 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
| 31 |
+
tmp.write(image_data)
|
| 32 |
+
tmp.flush()
|
| 33 |
+
image_path = tmp.name
|
| 34 |
+
|
| 35 |
+
# 2️⃣ Send to Gradio MedGEMMA
|
| 36 |
+
client = load_gradio_client()
|
| 37 |
+
logger.info(f"[VLM] Sending prompt: {prompt}")
|
| 38 |
+
result = client.predict(
|
| 39 |
+
message={"text": prompt, "files": [handle_file(image_path)]},
|
| 40 |
+
param_2 = "You analyze medical images and report abnormalities, diseases with clear diagnostic insight.",
|
| 41 |
+
param_3=2048,
|
| 42 |
+
api_name="/chat"
|
| 43 |
+
)
|
| 44 |
+
if isinstance(result, str):
|
| 45 |
+
logger.info(f"[VLM] ✅ Response: {result}")
|
| 46 |
+
return result.strip()
|
| 47 |
+
else:
|
| 48 |
+
logger.warning(f"[VLM] ⚠️ Unexpected result type: {type(result)} — {result}")
|
| 49 |
+
return str(result)
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"[VLM] ❌ Exception: {e}")
|
| 53 |
+
logger.error(f"[VLM] 🔍 Traceback:\n{traceback.format_exc()}")
|
| 54 |
+
return f"[VLM] ⚠️ Failed to process image: {e}"
|
warmup.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
print("🚀 Warming up model...")
|
| 5 |
+
embedding_model = SentenceTransformer("/app/model_cache", device="cpu")
|
| 6 |
+
embedding_model = embedding_model.half() # Reduce memory
|
| 7 |
+
embedding_model.to(torch.device("cpu"))
|
| 8 |
+
print("✅ Model warm-up complete!")
|