Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
4a5418d
1
Parent(s):
acc39fd
Fix model preloader
Browse files- pipeline.py +4 -14
- ui.py +91 -58
pipeline.py
CHANGED
|
@@ -370,25 +370,15 @@ def stream_chat(
|
|
| 370 |
return
|
| 371 |
|
| 372 |
# Check if model is loaded before proceeding
|
|
|
|
| 373 |
if not is_model_loaded(medical_model):
|
| 374 |
loading_state = get_model_loading_state(medical_model)
|
| 375 |
if loading_state == "loading":
|
| 376 |
error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
|
| 377 |
else:
|
| 378 |
-
error_msg = f"⚠️ {medical_model} is not
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
set_model_loading_state(medical_model, "loading")
|
| 382 |
-
initialize_medical_model(medical_model)
|
| 383 |
-
# If successful, continue
|
| 384 |
-
except Exception as e:
|
| 385 |
-
error_msg = f"⚠️ Error loading {medical_model}: {str(e)[:200]}. Please try again."
|
| 386 |
-
yield history + [{"role": "assistant", "content": error_msg}], ""
|
| 387 |
-
return
|
| 388 |
-
|
| 389 |
-
if not is_model_loaded(medical_model):
|
| 390 |
-
yield history + [{"role": "assistant", "content": error_msg}], ""
|
| 391 |
-
return
|
| 392 |
|
| 393 |
thought_handler = None
|
| 394 |
if show_thoughts:
|
|
|
|
| 370 |
return
|
| 371 |
|
| 372 |
# Check if model is loaded before proceeding
|
| 373 |
+
# NOTE: We don't load the model here to save time - it should be pre-loaded before stream_chat is called
|
| 374 |
if not is_model_loaded(medical_model):
|
| 375 |
loading_state = get_model_loading_state(medical_model)
|
| 376 |
if loading_state == "loading":
|
| 377 |
error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
|
| 378 |
else:
|
| 379 |
+
error_msg = f"⚠️ {medical_model} is not loaded. Please wait for the model to finish loading or select a model from the dropdown."
|
| 380 |
+
yield history + [{"role": "assistant", "content": error_msg}], ""
|
| 381 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
thought_handler = None
|
| 384 |
if show_thoughts:
|
ui.py
CHANGED
|
@@ -649,71 +649,104 @@ def create_demo():
|
|
| 649 |
outputs=[model_status, submit_button, message_input]
|
| 650 |
)
|
| 651 |
|
| 652 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
def stream_chat_with_model_check(
|
| 654 |
message, history, system_prompt, temperature, max_new_tokens,
|
| 655 |
top_p, top_k, penalty, retriever_k, merge_threshold,
|
| 656 |
use_rag, medical_model_name, use_web_search,
|
| 657 |
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
|
| 658 |
):
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
try:
|
| 665 |
-
# Check if model is currently loading (don't block if it's already loaded)
|
| 666 |
-
loading_state = get_model_loading_state(medical_model_name)
|
| 667 |
-
if loading_state == "loading" and not is_model_loaded(medical_model_name):
|
| 668 |
-
error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
|
| 669 |
-
updated_history = history + [{"role": "assistant", "content": error_msg}]
|
| 670 |
-
yield updated_history, ""
|
| 671 |
-
return
|
| 672 |
-
|
| 673 |
-
# If request is None, create a mock request for compatibility
|
| 674 |
-
if request is None:
|
| 675 |
-
class MockRequest:
|
| 676 |
-
session_hash = "anonymous"
|
| 677 |
-
request = MockRequest()
|
| 678 |
-
|
| 679 |
-
# Let stream_chat handle model loading (it's GPU-decorated and can load on-demand)
|
| 680 |
-
for result in stream_chat(
|
| 681 |
-
message, history, system_prompt, temperature, max_new_tokens,
|
| 682 |
-
top_p, top_k, penalty, retriever_k, merge_threshold,
|
| 683 |
-
use_rag, medical_model_name, use_web_search,
|
| 684 |
-
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
|
| 685 |
-
):
|
| 686 |
-
yield result
|
| 687 |
-
# If we get here, stream_chat completed successfully
|
| 688 |
-
return
|
| 689 |
-
|
| 690 |
-
except Exception as e:
|
| 691 |
-
error_msg_lower = str(e).lower()
|
| 692 |
-
is_gpu_error = 'gpu task aborted' in error_msg_lower or 'gpu' in error_msg_lower or 'zerogpu' in error_msg_lower
|
| 693 |
-
|
| 694 |
-
if is_gpu_error and attempt < max_retries - 1:
|
| 695 |
-
delay = base_delay * (2 ** attempt) # Exponential backoff: 2s, 4s
|
| 696 |
-
logger.warning(f"[STREAM_CHAT] GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
|
| 697 |
-
# Yield a message to user about retry
|
| 698 |
-
retry_msg = f"⏳ GPU task was interrupted. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})"
|
| 699 |
-
updated_history = history + [{"role": "assistant", "content": retry_msg}]
|
| 700 |
-
yield updated_history, ""
|
| 701 |
-
time.sleep(delay)
|
| 702 |
-
continue
|
| 703 |
else:
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
|
| 718 |
submit_button.click(
|
| 719 |
fn=stream_chat_with_model_check,
|
|
|
|
| 649 |
outputs=[model_status, submit_button, message_input]
|
| 650 |
)
|
| 651 |
|
| 652 |
+
# Background model loading when user focuses on input (pre-loads before sending message)
|
| 653 |
+
@spaces.GPU(max_duration=MAX_DURATION)
|
| 654 |
+
def preload_model_on_input_focus():
|
| 655 |
+
"""Pre-load model when user focuses on input to avoid loading during stream_chat"""
|
| 656 |
+
try:
|
| 657 |
+
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 658 |
+
logger.info("[PRELOAD] User focused on input - pre-loading model in background...")
|
| 659 |
+
loading_state = get_model_loading_state(DEFAULT_MEDICAL_MODEL)
|
| 660 |
+
if loading_state != "loading": # Don't start if already loading
|
| 661 |
+
try:
|
| 662 |
+
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
|
| 663 |
+
initialize_medical_model(DEFAULT_MEDICAL_MODEL)
|
| 664 |
+
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 665 |
+
logger.info("[PRELOAD] ✅ Model pre-loaded successfully!")
|
| 666 |
+
return "✅ Model pre-loaded and ready"
|
| 667 |
+
else:
|
| 668 |
+
logger.warning("[PRELOAD] Model initialization completed but not marked as loaded")
|
| 669 |
+
return "⚠️ Model loading in progress..."
|
| 670 |
+
except Exception as e:
|
| 671 |
+
error_msg = str(e)
|
| 672 |
+
is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
|
| 673 |
+
"quota" in error_msg.lower() or "ZeroGPU" in error_msg or
|
| 674 |
+
"runnning out" in error_msg.lower() or "running out" in error_msg.lower())
|
| 675 |
+
if is_quota_error:
|
| 676 |
+
logger.warning(f"[PRELOAD] Quota error during pre-load: {error_msg[:100]}")
|
| 677 |
+
return "⚠️ Quota limit - model will load when you send message"
|
| 678 |
+
else:
|
| 679 |
+
logger.error(f"[PRELOAD] Error pre-loading model: {e}")
|
| 680 |
+
return "⚠️ Pre-load failed - will try on message send"
|
| 681 |
+
else:
|
| 682 |
+
return "⏳ Model is already loading..."
|
| 683 |
+
else:
|
| 684 |
+
return "✅ Model already loaded"
|
| 685 |
+
except Exception as e:
|
| 686 |
+
logger.error(f"[PRELOAD] Error in preload function: {e}")
|
| 687 |
+
return "⚠️ Pre-load error"
|
| 688 |
+
|
| 689 |
+
def trigger_preload_on_focus():
|
| 690 |
+
"""Trigger model pre-loading when user focuses on input"""
|
| 691 |
+
try:
|
| 692 |
+
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 693 |
+
# Start pre-loading in background (non-blocking)
|
| 694 |
+
logger.info("[PRELOAD] Input focused - triggering background model load...")
|
| 695 |
+
# This will run in GPU context but won't block the UI
|
| 696 |
+
preload_model_on_input_focus()
|
| 697 |
+
except Exception as e:
|
| 698 |
+
logger.debug(f"[PRELOAD] Pre-load trigger error (non-critical): {e}")
|
| 699 |
+
# Return empty string to not update any UI element
|
| 700 |
+
return ""
|
| 701 |
+
|
| 702 |
+
# Trigger model pre-loading when user focuses on message input
|
| 703 |
+
message_input.focus(
|
| 704 |
+
fn=trigger_preload_on_focus,
|
| 705 |
+
inputs=None,
|
| 706 |
+
outputs=None
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time)
|
| 710 |
def stream_chat_with_model_check(
|
| 711 |
message, history, system_prompt, temperature, max_new_tokens,
|
| 712 |
top_p, top_k, penalty, retriever_k, merge_threshold,
|
| 713 |
use_rag, medical_model_name, use_web_search,
|
| 714 |
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
|
| 715 |
):
|
| 716 |
+
# Check if model is loaded - if not, show error (don't load here to save stream_chat time)
|
| 717 |
+
if not is_model_loaded(medical_model_name):
|
| 718 |
+
loading_state = get_model_loading_state(medical_model_name)
|
| 719 |
+
if loading_state == "loading":
|
| 720 |
+
error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
else:
|
| 722 |
+
error_msg = f"⚠️ {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it."
|
| 723 |
+
updated_history = history + [{"role": "assistant", "content": error_msg}]
|
| 724 |
+
yield updated_history, ""
|
| 725 |
+
return
|
| 726 |
+
|
| 727 |
+
# If request is None, create a mock request for compatibility
|
| 728 |
+
if request is None:
|
| 729 |
+
class MockRequest:
|
| 730 |
+
session_hash = "anonymous"
|
| 731 |
+
request = MockRequest()
|
| 732 |
+
|
| 733 |
+
# Model is loaded, proceed with stream_chat (no model loading here to save time)
|
| 734 |
+
try:
|
| 735 |
+
for result in stream_chat(
|
| 736 |
+
message, history, system_prompt, temperature, max_new_tokens,
|
| 737 |
+
top_p, top_k, penalty, retriever_k, merge_threshold,
|
| 738 |
+
use_rag, medical_model_name, use_web_search,
|
| 739 |
+
enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
|
| 740 |
+
):
|
| 741 |
+
yield result
|
| 742 |
+
except Exception as e:
|
| 743 |
+
# Handle any errors gracefully
|
| 744 |
+
logger.error(f"Error in stream_chat_with_model_check: {e}")
|
| 745 |
+
import traceback
|
| 746 |
+
logger.debug(f"Full traceback: {traceback.format_exc()}")
|
| 747 |
+
error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
|
| 748 |
+
updated_history = history + [{"role": "assistant", "content": error_msg}]
|
| 749 |
+
yield updated_history, ""
|
| 750 |
|
| 751 |
submit_button.click(
|
| 752 |
fn=stream_chat_with_model_check,
|