|
|
""" |
|
|
MuseTalk HTTP API Server |
|
|
Keeps models loaded in GPU memory for fast inference. |
|
|
""" |
|
|
import os |
|
|
import cv2 |
|
|
import copy |
|
|
import torch |
|
|
import glob |
|
|
import shutil |
|
|
import pickle |
|
|
import numpy as np |
|
|
import subprocess |
|
|
import tempfile |
|
|
import hashlib |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from tqdm import tqdm |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import WhisperModel |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
from musetalk.utils.blending import get_image |
|
|
from musetalk.utils.face_parsing import FaceParsing |
|
|
from musetalk.utils.audio_processor import AudioProcessor |
|
|
from musetalk.utils.utils import get_file_type, datagen, load_all_model |
|
|
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder |
|
|
|
|
|
|
|
|
class MuseTalkServer: |
|
|
"""Singleton server that keeps models loaded in GPU memory.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.device = None |
|
|
self.vae = None |
|
|
self.unet = None |
|
|
self.pe = None |
|
|
self.whisper = None |
|
|
self.audio_processor = None |
|
|
self.fp = None |
|
|
self.timesteps = None |
|
|
self.weight_dtype = None |
|
|
self.is_loaded = False |
|
|
|
|
|
|
|
|
self.cache_dir = Path("./cache") |
|
|
self.cache_dir.mkdir(exist_ok=True) |
|
|
self.landmarks_cache = self.cache_dir / "landmarks" |
|
|
self.latents_cache = self.cache_dir / "latents" |
|
|
self.whisper_cache = self.cache_dir / "whisper_features" |
|
|
self.landmarks_cache.mkdir(exist_ok=True) |
|
|
self.latents_cache.mkdir(exist_ok=True) |
|
|
self.whisper_cache.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.fps = 25 |
|
|
self.batch_size = 8 |
|
|
self.use_float16 = True |
|
|
self.version = "v15" |
|
|
self.extra_margin = 10 |
|
|
self.parsing_mode = "jaw" |
|
|
self.left_cheek_width = 90 |
|
|
self.right_cheek_width = 90 |
|
|
self.audio_padding_left = 2 |
|
|
self.audio_padding_right = 2 |
|
|
|
|
|
def load_models( |
|
|
self, |
|
|
gpu_id: int = 0, |
|
|
unet_model_path: str = "./models/musetalkV15/unet.pth", |
|
|
unet_config: str = "./models/musetalk/config.json", |
|
|
vae_type: str = "sd-vae", |
|
|
whisper_dir: str = "./models/whisper", |
|
|
use_float16: bool = True, |
|
|
version: str = "v15" |
|
|
): |
|
|
"""Load all models into GPU memory.""" |
|
|
if self.is_loaded: |
|
|
print("Models already loaded!") |
|
|
return |
|
|
|
|
|
print("=" * 50) |
|
|
print("Loading MuseTalk models into GPU memory...") |
|
|
print("=" * 50) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
print("Loading VAE, UNet, PE...") |
|
|
self.vae, self.unet, self.pe = load_all_model( |
|
|
unet_model_path=unet_model_path, |
|
|
vae_type=vae_type, |
|
|
unet_config=unet_config, |
|
|
device=self.device |
|
|
) |
|
|
self.timesteps = torch.tensor([0], device=self.device) |
|
|
|
|
|
|
|
|
self.use_float16 = use_float16 |
|
|
if use_float16: |
|
|
print("Converting to float16...") |
|
|
self.pe = self.pe.half() |
|
|
self.vae.vae = self.vae.vae.half() |
|
|
self.unet.model = self.unet.model.half() |
|
|
|
|
|
|
|
|
self.pe = self.pe.to(self.device) |
|
|
self.vae.vae = self.vae.vae.to(self.device) |
|
|
self.unet.model = self.unet.model.to(self.device) |
|
|
|
|
|
|
|
|
print("Loading Whisper model...") |
|
|
self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) |
|
|
self.weight_dtype = self.unet.model.dtype |
|
|
self.whisper = WhisperModel.from_pretrained(whisper_dir) |
|
|
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() |
|
|
self.whisper.requires_grad_(False) |
|
|
|
|
|
|
|
|
self.version = version |
|
|
if version == "v15": |
|
|
self.fp = FaceParsing( |
|
|
left_cheek_width=self.left_cheek_width, |
|
|
right_cheek_width=self.right_cheek_width |
|
|
) |
|
|
else: |
|
|
self.fp = FaceParsing() |
|
|
|
|
|
self.is_loaded = True |
|
|
load_time = time.time() - start_time |
|
|
print(f"Models loaded in {load_time:.2f}s") |
|
|
print("=" * 50) |
|
|
print("Server ready for inference!") |
|
|
print("=" * 50) |
|
|
|
|
|
def _get_file_hash(self, file_path: str) -> str: |
|
|
"""Get MD5 hash of a file for caching.""" |
|
|
hash_md5 = hashlib.md5() |
|
|
with open(file_path, "rb") as f: |
|
|
for chunk in iter(lambda: f.read(4096), b""): |
|
|
hash_md5.update(chunk) |
|
|
return hash_md5.hexdigest()[:16] |
|
|
|
|
|
def _get_cached_landmarks(self, video_hash: str, bbox_shift: int): |
|
|
"""Get cached landmarks if available.""" |
|
|
|
|
|
return None |
|
|
|
|
|
def _save_landmarks_cache(self, video_hash: str, bbox_shift: int, coord_list, frame_list): |
|
|
"""Save landmarks to cache.""" |
|
|
cache_file = self.landmarks_cache / f"{video_hash}_shift{bbox_shift}.pkl" |
|
|
with open(cache_file, 'wb') as f: |
|
|
pickle.dump((coord_list, frame_list), f) |
|
|
|
|
|
def _get_cached_latents(self, video_hash: str): |
|
|
"""Get cached VAE latents if available.""" |
|
|
|
|
|
return None |
|
|
|
|
|
def _save_latents_cache(self, video_hash: str, latent_list): |
|
|
"""Save VAE latents to cache.""" |
|
|
cache_file = self.latents_cache / f"{video_hash}.pkl" |
|
|
with open(cache_file, 'wb') as f: |
|
|
pickle.dump(latent_list, f) |
|
|
|
|
|
def _get_cached_whisper(self, audio_hash: str): |
|
|
"""Get cached Whisper features if available.""" |
|
|
|
|
|
return None |
|
|
|
|
|
def _save_whisper_cache(self, audio_hash: str, whisper_data): |
|
|
"""Save Whisper features to cache.""" |
|
|
cache_file = self.whisper_cache / f"{audio_hash}.pkl" |
|
|
with open(cache_file, 'wb') as f: |
|
|
pickle.dump(whisper_data, f) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
video_path: str, |
|
|
audio_path: str, |
|
|
output_path: str, |
|
|
fps: Optional[int] = None, |
|
|
use_cache: bool = True |
|
|
) -> dict: |
|
|
""" |
|
|
Generate lip-synced video. |
|
|
|
|
|
Returns dict with timing info. |
|
|
""" |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("Models not loaded! Call load_models() first.") |
|
|
|
|
|
fps = fps or self.fps |
|
|
timings = {"total": 0} |
|
|
total_start = time.time() |
|
|
|
|
|
|
|
|
video_hash = self._get_file_hash(video_path) |
|
|
audio_hash = self._get_file_hash(audio_path) |
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
try: |
|
|
|
|
|
t0 = time.time() |
|
|
input_basename = Path(video_path).stem |
|
|
save_dir_full = os.path.join(temp_dir, "frames") |
|
|
os.makedirs(save_dir_full, exist_ok=True) |
|
|
|
|
|
if get_file_type(video_path) == "video": |
|
|
cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={fps} -start_number 0 {save_dir_full}/%08d.png" |
|
|
os.system(cmd) |
|
|
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) |
|
|
elif get_file_type(video_path) == "image": |
|
|
input_img_list = [video_path] |
|
|
else: |
|
|
raise ValueError(f"Unsupported video type: {video_path}") |
|
|
|
|
|
timings["frame_extraction"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
cached_whisper = self._get_cached_whisper(audio_hash) if use_cache else None |
|
|
|
|
|
if cached_whisper: |
|
|
whisper_chunks = cached_whisper |
|
|
timings["whisper_source"] = "cache" |
|
|
else: |
|
|
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) |
|
|
whisper_chunks = self.audio_processor.get_whisper_chunk( |
|
|
whisper_input_features, |
|
|
self.device, |
|
|
self.weight_dtype, |
|
|
self.whisper, |
|
|
librosa_length, |
|
|
fps=fps, |
|
|
audio_padding_length_left=self.audio_padding_left, |
|
|
audio_padding_length_right=self.audio_padding_right, |
|
|
) |
|
|
if use_cache: |
|
|
self._save_whisper_cache(audio_hash, whisper_chunks) |
|
|
timings["whisper_source"] = "computed" |
|
|
|
|
|
timings["whisper_features"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
bbox_shift = 0 if self.version == "v15" else 0 |
|
|
cache_key = f"{video_hash}_{fps}" |
|
|
|
|
|
cached_landmarks = self._get_cached_landmarks(cache_key, bbox_shift) if use_cache else None |
|
|
|
|
|
if cached_landmarks: |
|
|
coord_list, frame_list = cached_landmarks |
|
|
timings["landmarks_source"] = "cache" |
|
|
else: |
|
|
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) |
|
|
if use_cache: |
|
|
self._save_landmarks_cache(cache_key, bbox_shift, coord_list, frame_list) |
|
|
timings["landmarks_source"] = "computed" |
|
|
|
|
|
timings["landmarks"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
latent_cache_key = f"{video_hash}_{fps}_{self.version}" |
|
|
cached_latents = self._get_cached_latents(latent_cache_key) if use_cache else None |
|
|
|
|
|
if cached_latents: |
|
|
input_latent_list = cached_latents |
|
|
timings["latents_source"] = "cache" |
|
|
else: |
|
|
input_latent_list = [] |
|
|
for bbox, frame in zip(coord_list, frame_list): |
|
|
if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder): |
|
|
continue |
|
|
x1, y1, x2, y2 = bbox |
|
|
if self.version == "v15": |
|
|
y2 = y2 + self.extra_margin |
|
|
y2 = min(y2, frame.shape[0]) |
|
|
crop_frame = frame[y1:y2, x1:x2] |
|
|
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) |
|
|
latents = self.vae.get_latents_for_unet(crop_frame) |
|
|
input_latent_list.append(latents) |
|
|
|
|
|
if use_cache: |
|
|
self._save_latents_cache(latent_cache_key, input_latent_list) |
|
|
timings["latents_source"] = "computed" |
|
|
|
|
|
timings["vae_encoding"] = time.time() - t0 |
|
|
|
|
|
|
|
|
frame_list_cycle = frame_list + frame_list[::-1] |
|
|
coord_list_cycle = coord_list + coord_list[::-1] |
|
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
video_num = len(whisper_chunks) |
|
|
gen = datagen( |
|
|
whisper_chunks=whisper_chunks, |
|
|
vae_encode_latents=input_latent_list_cycle, |
|
|
batch_size=self.batch_size, |
|
|
delay_frame=0, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
res_frame_list = [] |
|
|
for whisper_batch, latent_batch in gen: |
|
|
audio_feature_batch = self.pe(whisper_batch) |
|
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
|
|
pred_latents = self.unet.model( |
|
|
latent_batch, self.timesteps, |
|
|
encoder_hidden_states=audio_feature_batch |
|
|
).sample |
|
|
recon = self.vae.decode_latents(pred_latents) |
|
|
for res_frame in recon: |
|
|
res_frame_list.append(res_frame) |
|
|
|
|
|
timings["unet_inference"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
result_img_path = os.path.join(temp_dir, "results") |
|
|
os.makedirs(result_img_path, exist_ok=True) |
|
|
|
|
|
for i, res_frame in enumerate(res_frame_list): |
|
|
bbox = coord_list_cycle[i % len(coord_list_cycle)] |
|
|
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) |
|
|
x1, y1, x2, y2 = bbox |
|
|
if self.version == "v15": |
|
|
y2 = y2 + self.extra_margin |
|
|
y2 = min(y2, ori_frame.shape[0]) |
|
|
try: |
|
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) |
|
|
except: |
|
|
continue |
|
|
|
|
|
if self.version == "v15": |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], |
|
|
mode=self.parsing_mode, fp=self.fp) |
|
|
else: |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) |
|
|
|
|
|
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) |
|
|
|
|
|
timings["face_blending"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
temp_vid = os.path.join(temp_dir, "temp.mp4") |
|
|
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" |
|
|
os.system(cmd_img2video) |
|
|
|
|
|
cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}" |
|
|
os.system(cmd_combine) |
|
|
|
|
|
timings["video_encoding"] = time.time() - t0 |
|
|
|
|
|
finally: |
|
|
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
timings["total"] = time.time() - total_start |
|
|
timings["frames_generated"] = len(res_frame_list) |
|
|
|
|
|
return timings |
|
|
|
|
|
|
|
|
|
|
|
server = MuseTalkServer() |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="MuseTalk API", |
|
|
description="HTTP API for MuseTalk lip-sync generation", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load models on server startup.""" |
|
|
server.load_models() |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Check if server is ready.""" |
|
|
return { |
|
|
"status": "ok" if server.is_loaded else "loading", |
|
|
"models_loaded": server.is_loaded, |
|
|
"device": str(server.device) if server.device else None |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/cache/stats") |
|
|
async def cache_stats(): |
|
|
"""Get cache statistics.""" |
|
|
landmarks_count = len(list(server.landmarks_cache.glob("*.pkl"))) |
|
|
latents_count = len(list(server.latents_cache.glob("*.pkl"))) |
|
|
whisper_count = len(list(server.whisper_cache.glob("*.pkl"))) |
|
|
|
|
|
return { |
|
|
"landmarks_cached": landmarks_count, |
|
|
"latents_cached": latents_count, |
|
|
"whisper_features_cached": whisper_count |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/cache/clear") |
|
|
async def clear_cache(): |
|
|
"""Clear all caches.""" |
|
|
for cache_dir in [server.landmarks_cache, server.latents_cache, server.whisper_cache]: |
|
|
for f in cache_dir.glob("*.pkl"): |
|
|
f.unlink() |
|
|
return {"status": "cleared"} |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
video_path: str |
|
|
audio_path: str |
|
|
output_path: str |
|
|
fps: Optional[int] = 25 |
|
|
use_cache: bool = True |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_from_paths(request: GenerateRequest): |
|
|
""" |
|
|
Generate lip-synced video from file paths. |
|
|
|
|
|
Use this when files are already on the server. |
|
|
""" |
|
|
if not server.is_loaded: |
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet") |
|
|
|
|
|
if not os.path.exists(request.video_path): |
|
|
raise HTTPException(status_code=404, detail=f"Video not found: {request.video_path}") |
|
|
if not os.path.exists(request.audio_path): |
|
|
raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") |
|
|
|
|
|
try: |
|
|
timings = server.generate( |
|
|
video_path=request.video_path, |
|
|
audio_path=request.audio_path, |
|
|
output_path=request.output_path, |
|
|
fps=request.fps, |
|
|
use_cache=request.use_cache |
|
|
) |
|
|
return { |
|
|
"status": "success", |
|
|
"output_path": request.output_path, |
|
|
"timings": timings |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/generate/upload") |
|
|
async def generate_from_upload( |
|
|
video: UploadFile = File(...), |
|
|
audio: UploadFile = File(...), |
|
|
fps: int = Form(25), |
|
|
use_cache: bool = Form(True) |
|
|
): |
|
|
""" |
|
|
Generate lip-synced video from uploaded files. |
|
|
|
|
|
Returns the generated video file. |
|
|
""" |
|
|
if not server.is_loaded: |
|
|
raise HTTPException(status_code=503, detail="Models not loaded yet") |
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
try: |
|
|
video_path = os.path.join(temp_dir, video.filename) |
|
|
audio_path = os.path.join(temp_dir, audio.filename) |
|
|
output_path = os.path.join(temp_dir, "output.mp4") |
|
|
|
|
|
with open(video_path, "wb") as f: |
|
|
f.write(await video.read()) |
|
|
with open(audio_path, "wb") as f: |
|
|
f.write(await audio.read()) |
|
|
|
|
|
timings = server.generate( |
|
|
video_path=video_path, |
|
|
audio_path=audio_path, |
|
|
output_path=output_path, |
|
|
fps=fps, |
|
|
use_cache=use_cache |
|
|
) |
|
|
|
|
|
|
|
|
return FileResponse( |
|
|
output_path, |
|
|
media_type="video/mp4", |
|
|
filename="result.mp4", |
|
|
headers={"X-Timings": str(timings)} |
|
|
) |
|
|
except Exception as e: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="MuseTalk API Server") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind") |
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind") |
|
|
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID") |
|
|
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth") |
|
|
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json") |
|
|
parser.add_argument("--whisper_dir", type=str, default="./models/whisper") |
|
|
parser.add_argument("--no_float16", action="store_true", help="Disable float16") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
server.load_models( |
|
|
gpu_id=args.gpu_id, |
|
|
unet_model_path=args.unet_model_path, |
|
|
unet_config=args.unet_config, |
|
|
whisper_dir=args.whisper_dir, |
|
|
use_float16=not args.no_float16 |
|
|
) |
|
|
|
|
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port) |
|
|
|