Y Phung Nguyen commited on
Commit
f7415cc
·
1 Parent(s): 09d7494

Fix model preloader

Browse files
Files changed (3) hide show
  1. models.py +115 -20
  2. pipeline.py +28 -1
  3. ui.py +77 -11
models.py CHANGED
@@ -57,26 +57,56 @@ def is_model_loaded(model_name: str) -> bool:
57
  config.global_medical_models[model_name] is not None and
58
  _model_loading_states.get(model_name) == "loaded")
59
 
60
- def initialize_medical_model(model_name: str):
61
- """Initialize medical model (MedSwin) - download on demand"""
 
 
 
 
 
 
 
 
 
 
62
  if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
63
  set_model_loading_state(model_name, "loading")
64
- logger.info(f"Initializing medical model: {model_name}...")
65
  try:
66
- # Clear GPU cache before loading to prevent memory issues
67
- if torch.cuda.is_available():
68
- torch.cuda.empty_cache()
69
- logger.debug("Cleared GPU cache before model loading")
70
-
71
  model_path = config.MEDSWIN_MODELS[model_name]
72
  tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
73
- model = AutoModelForCausalLM.from_pretrained(
74
- model_path,
75
- device_map="auto",
76
- trust_remote_code=True,
77
- token=config.HF_TOKEN,
78
- torch_dtype=torch.float16
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Set models in config BEFORE setting state to "loaded"
81
  config.global_medical_models[model_name] = model
82
  config.global_medical_tokenizers[model_name] = tokenizer
@@ -87,11 +117,6 @@ def initialize_medical_model(model_name: str):
87
  # Verify the state was set correctly
88
  if not is_model_loaded(model_name):
89
  logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}")
90
-
91
- # Clear cache after loading to free up temporary memory
92
- if torch.cuda.is_available():
93
- torch.cuda.empty_cache()
94
- logger.debug("Cleared GPU cache after model loading")
95
  except Exception as e:
96
  set_model_loading_state(model_name, "error")
97
  logger.error(f"Failed to initialize medical model {model_name}: {e}")
@@ -106,6 +131,76 @@ def initialize_medical_model(model_name: str):
106
  set_model_loading_state(model_name, "loaded")
107
  return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def initialize_tts_model():
110
  """Initialize TTS model for text-to-speech"""
111
  if not TTS_AVAILABLE:
 
57
  config.global_medical_models[model_name] is not None and
58
  _model_loading_states.get(model_name) == "loaded")
59
 
60
+ def initialize_medical_model(model_name: str, load_to_gpu: bool = True):
61
+ """
62
+ Initialize medical model (MedSwin) - download on demand
63
+
64
+ According to ZeroGPU best practices:
65
+ - If load_to_gpu=True: Load directly to GPU using device_map="auto" (must be called within @spaces.GPU decorated function)
66
+ - If load_to_gpu=False: Load to CPU first, then move to GPU in inference function
67
+
68
+ Args:
69
+ model_name: Name of the model to load
70
+ load_to_gpu: If True, load directly to GPU. If False, load to CPU (for ZeroGPU best practices)
71
+ """
72
  if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None:
73
  set_model_loading_state(model_name, "loading")
74
+ logger.info(f"Initializing medical model: {model_name}... (load_to_gpu={load_to_gpu})")
75
  try:
 
 
 
 
 
76
  model_path = config.MEDSWIN_MODELS[model_name]
77
  tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN)
78
+
79
+ if load_to_gpu:
80
+ # Load directly to GPU (must be within @spaces.GPU decorated function)
81
+ # Clear GPU cache before loading to prevent memory issues
82
+ if torch.cuda.is_available():
83
+ torch.cuda.empty_cache()
84
+ logger.debug("Cleared GPU cache before model loading")
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ model_path,
88
+ device_map="auto", # Automatically places model on GPU
89
+ trust_remote_code=True,
90
+ token=config.HF_TOKEN,
91
+ torch_dtype=torch.float16
92
+ )
93
+
94
+ # Clear cache after loading
95
+ if torch.cuda.is_available():
96
+ torch.cuda.empty_cache()
97
+ logger.debug("Cleared GPU cache after model loading")
98
+ else:
99
+ # Load to CPU first (ZeroGPU best practice - no GPU decorator needed)
100
+ logger.info(f"Loading {model_name} to CPU (will move to GPU during inference)...")
101
+ model = AutoModelForCausalLM.from_pretrained(
102
+ model_path,
103
+ device_map="cpu", # Load to CPU
104
+ trust_remote_code=True,
105
+ token=config.HF_TOKEN,
106
+ torch_dtype=torch.float16
107
+ )
108
+ logger.info(f"Model {model_name} loaded to CPU successfully")
109
+
110
  # Set models in config BEFORE setting state to "loaded"
