# audio_api.py import base64 import io from typing import Optional import torch import torchaudio from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from pydantic import BaseModel, Field from boson_multimodal.data_types import ChatMLSample, Message, AudioContent from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse # -------------------- 模型加载 -------------------- MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base" AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer" device = "cuda" if torch.cuda.is_available() else "cpu" serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device) # -------------------- FastAPI -------------------- app = FastAPI(title="Higgs Audio Generation API", version="0.1.0") class AudioRequest(BaseModel): user_prompt: str = Field(..., description="需要生成音频的文本") max_new_tokens: Optional[int] = Field(1024, ge=1, le=2048) temperature: Optional[float] = Field(0.3, ge=0.0, le=2.0) top_p: Optional[float] = Field(0.95, ge=0.0, le=1.0) top_k: Optional[int] = Field(50, ge=1, le=100) class AudioResponse(BaseModel): audio_base64: str sample_rate: int @app.post("/generate-audio", response_model=AudioResponse) def generate_audio(req: AudioRequest): system_prompt = ( "Generate audio following instruction.\n\n<|scene_desc_start|>\n" "Audio is recorded from a quiet room.\n<|scene_desc_end|>" ) messages = [ Message(role="system", content=system_prompt), Message(role="user", content=req.user_prompt), ] try: output: HiggsAudioResponse = serve_engine.generate( chat_ml_sample=ChatMLSample(messages=messages), max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, top_k=req.top_k, stop_strings=["<|end_of_text|>", "<|eot_id|>"], ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # 把 numpy 数组转 torch.Tensor 并编码成 WAV 字节流 waveform = torch.from_numpy(output.audio)[None, :] # shape=(1, T) buf = io.BytesIO() torchaudio.save(buf, waveform, output.sampling_rate, format="wav") audio_bytes = buf.getvalue() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") return AudioResponse(audio_base64=audio_b64, sample_rate=output.sampling_rate) # 新增:把 / 指向静态首页 app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/", include_in_schema=False) async def index(): return FileResponse("static/index.html")