|
|
""" |
|
|
MuseTalk HTTP API Server v2 |
|
|
Optimized for repeated use of the same avatar. |
|
|
""" |
|
|
import os |
|
|
import cv2 |
|
|
import copy |
|
|
import torch |
|
|
import glob |
|
|
import shutil |
|
|
import pickle |
|
|
import numpy as np |
|
|
import subprocess |
|
|
import tempfile |
|
|
import hashlib |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from tqdm import tqdm |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import WhisperModel |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
from musetalk.utils.blending import get_image |
|
|
from musetalk.utils.face_parsing import FaceParsing |
|
|
from musetalk.utils.audio_processor import AudioProcessor |
|
|
from musetalk.utils.utils import get_file_type, datagen, load_all_model |
|
|
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder |
|
|
|
|
|
|
|
|
class MuseTalkServerV2: |
|
|
"""Server optimized for pre-processed avatars.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.device = None |
|
|
self.vae = None |
|
|
self.unet = None |
|
|
self.pe = None |
|
|
self.whisper = None |
|
|
self.audio_processor = None |
|
|
self.fp = None |
|
|
self.timesteps = None |
|
|
self.weight_dtype = None |
|
|
self.is_loaded = False |
|
|
|
|
|
|
|
|
self.loaded_avatars = {} |
|
|
self.avatar_dir = Path("./avatars") |
|
|
|
|
|
|
|
|
self.fps = 25 |
|
|
self.batch_size = 8 |
|
|
self.use_float16 = True |
|
|
self.version = "v15" |
|
|
self.extra_margin = 10 |
|
|
self.parsing_mode = "jaw" |
|
|
self.left_cheek_width = 90 |
|
|
self.right_cheek_width = 90 |
|
|
self.audio_padding_left = 2 |
|
|
self.audio_padding_right = 2 |
|
|
|
|
|
def load_models( |
|
|
self, |
|
|
gpu_id: int = 0, |
|
|
unet_model_path: str = "./models/musetalkV15/unet.pth", |
|
|
unet_config: str = "./models/musetalk/config.json", |
|
|
vae_type: str = "sd-vae", |
|
|
whisper_dir: str = "./models/whisper", |
|
|
use_float16: bool = True, |
|
|
version: str = "v15" |
|
|
): |
|
|
if self.is_loaded: |
|
|
print("Models already loaded!") |
|
|
return |
|
|
|
|
|
print("=" * 50) |
|
|
print("Loading MuseTalk models into GPU memory...") |
|
|
print("=" * 50) |
|
|
|
|
|
start_time = time.time() |
|
|
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
print("Loading VAE, UNet, PE...") |
|
|
self.vae, self.unet, self.pe = load_all_model( |
|
|
unet_model_path=unet_model_path, |
|
|
vae_type=vae_type, |
|
|
unet_config=unet_config, |
|
|
device=self.device |
|
|
) |
|
|
self.timesteps = torch.tensor([0], device=self.device) |
|
|
|
|
|
self.use_float16 = use_float16 |
|
|
if use_float16: |
|
|
print("Converting to float16...") |
|
|
self.pe = self.pe.half() |
|
|
self.vae.vae = self.vae.vae.half() |
|
|
self.unet.model = self.unet.model.half() |
|
|
|
|
|
self.pe = self.pe.to(self.device) |
|
|
self.vae.vae = self.vae.vae.to(self.device) |
|
|
self.unet.model = self.unet.model.to(self.device) |
|
|
|
|
|
print("Loading Whisper model...") |
|
|
self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) |
|
|
self.weight_dtype = self.unet.model.dtype |
|
|
self.whisper = WhisperModel.from_pretrained(whisper_dir) |
|
|
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() |
|
|
self.whisper.requires_grad_(False) |
|
|
|
|
|
self.version = version |
|
|
if version == "v15": |
|
|
self.fp = FaceParsing( |
|
|
left_cheek_width=self.left_cheek_width, |
|
|
right_cheek_width=self.right_cheek_width |
|
|
) |
|
|
else: |
|
|
self.fp = FaceParsing() |
|
|
|
|
|
self.is_loaded = True |
|
|
print(f"Models loaded in {time.time() - start_time:.2f}s") |
|
|
print("=" * 50) |
|
|
|
|
|
def load_avatar(self, avatar_name: str) -> dict: |
|
|
"""Load a preprocessed avatar into memory.""" |
|
|
if avatar_name in self.loaded_avatars: |
|
|
return self.loaded_avatars[avatar_name] |
|
|
|
|
|
avatar_path = self.avatar_dir / avatar_name |
|
|
if not avatar_path.exists(): |
|
|
raise FileNotFoundError(f"Avatar not found: {avatar_name}") |
|
|
|
|
|
print(f"Loading avatar '{avatar_name}' into memory...") |
|
|
t0 = time.time() |
|
|
|
|
|
avatar_data = {} |
|
|
|
|
|
|
|
|
with open(avatar_path / "metadata.pkl", 'rb') as f: |
|
|
avatar_data['metadata'] = pickle.load(f) |
|
|
|
|
|
|
|
|
with open(avatar_path / "coords.pkl", 'rb') as f: |
|
|
avatar_data['coord_list'] = pickle.load(f) |
|
|
|
|
|
|
|
|
with open(avatar_path / "frames.pkl", 'rb') as f: |
|
|
avatar_data['frame_list'] = pickle.load(f) |
|
|
|
|
|
|
|
|
with open(avatar_path / "latents.pkl", 'rb') as f: |
|
|
latents_np = pickle.load(f) |
|
|
avatar_data['latent_list'] = [ |
|
|
torch.from_numpy(l).to(self.device) for l in latents_np |
|
|
] |
|
|
|
|
|
|
|
|
with open(avatar_path / "crop_info.pkl", 'rb') as f: |
|
|
avatar_data['crop_info'] = pickle.load(f) |
|
|
|
|
|
|
|
|
parsing_path = avatar_path / "parsing.pkl" |
|
|
if parsing_path.exists(): |
|
|
with open(parsing_path, 'rb') as f: |
|
|
avatar_data['parsing_data'] = pickle.load(f) |
|
|
|
|
|
self.loaded_avatars[avatar_name] = avatar_data |
|
|
print(f"Avatar loaded in {time.time() - t0:.2f}s") |
|
|
|
|
|
return avatar_data |
|
|
|
|
|
def unload_avatar(self, avatar_name: str): |
|
|
"""Unload avatar from memory.""" |
|
|
if avatar_name in self.loaded_avatars: |
|
|
del self.loaded_avatars[avatar_name] |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_with_avatar( |
|
|
self, |
|
|
avatar_name: str, |
|
|
audio_path: str, |
|
|
output_path: str, |
|
|
fps: Optional[int] = None |
|
|
) -> dict: |
|
|
"""Generate video using pre-processed avatar. Much faster!""" |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("Models not loaded!") |
|
|
|
|
|
fps = fps or self.fps |
|
|
timings = {} |
|
|
total_start = time.time() |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
avatar = self.load_avatar(avatar_name) |
|
|
timings["avatar_load"] = time.time() - t0 |
|
|
|
|
|
coord_list = avatar['coord_list'] |
|
|
frame_list = avatar['frame_list'] |
|
|
input_latent_list = avatar['latent_list'] |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
try: |
|
|
|
|
|
t0 = time.time() |
|
|
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) |
|
|
whisper_chunks = self.audio_processor.get_whisper_chunk( |
|
|
whisper_input_features, |
|
|
self.device, |
|
|
self.weight_dtype, |
|
|
self.whisper, |
|
|
librosa_length, |
|
|
fps=fps, |
|
|
audio_padding_length_left=self.audio_padding_left, |
|
|
audio_padding_length_right=self.audio_padding_right, |
|
|
) |
|
|
timings["whisper_features"] = time.time() - t0 |
|
|
|
|
|
|
|
|
frame_list_cycle = frame_list + frame_list[::-1] |
|
|
coord_list_cycle = coord_list + coord_list[::-1] |
|
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
gen = datagen( |
|
|
whisper_chunks=whisper_chunks, |
|
|
vae_encode_latents=input_latent_list_cycle, |
|
|
batch_size=self.batch_size, |
|
|
delay_frame=0, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
res_frame_list = [] |
|
|
for whisper_batch, latent_batch in gen: |
|
|
audio_feature_batch = self.pe(whisper_batch) |
|
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
|
|
pred_latents = self.unet.model( |
|
|
latent_batch, self.timesteps, |
|
|
encoder_hidden_states=audio_feature_batch |
|
|
).sample |
|
|
recon = self.vae.decode_latents(pred_latents) |
|
|
for res_frame in recon: |
|
|
res_frame_list.append(res_frame) |
|
|
|
|
|
timings["unet_inference"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
result_img_path = os.path.join(temp_dir, "results") |
|
|
os.makedirs(result_img_path, exist_ok=True) |
|
|
|
|
|
for i, res_frame in enumerate(res_frame_list): |
|
|
bbox = coord_list_cycle[i % len(coord_list_cycle)] |
|
|
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) |
|
|
x1, y1, x2, y2 = bbox |
|
|
|
|
|
if self.version == "v15": |
|
|
y2 = y2 + self.extra_margin |
|
|
y2 = min(y2, ori_frame.shape[0]) |
|
|
|
|
|
try: |
|
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) |
|
|
except: |
|
|
continue |
|
|
|
|
|
if self.version == "v15": |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], |
|
|
mode=self.parsing_mode, fp=self.fp) |
|
|
else: |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) |
|
|
|
|
|
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) |
|
|
|
|
|
timings["face_blending"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
temp_vid = os.path.join(temp_dir, "temp.mp4") |
|
|
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" |
|
|
os.system(cmd_img2video) |
|
|
|
|
|
cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}" |
|
|
os.system(cmd_combine) |
|
|
|
|
|
timings["video_encoding"] = time.time() - t0 |
|
|
|
|
|
finally: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
timings["total"] = time.time() - total_start |
|
|
timings["frames_generated"] = len(res_frame_list) |
|
|
|
|
|
return timings |
|
|
|
|
|
|
|
|
|
|
|
server = MuseTalkServerV2() |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="MuseTalk API v2", |
|
|
description="Optimized API for repeated avatar usage", |
|
|
version="2.0.0" |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
server.load_models() |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return { |
|
|
"status": "ok" if server.is_loaded else "loading", |
|
|
"models_loaded": server.is_loaded, |
|
|
"device": str(server.device) if server.device else None, |
|
|
"loaded_avatars": list(server.loaded_avatars.keys()) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/avatars") |
|
|
async def list_avatars(): |
|
|
"""List all available preprocessed avatars.""" |
|
|
avatars = [] |
|
|
for p in server.avatar_dir.iterdir(): |
|
|
if p.is_dir() and (p / "metadata.pkl").exists(): |
|
|
with open(p / "metadata.pkl", 'rb') as f: |
|
|
metadata = pickle.load(f) |
|
|
metadata['loaded'] = p.name in server.loaded_avatars |
|
|
avatars.append(metadata) |
|
|
return {"avatars": avatars} |
|
|
|
|
|
|
|
|
@app.post("/avatars/{avatar_name}/load") |
|
|
async def load_avatar(avatar_name: str): |
|
|
"""Pre-load an avatar into GPU memory.""" |
|
|
try: |
|
|
server.load_avatar(avatar_name) |
|
|
return {"status": "loaded", "avatar_name": avatar_name} |
|
|
except FileNotFoundError as e: |
|
|
raise HTTPException(status_code=404, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/avatars/{avatar_name}/unload") |
|
|
async def unload_avatar(avatar_name: str): |
|
|
"""Unload an avatar from memory.""" |
|
|
server.unload_avatar(avatar_name) |
|
|
return {"status": "unloaded", "avatar_name": avatar_name} |
|
|
|
|
|
|
|
|
class GenerateWithAvatarRequest(BaseModel): |
|
|
avatar_name: str |
|
|
audio_path: str |
|
|
output_path: str |
|
|
fps: Optional[int] = 25 |
|
|
|
|
|
|
|
|
@app.post("/generate/avatar") |
|
|
async def generate_with_avatar(request: GenerateWithAvatarRequest): |
|
|
"""Generate video using pre-processed avatar. FAST!""" |
|
|
if not server.is_loaded: |
|
|
raise HTTPException(status_code=503, detail="Models not loaded") |
|
|
|
|
|
if not os.path.exists(request.audio_path): |
|
|
raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") |
|
|
|
|
|
try: |
|
|
timings = server.generate_with_avatar( |
|
|
avatar_name=request.avatar_name, |
|
|
audio_path=request.audio_path, |
|
|
output_path=request.output_path, |
|
|
fps=request.fps |
|
|
) |
|
|
return { |
|
|
"status": "success", |
|
|
"output_path": request.output_path, |
|
|
"timings": timings |
|
|
} |
|
|
except FileNotFoundError as e: |
|
|
raise HTTPException(status_code=404, detail=str(e)) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/generate/avatar/upload") |
|
|
async def generate_with_avatar_upload( |
|
|
avatar_name: str = Form(...), |
|
|
audio: UploadFile = File(...), |
|
|
fps: int = Form(25) |
|
|
): |
|
|
"""Generate video from uploaded audio using pre-processed avatar.""" |
|
|
if not server.is_loaded: |
|
|
raise HTTPException(status_code=503, detail="Models not loaded") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
try: |
|
|
audio_path = os.path.join(temp_dir, audio.filename) |
|
|
output_path = os.path.join(temp_dir, "output.mp4") |
|
|
|
|
|
with open(audio_path, "wb") as f: |
|
|
f.write(await audio.read()) |
|
|
|
|
|
timings = server.generate_with_avatar( |
|
|
avatar_name=avatar_name, |
|
|
audio_path=audio_path, |
|
|
output_path=output_path, |
|
|
fps=fps |
|
|
) |
|
|
|
|
|
return FileResponse( |
|
|
output_path, |
|
|
media_type="video/mp4", |
|
|
filename="result.mp4", |
|
|
headers={"X-Timings": str(timings)} |
|
|
) |
|
|
except Exception as e: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
|
parser.add_argument("--port", type=int, default=8000) |
|
|
args = parser.parse_args() |
|
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port) |
|
|
|