Y Phung Nguyen commited on
Commit
83a4de1
·
1 Parent(s): 46971ea

Use Q&A breakdown agent

Browse files
Files changed (3) hide show
  1. pipeline.py +211 -3
  2. supervisor.py +204 -0
  3. ui.py +7 -0
pipeline.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import json
4
  import time
5
  import logging
 
6
  import concurrent.futures
7
  import gradio as gr
8
  import spaces
@@ -18,9 +19,168 @@ from supervisor import (
18
  gemini_supervisor_breakdown, gemini_supervisor_search_strategies,
19
  gemini_supervisor_rag_brainstorm, execute_medswin_task,
20
  gemini_supervisor_synthesize, gemini_supervisor_challenge,
21
- gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity
 
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @spaces.GPU(max_duration=120)
26
  def stream_chat(
@@ -37,6 +197,7 @@ def stream_chat(
37
  use_rag: bool,
38
  medical_model: str,
39
  use_web_search: bool,
 
40
  disable_agentic_reasoning: bool,
41
  show_thoughts: bool,
42
  request: gr.Request
@@ -73,9 +234,15 @@ def stream_chat(
73
  "plan": None,
74
  "strategy_decisions": [],
75
  "stage_metrics": {},
76
- "search": {"strategies": [], "total_results": 0}
 
 
 
 
 
 
 
77
  }
78
-
79
  def record_stage(stage_name: str, start_time: float):
80
  pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
81
 
@@ -95,6 +262,47 @@ def stream_chat(
95
  {"role": "assistant", "content": ""}
96
  ]
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  plan = None
99
  if not disable_agentic_reasoning:
100
  reasoning_stage_start = time.time()
 
3
  import json
4
  import time
5
  import logging
6
+ import threading
7
  import concurrent.futures
8
  import gradio as gr
9
  import spaces
 
19
  gemini_supervisor_breakdown, gemini_supervisor_search_strategies,
20
  gemini_supervisor_rag_brainstorm, execute_medswin_task,
21
  gemini_supervisor_synthesize, gemini_supervisor_challenge,
22
+ gemini_supervisor_enhance_answer, gemini_supervisor_check_clarity,
23
+ gemini_clinical_intake_triage, gemini_summarize_clinical_insights
24
  )
25
 
26
+ MAX_CLINICAL_QA_ROUNDS = 5
27
+ _clinical_intake_sessions = {}
28
+ _clinical_intake_lock = threading.Lock()
29
+
30
+
31
+ def _get_clinical_intake_state(session_id: str):
32
+ with _clinical_intake_lock:
33
+ return _clinical_intake_sessions.get(session_id)
34
+
35
+
36
+ def _set_clinical_intake_state(session_id: str, state: dict):
37
+ with _clinical_intake_lock:
38
+ _clinical_intake_sessions[session_id] = state
39
+
40
+
41
+ def _clear_clinical_intake_state(session_id: str):
42
+ with _clinical_intake_lock:
43
+ _clinical_intake_sessions.pop(session_id, None)
44
+
45
+
46
+ def _history_to_text(history: list, limit: int = 6) -> str:
47
+ if not history:
48
+ return "No prior conversation."
49
+ recent = history[-limit:]
50
+ lines = []
51
+ for turn in recent:
52
+ role = turn.get("role", "user")
53
+ content = turn.get("content", "")
54
+ lines.append(f"{role}: {content}")
55
+ return "\n".join(lines)
56
+
57
+
58
+ def _format_intake_question(question: dict, round_idx: int, max_rounds: int, target_lang: str) -> str:
59
+ header = f"🩺 Clinical intake question {round_idx}/{max_rounds}"
60
+ body = question.get("question") or "Could you share a bit more detail so I can give an accurate answer?"
61
+ focus = question.get("clinical_focus")
62
+ why = question.get("why_it_matters")
63
+ prompt_parts = [header, body]
64
+ if focus:
65
+ prompt_parts.append(f"Focus: {focus}")
66
+ if why:
67
+ prompt_parts.append(f"Why it matters: {why}")
68
+ prompt_parts.append("Please answer in 1-2 sentences so I can continue.")
69
+ prompt_text = "\n\n".join(prompt_parts)
70
+ if target_lang and target_lang != "en":
71
+ try:
72
+ prompt_text = translate_text(prompt_text, target_lang=target_lang, source_lang="en")
73
+ except Exception as exc:
74
+ logger.warning(f"[INTAKE] Question translation failed: {exc}")
75
+ return prompt_text
76
+
77
+
78
+ def _format_insights_block(insights: dict) -> str:
79
+ if not insights:
80
+ return ""
81
+ lines = []
82
+ profile = insights.get("patient_profile")
83
+ if profile:
84
+ lines.append(f"- Patient profile: {profile}")
85
+ for finding in insights.get("key_findings", []):
86
+ title = finding.get("title", "Insight")
87
+ detail = finding.get("detail", "")
88
+ implication = finding.get("clinical_implication", "")
89
+ line = f"- {title}: {detail}"
90
+ if implication:
91
+ line += f" (Clinical note: {implication})"
92
+ lines.append(line)
93
+ return "\n".join(lines)
94
+
95
+
96
+ def _build_refined_query(base_query: str, insights: dict, insights_block: str) -> str:
97
+ sections = [base_query.strip()] if base_query else []
98
+ if insights_block:
99
+ sections.append(f"Clinical intake summary:\n{insights_block}")
100
+ refined = insights.get("refined_problem_statement")
101
+ if refined:
102
+ sections.append(f"Refined problem statement:\n{refined}")
103
+ handoff = insights.get("handoff_note")
104
+ if handoff:
105
+ sections.append(f"Handoff note:\n{handoff}")
106
+ return "\n\n".join([section for section in sections if section])
107
+
108
+
109
+ def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str, original_language: str):
110
+ questions = plan.get("questions", []) or []
111
+ if not questions:
112
+ return None
113
+ max_rounds = plan.get("max_rounds") or len(questions)
114
+ max_rounds = max(1, min(MAX_CLINICAL_QA_ROUNDS, max_rounds, len(questions)))
115
+ state = {
116
+ "base_query": base_query,
117
+ "original_language": original_language or "en",
118
+ "questions": questions,
119
+ "max_rounds": max_rounds,
120
+ "current_round": 1,
121
+ "pending_question_index": 0,
122
+ "awaiting_answer": True,
123
+ "answers": [],
124
+ "decision_reason": plan.get("decision_reason", ""),
125
+ "initial_hypotheses": plan.get("initial_hypotheses", []),
126
+ "started_at": time.time()
127
+ }
128
+ _set_clinical_intake_state(session_id, state)
129
+ first_prompt = _format_intake_question(
130
+ questions[0],
131
+ round_idx=1,
132
+ max_rounds=max_rounds,
133
+ target_lang=state["original_language"]
134
+ )
135
+ return first_prompt
136
+
137
+
138
+ def _handle_clinical_answer(session_id: str, answer_text: str):
139
+ state = _get_clinical_intake_state(session_id)
140
+ if not state:
141
+ return {"type": "error"}
142
+ questions = state.get("questions", [])
143
+ idx = state.get("pending_question_index", 0)
144
+ if idx >= len(questions):
145
+ logger.warning("[INTAKE] Pending question index out of range, ending intake session")
146
+ _clear_clinical_intake_state(session_id)
147
+ return {"type": "error"}
148
+ question_meta = questions[idx] or {}
149
+ qa_entry = {
150
+ "question": question_meta.get("question", ""),
151
+ "focus": question_meta.get("clinical_focus"),
152
+ "why_it_matters": question_meta.get("why_it_matters"),
153
+ "round": state.get("current_round", len(state.get("answers", [])) + 1),
154
+ "answer": answer_text.strip()
155
+ }
156
+ state["answers"].append(qa_entry)
157
+ next_index = idx + 1
158
+ reached_round_limit = len(state["answers"]) >= state["max_rounds"]
159
+ if reached_round_limit or next_index >= len(questions):
160
+ insights = gemini_summarize_clinical_insights(state["base_query"], state["answers"])
161
+ insights_block = _format_insights_block(insights)
162
+ refined_query = _build_refined_query(state["base_query"], insights, insights_block)
163
+ _clear_clinical_intake_state(session_id)
164
+ return {
165
+ "type": "insights",
166
+ "insights": insights,
167
+ "insights_block": insights_block,
168
+ "refined_query": refined_query,
169
+ "qa_pairs": state["answers"]
170
+ }
171
+ state["pending_question_index"] = next_index
172
+ state["current_round"] = len(state["answers"]) + 1
173
+ state["awaiting_answer"] = True
174
+ _set_clinical_intake_state(session_id, state)
175
+ next_question = questions[next_index]
176
+ prompt = _format_intake_question(
177
+ next_question,
178
+ round_idx=state["current_round"],
179
+ max_rounds=state["max_rounds"],
180
+ target_lang=state["original_language"]
181
+ )
182
+ return {"type": "question", "prompt": prompt}
183
+
184
 
185
  @spaces.GPU(max_duration=120)
186
  def stream_chat(
 
197
  use_rag: bool,
198
  medical_model: str,
199
  use_web_search: bool,
200
+ enable_clinical_intake: bool,
201
  disable_agentic_reasoning: bool,
202
  show_thoughts: bool,
203
  request: gr.Request
 
234
  "plan": None,
235
  "strategy_decisions": [],
236
  "stage_metrics": {},
237
+ "search": {"strategies": [], "total_results": 0},
238
+ "clinical_intake": {
239
+ "enabled": enable_clinical_intake,
240
+ "activated": False,
241
+ "rounds": 0,
242
+ "reason": "",
243
+ "insights": []
244
+ }
245
  }
 
246
  def record_stage(stage_name: str, start_time: float):
247
  pipeline_diagnostics["stage_metrics"][stage_name] = round(time.time() - start_time, 3)
248
 
 
262
  {"role": "assistant", "content": ""}
263
  ]
