Spaces:
Running
on
Zero
Running
on
Zero
Y Phung Nguyen
commited on
Commit
·
f7415cc
1
Parent(s):
09d7494
Fix model preloader
Browse files- models.py +115 -20
- pipeline.py +28 -1
- ui.py +77 -11
models.py
CHANGED
|
@@ -57,26 +57,56 @@ def is_model_loaded(model_name: str) -> bool:
|
|
| 57 |
config.global_medical_models[model_name] is not None and
|
| 58 |
_model_loading_states.get(model_name) == "loaded")
|
| 59 |
|
| 60 |
-
def initialize_medical_model(model_name: str):
|
| 61 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
|
| 63 |
set_model_loading_state(model_name, "loading")
|
| 64 |
-
logger.info(f"Initializing medical model: {model_name}...")
|
| 65 |
try:
|
| 66 |
-
# Clear GPU cache before loading to prevent memory issues
|
| 67 |
-
if torch.cuda.is_available():
|
| 68 |
-
torch.cuda.empty_cache()
|
| 69 |
-
logger.debug("Cleared GPU cache before model loading")
|
| 70 |
-
|
| 71 |
model_path = config.MEDSWIN_MODELS[model_name]
|
| 72 |
tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Set models in config BEFORE setting state to "loaded"
|
| 81 |
config.global_medical_models[model_name] = model
|
| 82 |
config.global_medical_tokenizers[model_name] = tokenizer
|
|
@@ -87,11 +117,6 @@ def initialize_medical_model(model_name: str):
|
|
| 87 |
# Verify the state was set correctly
|
| 88 |
if not is_model_loaded(model_name):
|
| 89 |
logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}")
|
| 90 |
-
|
| 91 |
-
# Clear cache after loading to free up temporary memory
|
| 92 |
-
if torch.cuda.is_available():
|
| 93 |
-
torch.cuda.empty_cache()
|
| 94 |
-
logger.debug("Cleared GPU cache after model loading")
|
| 95 |
except Exception as e:
|
| 96 |
set_model_loading_state(model_name, "error")
|
| 97 |
logger.error(f"Failed to initialize medical model {model_name}: {e}")
|
|
@@ -106,6 +131,76 @@ def initialize_medical_model(model_name: str):
|
|
| 106 |
set_model_loading_state(model_name, "loaded")
|
| 107 |
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def initialize_tts_model():
|
| 110 |
"""Initialize TTS model for text-to-speech"""
|
| 111 |
if not TTS_AVAILABLE:
|
|
|
|
| 57 |
config.global_medical_models[model_name] is not None and
|
| 58 |
_model_loading_states.get(model_name) == "loaded")
|
| 59 |
|
| 60 |
+
def initialize_medical_model(model_name: str, load_to_gpu: bool = True):
|
| 61 |
+
"""
|
| 62 |
+
Initialize medical model (MedSwin) - download on demand
|
| 63 |
+
|
| 64 |
+
According to ZeroGPU best practices:
|
| 65 |
+
- If load_to_gpu=True: Load directly to GPU using device_map="auto" (must be called within @spaces.GPU decorated function)
|
| 66 |
+
- If load_to_gpu=False: Load to CPU first, then move to GPU in inference function
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model_name: Name of the model to load
|
| 70 |
+
load_to_gpu: If True, load directly to GPU. If False, load to CPU (for ZeroGPU best practices)
|
| 71 |
+
"""
|
| 72 |
if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
|
| 73 |
set_model_loading_state(model_name, "loading")
|
| 74 |
+
logger.info(f"Initializing medical model: {model_name}... (load_to_gpu={load_to_gpu})")
|
| 75 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
model_path = config.MEDSWIN_MODELS[model_name]
|
| 77 |
tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
|
| 78 |
+
|
| 79 |
+
if load_to_gpu:
|
| 80 |
+
# Load directly to GPU (must be within @spaces.GPU decorated function)
|
| 81 |
+
# Clear GPU cache before loading to prevent memory issues
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
torch.cuda.empty_cache()
|
| 84 |
+
logger.debug("Cleared GPU cache before model loading")
|
| 85 |
+
|
| 86 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 87 |
+
model_path,
|
| 88 |
+
device_map="auto", # Automatically places model on GPU
|
| 89 |
+
trust_remote_code=True,
|
| 90 |
+
token=config.HF_TOKEN,
|
| 91 |
+
torch_dtype=torch.float16
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Clear cache after loading
|
| 95 |
+
if torch.cuda.is_available():
|
| 96 |
+
torch.cuda.empty_cache()
|
| 97 |
+
logger.debug("Cleared GPU cache after model loading")
|
| 98 |
+
else:
|
| 99 |
+
# Load to CPU first (ZeroGPU best practice - no GPU decorator needed)
|
| 100 |
+
logger.info(f"Loading {model_name} to CPU (will move to GPU during inference)...")
|
| 101 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 102 |
+
model_path,
|
| 103 |
+
device_map="cpu", # Load to CPU
|
| 104 |
+
trust_remote_code=True,
|
| 105 |
+
token=config.HF_TOKEN,
|
| 106 |
+
torch_dtype=torch.float16
|
| 107 |
+
)
|
| 108 |
+
logger.info(f"Model {model_name} loaded to CPU successfully")
|
| 109 |
+
|
| 110 |
# Set models in config BEFORE setting state to "loaded"
|
| 111 |
config.global_medical_models[model_name] = model
|
| 112 |
config.global_medical_tokenizers[model_name] = tokenizer
|
|
|
|
| 117 |
# Verify the state was set correctly
|
| 118 |
if not is_model_loaded(model_name):
|
| 119 |
logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
except Exception as e:
|
| 121 |
set_model_loading_state(model_name, "error")
|
| 122 |
logger.error(f"Failed to initialize medical model {model_name}: {e}")
|
|
|
|
| 131 |
set_model_loading_state(model_name, "loaded")
|
| 132 |
return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
|
| 133 |
|
| 134 |
+
def move_model_to_gpu(model_name: str):
|
| 135 |
+
"""
|
| 136 |
+
Move a model from CPU to GPU (for ZeroGPU best practices)
|
| 137 |
+
Must be called within a @spaces.GPU decorated function
|
| 138 |
+
|
| 139 |
+
According to ZeroGPU best practices:
|
| 140 |
+
- Models should be loaded to CPU first (no GPU quota used)
|
| 141 |
+
- Models are moved to GPU only during inference (within @spaces.GPU decorated function)
|
| 142 |
+
"""
|
| 143 |
+
if model_name not in config.global_medical_models:
|
| 144 |
+
raise ValueError(f"Model {model_name} not found in config")
|
| 145 |
+
|
| 146 |
+
model = config.global_medical_models[model_name]
|
| 147 |
+
if model is None:
|
| 148 |
+
raise ValueError(f"Model {model_name} is None")
|
| 149 |
+
|
| 150 |
+
# Check if model is already on GPU
|
| 151 |
+
try:
|
| 152 |
+
# For models with device_map, check the actual device
|
| 153 |
+
if hasattr(model, 'device'):
|
| 154 |
+
device_str = str(model.device)
|
| 155 |
+
if 'cuda' in device_str.lower():
|
| 156 |
+
logger.debug(f"Model {model_name} is already on GPU ({device_str})")
|
| 157 |
+
return model
|
| 158 |
+
|
| 159 |
+
# Check device_map if available
|
| 160 |
+
if hasattr(model, 'hf_device_map'):
|
| 161 |
+
device_map = model.hf_device_map
|
| 162 |
+
if isinstance(device_map, dict):
|
| 163 |
+
# Check if any device is GPU
|
| 164 |
+
if any('cuda' in str(dev).lower() for dev in device_map.values()):
|
| 165 |
+
logger.debug(f"Model {model_name} is already on GPU (device_map)")
|
| 166 |
+
return model
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.debug(f"Could not check model device: {e}")
|
| 169 |
+
|
| 170 |
+
# Move model to GPU
|
| 171 |
+
logger.info(f"Moving model {model_name} from CPU to GPU...")
|
| 172 |
+
if torch.cuda.is_available():
|
| 173 |
+
torch.cuda.empty_cache()
|
| 174 |
+
|
| 175 |
+
# For models loaded with device_map="cpu", we need to reload with device_map="auto"
|
| 176 |
+
# or use accelerate to dispatch to GPU
|
| 177 |
+
try:
|
| 178 |
+
# Try using accelerate's dispatch_model for proper GPU placement
|
| 179 |
+
from accelerate import dispatch_model
|
| 180 |
+
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
| 181 |
+
|
| 182 |
+
# Get device map for GPU
|
| 183 |
+
max_memory = get_balanced_memory(model, max_memory={0: "20GiB"})
|
| 184 |
+
device_map = infer_auto_device_map(model, max_memory=max_memory)
|
| 185 |
+
model = dispatch_model(model, device_map=device_map)
|
| 186 |
+
config.global_medical_models[model_name] = model
|
| 187 |
+
logger.info(f"Model {model_name} moved to GPU successfully using accelerate")
|
| 188 |
+
except Exception as e:
|
| 189 |
+
# Fallback: simple move to cuda (may not work for all model architectures)
|
| 190 |
+
logger.warning(f"Could not use accelerate dispatch, trying simple .to('cuda'): {e}")
|
| 191 |
+
try:
|
| 192 |
+
model = model.to('cuda')
|
| 193 |
+
config.global_medical_models[model_name] = model
|
| 194 |
+
logger.info(f"Model {model_name} moved to GPU (cuda) successfully")
|
| 195 |
+
except Exception as e2:
|
| 196 |
+
logger.error(f"Failed to move model {model_name} to GPU: {e2}")
|
| 197 |
+
raise
|
| 198 |
+
|
| 199 |
+
if torch.cuda.is_available():
|
| 200 |
+
torch.cuda.empty_cache()
|
| 201 |
+
|
| 202 |
+
return model
|
| 203 |
+
|
| 204 |
def initialize_tts_model():
|
| 205 |
"""Initialize TTS model for text-to-speech"""
|
| 206 |
if not TTS_AVAILABLE:
|
pipeline.py
CHANGED
|
@@ -12,7 +12,7 @@ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_s
|
|
| 12 |
from llama_index.core import Settings
|
| 13 |
from llama_index.core.retrievers import AutoMergingRetriever
|
| 14 |
from logger import logger, ThoughtCaptureHandler
|
| 15 |
-
from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state
|
| 16 |
from utils import detect_language, translate_text, format_url_as_domain
|
| 17 |
from search import search_web, summarize_web_content
|
| 18 |
from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
|
|
@@ -380,6 +380,33 @@ def stream_chat(
|
|
| 380 |
yield history + [{"role": "assistant", "content": error_msg}], ""
|
| 381 |
return
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
thought_handler = None
|
| 384 |
if show_thoughts:
|
| 385 |
thought_handler = ThoughtCaptureHandler()
|
|
|
|
| 12 |
from llama_index.core import Settings
|
| 13 |
from llama_index.core.retrievers import AutoMergingRetriever
|
| 14 |
from logger import logger, ThoughtCaptureHandler
|
| 15 |
+
from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state, move_model_to_gpu
|
| 16 |
from utils import detect_language, translate_text, format_url_as_domain
|
| 17 |
from search import search_web, summarize_web_content
|
| 18 |
from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
|
|
|
|
| 380 |
yield history + [{"role": "assistant", "content": error_msg}], ""
|
| 381 |
return
|
| 382 |
|
| 383 |
+
# ZeroGPU best practice: If model is on CPU, move it to GPU now (we're in a GPU-decorated function)
|
| 384 |
+
# This ensures the model is ready for inference without consuming GPU quota during startup
|
| 385 |
+
try:
|
| 386 |
+
import config
|
| 387 |
+
if medical_model in config.global_medical_models:
|
| 388 |
+
model = config.global_medical_models[medical_model]
|
| 389 |
+
if model is not None:
|
| 390 |
+
# Check if model is on CPU (device_map="cpu" or device is CPU)
|
| 391 |
+
model_on_cpu = False
|
| 392 |
+
if hasattr(model, 'device'):
|
| 393 |
+
if str(model.device) == 'cpu':
|
| 394 |
+
model_on_cpu = True
|
| 395 |
+
elif hasattr(model, 'hf_device_map'):
|
| 396 |
+
# Model loaded with device_map - check if it's on CPU
|
| 397 |
+
if isinstance(model.hf_device_map, dict):
|
| 398 |
+
# If all devices are CPU, move to GPU
|
| 399 |
+
if all('cpu' in str(dev).lower() for dev in model.hf_device_map.values()):
|
| 400 |
+
model_on_cpu = True
|
| 401 |
+
|
| 402 |
+
if model_on_cpu:
|
| 403 |
+
logger.info(f"[STREAM_CHAT] Model {medical_model} is on CPU, moving to GPU for inference...")
|
| 404 |
+
move_model_to_gpu(medical_model)
|
| 405 |
+
logger.info(f"[STREAM_CHAT] ✅ Model {medical_model} moved to GPU successfully")
|
| 406 |
+
except Exception as e:
|
| 407 |
+
logger.warning(f"[STREAM_CHAT] Could not move model to GPU (may already be on GPU): {e}")
|
| 408 |
+
# Continue anyway - model might already be on GPU
|
| 409 |
+
|
| 410 |
thought_handler = None
|
| 411 |
if show_thoughts:
|
| 412 |
thought_handler = ThoughtCaptureHandler()
|
ui.py
CHANGED
|
@@ -406,10 +406,68 @@ def create_demo():
|
|
| 406 |
return status_text, is_ready
|
| 407 |
|
| 408 |
# GPU-decorated function to load ONLY medical model on startup
|
| 409 |
-
#
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
import torch
|
| 414 |
status_messages = []
|
| 415 |
|
|
@@ -421,14 +479,15 @@ def create_demo():
|
|
| 421 |
|
| 422 |
# Load only medical model (MedSwin) - TTS and Whisper load on-demand
|
| 423 |
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 424 |
-
logger.info(f"[STARTUP] Loading medical model: {DEFAULT_MEDICAL_MODEL}...")
|
| 425 |
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
|
| 426 |
try:
|
| 427 |
-
|
|
|
|
| 428 |
# Verify model is actually loaded
|
| 429 |
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 430 |
-
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded")
|
| 431 |
-
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded successfully!")
|
| 432 |
else:
|
| 433 |
status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
|
| 434 |
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
|
|
@@ -573,15 +632,22 @@ def create_demo():
|
|
| 573 |
# Load medical model on startup and update status
|
| 574 |
# Use a wrapper to handle GPU context properly with retry logic
|
| 575 |
def load_startup_and_update_ui():
|
| 576 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
import time
|
| 578 |
max_retries = 3
|
| 579 |
base_delay = 5.0 # Start with 5 seconds delay
|
| 580 |
|
| 581 |
for attempt in range(1, max_retries + 1):
|
| 582 |
try:
|
| 583 |
-
logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model...")
|
| 584 |
-
|
|
|
|
| 585 |
# Check if model is ready and update submit button state
|
| 586 |
is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
|
| 587 |
if is_ready:
|
|
|
|
| 406 |
return status_text, is_ready
|
| 407 |
|
| 408 |
# GPU-decorated function to load ONLY medical model on startup
|
| 409 |
+
# According to ZeroGPU best practices:
|
| 410 |
+
# 1. Load models to CPU in global scope (no GPU decorator needed)
|
| 411 |
+
# 2. Move models to GPU only in inference functions (with @spaces.GPU decorator)
|
| 412 |
+
# However, for large models, loading to CPU then moving to GPU uses more memory
|
| 413 |
+
# So we use a hybrid approach: load to GPU directly but within GPU-decorated function
|
| 414 |
+
|
| 415 |
+
def load_medical_model_on_startup_cpu():
|
| 416 |
+
"""
|
| 417 |
+
Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed)
|
| 418 |
+
Model will be moved to GPU during first inference
|
| 419 |
+
"""
|
| 420 |
+
status_messages = []
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
# Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand
|
| 424 |
+
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 425 |
+
logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...")
|
| 426 |
+
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
|
| 427 |
+
try:
|
| 428 |
+
# Load to CPU (no GPU decorator needed)
|
| 429 |
+
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False)
|
| 430 |
+
# Verify model is actually loaded
|
| 431 |
+
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 432 |
+
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU")
|
| 433 |
+
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!")
|
| 434 |
+
else:
|
| 435 |
+
status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
|
| 436 |
+
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
|
| 437 |
+
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
|
| 438 |
+
except Exception as e:
|
| 439 |
+
status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
|
| 440 |
+
logger.error(f"[STARTUP] Failed to load medical model: {e}")
|
| 441 |
+
import traceback
|
| 442 |
+
logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
|
| 443 |
+
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
|
| 444 |
+
else:
|
| 445 |
+
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
|
| 446 |
+
logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
|
| 447 |
+
|
| 448 |
+
# Add ASR status (will load on first use)
|
| 449 |
+
if WHISPER_AVAILABLE:
|
| 450 |
+
status_messages.append("⏳ ASR (Whisper): will load on first use")
|
| 451 |
+
else:
|
| 452 |
+
status_messages.append("❌ ASR: library not available")
|
| 453 |
+
|
| 454 |
+
# Return status
|
| 455 |
+
status_text = "\n".join(status_messages)
|
| 456 |
+
logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
|
| 457 |
+
return status_text
|
| 458 |
+
|
| 459 |
+
except Exception as e:
|
| 460 |
+
error_msg = str(e)
|
| 461 |
+
logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}")
|
| 462 |
+
return f"⚠️ Error loading model: {error_msg[:100]}"
|
| 463 |
+
|
| 464 |
+
# Alternative: Load directly to GPU (requires GPU decorator)
|
| 465 |
+
# @spaces.GPU(max_duration=MAX_DURATION)
|
| 466 |
+
def load_medical_model_on_startup_gpu():
|
| 467 |
+
"""
|
| 468 |
+
Load model directly to GPU on startup (alternative approach)
|
| 469 |
+
Uses GPU quota but model is immediately ready for inference
|
| 470 |
+
"""
|
| 471 |
import torch
|
| 472 |
status_messages = []
|
| 473 |
|
|
|
|
| 479 |
|
| 480 |
# Load only medical model (MedSwin) - TTS and Whisper load on-demand
|
| 481 |
if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 482 |
+
logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...")
|
| 483 |
set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
|
| 484 |
try:
|
| 485 |
+
# Load directly to GPU (within GPU-decorated function)
|
| 486 |
+
initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True)
|
| 487 |
# Verify model is actually loaded
|
| 488 |
if is_model_loaded(DEFAULT_MEDICAL_MODEL):
|
| 489 |
+
status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU")
|
| 490 |
+
logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!")
|
| 491 |
else:
|
| 492 |
status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
|
| 493 |
logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
|
|
|
|
| 632 |
# Load medical model on startup and update status
|
| 633 |
# Use a wrapper to handle GPU context properly with retry logic
|
| 634 |
def load_startup_and_update_ui():
|
| 635 |
+
"""
|
| 636 |
+
Load model on startup with retry logic (max 3 attempts) and return status with UI updates
|
| 637 |
+
|
| 638 |
+
Uses CPU-first approach (ZeroGPU best practice):
|
| 639 |
+
- Load model to CPU (no GPU decorator needed, avoids quota issues)
|
| 640 |
+
- Model will be moved to GPU during first inference
|
| 641 |
+
"""
|
| 642 |
import time
|
| 643 |
max_retries = 3
|
| 644 |
base_delay = 5.0 # Start with 5 seconds delay
|
| 645 |
|
| 646 |
for attempt in range(1, max_retries + 1):
|
| 647 |
try:
|
| 648 |
+
logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...")
|
| 649 |
+
# Use CPU-first approach (no GPU decorator, avoids quota issues)
|
| 650 |
+
status_text = load_medical_model_on_startup_cpu()
|
| 651 |
# Check if model is ready and update submit button state
|
| 652 |
is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
|
| 653 |
if is_ready:
|