""" MuseTalk HTTP API Server Keeps models loaded in GPU memory for fast inference. """ import os import cv2 import copy import torch import glob import shutil import pickle import numpy as np import subprocess import tempfile import hashlib import time from pathlib import Path from typing import Optional from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from tqdm import tqdm from omegaconf import OmegaConf from transformers import WhisperModel import uvicorn # MuseTalk imports from musetalk.utils.blending import get_image from musetalk.utils.face_parsing import FaceParsing from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder class MuseTalkServer: """Singleton server that keeps models loaded in GPU memory.""" def __init__(self): self.device = None self.vae = None self.unet = None self.pe = None self.whisper = None self.audio_processor = None self.fp = None self.timesteps = None self.weight_dtype = None self.is_loaded = False # Cache directories self.cache_dir = Path("./cache") self.cache_dir.mkdir(exist_ok=True) self.landmarks_cache = self.cache_dir / "landmarks" self.latents_cache = self.cache_dir / "latents" self.whisper_cache = self.cache_dir / "whisper_features" self.landmarks_cache.mkdir(exist_ok=True) self.latents_cache.mkdir(exist_ok=True) self.whisper_cache.mkdir(exist_ok=True) # Config self.fps = 25 self.batch_size = 8 self.use_float16 = True self.version = "v15" self.extra_margin = 10 self.parsing_mode = "jaw" self.left_cheek_width = 90 self.right_cheek_width = 90 self.audio_padding_left = 2 self.audio_padding_right = 2 def load_models( self, gpu_id: int = 0, unet_model_path: str = "./models/musetalkV15/unet.pth", unet_config: str = "./models/musetalk/config.json", vae_type: str = "sd-vae", whisper_dir: str = "./models/whisper", use_float16: bool = True, version: str = "v15" ): """Load all models into GPU memory.""" if self.is_loaded: print("Models already loaded!") return print("=" * 50) print("Loading MuseTalk models into GPU memory...") print("=" * 50) start_time = time.time() # Set device self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # Load model weights print("Loading VAE, UNet, PE...") self.vae, self.unet, self.pe = load_all_model( unet_model_path=unet_model_path, vae_type=vae_type, unet_config=unet_config, device=self.device ) self.timesteps = torch.tensor([0], device=self.device) # Convert to float16 if enabled self.use_float16 = use_float16 if use_float16: print("Converting to float16...") self.pe = self.pe.half() self.vae.vae = self.vae.vae.half() self.unet.model = self.unet.model.half() # Move to device self.pe = self.pe.to(self.device) self.vae.vae = self.vae.vae.to(self.device) self.unet.model = self.unet.model.to(self.device) # Initialize audio processor and Whisper print("Loading Whisper model...") self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) self.weight_dtype = self.unet.model.dtype self.whisper = WhisperModel.from_pretrained(whisper_dir) self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() self.whisper.requires_grad_(False) # Initialize face parser self.version = version if version == "v15": self.fp = FaceParsing( left_cheek_width=self.left_cheek_width, right_cheek_width=self.right_cheek_width ) else: self.fp = FaceParsing() self.is_loaded = True load_time = time.time() - start_time print(f"Models loaded in {load_time:.2f}s") print("=" * 50) print("Server ready for inference!") print("=" * 50) def _get_file_hash(self, file_path: str) -> str: """Get MD5 hash of a file for caching.""" hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest()[:16] def _get_cached_landmarks(self, video_hash: str, bbox_shift: int): """Get cached landmarks if available.""" # Disabled due to tensor comparison issues return None def _save_landmarks_cache(self, video_hash: str, bbox_shift: int, coord_list, frame_list): """Save landmarks to cache.""" cache_file = self.landmarks_cache / f"{video_hash}_shift{bbox_shift}.pkl" with open(cache_file, 'wb') as f: pickle.dump((coord_list, frame_list), f) def _get_cached_latents(self, video_hash: str): """Get cached VAE latents if available.""" # Disabled due to tensor comparison issues return None def _save_latents_cache(self, video_hash: str, latent_list): """Save VAE latents to cache.""" cache_file = self.latents_cache / f"{video_hash}.pkl" with open(cache_file, 'wb') as f: pickle.dump(latent_list, f) def _get_cached_whisper(self, audio_hash: str): """Get cached Whisper features if available.""" # Disabled due to tensor comparison issues return None def _save_whisper_cache(self, audio_hash: str, whisper_data): """Save Whisper features to cache.""" cache_file = self.whisper_cache / f"{audio_hash}.pkl" with open(cache_file, 'wb') as f: pickle.dump(whisper_data, f) @torch.no_grad() def generate( self, video_path: str, audio_path: str, output_path: str, fps: Optional[int] = None, use_cache: bool = True ) -> dict: """ Generate lip-synced video. Returns dict with timing info. """ if not self.is_loaded: raise RuntimeError("Models not loaded! Call load_models() first.") fps = fps or self.fps timings = {"total": 0} total_start = time.time() # Get file hashes for caching video_hash = self._get_file_hash(video_path) audio_hash = self._get_file_hash(audio_path) # Create temp directory temp_dir = tempfile.mkdtemp() try: # 1. Extract frames t0 = time.time() input_basename = Path(video_path).stem save_dir_full = os.path.join(temp_dir, "frames") os.makedirs(save_dir_full, exist_ok=True) if get_file_type(video_path) == "video": cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={fps} -start_number 0 {save_dir_full}/%08d.png" os.system(cmd) input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) elif get_file_type(video_path) == "image": input_img_list = [video_path] else: raise ValueError(f"Unsupported video type: {video_path}") timings["frame_extraction"] = time.time() - t0 # 2. Extract audio features (with caching) t0 = time.time() cached_whisper = self._get_cached_whisper(audio_hash) if use_cache else None if cached_whisper: whisper_chunks = cached_whisper timings["whisper_source"] = "cache" else: whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) whisper_chunks = self.audio_processor.get_whisper_chunk( whisper_input_features, self.device, self.weight_dtype, self.whisper, librosa_length, fps=fps, audio_padding_length_left=self.audio_padding_left, audio_padding_length_right=self.audio_padding_right, ) if use_cache: self._save_whisper_cache(audio_hash, whisper_chunks) timings["whisper_source"] = "computed" timings["whisper_features"] = time.time() - t0 # 3. Get landmarks (with caching) t0 = time.time() bbox_shift = 0 if self.version == "v15" else 0 cache_key = f"{video_hash}_{fps}" cached_landmarks = self._get_cached_landmarks(cache_key, bbox_shift) if use_cache else None if cached_landmarks: coord_list, frame_list = cached_landmarks timings["landmarks_source"] = "cache" else: coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) if use_cache: self._save_landmarks_cache(cache_key, bbox_shift, coord_list, frame_list) timings["landmarks_source"] = "computed" timings["landmarks"] = time.time() - t0 # 4. Compute VAE latents (with caching) t0 = time.time() latent_cache_key = f"{video_hash}_{fps}_{self.version}" cached_latents = self._get_cached_latents(latent_cache_key) if use_cache else None if cached_latents: input_latent_list = cached_latents timings["latents_source"] = "cache" else: input_latent_list = [] for bbox, frame in zip(coord_list, frame_list): if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder): continue x1, y1, x2, y2 = bbox if self.version == "v15": y2 = y2 + self.extra_margin y2 = min(y2, frame.shape[0]) crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) latents = self.vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) if use_cache: self._save_latents_cache(latent_cache_key, input_latent_list) timings["latents_source"] = "computed" timings["vae_encoding"] = time.time() - t0 # 5. Prepare cycled lists frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] # 6. UNet inference t0 = time.time() video_num = len(whisper_chunks) gen = datagen( whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=self.batch_size, delay_frame=0, device=self.device, ) res_frame_list = [] for whisper_batch, latent_batch in gen: audio_feature_batch = self.pe(whisper_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) pred_latents = self.unet.model( latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch ).sample recon = self.vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) timings["unet_inference"] = time.time() - t0 # 7. Face blending t0 = time.time() result_img_path = os.path.join(temp_dir, "results") os.makedirs(result_img_path, exist_ok=True) for i, res_frame in enumerate(res_frame_list): bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) x1, y1, x2, y2 = bbox if self.version == "v15": y2 = y2 + self.extra_margin y2 = min(y2, ori_frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) except: continue if self.version == "v15": combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=self.parsing_mode, fp=self.fp) else: combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) timings["face_blending"] = time.time() - t0 # 8. Encode video t0 = time.time() temp_vid = os.path.join(temp_dir, "temp.mp4") cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" os.system(cmd_img2video) cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}" os.system(cmd_combine) timings["video_encoding"] = time.time() - t0 finally: # Cleanup shutil.rmtree(temp_dir, ignore_errors=True) timings["total"] = time.time() - total_start timings["frames_generated"] = len(res_frame_list) return timings # Global server instance server = MuseTalkServer() # FastAPI app app = FastAPI( title="MuseTalk API", description="HTTP API for MuseTalk lip-sync generation", version="1.0.0" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): """Load models on server startup.""" server.load_models() @app.get("/health") async def health_check(): """Check if server is ready.""" return { "status": "ok" if server.is_loaded else "loading", "models_loaded": server.is_loaded, "device": str(server.device) if server.device else None } @app.get("/cache/stats") async def cache_stats(): """Get cache statistics.""" landmarks_count = len(list(server.landmarks_cache.glob("*.pkl"))) latents_count = len(list(server.latents_cache.glob("*.pkl"))) whisper_count = len(list(server.whisper_cache.glob("*.pkl"))) return { "landmarks_cached": landmarks_count, "latents_cached": latents_count, "whisper_features_cached": whisper_count } @app.post("/cache/clear") async def clear_cache(): """Clear all caches.""" for cache_dir in [server.landmarks_cache, server.latents_cache, server.whisper_cache]: for f in cache_dir.glob("*.pkl"): f.unlink() return {"status": "cleared"} class GenerateRequest(BaseModel): video_path: str audio_path: str output_path: str fps: Optional[int] = 25 use_cache: bool = True @app.post("/generate") async def generate_from_paths(request: GenerateRequest): """ Generate lip-synced video from file paths. Use this when files are already on the server. """ if not server.is_loaded: raise HTTPException(status_code=503, detail="Models not loaded yet") if not os.path.exists(request.video_path): raise HTTPException(status_code=404, detail=f"Video not found: {request.video_path}") if not os.path.exists(request.audio_path): raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") try: timings = server.generate( video_path=request.video_path, audio_path=request.audio_path, output_path=request.output_path, fps=request.fps, use_cache=request.use_cache ) return { "status": "success", "output_path": request.output_path, "timings": timings } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate/upload") async def generate_from_upload( video: UploadFile = File(...), audio: UploadFile = File(...), fps: int = Form(25), use_cache: bool = Form(True) ): """ Generate lip-synced video from uploaded files. Returns the generated video file. """ if not server.is_loaded: raise HTTPException(status_code=503, detail="Models not loaded yet") # Save uploaded files temp_dir = tempfile.mkdtemp() try: video_path = os.path.join(temp_dir, video.filename) audio_path = os.path.join(temp_dir, audio.filename) output_path = os.path.join(temp_dir, "output.mp4") with open(video_path, "wb") as f: f.write(await video.read()) with open(audio_path, "wb") as f: f.write(await audio.read()) timings = server.generate( video_path=video_path, audio_path=audio_path, output_path=output_path, fps=fps, use_cache=use_cache ) # Return the video file return FileResponse( output_path, media_type="video/mp4", filename="result.mp4", headers={"X-Timings": str(timings)} ) except Exception as e: shutil.rmtree(temp_dir, ignore_errors=True) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="MuseTalk API Server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind") parser.add_argument("--port", type=int, default=8000, help="Port to bind") parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID") parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth") parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json") parser.add_argument("--whisper_dir", type=str, default="./models/whisper") parser.add_argument("--no_float16", action="store_true", help="Disable float16") args = parser.parse_args() # Pre-configure server server.load_models( gpu_id=args.gpu_id, unet_model_path=args.unet_model_path, unet_config=args.unet_config, whisper_dir=args.whisper_dir, use_float16=not args.no_float16 ) # Start server uvicorn.run(app, host=args.host, port=args.port)