111
  config.global_medical_models[model_name] = model
112
  config.global_medical_tokenizers[model_name] = tokenizer
 
117
  # Verify the state was set correctly
118
  if not is_model_loaded(model_name):
119
  logger.warning(f"Model {model_name} initialized but is_model_loaded() returns False. State: {get_model_loading_state(model_name)}, in dict: {model_name in config.global_medical_models}")
 
 
 
 
 
120
  except Exception as e:
121
  set_model_loading_state(model_name, "error")
122
  logger.error(f"Failed to initialize medical model {model_name}: {e}")
 
131
  set_model_loading_state(model_name, "loaded")
132
  return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name]
133
 
134
+ def move_model_to_gpu(model_name: str):
135
+ """
136
+ Move a model from CPU to GPU (for ZeroGPU best practices)
137
+ Must be called within a @spaces.GPU decorated function
138
+
139
+ According to ZeroGPU best practices:
140
+ - Models should be loaded to CPU first (no GPU quota used)
141
+ - Models are moved to GPU only during inference (within @spaces.GPU decorated function)
142
+ """
143
+ if model_name not in config.global_medical_models:
144
+ raise ValueError(f"Model {model_name} not found in config")
145
+
146
+ model = config.global_medical_models[model_name]
147
+ if model is None:
148
+ raise ValueError(f"Model {model_name} is None")
149
+
150
+ # Check if model is already on GPU
151
+ try:
152
+ # For models with device_map, check the actual device
153
+ if hasattr(model, 'device'):
154
+ device_str = str(model.device)
155
+ if 'cuda' in device_str.lower():
156
+ logger.debug(f"Model {model_name} is already on GPU ({device_str})")
157
+ return model
158
+
159
+ # Check device_map if available
160
+ if hasattr(model, 'hf_device_map'):
161
+ device_map = model.hf_device_map
162
+ if isinstance(device_map, dict):
163
+ # Check if any device is GPU
164
+ if any('cuda' in str(dev).lower() for dev in device_map.values()):
165
+ logger.debug(f"Model {model_name} is already on GPU (device_map)")
166
+ return model
167
+ except Exception as e:
168
+ logger.debug(f"Could not check model device: {e}")
169
+
170
+ # Move model to GPU
171
+ logger.info(f"Moving model {model_name} from CPU to GPU...")
172
+ if torch.cuda.is_available():
173
+ torch.cuda.empty_cache()
174
+
175
+ # For models loaded with device_map="cpu", we need to reload with device_map="auto"
176
+ # or use accelerate to dispatch to GPU
177
+ try:
178
+ # Try using accelerate's dispatch_model for proper GPU placement
179
+ from accelerate import dispatch_model
180
+ from accelerate.utils import get_balanced_memory, infer_auto_device_map
181
+
182
+ # Get device map for GPU
183
+ max_memory = get_balanced_memory(model, max_memory={0: "20GiB"})
184
+ device_map = infer_auto_device_map(model, max_memory=max_memory)
185
+ model = dispatch_model(model, device_map=device_map)
186
+ config.global_medical_models[model_name] = model
187
+ logger.info(f"Model {model_name} moved to GPU successfully using accelerate")
188
+ except Exception as e:
189
+ # Fallback: simple move to cuda (may not work for all model architectures)
190
+ logger.warning(f"Could not use accelerate dispatch, trying simple .to('cuda'): {e}")
191
+ try:
192
+ model = model.to('cuda')
193
+ config.global_medical_models[model_name] = model
194
+ logger.info(f"Model {model_name} moved to GPU (cuda) successfully")
195
+ except Exception as e2:
196
+ logger.error(f"Failed to move model {model_name} to GPU: {e2}")
197
+ raise
198
+
199
+ if torch.cuda.is_available():
200
+ torch.cuda.empty_cache()
201
+
202
+ return model
203
+
204
  def initialize_tts_model():
