Y Phung Nguyen commited on
Commit
4a5418d
·
1 Parent(s): acc39fd

Fix model preloader

Browse files
Files changed (2) hide show
  1. pipeline.py +4 -14
  2. ui.py +91 -58
pipeline.py CHANGED
@@ -370,25 +370,15 @@ def stream_chat(
370
  return
371
 
372
  # Check if model is loaded before proceeding
 
373
  if not is_model_loaded(medical_model):
374
  loading_state = get_model_loading_state(medical_model)
375
  if loading_state == "loading":
376
  error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
377
  else:
378
- error_msg = f"⚠️ {medical_model} is not ready. Please wait for the model to finish loading."
379
- # Try to load it
380
- try:
381
- set_model_loading_state(medical_model, "loading")
382
- initialize_medical_model(medical_model)
383
- # If successful, continue
384
- except Exception as e:
385
- error_msg = f"⚠️ Error loading {medical_model}: {str(e)[:200]}. Please try again."
386
- yield history + [{"role": "assistant", "content": error_msg}], ""
387
- return
388
-
389
- if not is_model_loaded(medical_model):
390
- yield history + [{"role": "assistant", "content": error_msg}], ""
391
- return
392
 
393
  thought_handler = None
394
  if show_thoughts:
 
370
  return
371
 
372
  # Check if model is loaded before proceeding
373
+ # NOTE: We don't load the model here to save time - it should be pre-loaded before stream_chat is called
374
  if not is_model_loaded(medical_model):
375
  loading_state = get_model_loading_state(medical_model)
376
  if loading_state == "loading":
377
  error_msg = f"⏳ {medical_model} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
378
  else:
379
+ error_msg = f"⚠️ {medical_model} is not loaded. Please wait for the model to finish loading or select a model from the dropdown."
380
+ yield history + [{"role": "assistant", "content": error_msg}], ""
381
+ return
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  thought_handler = None
384
  if show_thoughts:
ui.py CHANGED
@@ -649,71 +649,104 @@ def create_demo():
649
  outputs=[model_status, submit_button, message_input]
650
  )
651
 
652
- # Wrap stream_chat - let stream_chat handle model loading since it's GPU-decorated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  def stream_chat_with_model_check(
654
  message, history, system_prompt, temperature, max_new_tokens,
655
  top_p, top_k, penalty, retriever_k, merge_threshold,
656
  use_rag, medical_model_name, use_web_search,
657
  enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
658
  ):
659
- import time
660
- max_retries = 2
661
- base_delay = 2.0
662
-
663
- for attempt in range(max_retries):
664
- try:
665
- # Check if model is currently loading (don't block if it's already loaded)
666
- loading_state = get_model_loading_state(medical_model_name)
667
- if loading_state == "loading" and not is_model_loaded(medical_model_name):
668
- error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
669
- updated_history = history + [{"role": "assistant", "content": error_msg}]
670
- yield updated_history, ""
671
- return
672
-
673
- # If request is None, create a mock request for compatibility
674
- if request is None:
675
- class MockRequest:
676
- session_hash = "anonymous"
677
- request = MockRequest()
678
-
679
- # Let stream_chat handle model loading (it's GPU-decorated and can load on-demand)
680
- for result in stream_chat(
681
- message, history, system_prompt, temperature, max_new_tokens,
682
- top_p, top_k, penalty, retriever_k, merge_threshold,
683
- use_rag, medical_model_name, use_web_search,
684
- enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
685
- ):
686
- yield result
687
- # If we get here, stream_chat completed successfully
688
- return
689
-
690
- except Exception as e:
691
- error_msg_lower = str(e).lower()
692
- is_gpu_error = 'gpu task aborted' in error_msg_lower or 'gpu' in error_msg_lower or 'zerogpu' in error_msg_lower
693
-
694
- if is_gpu_error and attempt < max_retries - 1:
695
- delay = base_delay * (2 ** attempt) # Exponential backoff: 2s, 4s
696
- logger.warning(f"[STREAM_CHAT] GPU task aborted (attempt {attempt + 1}/{max_retries}), retrying after {delay}s...")
697
- # Yield a message to user about retry
698
- retry_msg = f"⏳ GPU task was interrupted. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})"
699
- updated_history = history + [{"role": "assistant", "content": retry_msg}]
700
- yield updated_history, ""
701
- time.sleep(delay)
702
- continue
703
  else:
704
- # Final error handling
705
- logger.error(f"[STREAM_CHAT] Error in stream_chat_with_model_check: {e}")
706
- import traceback
707
- logger.error(f"[STREAM_CHAT] Full traceback: {traceback.format_exc()}")
708
-
709
- if is_gpu_error:
710
- error_msg = f"⚠️ GPU task was aborted. This can happen if:\n- The request took too long\n- Multiple GPU requests conflicted\n- GPU quota was exceeded\n\nPlease try again or select a different model."
711
- else:
712
- error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
713
-
714
- updated_history = history + [{"role": "assistant", "content": error_msg}]
715
- yield updated_history, ""
716
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
 
