English
John6666 commited on
Commit
35bc30c
·
verified ·
1 Parent(s): f7b87cf

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -6
handler.py CHANGED
@@ -4,9 +4,11 @@ from typing import Any, Dict
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
7
- from torchao.quantization import quantize_, autoquant
 
8
 
9
- IS_COMPILE = True
 
10
 
11
  if IS_COMPILE:
12
  import torch._dynamo
@@ -49,13 +51,44 @@ def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
49
  pipe.to("cuda")
50
  return pipe
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  class EndpointHandler:
53
  def __init__(self, path=""):
54
  repo_id = "camenduru/FLUX.1-dev-diffusers"
55
  dtype = torch.bfloat16
56
- self.pipeline = load_pipeline_autoquant(repo_id, dtype)
57
- #if IS_COMPILE: self.pipeline = load_pipeline_compile(repo_id, dtype)
58
- #else: self.pipeline = load_pipeline_stable(repo_id, dtype)
 
 
 
59
 
60
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
61
  logger.info(f"Received incoming request with {data=}")
@@ -72,7 +105,7 @@ class EndpointHandler:
72
 
73
  parameters = data.pop("parameters", {})
74
 
75
- num_inference_steps = parameters.get("num_inference_steps", 28)
76
  width = parameters.get("width", 1024)
77
  height = parameters.get("height", 1024)
78
  guidance_scale = parameters.get("guidance_scale", 3.5)
 
4
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
5
  from PIL import Image
6
  import torch
7
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight
8
+ from huggingface_hub import hf_hub_download
9
 
10
+ IS_COMPILE = False
11
+ IS_TURBO = True
12
 
13
  if IS_COMPILE:
14
  import torch._dynamo
 
51
  pipe.to("cuda")
52
  return pipe
53
 
54
+ def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
55
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
56
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
57
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
58
+ pipe.fuse_lora()
59
+ pipe.transformer.fuse_qkv_projections()
60
+ pipe.vae.fuse_qkv_projections()
61
+ quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
62
+ quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
63
+ pipe.to("cuda")
64
+ return pipe
65
+
66
+ def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
67
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
68
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
69
+ pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
70
+ pipe.fuse_lora()
71
+ pipe.transformer.fuse_qkv_projections()
72
+ pipe.vae.fuse_qkv_projections()
73
+ quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
74
+ quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
75
+ pipe.transformer.to(memory_format=torch.channels_last)
76
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
77
+ pipe.vae.to(memory_format=torch.channels_last)
78
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
79
+ pipe.to("cuda")
80
+ return pipe
81
+
82
  class EndpointHandler:
83
  def __init__(self, path=""):
84
  repo_id = "camenduru/FLUX.1-dev-diffusers"
85
  dtype = torch.bfloat16
86
+ if IS_COMPILE:
87
+ if IS_TURBO: self.pipeline = load_pipeline_turbo_compile(repo_id, dtype)
88
+ else: self.pipeline = load_pipeline_compile(repo_id, dtype)
89
+ else:
90
+ if IS_TURBO: self.pipeline = load_pipeline_turbo(repo_id, dtype)
91
+ else: self.pipeline = load_pipeline_stable(repo_id, dtype)
92
 
93
  def __call__(self, data: Dict[str, Any]) -> Image.Image:
94
  logger.info(f"Received incoming request with {data=}")
 
105
 
106
  parameters = data.pop("parameters", {})
107
 
108
+ num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
109
  width = parameters.get("width", 1024)
110
  height = parameters.get("height", 1024)
111
  guidance_scale = parameters.get("guidance_scale", 3.5)