MuseTalk / musetalk_api_server.py
marcosremar2's picture
fix: correct facial alignment issues and add API server
af11910
"""
MuseTalk HTTP API Server
Keeps models loaded in GPU memory for fast inference.
"""
import os
import cv2
import copy
import torch
import glob
import shutil
import pickle
import numpy as np
import subprocess
import tempfile
import hashlib
import time
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from tqdm import tqdm
from omegaconf import OmegaConf
from transformers import WhisperModel
import uvicorn
# MuseTalk imports
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
class MuseTalkServer:
"""Singleton server that keeps models loaded in GPU memory."""
def __init__(self):
self.device = None
self.vae = None
self.unet = None
self.pe = None
self.whisper = None
self.audio_processor = None
self.fp = None
self.timesteps = None
self.weight_dtype = None
self.is_loaded = False
# Cache directories
self.cache_dir = Path("./cache")
self.cache_dir.mkdir(exist_ok=True)
self.landmarks_cache = self.cache_dir / "landmarks"
self.latents_cache = self.cache_dir / "latents"
self.whisper_cache = self.cache_dir / "whisper_features"
self.landmarks_cache.mkdir(exist_ok=True)
self.latents_cache.mkdir(exist_ok=True)
self.whisper_cache.mkdir(exist_ok=True)
# Config
self.fps = 25
self.batch_size = 8
self.use_float16 = True
self.version = "v15"
self.extra_margin = 10
self.parsing_mode = "jaw"
self.left_cheek_width = 90
self.right_cheek_width = 90
self.audio_padding_left = 2
self.audio_padding_right = 2
def load_models(
self,
gpu_id: int = 0,
unet_model_path: str = "./models/musetalkV15/unet.pth",
unet_config: str = "./models/musetalk/config.json",
vae_type: str = "sd-vae",
whisper_dir: str = "./models/whisper",
use_float16: bool = True,
version: str = "v15"
):
"""Load all models into GPU memory."""
if self.is_loaded:
print("Models already loaded!")
return
print("=" * 50)
print("Loading MuseTalk models into GPU memory...")
print("=" * 50)
start_time = time.time()
# Set device
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Load model weights
print("Loading VAE, UNet, PE...")
self.vae, self.unet, self.pe = load_all_model(
unet_model_path=unet_model_path,
vae_type=vae_type,
unet_config=unet_config,
device=self.device
)
self.timesteps = torch.tensor([0], device=self.device)
# Convert to float16 if enabled
self.use_float16 = use_float16
if use_float16:
print("Converting to float16...")
self.pe = self.pe.half()
self.vae.vae = self.vae.vae.half()
self.unet.model = self.unet.model.half()
# Move to device
self.pe = self.pe.to(self.device)
self.vae.vae = self.vae.vae.to(self.device)
self.unet.model = self.unet.model.to(self.device)
# Initialize audio processor and Whisper
print("Loading Whisper model...")
self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
self.weight_dtype = self.unet.model.dtype
self.whisper = WhisperModel.from_pretrained(whisper_dir)
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
self.whisper.requires_grad_(False)
# Initialize face parser
self.version = version
if version == "v15":
self.fp = FaceParsing(
left_cheek_width=self.left_cheek_width,
right_cheek_width=self.right_cheek_width
)
else:
self.fp = FaceParsing()
self.is_loaded = True
load_time = time.time() - start_time
print(f"Models loaded in {load_time:.2f}s")
print("=" * 50)
print("Server ready for inference!")
print("=" * 50)
def _get_file_hash(self, file_path: str) -> str:
"""Get MD5 hash of a file for caching."""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()[:16]
def _get_cached_landmarks(self, video_hash: str, bbox_shift: int):
"""Get cached landmarks if available."""
# Disabled due to tensor comparison issues
return None
def _save_landmarks_cache(self, video_hash: str, bbox_shift: int, coord_list, frame_list):
"""Save landmarks to cache."""
cache_file = self.landmarks_cache / f"{video_hash}_shift{bbox_shift}.pkl"
with open(cache_file, 'wb') as f:
pickle.dump((coord_list, frame_list), f)
def _get_cached_latents(self, video_hash: str):
"""Get cached VAE latents if available."""
# Disabled due to tensor comparison issues
return None
def _save_latents_cache(self, video_hash: str, latent_list):
"""Save VAE latents to cache."""
cache_file = self.latents_cache / f"{video_hash}.pkl"
with open(cache_file, 'wb') as f:
pickle.dump(latent_list, f)
def _get_cached_whisper(self, audio_hash: str):
"""Get cached Whisper features if available."""
# Disabled due to tensor comparison issues
return None
def _save_whisper_cache(self, audio_hash: str, whisper_data):
"""Save Whisper features to cache."""
cache_file = self.whisper_cache / f"{audio_hash}.pkl"
with open(cache_file, 'wb') as f:
pickle.dump(whisper_data, f)
@torch.no_grad()
def generate(
self,
video_path: str,
audio_path: str,
output_path: str,
fps: Optional[int] = None,
use_cache: bool = True
) -> dict:
"""
Generate lip-synced video.
Returns dict with timing info.
"""
if not self.is_loaded:
raise RuntimeError("Models not loaded! Call load_models() first.")
fps = fps or self.fps
timings = {"total": 0}
total_start = time.time()
# Get file hashes for caching
video_hash = self._get_file_hash(video_path)
audio_hash = self._get_file_hash(audio_path)
# Create temp directory
temp_dir = tempfile.mkdtemp()
try:
# 1. Extract frames
t0 = time.time()
input_basename = Path(video_path).stem
save_dir_full = os.path.join(temp_dir, "frames")
os.makedirs(save_dir_full, exist_ok=True)
if get_file_type(video_path) == "video":
cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={fps} -start_number 0 {save_dir_full}/%08d.png"
os.system(cmd)
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
else:
raise ValueError(f"Unsupported video type: {video_path}")
timings["frame_extraction"] = time.time() - t0
# 2. Extract audio features (with caching)
t0 = time.time()
cached_whisper = self._get_cached_whisper(audio_hash) if use_cache else None
if cached_whisper:
whisper_chunks = cached_whisper
timings["whisper_source"] = "cache"
else:
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
whisper_chunks = self.audio_processor.get_whisper_chunk(
whisper_input_features,
self.device,
self.weight_dtype,
self.whisper,
librosa_length,
fps=fps,
audio_padding_length_left=self.audio_padding_left,
audio_padding_length_right=self.audio_padding_right,
)
if use_cache:
self._save_whisper_cache(audio_hash, whisper_chunks)
timings["whisper_source"] = "computed"
timings["whisper_features"] = time.time() - t0
# 3. Get landmarks (with caching)
t0 = time.time()
bbox_shift = 0 if self.version == "v15" else 0
cache_key = f"{video_hash}_{fps}"
cached_landmarks = self._get_cached_landmarks(cache_key, bbox_shift) if use_cache else None
if cached_landmarks:
coord_list, frame_list = cached_landmarks
timings["landmarks_source"] = "cache"
else:
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
if use_cache:
self._save_landmarks_cache(cache_key, bbox_shift, coord_list, frame_list)
timings["landmarks_source"] = "computed"
timings["landmarks"] = time.time() - t0
# 4. Compute VAE latents (with caching)
t0 = time.time()
latent_cache_key = f"{video_hash}_{fps}_{self.version}"
cached_latents = self._get_cached_latents(latent_cache_key) if use_cache else None
if cached_latents:
input_latent_list = cached_latents
timings["latents_source"] = "cache"
else:
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder):
continue
x1, y1, x2, y2 = bbox
if self.version == "v15":
y2 = y2 + self.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = self.vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
if use_cache:
self._save_latents_cache(latent_cache_key, input_latent_list)
timings["latents_source"] = "computed"
timings["vae_encoding"] = time.time() - t0
# 5. Prepare cycled lists
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# 6. UNet inference
t0 = time.time()
video_num = len(whisper_chunks)
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=self.batch_size,
delay_frame=0,
device=self.device,
)
res_frame_list = []
for whisper_batch, latent_batch in gen:
audio_feature_batch = self.pe(whisper_batch)
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
pred_latents = self.unet.model(
latent_batch, self.timesteps,
encoder_hidden_states=audio_feature_batch
).sample
recon = self.vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
timings["unet_inference"] = time.time() - t0
# 7. Face blending
t0 = time.time()
result_img_path = os.path.join(temp_dir, "results")
os.makedirs(result_img_path, exist_ok=True)
for i, res_frame in enumerate(res_frame_list):
bbox = coord_list_cycle[i % len(coord_list_cycle)]
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
x1, y1, x2, y2 = bbox
if self.version == "v15":
y2 = y2 + self.extra_margin
y2 = min(y2, ori_frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except:
continue
if self.version == "v15":
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
mode=self.parsing_mode, fp=self.fp)
else:
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp)
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
timings["face_blending"] = time.time() - t0
# 8. Encode video
t0 = time.time()
temp_vid = os.path.join(temp_dir, "temp.mp4")
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}"
os.system(cmd_img2video)
cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}"
os.system(cmd_combine)
timings["video_encoding"] = time.time() - t0
finally:
# Cleanup
shutil.rmtree(temp_dir, ignore_errors=True)
timings["total"] = time.time() - total_start
timings["frames_generated"] = len(res_frame_list)
return timings
# Global server instance
server = MuseTalkServer()
# FastAPI app
app = FastAPI(
title="MuseTalk API",
description="HTTP API for MuseTalk lip-sync generation",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Load models on server startup."""
server.load_models()
@app.get("/health")
async def health_check():
"""Check if server is ready."""
return {
"status": "ok" if server.is_loaded else "loading",
"models_loaded": server.is_loaded,
"device": str(server.device) if server.device else None
}
@app.get("/cache/stats")
async def cache_stats():
"""Get cache statistics."""
landmarks_count = len(list(server.landmarks_cache.glob("*.pkl")))
latents_count = len(list(server.latents_cache.glob("*.pkl")))
whisper_count = len(list(server.whisper_cache.glob("*.pkl")))
return {
"landmarks_cached": landmarks_count,
"latents_cached": latents_count,
"whisper_features_cached": whisper_count
}
@app.post("/cache/clear")
async def clear_cache():
"""Clear all caches."""
for cache_dir in [server.landmarks_cache, server.latents_cache, server.whisper_cache]:
for f in cache_dir.glob("*.pkl"):
f.unlink()
return {"status": "cleared"}
class GenerateRequest(BaseModel):
video_path: str
audio_path: str
output_path: str
fps: Optional[int] = 25
use_cache: bool = True
@app.post("/generate")
async def generate_from_paths(request: GenerateRequest):
"""
Generate lip-synced video from file paths.
Use this when files are already on the server.
"""
if not server.is_loaded:
raise HTTPException(status_code=503, detail="Models not loaded yet")
if not os.path.exists(request.video_path):
raise HTTPException(status_code=404, detail=f"Video not found: {request.video_path}")
if not os.path.exists(request.audio_path):
raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}")
try:
timings = server.generate(
video_path=request.video_path,
audio_path=request.audio_path,
output_path=request.output_path,
fps=request.fps,
use_cache=request.use_cache
)
return {
"status": "success",
"output_path": request.output_path,
"timings": timings
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate/upload")
async def generate_from_upload(
video: UploadFile = File(...),
audio: UploadFile = File(...),
fps: int = Form(25),
use_cache: bool = Form(True)
):
"""
Generate lip-synced video from uploaded files.
Returns the generated video file.
"""
if not server.is_loaded:
raise HTTPException(status_code=503, detail="Models not loaded yet")
# Save uploaded files
temp_dir = tempfile.mkdtemp()
try:
video_path = os.path.join(temp_dir, video.filename)
audio_path = os.path.join(temp_dir, audio.filename)
output_path = os.path.join(temp_dir, "output.mp4")
with open(video_path, "wb") as f:
f.write(await video.read())
with open(audio_path, "wb") as f:
f.write(await audio.read())
timings = server.generate(
video_path=video_path,
audio_path=audio_path,
output_path=output_path,
fps=fps,
use_cache=use_cache
)
# Return the video file
return FileResponse(
output_path,
media_type="video/mp4",
filename="result.mp4",
headers={"X-Timings": str(timings)}
)
except Exception as e:
shutil.rmtree(temp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="MuseTalk API Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
parser.add_argument("--port", type=int, default=8000, help="Port to bind")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper")
parser.add_argument("--no_float16", action="store_true", help="Disable float16")
args = parser.parse_args()
# Pre-configure server
server.load_models(
gpu_id=args.gpu_id,
unet_model_path=args.unet_model_path,
unet_config=args.unet_config,
whisper_dir=args.whisper_dir,
use_float16=not args.no_float16
)
# Start server
uvicorn.run(app, host=args.host, port=args.port)