""" MuseTalk HTTP API Server v3 Ultra-optimized with: 1. GPU-accelerated face blending (parallel processing) 2. NVENC hardware video encoding 3. Batch audio processing """ import os import cv2 import copy import torch import glob import shutil import pickle import numpy as np import subprocess import tempfile import hashlib import time import asyncio from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor from pathlib import Path from typing import Optional, List from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from tqdm import tqdm from omegaconf import OmegaConf from transformers import WhisperModel import uvicorn import multiprocessing as mp # MuseTalk imports from musetalk.utils.blending import get_image from musetalk.utils.face_parsing import FaceParsing from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder def blend_single_frame(args): """Worker function for parallel face blending.""" i, res_frame, bbox, ori_frame, extra_margin, version, parsing_mode, fp_config = args x1, y1, x2, y2 = bbox if version == "v15": y2 = y2 + extra_margin y2 = min(y2, ori_frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) except: return i, None # Create FaceParsing instance for this worker fp = FaceParsing( left_cheek_width=fp_config['left_cheek_width'], right_cheek_width=fp_config['right_cheek_width'] ) if version == "v15": combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=parsing_mode, fp=fp) else: combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) return i, combine_frame class MuseTalkServerV3: """Ultra-optimized server.""" def __init__(self): self.device = None self.vae = None self.unet = None self.pe = None self.whisper = None self.audio_processor = None self.fp = None self.timesteps = None self.weight_dtype = None self.is_loaded = False # Avatar cache self.loaded_avatars = {} self.avatar_dir = Path("./avatars") # Config self.fps = 25 self.batch_size = 8 self.use_float16 = True self.version = "v15" self.extra_margin = 10 self.parsing_mode = "jaw" self.left_cheek_width = 90 self.right_cheek_width = 90 self.audio_padding_left = 2 self.audio_padding_right = 2 # Thread pool for parallel blending self.num_workers = min(8, mp.cpu_count()) self.thread_pool = ThreadPoolExecutor(max_workers=self.num_workers) # NVENC settings self.use_nvenc = True self.nvenc_preset = "p4" # p1(fastest) to p7(best quality) self.crf = 23 def load_models( self, gpu_id: int = 0, unet_model_path: str = "./models/musetalkV15/unet.pth", unet_config: str = "./models/musetalk/config.json", vae_type: str = "sd-vae", whisper_dir: str = "./models/whisper", use_float16: bool = True, version: str = "v15" ): if self.is_loaded: print("Models already loaded!") return print("=" * 50) print("Loading MuseTalk models (v3 Ultra-Optimized)...") print("=" * 50) start_time = time.time() self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") print(f"Parallel workers: {self.num_workers}") print(f"NVENC encoding: {self.use_nvenc}") print("Loading VAE, UNet, PE...") self.vae, self.unet, self.pe = load_all_model( unet_model_path=unet_model_path, vae_type=vae_type, unet_config=unet_config, device=self.device ) self.timesteps = torch.tensor([0], device=self.device) self.use_float16 = use_float16 if use_float16: print("Converting to float16...") self.pe = self.pe.half() self.vae.vae = self.vae.vae.half() self.unet.model = self.unet.model.half() self.pe = self.pe.to(self.device) self.vae.vae = self.vae.vae.to(self.device) self.unet.model = self.unet.model.to(self.device) print("Loading Whisper model...") self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) self.weight_dtype = self.unet.model.dtype self.whisper = WhisperModel.from_pretrained(whisper_dir) self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() self.whisper.requires_grad_(False) self.version = version if version == "v15": self.fp = FaceParsing( left_cheek_width=self.left_cheek_width, right_cheek_width=self.right_cheek_width ) else: self.fp = FaceParsing() self.is_loaded = True print(f"Models loaded in {time.time() - start_time:.2f}s") print("=" * 50) def load_avatar(self, avatar_name: str) -> dict: if avatar_name in self.loaded_avatars: return self.loaded_avatars[avatar_name] avatar_path = self.avatar_dir / avatar_name if not avatar_path.exists(): raise FileNotFoundError(f"Avatar not found: {avatar_name}") print(f"Loading avatar '{avatar_name}' into memory...") t0 = time.time() avatar_data = {} with open(avatar_path / "metadata.pkl", 'rb') as f: avatar_data['metadata'] = pickle.load(f) with open(avatar_path / "coords.pkl", 'rb') as f: avatar_data['coord_list'] = pickle.load(f) with open(avatar_path / "frames.pkl", 'rb') as f: avatar_data['frame_list'] = pickle.load(f) with open(avatar_path / "latents.pkl", 'rb') as f: latents_np = pickle.load(f) avatar_data['latent_list'] = [ torch.from_numpy(l).to(self.device) for l in latents_np ] with open(avatar_path / "crop_info.pkl", 'rb') as f: avatar_data['crop_info'] = pickle.load(f) self.loaded_avatars[avatar_name] = avatar_data print(f"Avatar loaded in {time.time() - t0:.2f}s") return avatar_data def unload_avatar(self, avatar_name: str): if avatar_name in self.loaded_avatars: del self.loaded_avatars[avatar_name] torch.cuda.empty_cache() def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float: """Encode video using NVENC hardware acceleration.""" t0 = time.time() temp_vid = frames_dir.replace('/results', '/temp.mp4') if self.use_nvenc: # NVENC H.264 encoding (much faster) cmd_img2video = ( f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} " f"-pix_fmt yuv420p {temp_vid}" ) else: # Fallback to CPU encoding cmd_img2video = ( f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " f"-vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" ) os.system(cmd_img2video) # Add audio cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}" os.system(cmd_combine) # Cleanup temp video if os.path.exists(temp_vid): os.remove(temp_vid) return time.time() - t0 def _parallel_face_blending(self, res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path) -> float: """Parallel face blending using thread pool.""" t0 = time.time() fp_config = { 'left_cheek_width': self.left_cheek_width, 'right_cheek_width': self.right_cheek_width } # Prepare all tasks tasks = [] for i, res_frame in enumerate(res_frame_list): bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) tasks.append(( i, res_frame, bbox, ori_frame, self.extra_margin, self.version, self.parsing_mode, fp_config )) # Process in parallel results = list(self.thread_pool.map(blend_single_frame, tasks)) # Sort and save results results.sort(key=lambda x: x[0]) for i, combine_frame in results: if combine_frame is not None: cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) return time.time() - t0 @torch.no_grad() def generate_with_avatar( self, avatar_name: str, audio_path: str, output_path: str, fps: Optional[int] = None, use_parallel_blending: bool = True ) -> dict: """Generate video using pre-processed avatar with all optimizations.""" if not self.is_loaded: raise RuntimeError("Models not loaded!") fps = fps or self.fps timings = {} total_start = time.time() # Load avatar t0 = time.time() avatar = self.load_avatar(avatar_name) timings["avatar_load"] = time.time() - t0 coord_list = avatar['coord_list'] frame_list = avatar['frame_list'] input_latent_list = avatar['latent_list'] temp_dir = tempfile.mkdtemp() try: # 1. Extract audio features t0 = time.time() whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) whisper_chunks = self.audio_processor.get_whisper_chunk( whisper_input_features, self.device, self.weight_dtype, self.whisper, librosa_length, fps=fps, audio_padding_length_left=self.audio_padding_left, audio_padding_length_right=self.audio_padding_right, ) timings["whisper_features"] = time.time() - t0 # 2. Prepare cycled lists frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] # 3. UNet inference t0 = time.time() gen = datagen( whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=self.batch_size, delay_frame=0, device=self.device, ) res_frame_list = [] for whisper_batch, latent_batch in gen: audio_feature_batch = self.pe(whisper_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) pred_latents = self.unet.model( latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch ).sample recon = self.vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) timings["unet_inference"] = time.time() - t0 # 4. Face blending (parallel or sequential) result_img_path = os.path.join(temp_dir, "results") os.makedirs(result_img_path, exist_ok=True) if use_parallel_blending: timings["face_blending"] = self._parallel_face_blending( res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path ) timings["blending_mode"] = "parallel" else: t0 = time.time() for i, res_frame in enumerate(res_frame_list): bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) x1, y1, x2, y2 = bbox if self.version == "v15": y2 = y2 + self.extra_margin y2 = min(y2, ori_frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) except: continue if self.version == "v15": combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=self.parsing_mode, fp=self.fp) else: combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) timings["face_blending"] = time.time() - t0 timings["blending_mode"] = "sequential" # 5. Video encoding (NVENC) timings["video_encoding"] = self._encode_video_nvenc( result_img_path, audio_path, output_path, fps ) timings["encoding_mode"] = "nvenc" if self.use_nvenc else "cpu" finally: shutil.rmtree(temp_dir, ignore_errors=True) timings["total"] = time.time() - total_start timings["frames_generated"] = len(res_frame_list) return timings @torch.no_grad() def generate_batch( self, avatar_name: str, audio_paths: List[str], output_dir: str, fps: Optional[int] = None ) -> dict: """Generate multiple videos from multiple audios efficiently.""" if not self.is_loaded: raise RuntimeError("Models not loaded!") fps = fps or self.fps batch_timings = {"videos": [], "total": 0} total_start = time.time() # Load avatar once t0 = time.time() avatar = self.load_avatar(avatar_name) batch_timings["avatar_load"] = time.time() - t0 coord_list = avatar['coord_list'] frame_list = avatar['frame_list'] input_latent_list = avatar['latent_list'] # Prepare cycled lists once frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] os.makedirs(output_dir, exist_ok=True) for idx, audio_path in enumerate(audio_paths): video_start = time.time() timings = {} audio_name = Path(audio_path).stem output_path = os.path.join(output_dir, f"{audio_name}.mp4") temp_dir = tempfile.mkdtemp() try: # 1. Extract audio features t0 = time.time() whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) whisper_chunks = self.audio_processor.get_whisper_chunk( whisper_input_features, self.device, self.weight_dtype, self.whisper, librosa_length, fps=fps, audio_padding_length_left=self.audio_padding_left, audio_padding_length_right=self.audio_padding_right, ) timings["whisper_features"] = time.time() - t0 # 2. UNet inference t0 = time.time() gen = datagen( whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=self.batch_size, delay_frame=0, device=self.device, ) res_frame_list = [] for whisper_batch, latent_batch in gen: audio_feature_batch = self.pe(whisper_batch) latent_batch = latent_batch.to(dtype=self.unet.model.dtype) pred_latents = self.unet.model( latent_batch, self.timesteps, encoder_hidden_states=audio_feature_batch ).sample recon = self.vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) timings["unet_inference"] = time.time() - t0 # 3. Face blending (parallel) result_img_path = os.path.join(temp_dir, "results") os.makedirs(result_img_path, exist_ok=True) timings["face_blending"] = self._parallel_face_blending( res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path ) # 4. Video encoding (NVENC) timings["video_encoding"] = self._encode_video_nvenc( result_img_path, audio_path, output_path, fps ) finally: shutil.rmtree(temp_dir, ignore_errors=True) timings["total"] = time.time() - video_start timings["frames_generated"] = len(res_frame_list) timings["output_path"] = output_path timings["audio_path"] = audio_path batch_timings["videos"].append(timings) print(f" [{idx+1}/{len(audio_paths)}] {audio_name}: {timings['total']:.2f}s") batch_timings["total"] = time.time() - total_start batch_timings["num_videos"] = len(audio_paths) batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0 return batch_timings # Global server server = MuseTalkServerV3() # FastAPI app app = FastAPI( title="MuseTalk API v3", description="Ultra-optimized API with parallel blending, NVENC, and batch processing", version="3.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): server.load_models() @app.get("/health") async def health_check(): return { "status": "ok" if server.is_loaded else "loading", "models_loaded": server.is_loaded, "device": str(server.device) if server.device else None, "loaded_avatars": list(server.loaded_avatars.keys()), "optimizations": { "parallel_workers": server.num_workers, "nvenc_enabled": server.use_nvenc, "nvenc_preset": server.nvenc_preset } } @app.get("/avatars") async def list_avatars(): avatars = [] for p in server.avatar_dir.iterdir(): if p.is_dir() and (p / "metadata.pkl").exists(): with open(p / "metadata.pkl", 'rb') as f: metadata = pickle.load(f) metadata['loaded'] = p.name in server.loaded_avatars avatars.append(metadata) return {"avatars": avatars} @app.post("/avatars/{avatar_name}/load") async def load_avatar(avatar_name: str): try: server.load_avatar(avatar_name) return {"status": "loaded", "avatar_name": avatar_name} except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) @app.post("/avatars/{avatar_name}/unload") async def unload_avatar(avatar_name: str): server.unload_avatar(avatar_name) return {"status": "unloaded", "avatar_name": avatar_name} class GenerateRequest(BaseModel): avatar_name: str audio_path: str output_path: str fps: Optional[int] = 25 use_parallel_blending: bool = True @app.post("/generate/avatar") async def generate_with_avatar(request: GenerateRequest): if not server.is_loaded: raise HTTPException(status_code=503, detail="Models not loaded") if not os.path.exists(request.audio_path): raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") try: timings = server.generate_with_avatar( avatar_name=request.avatar_name, audio_path=request.audio_path, output_path=request.output_path, fps=request.fps, use_parallel_blending=request.use_parallel_blending ) return { "status": "success", "output_path": request.output_path, "timings": timings } except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) class BatchGenerateRequest(BaseModel): avatar_name: str audio_paths: List[str] output_dir: str fps: Optional[int] = 25 @app.post("/generate/batch") async def generate_batch(request: BatchGenerateRequest): """Generate multiple videos from multiple audios.""" if not server.is_loaded: raise HTTPException(status_code=503, detail="Models not loaded") for audio_path in request.audio_paths: if not os.path.exists(audio_path): raise HTTPException(status_code=404, detail=f"Audio not found: {audio_path}") try: timings = server.generate_batch( avatar_name=request.avatar_name, audio_paths=request.audio_paths, output_dir=request.output_dir, fps=request.fps ) return { "status": "success", "output_dir": request.output_dir, "timings": timings } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port)