English
John6666 commited on
Commit
cf00167
·
verified ·
1 Parent(s): 8b97db1

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +63 -0
  2. requirements.txt +12 -0
handler.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict
3
+
4
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
+ from PIL.Image import Image
6
+ import torch
7
+
8
+ import torch._dynamo
9
+ torch._dynamo.config.suppress_errors = True
10
+
11
+ #from huggingface_inference_toolkit.logging import logger
12
+
13
+ def compile_pipeline(pipe):
14
+ pipe.transformer.to(memory_format=torch.channels_last)
15
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
16
+ return pipe
17
+
18
+ class EndpointHandler:
19
+ def __init__(self, **kwargs: Any) -> None: # type: ignore
20
+ is_compile = False
21
+ #repo_id = "camenduru/FLUX.1-dev-diffusers"
22
+ repo_id = "NoMoreCopyright/FLUX.1-dev-test"
23
+ dtype = torch.bfloat16
24
+ quantization_config = TorchAoConfig("int4dq")
25
+ vae = AutoencoderKL.from_pretrained(repo_id, subfolder="vae", torch_dtype=dtype)
26
+ #transformer = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=dtype).to("cuda")
27
+ self.pipeline = FluxPipeline.from_pretrained(repo_id, vae=vae, torch_dtype=dtype, quantization_config=quantization_config)
28
+ if is_compile: self.pipeline = compile_pipeline(self.pipeline)
29
+ self.pipeline.to("cuda")
30
+
31
+ @torch.inference_mode()
32
+ def __call__(self, data: Dict[str, Any]) -> Image:
33
+ #logger.info(f"Received incoming request with {data=}")
34
+
35
+ if "inputs" in data and isinstance(data["inputs"], str):
36
+ prompt = data.pop("inputs")
37
+ elif "prompt" in data and isinstance(data["prompt"], str):
38
+ prompt = data.pop("prompt")
39
+ else:
40
+ raise ValueError(
41
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
42
+ " prompt to use for the image generation, and it needs to be a non-empty string."
43
+ )
44
+
45
+ parameters = data.pop("parameters", {})
46
+
47
+ num_inference_steps = parameters.get("num_inference_steps", 30)
48
+ width = parameters.get("width", 1024)
49
+ height = parameters.get("height", 768)
50
+ guidance_scale = parameters.get("guidance_scale", 3.5)
51
+
52
+ # seed generator (seed cannot be provided as is but via a generator)
53
+ seed = parameters.get("seed", 0)
54
+ generator = torch.manual_seed(seed)
55
+
56
+ return self.pipeline( # type: ignore
57
+ prompt,
58
+ height=height,
59
+ width=width,
60
+ guidance_scale=guidance_scale,
61
+ num_inference_steps=num_inference_steps,
62
+ generator=generator,
63
+ ).images[0]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ torch
3
+ torchvision
4
+ torchao
5
+ diffusers
6
+ peft
7
+ accelerate
8
+ transformers
9
+ numpy
10
+ scipy
11
+ Pillow
12
+ triton