""" MuseTalk Real-Time Server Servidor FastAPI para lip-sync em tempo real """ import os import sys import io import time import json import uuid import queue import pickle import shutil import asyncio import threading from pathlib import Path from typing import Optional import tempfile import cv2 import glob import copy import torch import numpy as np from tqdm import tqdm from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks from fastapi.responses import FileResponse, StreamingResponse, JSONResponse from pydantic import BaseModel import uvicorn # Suppress warnings import warnings warnings.filterwarnings("ignore") # MuseTalk imports from musetalk.utils.utils import datagen, load_all_model from musetalk.utils.blending import get_image_prepare_material, get_image_blending from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.preprocessing_simple import get_landmark_and_bbox, read_imgs from transformers import WhisperModel app = FastAPI(title="MuseTalk Real-Time Server", version="1.5") # Global model instances models = {} avatars = {} class AvatarConfig(BaseModel): avatar_id: str video_path: str bbox_shift: int = 0 class InferenceRequest(BaseModel): avatar_id: str fps: int = 25 def video2imgs(vid_path, save_path): """Extract frames from video""" cap = cv2.VideoCapture(vid_path) count = 0 while True: ret, frame = cap.read() if ret: cv2.imwrite(f"{save_path}/{count:08d}.png", frame) count += 1 else: break cap.release() return count @app.on_event("startup") async def load_models(): """Load all models at startup""" global models print("Loading MuseTalk models...") # Force CPU if FORCE_CPU env var is set or if CUDA kernels are incompatible force_cpu = os.environ.get("FORCE_CPU", "0") == "1" if force_cpu or not torch.cuda.is_available(): device = torch.device("cpu") else: try: # Test if CUDA kernels work for this GPU test_tensor = torch.zeros(1).cuda() _ = test_tensor.half() device = torch.device("cuda:0") except RuntimeError as e: print(f"CUDA kernel test failed: {e}") print("Falling back to CPU...") device = torch.device("cpu") print(f"Using device: {device}") # Model paths unet_model_path = "./models/musetalkV15/unet.pth" unet_config = "./models/musetalkV15/musetalk.json" whisper_dir = "./models/whisper" vae_type = "sd-vae" # Load models vae, unet, pe = load_all_model( unet_model_path=unet_model_path, vae_type=vae_type, unet_config=unet_config, device=device ) # Move to device, use half precision only for GPU if device.type == "cuda": pe = pe.half().to(device) vae.vae = vae.vae.half().to(device) unet.model = unet.model.half().to(device) else: pe = pe.to(device) vae.vae = vae.vae.to(device) unet.model = unet.model.to(device) # Load whisper audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) whisper = WhisperModel.from_pretrained(whisper_dir) weight_dtype = unet.model.dtype if device.type == "cuda" else torch.float32 whisper = whisper.to(device=device, dtype=weight_dtype).eval() whisper.requires_grad_(False) # Initialize face parser from musetalk.utils.face_parsing import FaceParsing fp = FaceParsing(left_cheek_width=90, right_cheek_width=90) timesteps = torch.tensor([0], device=device) models = { "vae": vae, "unet": unet, "pe": pe, "whisper": whisper, "audio_processor": audio_processor, "fp": fp, "device": device, "timesteps": timesteps, "weight_dtype": weight_dtype } print("Models loaded successfully!") @app.get("/") async def root(): return {"status": "ok", "message": "MuseTalk Real-Time Server"} @app.get("/health") async def health(): return { "status": "healthy", "models_loaded": len(models) > 0, "avatars_count": len(avatars), "gpu_available": torch.cuda.is_available() } @app.post("/avatar/prepare") async def prepare_avatar( avatar_id: str = Form(...), video: UploadFile = File(...), bbox_shift: int = Form(0, description="Ajusta abertura da boca: positivo=mais aberto, negativo=menos aberto (-9 a 9)"), extra_margin: int = Form(10, description="Margem extra para movimento do queixo"), parsing_mode: str = Form("jaw", description="Modo de parsing: 'jaw' (v1.5) ou 'raw' (v1.0)"), left_cheek_width: int = Form(90, description="Largura da bochecha esquerda"), right_cheek_width: int = Form(90, description="Largura da bochecha direita") ): """Prepare an avatar from video for real-time inference""" global avatars if not models: raise HTTPException(status_code=503, detail="Models not loaded") # Save uploaded video avatar_path = f"./results/v15/avatars/{avatar_id}" full_imgs_path = f"{avatar_path}/full_imgs" mask_out_path = f"{avatar_path}/mask" os.makedirs(avatar_path, exist_ok=True) os.makedirs(full_imgs_path, exist_ok=True) os.makedirs(mask_out_path, exist_ok=True) # Save video video_path = f"{avatar_path}/source_video{Path(video.filename).suffix}" with open(video_path, "wb") as f: content = await video.read() f.write(content) # Extract frames print(f"Extracting frames from video...") frame_count = video2imgs(video_path, full_imgs_path) print(f"Extracted {frame_count} frames") input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))) print("Extracting landmarks...") # bbox_shift controls mouth openness: positive=more open, negative=less open coord_list_raw, frame_list_raw = get_landmark_and_bbox(input_img_list, upperbondrange=bbox_shift) # Generate latents - filter out frames without detected faces input_latent_list = [] valid_coord_list = [] valid_frame_list = [] coord_placeholder = (0.0, 0.0, 0.0, 0.0) vae = models["vae"] # Create FaceParsing with custom cheek widths for this avatar from musetalk.utils.face_parsing import FaceParsing fp_avatar = FaceParsing(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width) for bbox, frame in zip(coord_list_raw, frame_list_raw): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox # Validate bbox dimensions if x2 <= x1 or y2 <= y1: continue # Add extra margin for jaw movement (v1.5 feature) y2 = min(y2 + extra_margin, frame.shape[0]) # Store valid frame and coordinates valid_coord_list.append([x1, y1, x2, y2]) valid_frame_list.append(frame) crop_frame = frame[y1:y2, x1:x2] if crop_frame.size == 0: valid_coord_list.pop() valid_frame_list.pop() continue resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(resized_crop_frame) input_latent_list.append(latents) print(f"Valid frames with detected faces: {len(valid_frame_list)}/{len(frame_list_raw)}") if len(valid_frame_list) == 0: raise HTTPException(status_code=400, detail="No faces detected in video. Please use a video with a clear frontal face.") # Create cycles from valid frames only frame_list_cycle = valid_frame_list + valid_frame_list[::-1] coord_list_cycle = valid_coord_list + valid_coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] # Generate masks mask_list_cycle = [] mask_coords_list_cycle = [] print(f"Generating masks with mode={parsing_mode}...") for i, frame in enumerate(tqdm(frame_list_cycle)): x1, y1, x2, y2 = coord_list_cycle[i] mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp_avatar, mode=parsing_mode) cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask) mask_coords_list_cycle.append(crop_box) mask_list_cycle.append(mask) # Save preprocessed data with open(f"{avatar_path}/coords.pkl", 'wb') as f: pickle.dump(coord_list_cycle, f) with open(f"{avatar_path}/mask_coords.pkl", 'wb') as f: pickle.dump(mask_coords_list_cycle, f) # Save quality settings quality_settings = { "bbox_shift": bbox_shift, "extra_margin": extra_margin, "parsing_mode": parsing_mode, "left_cheek_width": left_cheek_width, "right_cheek_width": right_cheek_width } with open(f"{avatar_path}/quality_settings.json", 'w') as f: json.dump(quality_settings, f) torch.save(input_latent_list_cycle, f"{avatar_path}/latents.pt") # Store in memory - keep latents on CPU to save GPU memory input_latent_list_cpu = [lat.cpu() for lat in input_latent_list_cycle] avatars[avatar_id] = { "path": avatar_path, "frame_list_cycle": frame_list_cycle, "coord_list_cycle": coord_list_cycle, "input_latent_list_cycle": input_latent_list_cpu, "mask_list_cycle": mask_list_cycle, "mask_coords_list_cycle": mask_coords_list_cycle, "quality_settings": quality_settings } # Clear GPU cache after preparation import gc gc.collect() torch.cuda.empty_cache() return { "status": "success", "avatar_id": avatar_id, "frame_count": len(frame_list_cycle), "quality_settings": quality_settings } @app.post("/avatar/load/{avatar_id}") async def load_avatar(avatar_id: str): """Load a previously prepared avatar""" global avatars avatar_path = f"./results/v15/avatars/{avatar_id}" if not os.path.exists(avatar_path): raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not found") full_imgs_path = f"{avatar_path}/full_imgs" mask_out_path = f"{avatar_path}/mask" # Load preprocessed data input_latent_list_cycle = torch.load(f"{avatar_path}/latents.pt") with open(f"{avatar_path}/coords.pkl", 'rb') as f: coord_list_cycle = pickle.load(f) with open(f"{avatar_path}/mask_coords.pkl", 'rb') as f: mask_coords_list_cycle = pickle.load(f) # Load quality settings (with defaults for backwards compatibility) quality_settings_path = f"{avatar_path}/quality_settings.json" if os.path.exists(quality_settings_path): with open(quality_settings_path, 'r') as f: quality_settings = json.load(f) else: quality_settings = { "bbox_shift": 0, "extra_margin": 10, "parsing_mode": "jaw", "left_cheek_width": 90, "right_cheek_width": 90 } # Load frames input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))) frame_list_cycle = read_imgs(input_img_list) # Load masks input_mask_list = sorted(glob.glob(os.path.join(mask_out_path, '*.[jpJP][pnPN]*[gG]'))) mask_list_cycle = read_imgs(input_mask_list) # Keep latents on CPU to save GPU memory input_latent_list_cpu = [lat.cpu() if hasattr(lat, 'cpu') else lat for lat in input_latent_list_cycle] avatars[avatar_id] = { "path": avatar_path, "frame_list_cycle": frame_list_cycle, "coord_list_cycle": coord_list_cycle, "input_latent_list_cycle": input_latent_list_cpu, "mask_list_cycle": mask_list_cycle, "mask_coords_list_cycle": mask_coords_list_cycle, "quality_settings": quality_settings } # Clear GPU cache import gc gc.collect() torch.cuda.empty_cache() return { "status": "success", "avatar_id": avatar_id, "frame_count": len(frame_list_cycle), "quality_settings": quality_settings } @app.get("/avatars") async def list_avatars(): """List all available avatars""" avatar_dir = "./results/v15/avatars" if not os.path.exists(avatar_dir): return {"avatars": [], "loaded": list(avatars.keys())} available = [d for d in os.listdir(avatar_dir) if os.path.isdir(os.path.join(avatar_dir, d))] return {"avatars": available, "loaded": list(avatars.keys())} @app.post("/inference") async def inference( avatar_id: str = Form(...), audio: UploadFile = File(...), fps: int = Form(25) ): """Run inference with uploaded audio and return video""" if avatar_id not in avatars: raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded. Use /avatar/load first") if not models: raise HTTPException(status_code=503, detail="Models not loaded") avatar = avatars[avatar_id] device = models["device"] # Save audio temporarily with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp: content = await audio.read() tmp.write(content) audio_path = tmp.name try: start_time = time.time() # Extract audio features audio_processor = models["audio_processor"] whisper = models["whisper"] weight_dtype = models["weight_dtype"] whisper_input_features, librosa_length = audio_processor.get_audio_feature( audio_path, weight_dtype=weight_dtype ) whisper_chunks = audio_processor.get_whisper_chunk( whisper_input_features, device, weight_dtype, whisper, librosa_length, fps=fps, audio_padding_length_left=2, audio_padding_length_right=2, ) print(f"Audio processing: {(time.time() - start_time)*1000:.0f}ms") # Inference vae = models["vae"] unet = models["unet"] pe = models["pe"] timesteps = models["timesteps"] video_num = len(whisper_chunks) batch_size = 4 # Reduced batch size to save GPU memory gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size) result_frames = [] inference_start = time.time() for i, (whisper_batch, latent_batch) in enumerate(gen): audio_feature_batch = pe(whisper_batch.to(device)) latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype) pred_latents = unet.model( latent_batch, timesteps, encoder_hidden_states=audio_feature_batch ).sample pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype) recon = vae.decode_latents(pred_latents) for idx_in_batch, res_frame in enumerate(recon): frame_idx = i * batch_size + idx_in_batch if frame_idx >= video_num: break bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])] ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])]) x1, y1, x2, y2 = bbox res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])] mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])] combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box) result_frames.append(combine_frame) print(f"Inference: {(time.time() - inference_start)*1000:.0f}ms for {video_num} frames") print(f"FPS: {video_num / (time.time() - inference_start):.1f}") # Create video output_path = tempfile.mktemp(suffix=".mp4") h, w = result_frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) for frame in result_frames: out.write(frame) out.release() # Combine with audio using ffmpeg final_output = tempfile.mktemp(suffix=".mp4") os.system(f"ffmpeg -y -v warning -i {audio_path} -i {output_path} -c:v libx264 -c:a aac {final_output}") os.unlink(output_path) os.unlink(audio_path) total_time = time.time() - start_time print(f"Total time: {total_time*1000:.0f}ms") return FileResponse( final_output, media_type="video/mp4", filename=f"output_{avatar_id}.mp4", headers={"X-Processing-Time": f"{total_time:.2f}s"} ) except Exception as e: if os.path.exists(audio_path): os.unlink(audio_path) raise HTTPException(status_code=500, detail=str(e)) @app.post("/inference/frames") async def inference_frames( avatar_id: str = Form(...), audio: UploadFile = File(...), fps: int = Form(25) ): """Run inference and return frames as JSON (for streaming)""" if avatar_id not in avatars: raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded") avatar = avatars[avatar_id] device = models["device"] # Save audio temporarily with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp: content = await audio.read() tmp.write(content) audio_path = tmp.name try: # Extract audio features audio_processor = models["audio_processor"] whisper = models["whisper"] weight_dtype = models["weight_dtype"] whisper_input_features, librosa_length = audio_processor.get_audio_feature( audio_path, weight_dtype=weight_dtype ) whisper_chunks = audio_processor.get_whisper_chunk( whisper_input_features, device, weight_dtype, whisper, librosa_length, fps=fps, ) # Inference vae = models["vae"] unet = models["unet"] pe = models["pe"] timesteps = models["timesteps"] video_num = len(whisper_chunks) batch_size = 4 # Reduced batch size to save GPU memory gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size) frames_data = [] for i, (whisper_batch, latent_batch) in enumerate(gen): audio_feature_batch = pe(whisper_batch.to(device)) latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype) pred_latents = unet.model( latent_batch, timesteps, encoder_hidden_states=audio_feature_batch ).sample pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype) recon = vae.decode_latents(pred_latents) for idx_in_batch, res_frame in enumerate(recon): frame_idx = i * batch_size + idx_in_batch if frame_idx >= video_num: break bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])] ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])]) x1, y1, x2, y2 = bbox res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])] mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])] combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box) # Encode frame as JPEG _, buffer = cv2.imencode('.jpg', combine_frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) import base64 frame_b64 = base64.b64encode(buffer).decode('utf-8') frames_data.append(frame_b64) os.unlink(audio_path) return { "frames": frames_data, "fps": fps, "total_frames": len(frames_data) } except Exception as e: if os.path.exists(audio_path): os.unlink(audio_path) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)