"""Model initialization and management""" import torch import threading from transformers import AutoModelForCausalLM, AutoTokenizer from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.embeddings.huggingface import HuggingFaceEmbedding from logger import logger import config try: from TTS.api import TTS TTS_AVAILABLE = True except ImportError: TTS_AVAILABLE = False TTS = None # Model loading state tracking _model_loading_states = {} _model_loading_lock = threading.Lock() @spaces.GPU(max_duration=120) def set_model_loading_state(model_name: str, state: str): """Set model loading state: 'loading', 'loaded', 'error'""" with _model_loading_lock: _model_loading_states[model_name] = state logger.debug(f"Model {model_name} state set to: {state}") @spaces.GPU(max_duration=120) def get_model_loading_state(model_name: str) -> str: """Get model loading state: 'loading', 'loaded', 'error', or 'unknown'""" with _model_loading_lock: return _model_loading_states.get(model_name, "unknown") def is_model_loaded(model_name: str) -> bool: """Check if model is loaded and ready""" with _model_loading_lock: return (model_name in config.global_medical_models and config.global_medical_models[model_name] is not None and _model_loading_states.get(model_name) == "loaded") @spaces.GPU(max_duration=120) def initialize_medical_model(model_name: str): """Initialize medical model (MedSwin) - download on demand""" if model_name not in config.global_medical_models or config.global_medical_models[model_name] is None: set_model_loading_state(model_name, "loading") logger.info(f"Initializing medical model: {model_name}...") try: model_path = config.MEDSWIN_MODELS[model_name] tokenizer = AutoTokenizer.from_pretrained(model_path, token=config.HF_TOKEN) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", trust_remote_code=True, token=config.HF_TOKEN, torch_dtype=torch.float16 ) config.global_medical_models[model_name] = model config.global_medical_tokenizers[model_name] = tokenizer set_model_loading_state(model_name, "loaded") logger.info(f"Medical model {model_name} initialized successfully") except Exception as e: set_model_loading_state(model_name, "error") logger.error(f"Failed to initialize medical model {model_name}: {e}") raise else: # Model already loaded, ensure state is set if get_model_loading_state(model_name) != "loaded": set_model_loading_state(model_name, "loaded") return config.global_medical_models[model_name], config.global_medical_tokenizers[model_name] @spaces.GPU(max_duration=120) def initialize_tts_model(): """Initialize TTS model for text-to-speech""" if not TTS_AVAILABLE: logger.warning("TTS library not installed. TTS features will be disabled.") return None if config.global_tts_model is None: try: logger.info("Initializing TTS model for voice generation...") config.global_tts_model = TTS(model_name=config.TTS_MODEL, progress_bar=False) logger.info("TTS model initialized successfully") except Exception as e: logger.warning(f"TTS model initialization failed: {e}") logger.warning("TTS features will be disabled. If pyworld dependency is missing, try: pip install TTS --no-deps && pip install coqui-tts") config.global_tts_model = None return config.global_tts_model @spaces.GPU(max_duration=120) def get_or_create_embed_model(): """Reuse embedding model to avoid reloading weights each request""" if config.global_embed_model is None: logger.info("Initializing shared embedding model for RAG retrieval...") config.global_embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL, token=config.HF_TOKEN) return config.global_embed_model @spaces.GPU(max_duration=120) def get_llm_for_rag(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): """Get LLM for RAG indexing (uses medical model)""" medical_model_obj, medical_tokenizer = initialize_medical_model(config.DEFAULT_MEDICAL_MODEL) return HuggingFaceLLM( context_window=4096, max_new_tokens=max_new_tokens, tokenizer=medical_tokenizer, model=medical_model_obj, generate_kwargs={ "do_sample": True, "temperature": temperature, "top_k": top_k, "top_p": top_p } )