import gradio as gr from transformers import pipeline import soundfile as sf import os import base64 import tempfile from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import uvicorn # --- Load Model --- try: classifier = pipeline("audio-classification", model="superb/wav2vec2-base-superb-er") except Exception as e: classifier = None model_load_error = str(e) else: model_load_error = None # --- FastAPI App for a dedicated, robust API --- app = FastAPI() @app.post("/api/predict/") async def predict_emotion_api(request: Request): if classifier is None: return JSONResponse(content={"error": f"Model is not loaded: {model_load_error}"}, status_code=503) try: body = await request.json() # The JS FileReader sends a string like "data:audio/wav;base64,AABBCC..." base64_with_prefix = body.get("data") if not base64_with_prefix: return JSONResponse(content={"error": "Missing 'data' field in request body."}, status_code=400) # Robustly strip the prefix to get the pure base64 data try: # Find the comma that separates the prefix from the data header, encoded = base64_with_prefix.split(",", 1) audio_data = base64.b64decode(encoded) except (ValueError, TypeError): return JSONResponse(content={"error": "Invalid base64 data format."}, status_code=400) # Write to a temporary file for the pipeline with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: temp_file.write(audio_data) temp_audio_path = temp_file.name results = classifier(temp_audio_path, top_k=5) os.unlink(temp_audio_path) # Clean up the temp file # Return a successful response return JSONResponse(content={"data": results}) except Exception as e: return JSONResponse(content={"error": f"Internal server error during prediction: {str(e)}"}, status_code=500) # --- Gradio UI function (optional, for the direct Space page) --- def gradio_predict_wrapper(audio_file): # This is just for the UI on the Hugging Face page itself if audio_file is None: return {"error": "Please provide an audio file."} results = classifier(audio_file, top_k=5) return {item['label']: round(item['score'], 3) for item in results} gradio_interface = gr.Interface( fn=gradio_predict_wrapper, inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload Audio or Record"), outputs=gr.Label(num_top_classes=5, label="Emotion Predictions"), title="Audio Emotion Detector", description="This UI is for direct demonstration. The primary API is at /api/predict/", allow_flagging="never" ) # --- Mount the Gradio UI onto the FastAPI app --- # The API at /api/predict/ will work even if the UI is at a different path. app = gr.mount_gradio_app(app, gradio_interface, path="/ui")