|
|
""" |
|
|
MuseTalk HTTP API Server v3 |
|
|
Ultra-optimized with: |
|
|
1. GPU-accelerated face blending (parallel processing) |
|
|
2. NVENC hardware video encoding |
|
|
3. Batch audio processing |
|
|
""" |
|
|
import os |
|
|
import cv2 |
|
|
import copy |
|
|
import torch |
|
|
import glob |
|
|
import shutil |
|
|
import pickle |
|
|
import numpy as np |
|
|
import subprocess |
|
|
import tempfile |
|
|
import hashlib |
|
|
import time |
|
|
import asyncio |
|
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor |
|
|
from pathlib import Path |
|
|
from typing import Optional, List |
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from tqdm import tqdm |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import WhisperModel |
|
|
import uvicorn |
|
|
import multiprocessing as mp |
|
|
|
|
|
|
|
|
from musetalk.utils.blending import get_image |
|
|
from musetalk.utils.face_parsing import FaceParsing |
|
|
from musetalk.utils.audio_processor import AudioProcessor |
|
|
from musetalk.utils.utils import get_file_type, datagen, load_all_model |
|
|
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder |
|
|
|
|
|
|
|
|
def blend_single_frame(args): |
|
|
"""Worker function for parallel face blending.""" |
|
|
i, res_frame, bbox, ori_frame, extra_margin, version, parsing_mode, fp_config = args |
|
|
|
|
|
x1, y1, x2, y2 = bbox |
|
|
if version == "v15": |
|
|
y2 = y2 + extra_margin |
|
|
y2 = min(y2, ori_frame.shape[0]) |
|
|
|
|
|
try: |
|
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) |
|
|
except: |
|
|
return i, None |
|
|
|
|
|
|
|
|
fp = FaceParsing( |
|
|
left_cheek_width=fp_config['left_cheek_width'], |
|
|
right_cheek_width=fp_config['right_cheek_width'] |
|
|
) |
|
|
|
|
|
if version == "v15": |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], |
|
|
mode=parsing_mode, fp=fp) |
|
|
else: |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) |
|
|
|
|
|
return i, combine_frame |
|
|
|
|
|
|
|
|
class MuseTalkServerV3: |
|
|
"""Ultra-optimized server.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.device = None |
|
|
self.vae = None |
|
|
self.unet = None |
|
|
self.pe = None |
|
|
self.whisper = None |
|
|
self.audio_processor = None |
|
|
self.fp = None |
|
|
self.timesteps = None |
|
|
self.weight_dtype = None |
|
|
self.is_loaded = False |
|
|
|
|
|
|
|
|
self.loaded_avatars = {} |
|
|
self.avatar_dir = Path("./avatars") |
|
|
|
|
|
|
|
|
self.fps = 25 |
|
|
self.batch_size = 8 |
|
|
self.use_float16 = True |
|
|
self.version = "v15" |
|
|
self.extra_margin = 10 |
|
|
self.parsing_mode = "jaw" |
|
|
self.left_cheek_width = 90 |
|
|
self.right_cheek_width = 90 |
|
|
self.audio_padding_left = 2 |
|
|
self.audio_padding_right = 2 |
|
|
|
|
|
|
|
|
self.num_workers = min(8, mp.cpu_count()) |
|
|
self.thread_pool = ThreadPoolExecutor(max_workers=self.num_workers) |
|
|
|
|
|
|
|
|
self.use_nvenc = True |
|
|
self.nvenc_preset = "p4" |
|
|
self.crf = 23 |
|
|
|
|
|
def load_models( |
|
|
self, |
|
|
gpu_id: int = 0, |
|
|
unet_model_path: str = "./models/musetalkV15/unet.pth", |
|
|
unet_config: str = "./models/musetalk/config.json", |
|
|
vae_type: str = "sd-vae", |
|
|
whisper_dir: str = "./models/whisper", |
|
|
use_float16: bool = True, |
|
|
version: str = "v15" |
|
|
): |
|
|
if self.is_loaded: |
|
|
print("Models already loaded!") |
|
|
return |
|
|
|
|
|
print("=" * 50) |
|
|
print("Loading MuseTalk models (v3 Ultra-Optimized)...") |
|
|
print("=" * 50) |
|
|
|
|
|
start_time = time.time() |
|
|
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {self.device}") |
|
|
print(f"Parallel workers: {self.num_workers}") |
|
|
print(f"NVENC encoding: {self.use_nvenc}") |
|
|
|
|
|
print("Loading VAE, UNet, PE...") |
|
|
self.vae, self.unet, self.pe = load_all_model( |
|
|
unet_model_path=unet_model_path, |
|
|
vae_type=vae_type, |
|
|
unet_config=unet_config, |
|
|
device=self.device |
|
|
) |
|
|
self.timesteps = torch.tensor([0], device=self.device) |
|
|
|
|
|
self.use_float16 = use_float16 |
|
|
if use_float16: |
|
|
print("Converting to float16...") |
|
|
self.pe = self.pe.half() |
|
|
self.vae.vae = self.vae.vae.half() |
|
|
self.unet.model = self.unet.model.half() |
|
|
|
|
|
self.pe = self.pe.to(self.device) |
|
|
self.vae.vae = self.vae.vae.to(self.device) |
|
|
self.unet.model = self.unet.model.to(self.device) |
|
|
|
|
|
print("Loading Whisper model...") |
|
|
self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) |
|
|
self.weight_dtype = self.unet.model.dtype |
|
|
self.whisper = WhisperModel.from_pretrained(whisper_dir) |
|
|
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval() |
|
|
self.whisper.requires_grad_(False) |
|
|
|
|
|
self.version = version |
|
|
if version == "v15": |
|
|
self.fp = FaceParsing( |
|
|
left_cheek_width=self.left_cheek_width, |
|
|
right_cheek_width=self.right_cheek_width |
|
|
) |
|
|
else: |
|
|
self.fp = FaceParsing() |
|
|
|
|
|
self.is_loaded = True |
|
|
print(f"Models loaded in {time.time() - start_time:.2f}s") |
|
|
print("=" * 50) |
|
|
|
|
|
def load_avatar(self, avatar_name: str) -> dict: |
|
|
if avatar_name in self.loaded_avatars: |
|
|
return self.loaded_avatars[avatar_name] |
|
|
|
|
|
avatar_path = self.avatar_dir / avatar_name |
|
|
if not avatar_path.exists(): |
|
|
raise FileNotFoundError(f"Avatar not found: {avatar_name}") |
|
|
|
|
|
print(f"Loading avatar '{avatar_name}' into memory...") |
|
|
t0 = time.time() |
|
|
|
|
|
avatar_data = {} |
|
|
|
|
|
with open(avatar_path / "metadata.pkl", 'rb') as f: |
|
|
avatar_data['metadata'] = pickle.load(f) |
|
|
|
|
|
with open(avatar_path / "coords.pkl", 'rb') as f: |
|
|
avatar_data['coord_list'] = pickle.load(f) |
|
|
|
|
|
with open(avatar_path / "frames.pkl", 'rb') as f: |
|
|
avatar_data['frame_list'] = pickle.load(f) |
|
|
|
|
|
with open(avatar_path / "latents.pkl", 'rb') as f: |
|
|
latents_np = pickle.load(f) |
|
|
avatar_data['latent_list'] = [ |
|
|
torch.from_numpy(l).to(self.device) for l in latents_np |
|
|
] |
|
|
|
|
|
with open(avatar_path / "crop_info.pkl", 'rb') as f: |
|
|
avatar_data['crop_info'] = pickle.load(f) |
|
|
|
|
|
self.loaded_avatars[avatar_name] = avatar_data |
|
|
print(f"Avatar loaded in {time.time() - t0:.2f}s") |
|
|
|
|
|
return avatar_data |
|
|
|
|
|
def unload_avatar(self, avatar_name: str): |
|
|
if avatar_name in self.loaded_avatars: |
|
|
del self.loaded_avatars[avatar_name] |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float: |
|
|
"""Encode video using NVENC hardware acceleration.""" |
|
|
t0 = time.time() |
|
|
temp_vid = frames_dir.replace('/results', '/temp.mp4') |
|
|
|
|
|
if self.use_nvenc: |
|
|
|
|
|
cmd_img2video = ( |
|
|
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " |
|
|
f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} " |
|
|
f"-pix_fmt yuv420p {temp_vid}" |
|
|
) |
|
|
else: |
|
|
|
|
|
cmd_img2video = ( |
|
|
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png " |
|
|
f"-vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}" |
|
|
) |
|
|
|
|
|
os.system(cmd_img2video) |
|
|
|
|
|
|
|
|
cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}" |
|
|
os.system(cmd_combine) |
|
|
|
|
|
|
|
|
if os.path.exists(temp_vid): |
|
|
os.remove(temp_vid) |
|
|
|
|
|
return time.time() - t0 |
|
|
|
|
|
def _parallel_face_blending(self, res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path) -> float: |
|
|
"""Parallel face blending using thread pool.""" |
|
|
t0 = time.time() |
|
|
|
|
|
fp_config = { |
|
|
'left_cheek_width': self.left_cheek_width, |
|
|
'right_cheek_width': self.right_cheek_width |
|
|
} |
|
|
|
|
|
|
|
|
tasks = [] |
|
|
for i, res_frame in enumerate(res_frame_list): |
|
|
bbox = coord_list_cycle[i % len(coord_list_cycle)] |
|
|
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) |
|
|
tasks.append(( |
|
|
i, res_frame, bbox, ori_frame, |
|
|
self.extra_margin, self.version, self.parsing_mode, fp_config |
|
|
)) |
|
|
|
|
|
|
|
|
results = list(self.thread_pool.map(blend_single_frame, tasks)) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[0]) |
|
|
for i, combine_frame in results: |
|
|
if combine_frame is not None: |
|
|
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) |
|
|
|
|
|
return time.time() - t0 |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_with_avatar( |
|
|
self, |
|
|
avatar_name: str, |
|
|
audio_path: str, |
|
|
output_path: str, |
|
|
fps: Optional[int] = None, |
|
|
use_parallel_blending: bool = True |
|
|
) -> dict: |
|
|
"""Generate video using pre-processed avatar with all optimizations.""" |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("Models not loaded!") |
|
|
|
|
|
fps = fps or self.fps |
|
|
timings = {} |
|
|
total_start = time.time() |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
avatar = self.load_avatar(avatar_name) |
|
|
timings["avatar_load"] = time.time() - t0 |
|
|
|
|
|
coord_list = avatar['coord_list'] |
|
|
frame_list = avatar['frame_list'] |
|
|
input_latent_list = avatar['latent_list'] |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
try: |
|
|
|
|
|
t0 = time.time() |
|
|
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) |
|
|
whisper_chunks = self.audio_processor.get_whisper_chunk( |
|
|
whisper_input_features, |
|
|
self.device, |
|
|
self.weight_dtype, |
|
|
self.whisper, |
|
|
librosa_length, |
|
|
fps=fps, |
|
|
audio_padding_length_left=self.audio_padding_left, |
|
|
audio_padding_length_right=self.audio_padding_right, |
|
|
) |
|
|
timings["whisper_features"] = time.time() - t0 |
|
|
|
|
|
|
|
|
frame_list_cycle = frame_list + frame_list[::-1] |
|
|
coord_list_cycle = coord_list + coord_list[::-1] |
|
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
gen = datagen( |
|
|
whisper_chunks=whisper_chunks, |
|
|
vae_encode_latents=input_latent_list_cycle, |
|
|
batch_size=self.batch_size, |
|
|
delay_frame=0, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
res_frame_list = [] |
|
|
for whisper_batch, latent_batch in gen: |
|
|
audio_feature_batch = self.pe(whisper_batch) |
|
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
|
|
pred_latents = self.unet.model( |
|
|
latent_batch, self.timesteps, |
|
|
encoder_hidden_states=audio_feature_batch |
|
|
).sample |
|
|
recon = self.vae.decode_latents(pred_latents) |
|
|
for res_frame in recon: |
|
|
res_frame_list.append(res_frame) |
|
|
|
|
|
timings["unet_inference"] = time.time() - t0 |
|
|
|
|
|
|
|
|
result_img_path = os.path.join(temp_dir, "results") |
|
|
os.makedirs(result_img_path, exist_ok=True) |
|
|
|
|
|
if use_parallel_blending: |
|
|
timings["face_blending"] = self._parallel_face_blending( |
|
|
res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path |
|
|
) |
|
|
timings["blending_mode"] = "parallel" |
|
|
else: |
|
|
t0 = time.time() |
|
|
for i, res_frame in enumerate(res_frame_list): |
|
|
bbox = coord_list_cycle[i % len(coord_list_cycle)] |
|
|
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) |
|
|
x1, y1, x2, y2 = bbox |
|
|
|
|
|
if self.version == "v15": |
|
|
y2 = y2 + self.extra_margin |
|
|
y2 = min(y2, ori_frame.shape[0]) |
|
|
|
|
|
try: |
|
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) |
|
|
except: |
|
|
continue |
|
|
|
|
|
if self.version == "v15": |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], |
|
|
mode=self.parsing_mode, fp=self.fp) |
|
|
else: |
|
|
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp) |
|
|
|
|
|
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame) |
|
|
timings["face_blending"] = time.time() - t0 |
|
|
timings["blending_mode"] = "sequential" |
|
|
|
|
|
|
|
|
timings["video_encoding"] = self._encode_video_nvenc( |
|
|
result_img_path, audio_path, output_path, fps |
|
|
) |
|
|
timings["encoding_mode"] = "nvenc" if self.use_nvenc else "cpu" |
|
|
|
|
|
finally: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
timings["total"] = time.time() - total_start |
|
|
timings["frames_generated"] = len(res_frame_list) |
|
|
|
|
|
return timings |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_batch( |
|
|
self, |
|
|
avatar_name: str, |
|
|
audio_paths: List[str], |
|
|
output_dir: str, |
|
|
fps: Optional[int] = None |
|
|
) -> dict: |
|
|
"""Generate multiple videos from multiple audios efficiently.""" |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("Models not loaded!") |
|
|
|
|
|
fps = fps or self.fps |
|
|
batch_timings = {"videos": [], "total": 0} |
|
|
total_start = time.time() |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
avatar = self.load_avatar(avatar_name) |
|
|
batch_timings["avatar_load"] = time.time() - t0 |
|
|
|
|
|
coord_list = avatar['coord_list'] |
|
|
frame_list = avatar['frame_list'] |
|
|
input_latent_list = avatar['latent_list'] |
|
|
|
|
|
|
|
|
frame_list_cycle = frame_list + frame_list[::-1] |
|
|
coord_list_cycle = coord_list + coord_list[::-1] |
|
|
input_latent_list_cycle = input_latent_list + input_latent_list[::-1] |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
for idx, audio_path in enumerate(audio_paths): |
|
|
video_start = time.time() |
|
|
timings = {} |
|
|
|
|
|
audio_name = Path(audio_path).stem |
|
|
output_path = os.path.join(output_dir, f"{audio_name}.mp4") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
try: |
|
|
|
|
|
t0 = time.time() |
|
|
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path) |
|
|
whisper_chunks = self.audio_processor.get_whisper_chunk( |
|
|
whisper_input_features, |
|
|
self.device, |
|
|
self.weight_dtype, |
|
|
self.whisper, |
|
|
librosa_length, |
|
|
fps=fps, |
|
|
audio_padding_length_left=self.audio_padding_left, |
|
|
audio_padding_length_right=self.audio_padding_right, |
|
|
) |
|
|
timings["whisper_features"] = time.time() - t0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
gen = datagen( |
|
|
whisper_chunks=whisper_chunks, |
|
|
vae_encode_latents=input_latent_list_cycle, |
|
|
batch_size=self.batch_size, |
|
|
delay_frame=0, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
res_frame_list = [] |
|
|
for whisper_batch, latent_batch in gen: |
|
|
audio_feature_batch = self.pe(whisper_batch) |
|
|
latent_batch = latent_batch.to(dtype=self.unet.model.dtype) |
|
|
pred_latents = self.unet.model( |
|
|
latent_batch, self.timesteps, |
|
|
encoder_hidden_states=audio_feature_batch |
|
|
).sample |
|
|
recon = self.vae.decode_latents(pred_latents) |
|
|
for res_frame in recon: |
|
|
res_frame_list.append(res_frame) |
|
|
|
|
|
timings["unet_inference"] = time.time() - t0 |
|
|
|
|
|
|
|
|
result_img_path = os.path.join(temp_dir, "results") |
|
|
os.makedirs(result_img_path, exist_ok=True) |
|
|
timings["face_blending"] = self._parallel_face_blending( |
|
|
res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path |
|
|
) |
|
|
|
|
|
|
|
|
timings["video_encoding"] = self._encode_video_nvenc( |
|
|
result_img_path, audio_path, output_path, fps |
|
|
) |
|
|
|
|
|
finally: |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
|
timings["total"] = time.time() - video_start |
|
|
timings["frames_generated"] = len(res_frame_list) |
|
|
timings["output_path"] = output_path |
|
|
timings["audio_path"] = audio_path |
|
|
|
|
|
batch_timings["videos"].append(timings) |
|
|
print(f" [{idx+1}/{len(audio_paths)}] {audio_name}: {timings['total']:.2f}s") |
|
|
|
|
|
batch_timings["total"] = time.time() - total_start |
|
|
batch_timings["num_videos"] = len(audio_paths) |
|
|
batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0 |
|
|
|
|
|
return batch_timings |
|
|
|
|
|
|
|
|
|
|
|
server = MuseTalkServerV3() |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="MuseTalk API v3", |
|
|
description="Ultra-optimized API with parallel blending, NVENC, and batch processing", |
|
|
version="3.0.0" |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
server.load_models() |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return { |
|
|
"status": "ok" if server.is_loaded else "loading", |
|
|
"models_loaded": server.is_loaded, |
|
|
"device": str(server.device) if server.device else None, |
|
|
"loaded_avatars": list(server.loaded_avatars.keys()), |
|
|
"optimizations": { |
|
|
"parallel_workers": server.num_workers, |
|
|
"nvenc_enabled": server.use_nvenc, |
|
|
"nvenc_preset": server.nvenc_preset |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/avatars") |
|
|
async def list_avatars(): |
|
|
avatars = [] |
|
|
for p in server.avatar_dir.iterdir(): |
|
|
if p.is_dir() and (p / "metadata.pkl").exists(): |
|
|
with open(p / "metadata.pkl", 'rb') as f: |
|
|
metadata = pickle.load(f) |
|
|
metadata['loaded'] = p.name in server.loaded_avatars |
|
|
avatars.append(metadata) |
|
|
return {"avatars": avatars} |
|
|
|
|
|
|
|
|
@app.post("/avatars/{avatar_name}/load") |
|
|
async def load_avatar(avatar_name: str): |
|
|
try: |
|
|
server.load_avatar(avatar_name) |
|
|
return {"status": "loaded", "avatar_name": avatar_name} |
|
|
except FileNotFoundError as e: |
|
|
raise HTTPException(status_code=404, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/avatars/{avatar_name}/unload") |
|
|
async def unload_avatar(avatar_name: str): |
|
|
server.unload_avatar(avatar_name) |
|
|
return {"status": "unloaded", "avatar_name": avatar_name} |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
avatar_name: str |
|
|
audio_path: str |
|
|
output_path: str |
|
|
fps: Optional[int] = 25 |
|
|
use_parallel_blending: bool = True |
|
|
|
|
|
|
|
|
@app.post("/generate/avatar") |
|
|
async def generate_with_avatar(request: GenerateRequest): |
|
|
if not server.is_loaded: |
|
|
raise HTTPException(status_code=503, detail="Models not loaded") |
|
|
|
|
|
if not os.path.exists(request.audio_path): |
|
|
raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}") |
|
|
|
|
|
try: |
|
|
timings = server.generate_with_avatar( |
|
|
avatar_name=request.avatar_name, |
|
|
audio_path=request.audio_path, |
|
|
output_path=request.output_path, |
|
|
fps=request.fps, |
|
|
use_parallel_blending=request.use_parallel_blending |
|
|
) |
|
|
return { |
|
|
"status": "success", |
|
|
"output_path": request.output_path, |
|
|
"timings": timings |
|
|
} |
|
|
except FileNotFoundError as e: |
|
|
raise HTTPException(status_code=404, detail=str(e)) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
class BatchGenerateRequest(BaseModel): |
|
|
avatar_name: str |
|
|
audio_paths: List[str] |
|
|
output_dir: str |
|
|
fps: Optional[int] = 25 |
|
|
|
|
|
|
|
|
@app.post("/generate/batch") |
|
|
async def generate_batch(request: BatchGenerateRequest): |
|
|
"""Generate multiple videos from multiple audios.""" |
|
|
if not server.is_loaded: |
|
|
raise HTTPException(status_code=503, detail="Models not loaded") |
|
|
|
|
|
for audio_path in request.audio_paths: |
|
|
if not os.path.exists(audio_path): |
|
|
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_path}") |
|
|
|
|
|
try: |
|
|
timings = server.generate_batch( |
|
|
avatar_name=request.avatar_name, |
|
|
audio_paths=request.audio_paths, |
|
|
output_dir=request.output_dir, |
|
|
fps=request.fps |
|
|
) |
|
|
return { |
|
|
"status": "success", |
|
|
"output_dir": request.output_dir, |
|
|
"timings": timings |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
|
parser.add_argument("--port", type=int, default=8000) |
|
|
args = parser.parse_args() |
|
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port) |
|
|
|