205
  """Initialize TTS model for text-to-speech"""
206
  if not TTS_AVAILABLE:
pipeline.py CHANGED
@@ -12,7 +12,7 @@ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_s
12
  from llama_index.core import Settings
13
  from llama_index.core.retrievers import AutoMergingRetriever
14
  from logger import logger, ThoughtCaptureHandler
15
- from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state
16
  from utils import detect_language, translate_text, format_url_as_domain
17
  from search import search_web, summarize_web_content
18
  from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
@@ -380,6 +380,33 @@ def stream_chat(
380
  yield history + [{"role": "assistant", "content": error_msg}], ""
381
  return
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  thought_handler = None
384
  if show_thoughts:
385
  thought_handler = ThoughtCaptureHandler()
 
12
  from llama_index.core import Settings
13
  from llama_index.core.retrievers import AutoMergingRetriever
14
  from logger import logger, ThoughtCaptureHandler
15
+ from models import initialize_medical_model, get_or_create_embed_model, is_model_loaded, get_model_loading_state, set_model_loading_state, move_model_to_gpu
16
  from utils import detect_language, translate_text, format_url_as_domain
17
  from search import search_web, summarize_web_content
18
  from reasoning import autonomous_reasoning, create_execution_plan, autonomous_execution_strategy
 
380
  yield history + [{"role": "assistant", "content": error_msg}], ""
381
  return
382
 
383
+ # ZeroGPU best practice: If model is on CPU, move it to GPU now (we're in a GPU-decorated function)
384
+ # This ensures the model is ready for inference without consuming GPU quota during startup
385
+ try:
386
+ import config
387
+ if medical_model in config.global_medical_models:
388
+ model = config.global_medical_models[medical_model]
389
+ if model is not None:
390
+ # Check if model is on CPU (device_map="cpu" or device is CPU)
391
+ model_on_cpu = False
392
+ if hasattr(model, 'device'):
393
+ if str(model.device) == 'cpu':
394
+ model_on_cpu = True
395
+ elif hasattr(model, 'hf_device_map'):
396
+ # Model loaded with device_map - check if it's on CPU
397
+ if isinstance(model.hf_device_map, dict):
398
+ # If all devices are CPU, move to GPU
399
+ if all('cpu' in str(dev).lower() for dev in model.hf_device_map.values()):
400
+ model_on_cpu = True
401
+
402
+ if model_on_cpu:
403
+ logger.info(f"[STREAM_CHAT] Model {medical_model} is on CPU, moving to GPU for inference...")
404
+ move_model_to_gpu(medical_model)
405
+ logger.info(f"[STREAM_CHAT] ✅ Model {medical_model} moved to GPU successfully")
406
+ except Exception as e:
407
+ logger.warning(f"[STREAM_CHAT] Could not move model to GPU (may already be on GPU): {e}")
408
+ # Continue anyway - model might already be on GPU
409
+
410
  thought_handler = None
411
  if show_thoughts:
412
  thought_handler = ThoughtCaptureHandler()
ui.py CHANGED
@@ -406,10 +406,68 @@ def create_demo():
406
  return status_text, is_ready
407
 
408
  # GPU-decorated function to load ONLY medical model on startup
409
- # TTS and Whisper load on-demand to avoid GPU conflicts and reduce startup time
410
- @spaces.GPU(max_duration=MAX_DURATION)
411
- def load_medical_model_on_startup():
412
- """Load only the default medical model on startup to avoid GPU conflicts"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  import torch
414
  status_messages = []
415
 
@@ -421,14 +479,15 @@ def create_demo():
421
 
422
  # Load only medical model (MedSwin) - TTS and Whisper load on-demand
423
  if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
424
- logger.info(f"[STARTUP] Loading medical model: {DEFAULT_MEDICAL_MODEL}...")
425
  set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
426
  try:
427
- initialize_medical_model(DEFAULT_MEDICAL_MODEL)
 
428
  # Verify model is actually loaded
429
  if is_model_loaded(DEFAULT_MEDICAL_MODEL):
430
- status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded")
431
- logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded successfully!")
432
  else:
433
  status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
434
  logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
@@ -573,15 +632,22 @@ def create_demo():
573
  # Load medical model on startup and update status
574
  # Use a wrapper to handle GPU context properly with retry logic
575
  def load_startup_and_update_ui():
576
- """Load model on startup with retry logic (max 3 attempts) and return status with UI updates"""
 
 
 
 
 
 
577
  import time
578
  max_retries = 3
579
  base_delay = 5.0 # Start with 5 seconds delay
580
 
581
  for attempt in range(1, max_retries + 1):
582
  try:
583
- logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model...")
584
- status_text = load_medical_model_on_startup()
 
585
  # Check if model is ready and update submit button state
586
  is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
587
  if is_ready:
 
406
  return status_text, is_ready
407
 
408
  # GPU-decorated function to load ONLY medical model on startup
409
+ # According to ZeroGPU best practices:
410
+ # 1. Load models to CPU in global scope (no GPU decorator needed)
411
+ # 2. Move models to GPU only in inference functions (with @spaces.GPU decorator)
412
+ # However, for large models, loading to CPU then moving to GPU uses more memory
413
+ # So we use a hybrid approach: load to GPU directly but within GPU-decorated function
414
+
415
+ def load_medical_model_on_startup_cpu():
416
+ """
417
+ Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed)
418
+ Model will be moved to GPU during first inference
419
+ """
420
+ status_messages = []
421
+
422
+ try:
423
+ # Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand
424
+ if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
425
+ logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...")
426
+ set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
427
+ try:
428
+ # Load to CPU (no GPU decorator needed)
429
+ initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False)
430
+ # Verify model is actually loaded
431
+ if is_model_loaded(DEFAULT_MEDICAL_MODEL):
432
+ status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU")
433
+ logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!")
434
+ else:
435
+ status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
436
+ logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
437
+ set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
438
+ except Exception as e:
439
+ status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}")
440
+ logger.error(f"[STARTUP] Failed to load medical model: {e}")
441
+ import traceback
442
+ logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}")
443
+ set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error")
444
+ else:
445
+ status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded")
446
+ logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded")
447
+
448
+ # Add ASR status (will load on first use)
449
+ if WHISPER_AVAILABLE:
450
+ status_messages.append("⏳ ASR (Whisper): will load on first use")
451
+ else:
452
+ status_messages.append("❌ ASR: library not available")
453
+
454
+ # Return status
455
+ status_text = "\n".join(status_messages)
456
+ logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}")
457
+ return status_text
458
+
459
+ except Exception as e:
460
+ error_msg = str(e)
461
+ logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}")
462
+ return f"⚠️ Error loading model: {error_msg[:100]}"
463
+
464
+ # Alternative: Load directly to GPU (requires GPU decorator)
465
+ # @spaces.GPU(max_duration=MAX_DURATION)
466
+ def load_medical_model_on_startup_gpu():
467
+ """
468
+ Load model directly to GPU on startup (alternative approach)
469
+ Uses GPU quota but model is immediately ready for inference
470
+ """
471
  import torch
472
  status_messages = []
473
 
 
479
 
480
  # Load only medical model (MedSwin) - TTS and Whisper load on-demand
481
  if not is_model_loaded(DEFAULT_MEDICAL_MODEL):
482
+ logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...")
483
  set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading")
484
  try:
485
+ # Load directly to GPU (within GPU-decorated function)
486
+ initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True)
487
  # Verify model is actually loaded
488
  if is_model_loaded(DEFAULT_MEDICAL_MODEL):
489
+ status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU")
490
+ logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!")
491
  else:
492
  status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed")
493
  logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded")
 
632
  # Load medical model on startup and update status
633
  # Use a wrapper to handle GPU context properly with retry logic
634
  def load_startup_and_update_ui():
635
+ """
636
+ Load model on startup with retry logic (max 3 attempts) and return status with UI updates
637
+
638
+ Uses CPU-first approach (ZeroGPU best practice):
639
+ - Load model to CPU (no GPU decorator needed, avoids quota issues)
640
+ - Model will be moved to GPU during first inference
641
+ """
642
  import time
643
  max_retries = 3
644
  base_delay = 5.0 # Start with 5 seconds delay
645
 
646
  for attempt in range(1, max_retries + 1):
647
  try:
648
+ logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...")
649
+ # Use CPU-first approach (no GPU decorator, avoids quota issues)
650
+ status_text = load_medical_model_on_startup_cpu()
651
  # Check if model is ready and update submit button state
652
  is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL)
653
  if is_ready: