Reward-Forcing / app_wip.py
fffiloni's picture
use ZeroGPU when available
adc21bf verified
import spaces
import os
import sys
import uuid
import shutil
import gradio as gr
import torch
from omegaconf import OmegaConf
from torchvision.io import write_video
from einops import rearrange
from huggingface_hub import snapshot_download
from pipeline import (
CausalDiffusionInferencePipeline,
CausalInferencePipeline,
)
from utils.dataset import TextDataset
from utils.misc import set_seed
from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller
# -------------------------------------------------------------------
# Download checkpoints once when the Space starts
# -------------------------------------------------------------------
snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
local_dir="./checkpoints/Wan2.1-T2V-1.3B",
)
snapshot_download(
repo_id="KlingTeam/VideoReward",
local_dir="./checkpoints/Videoreward",
)
snapshot_download(
repo_id="gdhe17/Self-Forcing",
local_dir="./checkpoints/ode_init.pt",
)
snapshot_download(
repo_id="JaydenLu666/Reward-Forcing-T2V-1.3B",
local_dir="./checkpoints/Reward-Forcing-T2V-1.3B",
)
# === Paths ===
CONFIG_PATH = "configs/reward_forcing.yaml"
CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
PROMPT_DIR = "prompts/gradio_inputs"
OUTPUT_ROOT = "videos"
os.makedirs(PROMPT_DIR, exist_ok=True)
os.makedirs(OUTPUT_ROOT, exist_ok=True)
def reward_forcing_inference(
prompt_txt_path: str,
num_output_frames: int,
use_ema: bool,
output_root: str,
progress: gr.Progress,
):
"""
Inline / simplified version of inference.py:
- single GPU
- text-to-video only
- one .txt file = N prompts, but returns only the first generated video
"""
logs = ""
# --------------------- Device & randomness ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(0)
free_vram = get_cuda_free_memory_gb(device)
logs += f"Free VRAM {free_vram} GB\n"
low_memory = free_vram < 40
torch.set_grad_enabled(False)
# --------------------- Phase 1: model & config init ---------------------
progress(0.05, desc="Init: loading config")
logs += "Loading config...\n"
config = OmegaConf.load(CONFIG_PATH)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
progress(0.15, desc="Init: creating pipeline")
logs += "Creating pipeline...\n"
if hasattr(config, "denoising_step_list"):
pipeline = CausalInferencePipeline(config, device=device)
else:
pipeline = CausalDiffusionInferencePipeline(config, device=device)
progress(0.35, desc="Init: loading checkpoint")
logs += "Loading checkpoint weights...\n"
state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
pipeline.generator.load_state_dict(state_dict)
checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
checkpoint_step = checkpoint_step.split("_")[-1]
progress(0.55, desc="Init: moving model to device")
logs += "Moving model to device...\n"
pipeline = pipeline.to(dtype=torch.bfloat16)
if low_memory:
DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
else:
pipeline.text_encoder.to(device=device)
pipeline.generator.to(device=device)
pipeline.vae.to(device=device)
# --------------------- Dataset setup ---------------------
progress(0.65, desc="Preparing dataset")
logs += "Preparing dataset (TextDataset)...\n"
dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
num_prompts = len(dataset)
logs += f"Number of prompts: {num_prompts}\n"
from torch.utils.data import DataLoader, SequentialSampler
sampler = SequentialSampler(dataset)
dataloader = DataLoader(
dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False
)
# --------------------- Clean output folder ---------------------
progress(0.7, desc="Cleaning output folder")
output_folder = os.path.join(
output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
)
shutil.rmtree(output_folder, ignore_errors=True)
os.makedirs(output_folder, exist_ok=True)
logs += f"Output directory: {output_folder}\n"
# --------------------- Phase 2: inference loop ---------------------
for i, batch_data in progress.tqdm(
enumerate(dataloader),
total=num_prompts,
desc="Video generation",
unit="prompt",
):
idx = batch_data["idx"].item()
# Unpack dataset batch
if isinstance(batch_data, dict):
batch = batch_data
elif isinstance(batch_data, list):
batch = batch_data[0]
else:
batch = batch_data
all_video = []
# TEXT-TO-VIDEO only
prompt = batch["prompts"][0]
extended_prompt = batch.get("extended_prompts", [None])[0]
prompts = [extended_prompt] if extended_prompt else [prompt]
initial_latent = None
sampled_noise = torch.randn(
[1, num_output_frames, 16, 60, 104],
device=device,
dtype=torch.bfloat16,
)
logs += f"Generating for prompt: {prompt[:80]}...\n"
# WAN2 inference
video, latents = pipeline.inference(
noise=sampled_noise,
text_prompts=prompts,
return_latents=True,
initial_latent=initial_latent,
low_memory=low_memory,
)
current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
all_video.append(current_video)
video = 255.0 * torch.cat(all_video, dim=1)
pipeline.vae.model.clear_cache()
if idx < num_prompts:
model = "regular" if not use_ema else "ema"
safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
output_path = os.path.join(output_folder, f"{safe_name}.mp4")
write_video(output_path, video[0], fps=16)
logs += f"Saved video: {output_path}\n"
progress(1.0, desc="Done")
return output_path, logs
logs += "[WARN] No video generated.\n"
return None, logs
@spaces.GPU(duration=200)
def gradio_generate(
prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)
):
"""
Triggered by Gradio:
- writes prompt to a .txt file
- performs inference
- returns video + logs
"""
if not prompt or not prompt.strip():
raise gr.Error("Please enter a text prompt πŸ™‚")
# Duration β†’ number of frames
num_output_frames = 21 if duration == "5s (21 frames)" else 120
os.makedirs(PROMPT_DIR, exist_ok=True)
prompt_id = uuid.uuid4().hex[:8]
prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
with open(prompt_path, "w", encoding="utf-8") as f:
f.write(prompt.strip() + "\n")
video_path, logs = reward_forcing_inference(
prompt_txt_path=prompt_path,
num_output_frames=num_output_frames,
use_ema=use_ema,
output_root=OUTPUT_ROOT,
progress=progress,
)
if video_path is None or not os.path.exists(video_path):
raise gr.Error("No video generated. Check logs for details.")
return video_path, logs
# -------------------------------------------------------------------
# Gradio UI β€” updated title + intro text
# -------------------------------------------------------------------
with gr.Blocks(title="Reward Forcing β€” Text-to-Video Demo") as demo:
gr.Markdown(
"""
# 🎬 Reward Forcing β€” Text-to-Video Demo
Generate short videos from text prompts using a model trained with the **Reward Forcing** method.
Reward Forcing is a recent research technique that improves how well a video model follows a written description
by guiding training with learned reward signals. You can learn more here:
https://reward-forcing.github.io
πŸ‘‰ Type a prompt, click **Generate**, and the video will appear below.
Longer and more detailed prompts usually produce better results.
> ⏳ The first run may take a little longer while the model loads β€” generation is faster afterwards.
"""
)
with gr.Row():
prompt_in = gr.Textbox(
label="Prompt",
placeholder="A cinematic shot of late-summer wheat fields moving in the wind...",
lines=4,
)
with gr.Row():
duration = gr.Radio(
["5s (21 frames)", "30s (120 frames)"],
value="5s (21 frames)",
label="Duration",
)
use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)")
generate_btn = gr.Button("πŸš€ Generate Video", variant="primary")
with gr.Row():
video_out = gr.Video(label="Generated Video")
logs_out = gr.Textbox(label="Logs", lines=12, interactive=False)
generate_btn.click(
fn=gradio_generate,
inputs=[prompt_in, duration, use_ema],
outputs=[video_out, logs_out],
)
demo.queue()
if __name__ == "__main__":
demo.launch()