Spaces:
Runtime error
Runtime error
Commit
·
b648253
1
Parent(s):
d3f303b
commit 9
Browse files
main.py
CHANGED
|
@@ -8,58 +8,63 @@ import torch
|
|
| 8 |
import io
|
| 9 |
import os
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
app = FastAPI()
|
| 13 |
-
|
| 14 |
-
# Login to Hugging Face with secret token (set in HF Spaces secrets)
|
| 15 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 16 |
if not HF_TOKEN:
|
| 17 |
raise ValueError("HF_TOKEN environment variable not set")
|
|
|
|
|
|
|
| 18 |
login(token=HF_TOKEN)
|
| 19 |
|
| 20 |
-
#
|
|
|
|
|
|
|
|
|
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
print(f"Using device: {device}")
|
| 23 |
|
| 24 |
-
# Load Stable Diffusion pipeline
|
| 25 |
pipe = StableDiffusionPipeline.from_pretrained(
|
| 26 |
"CompVis/stable-diffusion-v1-4",
|
| 27 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 28 |
-
use_safetensors=True,
|
| 29 |
).to(device)
|
| 30 |
|
| 31 |
-
#
|
| 32 |
if device == "cuda":
|
| 33 |
-
pipe.enable_attention_slicing()
|
| 34 |
-
|
| 35 |
-
pipe.enable_xformers_memory_efficient_attention()
|
| 36 |
-
except Exception:
|
| 37 |
-
pass # Ignore if xformers not installed
|
| 38 |
|
| 39 |
-
#
|
| 40 |
class PromptRequest(BaseModel):
|
| 41 |
prompt: str
|
| 42 |
num_inference_steps: int = 20
|
| 43 |
guidance_scale: float = 7.5
|
| 44 |
|
| 45 |
-
#
|
| 46 |
@app.post("/generate-image/")
|
| 47 |
def generate_image(req: PromptRequest):
|
| 48 |
try:
|
| 49 |
-
|
|
|
|
| 50 |
prompt=req.prompt,
|
| 51 |
num_inference_steps=req.num_inference_steps,
|
| 52 |
guidance_scale=req.guidance_scale
|
| 53 |
).images[0]
|
| 54 |
|
|
|
|
| 55 |
img_bytes = io.BytesIO()
|
| 56 |
-
|
| 57 |
img_bytes.seek(0)
|
|
|
|
| 58 |
return StreamingResponse(img_bytes, media_type="image/png")
|
| 59 |
except Exception as e:
|
| 60 |
-
raise HTTPException(status_code=500, detail=f"
|
| 61 |
|
| 62 |
-
# Health check
|
| 63 |
@app.get("/health/")
|
| 64 |
def health_check():
|
| 65 |
return {"status": "healthy"}
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import io
|
| 9 |
import os
|
| 10 |
|
| 11 |
+
# Get token from environment (set in Hugging Face Spaces Secrets)
|
|
|
|
|
|
|
|
|
|
| 12 |
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 13 |
if not HF_TOKEN:
|
| 14 |
raise ValueError("HF_TOKEN environment variable not set")
|
| 15 |
+
|
| 16 |
+
# Login to Hugging Face
|
| 17 |
login(token=HF_TOKEN)
|
| 18 |
|
| 19 |
+
# Initialize FastAPI app
|
| 20 |
+
app = FastAPI()
|
| 21 |
+
|
| 22 |
+
# Dynamically select device (CPU or GPU)
|
| 23 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
print(f"Using device: {device}")
|
| 25 |
|
| 26 |
+
# Load Stable Diffusion pipeline with dynamic device and mixed precision
|
| 27 |
pipe = StableDiffusionPipeline.from_pretrained(
|
| 28 |
"CompVis/stable-diffusion-v1-4",
|
| 29 |
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 30 |
+
use_safetensors=True, # Use safetensors for efficiency
|
| 31 |
).to(device)
|
| 32 |
|
| 33 |
+
# Enable mixed precision training if GPU is available
|
| 34 |
if device == "cuda":
|
| 35 |
+
pipe.enable_attention_slicing() # Optimize memory usage
|
| 36 |
+
pipe.enable_xformers_memory_efficient_attention() # Optional, if xformers is installed
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
# Input schema
|
| 39 |
class PromptRequest(BaseModel):
|
| 40 |
prompt: str
|
| 41 |
num_inference_steps: int = 20
|
| 42 |
guidance_scale: float = 7.5
|
| 43 |
|
| 44 |
+
# Image generation endpoint
|
| 45 |
@app.post("/generate-image/")
|
| 46 |
def generate_image(req: PromptRequest):
|
| 47 |
try:
|
| 48 |
+
# Generate image
|
| 49 |
+
image = pipe(
|
| 50 |
prompt=req.prompt,
|
| 51 |
num_inference_steps=req.num_inference_steps,
|
| 52 |
guidance_scale=req.guidance_scale
|
| 53 |
).images[0]
|
| 54 |
|
| 55 |
+
# Convert to bytes
|
| 56 |
img_bytes = io.BytesIO()
|
| 57 |
+
image.save(img_bytes, format="PNG")
|
| 58 |
img_bytes.seek(0)
|
| 59 |
+
|
| 60 |
return StreamingResponse(img_bytes, media_type="image/png")
|
| 61 |
except Exception as e:
|
| 62 |
+
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
|
| 63 |
|
| 64 |
+
# Health check
|
| 65 |
@app.get("/health/")
|
| 66 |
def health_check():
|
| 67 |
return {"status": "healthy"}
|
| 68 |
+
|
| 69 |
+
# Note: uvicorn.run is handled by Hugging Face Spaces automatically
|
| 70 |
+
# No need to run the server manually here
|