Abdurrehman015 commited on
Commit
b648253
·
1 Parent(s): d3f303b
Files changed (1) hide show
  1. main.py +24 -19
main.py CHANGED
@@ -8,58 +8,63 @@ import torch
8
  import io
9
  import os
10
 
11
- # Initialize FastAPI app ( First, before anything else)
12
- app = FastAPI()
13
-
14
- # Login to Hugging Face with secret token (set in HF Spaces secrets)
15
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
16
  if not HF_TOKEN:
17
  raise ValueError("HF_TOKEN environment variable not set")
 
 
18
  login(token=HF_TOKEN)
19
 
20
- # Select device
 
 
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  print(f"Using device: {device}")
23
 
24
- # Load Stable Diffusion pipeline
25
  pipe = StableDiffusionPipeline.from_pretrained(
26
  "CompVis/stable-diffusion-v1-4",
27
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
28
- use_safetensors=True,
29
  ).to(device)
30
 
31
- # Optimize if on GPU
32
  if device == "cuda":
33
- pipe.enable_attention_slicing()
34
- try:
35
- pipe.enable_xformers_memory_efficient_attention()
36
- except Exception:
37
- pass # Ignore if xformers not installed
38
 
39
- # Request model
40
  class PromptRequest(BaseModel):
41
  prompt: str
42
  num_inference_steps: int = 20
43
  guidance_scale: float = 7.5
44
 
45
- # Generation route
46
  @app.post("/generate-image/")
47
  def generate_image(req: PromptRequest):
48
  try:
49
- result = pipe(
 
50
  prompt=req.prompt,
51
  num_inference_steps=req.num_inference_steps,
52
  guidance_scale=req.guidance_scale
53
  ).images[0]
54
 
 
55
  img_bytes = io.BytesIO()
56
- result.save(img_bytes, format="PNG")
57
  img_bytes.seek(0)
 
58
  return StreamingResponse(img_bytes, media_type="image/png")
59
  except Exception as e:
60
- raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
61
 
62
- # Health check route
63
  @app.get("/health/")
64
  def health_check():
65
  return {"status": "healthy"}
 
 
 
 
8
  import io
9
  import os
10
 
11
+ # Get token from environment (set in Hugging Face Spaces Secrets)
 
 
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
13
  if not HF_TOKEN:
14
  raise ValueError("HF_TOKEN environment variable not set")
15
+
16
+ # Login to Hugging Face
17
  login(token=HF_TOKEN)
18
 
19
+ # Initialize FastAPI app
20
+ app = FastAPI()
21
+
22
+ # Dynamically select device (CPU or GPU)
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  print(f"Using device: {device}")
25
 
26
+ # Load Stable Diffusion pipeline with dynamic device and mixed precision
27
  pipe = StableDiffusionPipeline.from_pretrained(
28
  "CompVis/stable-diffusion-v1-4",
29
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
30
+ use_safetensors=True, # Use safetensors for efficiency
31
  ).to(device)
32
 
33
+ # Enable mixed precision training if GPU is available
34
  if device == "cuda":
35
+ pipe.enable_attention_slicing() # Optimize memory usage
36
+ pipe.enable_xformers_memory_efficient_attention() # Optional, if xformers is installed
 
 
 
37
 
38
+ # Input schema
39
  class PromptRequest(BaseModel):
40
  prompt: str
41
  num_inference_steps: int = 20
42
  guidance_scale: float = 7.5
43
 
44
+ # Image generation endpoint
45
  @app.post("/generate-image/")
46
  def generate_image(req: PromptRequest):
47
  try:
48
+ # Generate image
49
+ image = pipe(
50
  prompt=req.prompt,
51
  num_inference_steps=req.num_inference_steps,
52
  guidance_scale=req.guidance_scale
53
  ).images[0]
54
 
55
+ # Convert to bytes
56
  img_bytes = io.BytesIO()
57
+ image.save(img_bytes, format="PNG")
58
  img_bytes.seek(0)
59
+
60
  return StreamingResponse(img_bytes, media_type="image/png")
61
  except Exception as e:
62
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
63
 
64
+ # Health check
65
  @app.get("/health/")
66
  def health_check():
67
  return {"status": "healthy"}
68
+
69
+ # Note: uvicorn.run is handled by Hugging Face Spaces automatically
70
+ # No need to run the server manually here