264
 
265
+ if not enable_clinical_intake:
266
+ _clear_clinical_intake_state(user_id)
267
+ else:
268
+ intake_state = _get_clinical_intake_state(user_id)
269
+ if intake_state and intake_state.get("awaiting_answer"):
270
+ logger.info("[INTAKE] Awaiting patient response - processing answer")
271
+ intake_result = _handle_clinical_answer(user_id, message)
272
+ if intake_result.get("type") == "question":
273
+ logger.info("[INTAKE] Requesting additional follow-up")
274
+ updated_history[-1]["content"] = intake_result["prompt"]
275
+ thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
276
+ yield updated_history, thoughts_text
277
+ if thought_handler:
278
+ logger.removeHandler(thought_handler)
279
+ return
280
+ if intake_result.get("type") == "insights":
281
+ pipeline_diagnostics["clinical_intake"]["activated"] = True
282
+ pipeline_diagnostics["clinical_intake"]["rounds"] = len(intake_result.get("qa_pairs", []))
283
+ pipeline_diagnostics["clinical_intake"]["insights"] = intake_result.get("insights", {}).get("key_findings", [])
284
+ message = intake_result.get("refined_query", message)
285
+ else:
286
+ history_context = _history_to_text(history)
287
+ triage_plan = gemini_clinical_intake_triage(message, history_context, MAX_CLINICAL_QA_ROUNDS)
288
+ pipeline_diagnostics["clinical_intake"]["reason"] = triage_plan.get("decision_reason", "")
289
+ needs_intake = triage_plan.get("needs_additional_info") and triage_plan.get("questions")
290
+ if needs_intake:
291
+ first_prompt = _start_clinical_intake_session(
292
+ user_id,
293
+ triage_plan,
294
+ message,
295
+ original_lang
296
+ )
297
+ if first_prompt:
298
+ pipeline_diagnostics["clinical_intake"]["activated"] = True
299
+ updated_history[-1]["content"] = first_prompt
300
+ thoughts_text = thought_handler.get_thoughts() if (show_thoughts and thought_handler) else ""
301
+ yield updated_history, thoughts_text
302
+ if thought_handler:
303
+ logger.removeHandler(thought_handler)
304
+ return
305
+
306
  plan = None
