Upload handler.py
Browse files- handler.py +39 -6
handler.py
CHANGED
|
@@ -4,9 +4,11 @@ from typing import Any, Dict
|
|
| 4 |
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
|
| 5 |
from PIL import Image
|
| 6 |
import torch
|
| 7 |
-
from torchao.quantization import quantize_, autoquant
|
|
|
|
| 8 |
|
| 9 |
-
IS_COMPILE =
|
|
|
|
| 10 |
|
| 11 |
if IS_COMPILE:
|
| 12 |
import torch._dynamo
|
|
@@ -49,13 +51,44 @@ def load_pipeline_autoquant(repo_id: str, dtype: torch.dtype) -> Any:
|
|
| 49 |
pipe.to("cuda")
|
| 50 |
return pipe
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
class EndpointHandler:
|
| 53 |
def __init__(self, path=""):
|
| 54 |
repo_id = "camenduru/FLUX.1-dev-diffusers"
|
| 55 |
dtype = torch.bfloat16
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def __call__(self, data: Dict[str, Any]) -> Image.Image:
|
| 61 |
logger.info(f"Received incoming request with {data=}")
|
|
@@ -72,7 +105,7 @@ class EndpointHandler:
|
|
| 72 |
|
| 73 |
parameters = data.pop("parameters", {})
|
| 74 |
|
| 75 |
-
num_inference_steps = parameters.get("num_inference_steps", 28)
|
| 76 |
width = parameters.get("width", 1024)
|
| 77 |
height = parameters.get("height", 1024)
|
| 78 |
guidance_scale = parameters.get("guidance_scale", 3.5)
|
|
|
|
| 4 |
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
|
| 5 |
from PIL import Image
|
| 6 |
import torch
|
| 7 |
+
from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int8_weight
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
|
| 10 |
+
IS_COMPILE = False
|
| 11 |
+
IS_TURBO = True
|
| 12 |
|
| 13 |
if IS_COMPILE:
|
| 14 |
import torch._dynamo
|
|
|
|
| 51 |
pipe.to("cuda")
|
| 52 |
return pipe
|
| 53 |
|
| 54 |
+
def load_pipeline_turbo(repo_id: str, dtype: torch.dtype) -> Any:
|
| 55 |
+
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
|
| 56 |
+
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
|
| 57 |
+
pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
|
| 58 |
+
pipe.fuse_lora()
|
| 59 |
+
pipe.transformer.fuse_qkv_projections()
|
| 60 |
+
pipe.vae.fuse_qkv_projections()
|
| 61 |
+
quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
|
| 62 |
+
quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
|
| 63 |
+
pipe.to("cuda")
|
| 64 |
+
return pipe
|
| 65 |
+
|
| 66 |
+
def load_pipeline_turbo_compile(repo_id: str, dtype: torch.dtype) -> Any:
|
| 67 |
+
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
|
| 68 |
+
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd")
|
| 69 |
+
pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
|
| 70 |
+
pipe.fuse_lora()
|
| 71 |
+
pipe.transformer.fuse_qkv_projections()
|
| 72 |
+
pipe.vae.fuse_qkv_projections()
|
| 73 |
+
quantize_(pipe.transformer, int8_dynamic_activation_int8_weight(), device="cuda")
|
| 74 |
+
quantize_(pipe.vae, int8_dynamic_activation_int8_weight(), device="cuda")
|
| 75 |
+
pipe.transformer.to(memory_format=torch.channels_last)
|
| 76 |
+
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
|
| 77 |
+
pipe.vae.to(memory_format=torch.channels_last)
|
| 78 |
+
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=False, dynamic=False, backend="inductor")
|
| 79 |
+
pipe.to("cuda")
|
| 80 |
+
return pipe
|
| 81 |
+
|
| 82 |
class EndpointHandler:
|
| 83 |
def __init__(self, path=""):
|
| 84 |
repo_id = "camenduru/FLUX.1-dev-diffusers"
|
| 85 |
dtype = torch.bfloat16
|
| 86 |
+
if IS_COMPILE:
|
| 87 |
+
if IS_TURBO: self.pipeline = load_pipeline_turbo_compile(repo_id, dtype)
|
| 88 |
+
else: self.pipeline = load_pipeline_compile(repo_id, dtype)
|
| 89 |
+
else:
|
| 90 |
+
if IS_TURBO: self.pipeline = load_pipeline_turbo(repo_id, dtype)
|
| 91 |
+
else: self.pipeline = load_pipeline_stable(repo_id, dtype)
|
| 92 |
|
| 93 |
def __call__(self, data: Dict[str, Any]) -> Image.Image:
|
| 94 |
logger.info(f"Received incoming request with {data=}")
|
|
|
|
| 105 |
|
| 106 |
parameters = data.pop("parameters", {})
|
| 107 |
|
| 108 |
+
num_inference_steps = parameters.get("num_inference_steps", 8 if IS_TURBO else 28)
|
| 109 |
width = parameters.get("width", 1024)
|
| 110 |
height = parameters.get("height", 1024)
|
| 111 |
guidance_scale = parameters.get("guidance_scale", 3.5)
|