MuseTalk / server.py
marcosremar2's picture
fix: correct facial alignment issues and add API server
af11910
"""
MuseTalk Real-Time Server
Servidor FastAPI para lip-sync em tempo real
"""
import os
import sys
import io
import time
import json
import uuid
import queue
import pickle
import shutil
import asyncio
import threading
from pathlib import Path
from typing import Optional
import tempfile
import cv2
import glob
import copy
import torch
import numpy as np
from tqdm import tqdm
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
from pydantic import BaseModel
import uvicorn
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")
# MuseTalk imports
from musetalk.utils.utils import datagen, load_all_model
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.preprocessing_simple import get_landmark_and_bbox, read_imgs
from transformers import WhisperModel
app = FastAPI(title="MuseTalk Real-Time Server", version="1.5")
# Global model instances
models = {}
avatars = {}
class AvatarConfig(BaseModel):
avatar_id: str
video_path: str
bbox_shift: int = 0
class InferenceRequest(BaseModel):
avatar_id: str
fps: int = 25
def video2imgs(vid_path, save_path):
"""Extract frames from video"""
cap = cv2.VideoCapture(vid_path)
count = 0
while True:
ret, frame = cap.read()
if ret:
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:
break
cap.release()
return count
@app.on_event("startup")
async def load_models():
"""Load all models at startup"""
global models
print("Loading MuseTalk models...")
# Force CPU if FORCE_CPU env var is set or if CUDA kernels are incompatible
force_cpu = os.environ.get("FORCE_CPU", "0") == "1"
if force_cpu or not torch.cuda.is_available():
device = torch.device("cpu")
else:
try:
# Test if CUDA kernels work for this GPU
test_tensor = torch.zeros(1).cuda()
_ = test_tensor.half()
device = torch.device("cuda:0")
except RuntimeError as e:
print(f"CUDA kernel test failed: {e}")
print("Falling back to CPU...")
device = torch.device("cpu")
print(f"Using device: {device}")
# Model paths
unet_model_path = "./models/musetalkV15/unet.pth"
unet_config = "./models/musetalkV15/musetalk.json"
whisper_dir = "./models/whisper"
vae_type = "sd-vae"
# Load models
vae, unet, pe = load_all_model(
unet_model_path=unet_model_path,
vae_type=vae_type,
unet_config=unet_config,
device=device
)
# Move to device, use half precision only for GPU
if device.type == "cuda":
pe = pe.half().to(device)
vae.vae = vae.vae.half().to(device)
unet.model = unet.model.half().to(device)
else:
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Load whisper
audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
whisper = WhisperModel.from_pretrained(whisper_dir)
weight_dtype = unet.model.dtype if device.type == "cuda" else torch.float32
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# Initialize face parser
from musetalk.utils.face_parsing import FaceParsing
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
timesteps = torch.tensor([0], device=device)
models = {
"vae": vae,
"unet": unet,
"pe": pe,
"whisper": whisper,
"audio_processor": audio_processor,
"fp": fp,
"device": device,
"timesteps": timesteps,
"weight_dtype": weight_dtype
}
print("Models loaded successfully!")
@app.get("/")
async def root():
return {"status": "ok", "message": "MuseTalk Real-Time Server"}
@app.get("/health")
async def health():
return {
"status": "healthy",
"models_loaded": len(models) > 0,
"avatars_count": len(avatars),
"gpu_available": torch.cuda.is_available()
}
@app.post("/avatar/prepare")
async def prepare_avatar(
avatar_id: str = Form(...),
video: UploadFile = File(...),
bbox_shift: int = Form(0, description="Ajusta abertura da boca: positivo=mais aberto, negativo=menos aberto (-9 a 9)"),
extra_margin: int = Form(10, description="Margem extra para movimento do queixo"),
parsing_mode: str = Form("jaw", description="Modo de parsing: 'jaw' (v1.5) ou 'raw' (v1.0)"),
left_cheek_width: int = Form(90, description="Largura da bochecha esquerda"),
right_cheek_width: int = Form(90, description="Largura da bochecha direita")
):
"""Prepare an avatar from video for real-time inference"""
global avatars
if not models:
raise HTTPException(status_code=503, detail="Models not loaded")
# Save uploaded video
avatar_path = f"./results/v15/avatars/{avatar_id}"
full_imgs_path = f"{avatar_path}/full_imgs"
mask_out_path = f"{avatar_path}/mask"
os.makedirs(avatar_path, exist_ok=True)
os.makedirs(full_imgs_path, exist_ok=True)
os.makedirs(mask_out_path, exist_ok=True)
# Save video
video_path = f"{avatar_path}/source_video{Path(video.filename).suffix}"
with open(video_path, "wb") as f:
content = await video.read()
f.write(content)
# Extract frames
print(f"Extracting frames from video...")
frame_count = video2imgs(video_path, full_imgs_path)
print(f"Extracted {frame_count} frames")
input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
print("Extracting landmarks...")
# bbox_shift controls mouth openness: positive=more open, negative=less open
coord_list_raw, frame_list_raw = get_landmark_and_bbox(input_img_list, upperbondrange=bbox_shift)
# Generate latents - filter out frames without detected faces
input_latent_list = []
valid_coord_list = []
valid_frame_list = []
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
vae = models["vae"]
# Create FaceParsing with custom cheek widths for this avatar
from musetalk.utils.face_parsing import FaceParsing
fp_avatar = FaceParsing(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
for bbox, frame in zip(coord_list_raw, frame_list_raw):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
# Validate bbox dimensions
if x2 <= x1 or y2 <= y1:
continue
# Add extra margin for jaw movement (v1.5 feature)
y2 = min(y2 + extra_margin, frame.shape[0])
# Store valid frame and coordinates
valid_coord_list.append([x1, y1, x2, y2])
valid_frame_list.append(frame)
crop_frame = frame[y1:y2, x1:x2]
if crop_frame.size == 0:
valid_coord_list.pop()
valid_frame_list.pop()
continue
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(resized_crop_frame)
input_latent_list.append(latents)
print(f"Valid frames with detected faces: {len(valid_frame_list)}/{len(frame_list_raw)}")
if len(valid_frame_list) == 0:
raise HTTPException(status_code=400, detail="No faces detected in video. Please use a video with a clear frontal face.")
# Create cycles from valid frames only
frame_list_cycle = valid_frame_list + valid_frame_list[::-1]
coord_list_cycle = valid_coord_list + valid_coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# Generate masks
mask_list_cycle = []
mask_coords_list_cycle = []
print(f"Generating masks with mode={parsing_mode}...")
for i, frame in enumerate(tqdm(frame_list_cycle)):
x1, y1, x2, y2 = coord_list_cycle[i]
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp_avatar, mode=parsing_mode)
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
mask_coords_list_cycle.append(crop_box)
mask_list_cycle.append(mask)
# Save preprocessed data
with open(f"{avatar_path}/coords.pkl", 'wb') as f:
pickle.dump(coord_list_cycle, f)
with open(f"{avatar_path}/mask_coords.pkl", 'wb') as f:
pickle.dump(mask_coords_list_cycle, f)
# Save quality settings
quality_settings = {
"bbox_shift": bbox_shift,
"extra_margin": extra_margin,
"parsing_mode": parsing_mode,
"left_cheek_width": left_cheek_width,
"right_cheek_width": right_cheek_width
}
with open(f"{avatar_path}/quality_settings.json", 'w') as f:
json.dump(quality_settings, f)
torch.save(input_latent_list_cycle, f"{avatar_path}/latents.pt")
# Store in memory - keep latents on CPU to save GPU memory
input_latent_list_cpu = [lat.cpu() for lat in input_latent_list_cycle]
avatars[avatar_id] = {
"path": avatar_path,
"frame_list_cycle": frame_list_cycle,
"coord_list_cycle": coord_list_cycle,
"input_latent_list_cycle": input_latent_list_cpu,
"mask_list_cycle": mask_list_cycle,
"mask_coords_list_cycle": mask_coords_list_cycle,
"quality_settings": quality_settings
}
# Clear GPU cache after preparation
import gc
gc.collect()
torch.cuda.empty_cache()
return {
"status": "success",
"avatar_id": avatar_id,
"frame_count": len(frame_list_cycle),
"quality_settings": quality_settings
}
@app.post("/avatar/load/{avatar_id}")
async def load_avatar(avatar_id: str):
"""Load a previously prepared avatar"""
global avatars
avatar_path = f"./results/v15/avatars/{avatar_id}"
if not os.path.exists(avatar_path):
raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not found")
full_imgs_path = f"{avatar_path}/full_imgs"
mask_out_path = f"{avatar_path}/mask"
# Load preprocessed data
input_latent_list_cycle = torch.load(f"{avatar_path}/latents.pt")
with open(f"{avatar_path}/coords.pkl", 'rb') as f:
coord_list_cycle = pickle.load(f)
with open(f"{avatar_path}/mask_coords.pkl", 'rb') as f:
mask_coords_list_cycle = pickle.load(f)
# Load quality settings (with defaults for backwards compatibility)
quality_settings_path = f"{avatar_path}/quality_settings.json"
if os.path.exists(quality_settings_path):
with open(quality_settings_path, 'r') as f:
quality_settings = json.load(f)
else:
quality_settings = {
"bbox_shift": 0,
"extra_margin": 10,
"parsing_mode": "jaw",
"left_cheek_width": 90,
"right_cheek_width": 90
}
# Load frames
input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
frame_list_cycle = read_imgs(input_img_list)
# Load masks
input_mask_list = sorted(glob.glob(os.path.join(mask_out_path, '*.[jpJP][pnPN]*[gG]')))
mask_list_cycle = read_imgs(input_mask_list)
# Keep latents on CPU to save GPU memory
input_latent_list_cpu = [lat.cpu() if hasattr(lat, 'cpu') else lat for lat in input_latent_list_cycle]
avatars[avatar_id] = {
"path": avatar_path,
"frame_list_cycle": frame_list_cycle,
"coord_list_cycle": coord_list_cycle,
"input_latent_list_cycle": input_latent_list_cpu,
"mask_list_cycle": mask_list_cycle,
"mask_coords_list_cycle": mask_coords_list_cycle,
"quality_settings": quality_settings
}
# Clear GPU cache
import gc
gc.collect()
torch.cuda.empty_cache()
return {
"status": "success",
"avatar_id": avatar_id,
"frame_count": len(frame_list_cycle),
"quality_settings": quality_settings
}
@app.get("/avatars")
async def list_avatars():
"""List all available avatars"""
avatar_dir = "./results/v15/avatars"
if not os.path.exists(avatar_dir):
return {"avatars": [], "loaded": list(avatars.keys())}
available = [d for d in os.listdir(avatar_dir) if os.path.isdir(os.path.join(avatar_dir, d))]
return {"avatars": available, "loaded": list(avatars.keys())}
@app.post("/inference")
async def inference(
avatar_id: str = Form(...),
audio: UploadFile = File(...),
fps: int = Form(25)
):
"""Run inference with uploaded audio and return video"""
if avatar_id not in avatars:
raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded. Use /avatar/load first")
if not models:
raise HTTPException(status_code=503, detail="Models not loaded")
avatar = avatars[avatar_id]
device = models["device"]
# Save audio temporarily
with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp:
content = await audio.read()
tmp.write(content)
audio_path = tmp.name
try:
start_time = time.time()
# Extract audio features
audio_processor = models["audio_processor"]
whisper = models["whisper"]
weight_dtype = models["weight_dtype"]
whisper_input_features, librosa_length = audio_processor.get_audio_feature(
audio_path, weight_dtype=weight_dtype
)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
audio_padding_length_left=2,
audio_padding_length_right=2,
)
print(f"Audio processing: {(time.time() - start_time)*1000:.0f}ms")
# Inference
vae = models["vae"]
unet = models["unet"]
pe = models["pe"]
timesteps = models["timesteps"]
video_num = len(whisper_chunks)
batch_size = 4 # Reduced batch size to save GPU memory
gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size)
result_frames = []
inference_start = time.time()
for i, (whisper_batch, latent_batch) in enumerate(gen):
audio_feature_batch = pe(whisper_batch.to(device))
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
pred_latents = unet.model(
latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch
).sample
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
recon = vae.decode_latents(pred_latents)
for idx_in_batch, res_frame in enumerate(recon):
frame_idx = i * batch_size + idx_in_batch
if frame_idx >= video_num:
break
bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])]
ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])])
x1, y1, x2, y2 = bbox
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])]
mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])]
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
result_frames.append(combine_frame)
print(f"Inference: {(time.time() - inference_start)*1000:.0f}ms for {video_num} frames")
print(f"FPS: {video_num / (time.time() - inference_start):.1f}")
# Create video
output_path = tempfile.mktemp(suffix=".mp4")
h, w = result_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for frame in result_frames:
out.write(frame)
out.release()
# Combine with audio using ffmpeg
final_output = tempfile.mktemp(suffix=".mp4")
os.system(f"ffmpeg -y -v warning -i {audio_path} -i {output_path} -c:v libx264 -c:a aac {final_output}")
os.unlink(output_path)
os.unlink(audio_path)
total_time = time.time() - start_time
print(f"Total time: {total_time*1000:.0f}ms")
return FileResponse(
final_output,
media_type="video/mp4",
filename=f"output_{avatar_id}.mp4",
headers={"X-Processing-Time": f"{total_time:.2f}s"}
)
except Exception as e:
if os.path.exists(audio_path):
os.unlink(audio_path)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/inference/frames")
async def inference_frames(
avatar_id: str = Form(...),
audio: UploadFile = File(...),
fps: int = Form(25)
):
"""Run inference and return frames as JSON (for streaming)"""
if avatar_id not in avatars:
raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded")
avatar = avatars[avatar_id]
device = models["device"]
# Save audio temporarily
with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp:
content = await audio.read()
tmp.write(content)
audio_path = tmp.name
try:
# Extract audio features
audio_processor = models["audio_processor"]
whisper = models["whisper"]
weight_dtype = models["weight_dtype"]
whisper_input_features, librosa_length = audio_processor.get_audio_feature(
audio_path, weight_dtype=weight_dtype
)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features,
device,
weight_dtype,
whisper,
librosa_length,
fps=fps,
)
# Inference
vae = models["vae"]
unet = models["unet"]
pe = models["pe"]
timesteps = models["timesteps"]
video_num = len(whisper_chunks)
batch_size = 4 # Reduced batch size to save GPU memory
gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size)
frames_data = []
for i, (whisper_batch, latent_batch) in enumerate(gen):
audio_feature_batch = pe(whisper_batch.to(device))
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
pred_latents = unet.model(
latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch
).sample
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
recon = vae.decode_latents(pred_latents)
for idx_in_batch, res_frame in enumerate(recon):
frame_idx = i * batch_size + idx_in_batch
if frame_idx >= video_num:
break
bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])]
ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])])
x1, y1, x2, y2 = bbox
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])]
mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])]
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
# Encode frame as JPEG
_, buffer = cv2.imencode('.jpg', combine_frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
import base64
frame_b64 = base64.b64encode(buffer).decode('utf-8')
frames_data.append(frame_b64)
os.unlink(audio_path)
return {
"frames": frames_data,
"fps": fps,
"total_frames": len(frames_data)
}
except Exception as e:
if os.path.exists(audio_path):
os.unlink(audio_path)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)