from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from diffusers import StableDiffusionPipeline from huggingface_hub import login from PIL import Image import torch import io import os # Initialize FastAPI app app = FastAPI(title="Stable Diffusion API") # Add CORS middleware for frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Get token from environment HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable not set") # Login to Hugging Face login(token=HF_TOKEN) # Dynamically select device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load Stable Diffusion pipeline pipe = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 if device == "cuda" else torch.float32, use_safetensors=True, ).to(device) # Optimize for CPU/GPU if device == "cuda": pipe.enable_attention_slicing() try: pipe.enable_xformers_memory_efficient_attention() except: pass # xformers may not be available # Input schema class PromptRequest(BaseModel): prompt: str num_inference_steps: int = 20 guidance_scale: float = 7.5 # Image generation endpoint @app.post("/generate-image/") def generate_image(req: PromptRequest): try: torch.cuda.empty_cache() # Clear GPU memory if applicable image = pipe( prompt=req.prompt, num_inference_steps=req.num_inference_steps, guidance_scale=req.guidance_scale ).images[0] img_bytes = io.BytesIO() image.save(img_bytes, format="PNG") img_bytes.seek(0) return StreamingResponse(img_bytes, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") # Health check @app.get("/health/") def health_check(): return {"status": "healthy"}