718
  submit_button.click(
719
  fn=stream_chat_with_model_check,
 
649
  outputs=[model_status, submit_button, message_input]
650
  )
651
 
652
+ # Background model loading when user focuses on input (pre-loads before sending message)
653
+ @spaces.GPU(max_duration=MAX_DURATION)
654
+ def preload_model_on_input_focus():
655
+ """Pre-load model when user focuses on input to avoid loading during stream_chat"""
656
+ try:
657
+ if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
658
+ logger.info("[PRELOAD] User focused on input - pre-loading model in background...")
659
+ loading_state = get_model_loading_state(DEFAULT_MEDICAL_MODEL)
660
+ if loading_state != "loading": # Don't start if already loading
661
+ try:
662
+ set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
663
+ initialize_medical_model(DEFAULT_MEDICAL_MODEL)
664
+ if is_model_loaded(DEFAULT_MEDICAL_MODEL):
665
+ logger.info("[PRELOAD] ✅ Model pre-loaded successfully!")
666
+ return "✅ Model pre-loaded and ready"
667
+ else:
668
+ logger.warning("[PRELOAD] Model initialization completed but not marked as loaded")
669
+ return "⚠️ Model loading in progress..."
670
+ except Exception as e:
671
+ error_msg = str(e)
672
+ is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or
673
+ "quota" in error_msg.lower() or "ZeroGPU" in error_msg or
674
+ "runnning out" in error_msg.lower() or "running out" in error_msg.lower())
675
+ if is_quota_error:
676
+ logger.warning(f"[PRELOAD] Quota error during pre-load: {error_msg[:100]}")
677
+ return "⚠️ Quota limit - model will load when you send message"
678
+ else:
679
+ logger.error(f"[PRELOAD] Error pre-loading model: {e}")
680
+ return "⚠️ Pre-load failed - will try on message send"
681
+ else:
682
+ return "⏳ Model is already loading..."
683
+ else:
684
+ return "✅ Model already loaded"
685
+ except Exception as e:
686
+ logger.error(f"[PRELOAD] Error in preload function: {e}")
687
+ return "⚠️ Pre-load error"
688
+
689
+ def trigger_preload_on_focus():
690
+ """Trigger model pre-loading when user focuses on input"""
691
+ try:
692
+ if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
693
+ # Start pre-loading in background (non-blocking)
694
+ logger.info("[PRELOAD] Input focused - triggering background model load...")
695
+ # This will run in GPU context but won't block the UI
696
+ preload_model_on_input_focus()
697
+ except Exception as e:
698
+ logger.debug(f"[PRELOAD] Pre-load trigger error (non-critical): {e}")
699
+ # Return empty string to not update any UI element
700
+ return ""
701
+
702
+ # Trigger model pre-loading when user focuses on message input
703
+ message_input.focus(
704
+ fn=trigger_preload_on_focus,
705
+ inputs=None,
706
+ outputs=None
707
+ )
708
+
709
+ # Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time)
710
  def stream_chat_with_model_check(
711
  message, history, system_prompt, temperature, max_new_tokens,
712
  top_p, top_k, penalty, retriever_k, merge_threshold,
713
  use_rag, medical_model_name, use_web_search,
714
  enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None
715
  ):
716
+ # Check if model is loaded - if not, show error (don't load here to save stream_chat time)
717
+ if not is_model_loaded(medical_model_name):
718
+ loading_state = get_model_loading_state(medical_model_name)
719
+ if loading_state == "loading":
720
+ error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  else:
722
+ error_msg = f"⚠️ {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it."
723
+ updated_history = history + [{"role": "assistant", "content": error_msg}]
724
+ yield updated_history, ""
725
+ return
726
+
727
+ # If request is None, create a mock request for compatibility
728
+ if request is None:
729
+ class MockRequest:
730
+ session_hash = "anonymous"
731
+ request = MockRequest()
732
+
733
+ # Model is loaded, proceed with stream_chat (no model loading here to save time)
734
+ try:
735
+ for result in stream_chat(
736
+ message, history, system_prompt, temperature, max_new_tokens,
737
+ top_p, top_k, penalty, retriever_k, merge_threshold,
738
+ use_rag, medical_model_name, use_web_search,
739
+ enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request
740
+ ):
741
+ yield result
742
+ except Exception as e:
743
+ # Handle any errors gracefully
744
+ logger.error(f"Error in stream_chat_with_model_check: {e}")
745
+ import traceback
746
+ logger.debug(f"Full traceback: {traceback.format_exc()}")
747
+ error_msg = f"⚠️ An error occurred: {str(e)[:200]}"
748
+ updated_history = history + [{"role": "assistant", "content": error_msg}]
749
+ yield updated_history, ""
750
 
751
  submit_button.click(
752
  fn=stream_chat_with_model_check,