307
  if not disable_agentic_reasoning:
308
  reasoning_stage_start = time.time()
supervisor.py CHANGED
@@ -217,6 +217,210 @@ Keep contexts brief and factual. Avoid redundancy."""
217
  }
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def gemini_supervisor_breakdown(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
221
  """Wrapper to obtain supervisor breakdown synchronously"""
222
  if not MCP_AVAILABLE:
 
217
  }
218
 
219
 
220
+ async def gemini_clinical_intake_triage_async(
221
+ query: str,
222
+ history_context: str,
223
+ max_rounds: int = 5
224
+ ) -> dict:
225
+ """Gemini Intake Agent: Decide if additional clinical intake is needed and plan questions"""
226
+ history_block = history_context if history_context else "No prior conversation."
227
+ safe_rounds = max(1, min(5, max_rounds))
228
+ prompt = f"""You are a clinical intake coordinator helping a medical AI system.
229
+ Your job is to review the patient's latest request and decide if more clinical details are required before analysis.
230
+
231
+ Patient query:
232
+ "{query}"
233
+
234
+ Recent conversation (if any):
235
+ {history_block}
236
+
237
+ Return ONLY valid JSON (no markdown):
238
+ {{
239
+ "needs_additional_info": true | false,
240
+ "decision_reason": "brief justification",
241
+ "max_rounds": {safe_rounds},
242
+ "questions": [
243
+ {{
244
+ "order": 1,
245
+ "question": "single follow-up question to ask the patient",
246
+ "clinical_focus": "what aspect it clarifies (e.g., onset, severity, meds)",
247
+ "why_it_matters": "concise clinical rationale",
248
+ "optional": false
249
+ }},
250
+ ...
251
+ ],
252
+ "initial_hypotheses": [
253
+ "optional bullet on potential etiologies or next steps"
254
+ ]
255
+ }}
256
+
257
+ Guidelines:
258
+ - Ask at most {safe_rounds} questions. Use fewer if the query is already specific.
259
+ - Order questions to maximize clinical value.
260
+ - Only mark needs_additional_info true when the current data is insufficient for safe reasoning.
261
+ - Keep wording patient-friendly and concise."""
262
+
263
+ system_prompt = "You are a triage clinician. Decide if more intake questions are required and outline them as structured JSON."
264
+
265
+ response = await call_agent(
266
+ user_prompt=prompt,
267
+ system_prompt=system_prompt,
268
+ model=GEMINI_MODEL_LITE,
269
+ temperature=0.15
270
+ )
271
+
272
+ try:
273
+ json_start = response.find('{')
274
+ json_end = response.rfind('}') + 1
275
+ if json_start >= 0 and json_end > json_start:
276
+ plan = json.loads(response[json_start:json_end])
277
+ return plan
278
+ raise ValueError("Clinical intake JSON not found")
279
+ except Exception as exc:
280
+ logger.error(f"[GEMINI INTAKE] Triage parsing failed: {exc}")
281
+ return {
282
+ "needs_additional_info": False,
283
+ "decision_reason": "Fallback: proceeding without intake",
284
+ "max_rounds": safe_rounds,
285
+ "questions": [],
286
+ "initial_hypotheses": []
287
+ }
288
+
289
+
290
+ def gemini_clinical_intake_triage(
291
+ query: str,
292
+ history_context: str,
293
+ max_rounds: int = 5
294
+ ) -> dict:
295
+ """Wrapper for synchronous clinical intake triage"""
296
+ if not MCP_AVAILABLE:
297
+ logger.warning("[GEMINI INTAKE] MCP unavailable, skipping clinical intake triage")
298
+ return {
299
+ "needs_additional_info": False,
300
+ "decision_reason": "MCP unavailable",
301
+ "max_rounds": max_rounds,
302
+ "questions": [],
303
+ "initial_hypotheses": []
304
+ }
305
+
306
+ try:
307
+ loop = asyncio.get_event_loop()
308
+ if loop.is_running():
309
+ if nest_asyncio:
310
+ return nest_asyncio.run(
311
+ gemini_clinical_intake_triage_async(query, history_context, max_rounds)
312
+ )
313
+ raise RuntimeError("nest_asyncio not available")
314
+ return loop.run_until_complete(
315
+ gemini_clinical_intake_triage_async(query, history_context, max_rounds)
316
+ )
317
+ except Exception as exc:
318
+ logger.error(f"[GEMINI INTAKE] Triage request failed: {exc}")
319
+ return {
320
+ "needs_additional_info": False,
321
+ "decision_reason": "Triage agent error",
322
+ "max_rounds": max_rounds,
323
+ "questions": [],
324
+ "initial_hypotheses": []
325
+ }
326
+
327
+
328
+ async def gemini_summarize_clinical_insights_async(
329
+ query: str,
330
+ qa_pairs: list
331
+ ) -> dict:
332
+ """Gemini Intake Agent: Convert answered intake questions into key clinical insights"""
333
+ qa_json = json.dumps(qa_pairs[:8]) # guard against very long history
334
+ prompt = f"""You are a clinical documentation expert.
335
+ Summarize the following intake Q&A into key insights for a supervising medical agent.
336
+
337
+ Original patient query:
338
+ "{query}"
339
+
340
+ Collected intake Q&A (JSON):
341
+ {qa_json}
342
+
343
+ Return ONLY valid JSON:
344
+ {{
345
+ "patient_profile": "1-2 sentence overview combining key demographics/symptoms",
346
+ "refined_problem_statement": "what problem the supervisor should solve now",
347
+ "key_findings": [
348
+ {{
349
+ "title": "short label",
350
+ "detail": "what the patient reported",
351
+ "clinical_implication": "why it matters"
352
+ }}
353
+ ],
354
+ "handoff_note": "action-oriented instruction for the supervisor (<=2 sentences)"
355
+ }}
356
+
357
+ Guidelines:
358
+ - Highlight red flags, chronic meds, relevant history, and symptom trajectory.
359
+ - Only include facts explicitly stated in the Q&A."""
360
+
361
+ system_prompt = "You transform clinical intake dialogs into structured insights for downstream medical reasoning."
362
+
363
+ response = await call_agent(
364
+ user_prompt=prompt,
365
+ system_prompt=system_prompt,
366
+ model=GEMINI_MODEL_LITE,
367
+ temperature=0.2
368
+ )
369
+
370
+ try:
371
+ json_start = response.find('{')
372
+ json_end = response.rfind('}') + 1
373
+ if json_start >= 0 and json_end > json_start:
374
+ return json.loads(response[json_start:json_end])
375
+ raise ValueError("Clinical insight JSON not found")
376
+ except Exception as exc:
377
+ logger.error(f"[GEMINI INTAKE] Insight summarization failed: {exc}")
378
+ return {
379
+ "patient_profile": "",
380
+ "refined_problem_statement": query,
381
+ "key_findings": [
382
+ {"title": "Patient concern", "detail": query, "clinical_implication": "Requires standard evaluation"}
383
+ ],
384
+ "handoff_note": "Proceed with regular workflow."
385
+ }
386
+
387
+
388
+ def gemini_summarize_clinical_insights(query: str, qa_pairs: list) -> dict:
389
+ """Wrapper for synchronous clinical insight summarization"""
390
+ if not MCP_AVAILABLE:
391
+ logger.warning("[GEMINI INTAKE] MCP unavailable, using fallback intake summary")
392
+ return {
393
+ "patient_profile": "",
394
+ "refined_problem_statement": query,
395
+ "key_findings": [
396
+ {"title": "Patient concern", "detail": query, "clinical_implication": "Requires standard evaluation"}
397
+ ],
398
+ "handoff_note": "Proceed with regular workflow."
399
+ }
400
+
401
+ try:
402
+ loop = asyncio.get_event_loop()
403
+ if loop.is_running():
404
+ if nest_asyncio:
405
+ return nest_asyncio.run(
406
+ gemini_summarize_clinical_insights_async(query, qa_pairs)
407
+ )
408
+ raise RuntimeError("nest_asyncio not available")
409
+ return loop.run_until_complete(
410
+ gemini_summarize_clinical_insights_async(query, qa_pairs)
411
+ )
412
+ except Exception as exc:
413
+ logger.error(f"[GEMINI INTAKE] Insight summarization request failed: {exc}")
414
+ return {
415
+ "patient_profile": "",
416
+ "refined_problem_statement": query,
417
+ "key_findings": [
418
+ {"title": "Patient concern", "detail": query, "clinical_implication": "Requires standard evaluation"}
419
+ ],
420
+ "handoff_note": "Proceed with regular workflow."
421
+ }
422
+
423
+
424
  def gemini_supervisor_breakdown(query: str, use_rag: bool, use_web_search: bool, time_elapsed: float, max_duration: int = 120) -> dict:
425
  """Wrapper to obtain supervisor breakdown synchronously"""
426
  if not MCP_AVAILABLE:
ui.py CHANGED
@@ -144,6 +144,11 @@ def create_demo():
144
  "Show agentic thought",
145
  size="sm"
146
  )
 
 
 
 
 
147
  agentic_thoughts_box = gr.Textbox(
148
  label="Agentic Thoughts",
149
  placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
@@ -261,6 +266,7 @@ def create_demo():
261
  use_rag,
262
  medical_model,
263
  use_web_search,
 
264
  disable_agentic_reasoning,
265
  show_thoughts_state
266
  ],
@@ -283,6 +289,7 @@ def create_demo():
283
  use_rag,
284
  medical_model,
285
  use_web_search,
 
286
  disable_agentic_reasoning,
287
  show_thoughts_state
288
  ],
 
144
  "Show agentic thought",
145
  size="sm"
146
  )
147
+ enable_clinical_intake = gr.Checkbox(
148
+ value=True,
149
+ label="Enable clinical intake (max 5 Q&A)",
150
+ info="Ask focused follow-up questions before breaking down the case"
151
+ )
152
  agentic_thoughts_box = gr.Textbox(
153
  label="Agentic Thoughts",
154
  placeholder="Internal thoughts from MedSwin and supervisor will appear here...",
 
266
  use_rag,
267
  medical_model,
268
  use_web_search,
269
+ enable_clinical_intake,
270
  disable_agentic_reasoning,
271
  show_thoughts_state
272
  ],
 
289
  use_rag,
290
  medical_model,
291
  use_web_search,
292
+ enable_clinical_intake,
293
  disable_agentic_reasoning,
294
  show_thoughts_state
295
  ],