|
|
""" |
|
|
Lumina 2.0 ZeroGPU demo – Hugging Face Spaces |
|
|
Run locally with: python app.py |
|
|
Push to HF and pick the “ZeroGPU” hardware tier. |
|
|
""" |
|
|
import os, gc, json, random, time |
|
|
import numpy as np |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from torchvision.transforms.functional import to_pil_image |
|
|
from tqdm import tqdm |
|
|
from diffusers.models import AutoencoderKL |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
import functools |
|
|
|
|
|
import models |
|
|
from transport import Sampler, create_transport |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "OnomaAIResearch/Illustrious-Lumina-v0.03" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
CKPT_FILE = "consolidated_ema.00-of-01.pth" |
|
|
ARGS_FILE = "model_args.pth" |
|
|
|
|
|
VAE_TYPE = os.getenv("VAE_TYPE", "flux") |
|
|
PRECISION = os.getenv("PRECISION", "bf16") |
|
|
TEXT_ENCODER_MODEL = os.getenv("TEXT_ENCODER_MODEL", "google/gemma-2-2B") |
|
|
|
|
|
DTYPE = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[PRECISION] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = text_encoder = vae = model = sampler = transport = None |
|
|
train_args = cap_feat_dim = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_prompt(batch, enc, tok, dev, dtype): |
|
|
"""Temporarily moves the text‑encoder to *dev* to extract embeddings.""" |
|
|
captions = [(c if isinstance(c, str) else c[0]) for c in batch] |
|
|
|
|
|
enc.to(dev) |
|
|
with torch.no_grad(), torch.autocast(device_type=dev.split(":")[0], dtype=dtype): |
|
|
inputs = tok( |
|
|
captions, |
|
|
padding=True, |
|
|
pad_to_multiple_of=8, |
|
|
max_length=256, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
).to(dev) |
|
|
out = enc(**inputs, output_hidden_states=True).hidden_states[-2] |
|
|
enc.to("cpu"); gc.collect() |
|
|
return out.cpu(), inputs.attention_mask.cpu() |
|
|
|
|
|
|
|
|
def none_or_str(v): |
|
|
return None if v in (None, "None") else str(v) |
|
|
|
|
|
def load_models(): |
|
|
global tokenizer, text_encoder, vae, model, sampler, transport, train_args, cap_feat_dim |
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
ckpt_path = hf_hub_download(REPO_ID, filename=CKPT_FILE, token=HF_TOKEN) |
|
|
args_path = hf_hub_download(REPO_ID, filename=ARGS_FILE, token=HF_TOKEN) |
|
|
|
|
|
|
|
|
train_args = torch.load(args_path, map_location="cpu", weights_only=False) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", token=HF_TOKEN, padding_side="right") |
|
|
text_encoder = AutoModel.from_pretrained( |
|
|
"google/gemma-2-2b", torch_dtype=DTYPE, token=HF_TOKEN |
|
|
).eval().cpu() |
|
|
cap_feat_dim = text_encoder.config.hidden_size |
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-dev", |
|
|
subfolder="vae", |
|
|
token=HF_TOKEN, |
|
|
torch_dtype=DTYPE |
|
|
).eval().cpu() |
|
|
|
|
|
|
|
|
dit_cls = getattr(models, train_args.model) |
|
|
model = dit_cls(in_channels=16, qk_norm=getattr(train_args, "qk_norm", True), |
|
|
cap_feat_dim=cap_feat_dim).eval().cpu() |
|
|
|
|
|
state = torch.load(ckpt_path, map_location="cpu") |
|
|
state = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state.items()} |
|
|
model.load_state_dict(state, strict=False) |
|
|
|
|
|
|
|
|
transport = create_transport("Linear", "velocity", None, None, None) |
|
|
sampler = Sampler(transport) |
|
|
|
|
|
|
|
|
print("🔄 Loading models to CPU …") |
|
|
load_models() |
|
|
print("✅ Models loaded on CPU") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_image( |
|
|
prompt, negative_prompt, system_type, solver, resolution_str, |
|
|
guidance_scale, num_steps, seed, time_shifting_factor, t_shift, |
|
|
atol=1e-6, rtol=1e-3, |
|
|
): |
|
|
"""Runs every time a user clicks “Generate”. |
|
|
ZeroGPU will attach an A100; GPU is released on return.""" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = DTYPE if device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
seed = random.randint(0, 2**32 - 1) if int(seed) == -1 else int(seed) |
|
|
torch.manual_seed(seed); np.random.seed(seed); random.seed(seed) |
|
|
|
|
|
|
|
|
sys_prompts = { |
|
|
"align": "You are an assistant designed to generate high-quality images with the highest degree of image‑text alignment based on textual prompts. <Prompt Start> ", |
|
|
"base": "You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> ", |
|
|
"aesthetics": "You are an assistant designed to generate high-quality images with highest degree of aesthetics based on user prompts. <Prompt Start> ", |
|
|
"real": "You are an assistant designed to generate superior images with the superior degree of image‑text alignment based on textual prompts or user prompts. <Prompt Start> ", |
|
|
"4grid": "You are an assistant designed to generate four high-quality images with highest degree of aesthetics arranged in 2x2 grids based on user prompts. <Prompt Start> ", |
|
|
"tags": "You are an assistant designed to generate high-quality images based on user prompts based on danbooru tags. <Prompt Start> ", |
|
|
"empty": "", |
|
|
} |
|
|
full_prompt = sys_prompts.get(system_type, sys_prompts["base"]) + prompt |
|
|
full_neg = (sys_prompts.get(system_type, "") + negative_prompt) if negative_prompt else "" |
|
|
|
|
|
|
|
|
w, h = map(int, resolution_str.split("x")); lat_w, lat_h = w // 8, h // 8 |
|
|
|
|
|
|
|
|
cap_feats_cpu, cap_mask_cpu = encode_prompt( |
|
|
[full_prompt, full_neg], text_encoder, tokenizer, device, dtype |
|
|
) |
|
|
|
|
|
|
|
|
model.to(device) |
|
|
z = torch.randn([1, 16, lat_h, lat_w], device=device, dtype=dtype).repeat(2,1,1,1) |
|
|
model_kwargs = dict( |
|
|
cap_feats = cap_feats_cpu.to(device, dtype=dtype), |
|
|
cap_mask = cap_mask_cpu.to(device), |
|
|
cfg_scale = guidance_scale, |
|
|
); del cap_feats_cpu, cap_mask_cpu |
|
|
|
|
|
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=dtype): |
|
|
if solver == "dpm": |
|
|
_sampler = Sampler(create_transport("Linear", "velocity")) |
|
|
samples = _sampler.sample_dpm( |
|
|
model.forward_with_cfg, model_kwargs=model_kwargs |
|
|
)(z, steps=num_steps, order=2, |
|
|
skip_type="time_uniform_flow", method="multistep", |
|
|
flow_shift=time_shifting_factor) |
|
|
else: |
|
|
samples = sampler.sample_ode( |
|
|
sampling_method=solver, num_steps=num_steps, |
|
|
atol=atol, rtol=rtol, time_shifting_factor=t_shift |
|
|
)(z, model.forward_with_cfg, **model_kwargs)[-1] |
|
|
samples = samples[:1] |
|
|
|
|
|
|
|
|
vae.to(device) |
|
|
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=dtype): |
|
|
sf, sh = vae.config.scaling_factor, vae.config.shift_factor |
|
|
img = vae.decode(samples / sf + sh)[0] |
|
|
vae.to("cpu"); model.to("cpu"); torch.cuda.empty_cache(); gc.collect() |
|
|
|
|
|
img = ((img.cpu() + 1) / 2).clamp(0,1) |
|
|
pil = to_pil_image(img[0].float()) |
|
|
return pil, f"Seed {seed}", seed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Lumina 2.0 (ZeroGPU)") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt = gr.Textbox(label="Prompt", value="1girl, kita ikuyo, bocchi the rock!, solo, backlighting, blurry, depth of field, bloom, light particles, transparent, blurry foreground, indoors, upper body, red hair, yellow flower, school uniform, white shirt, hair between eyes, lily \(flower\), floating hair, serafuku, chromatic aberration, white lily, green eyes, looking at viewer, red neckerchief, flower, pink flower, sunlight, day, neckerchief, sailor collar, grey sailor collar, white flower, lens flare abuse, medium hair, holding, closed mouth, one side up, long sleeves, holding bouquet, arms at sides, light smile, shirt, blurry background, bouquet") |
|
|
negative_prompt = gr.Textbox(label="Negative prompt", value="mutated, worst quality, blurry, bad anatomy, bad hands") |
|
|
system_type = gr.Dropdown(choices=["align","base","aesthetics","real","4grid","tags","empty"], |
|
|
value="tags", label="System prompt") |
|
|
resolution = gr.Dropdown(choices=["1024x1024","1280x768","768x1280","1536x1024","1024x1536"], |
|
|
value="1024x1024", label="Resolution") |
|
|
solver = gr.Dropdown(choices=["dpm","euler","midpoint","heun","rk4"], |
|
|
value="euler", label="Solver") |
|
|
run_btn = gr.Button("Generate", variant="primary") |
|
|
with gr.Column(scale=1): |
|
|
guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=4.0, label="CFG scale") |
|
|
num_steps = gr.Slider(10, 200, step=1, value=50, label="Steps") |
|
|
seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)") |
|
|
time_shifting_factor = gr.Slider(0.0, 10.0, step=0.1, value=1.0, label="Time‑shift (DPM)") |
|
|
t_shift = gr.Slider(0, 10, step=1, value=4, label="T‑shift (ODE)") |
|
|
|
|
|
out_img = gr.Image(label="Output") |
|
|
out_txt = gr.Textbox(label="Status") |
|
|
out_seed = gr.Number(label="Seed used", interactive=False) |
|
|
|
|
|
run_btn.click( |
|
|
generate_image, |
|
|
inputs=[prompt,negative_prompt,system_type,solver,resolution, |
|
|
guidance_scale,num_steps,seed,time_shifting_factor,t_shift], |
|
|
outputs=[out_img,out_txt,out_seed], |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|