import spaces import os import sys import uuid import shutil import gradio as gr import torch from omegaconf import OmegaConf from torchvision.io import write_video from einops import rearrange from huggingface_hub import snapshot_download from pipeline import ( CausalDiffusionInferencePipeline, CausalInferencePipeline, ) from utils.dataset import TextDataset from utils.misc import set_seed from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller # ------------------------------------------------------------------- # Download checkpoints once when the Space starts # ------------------------------------------------------------------- snapshot_download( repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./checkpoints/Wan2.1-T2V-1.3B", ) snapshot_download( repo_id="KlingTeam/VideoReward", local_dir="./checkpoints/Videoreward", ) snapshot_download( repo_id="gdhe17/Self-Forcing", local_dir="./checkpoints/ode_init.pt", ) snapshot_download( repo_id="JaydenLu666/Reward-Forcing-T2V-1.3B", local_dir="./checkpoints/Reward-Forcing-T2V-1.3B", ) # === Paths === CONFIG_PATH = "configs/reward_forcing.yaml" CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt" PROMPT_DIR = "prompts/gradio_inputs" OUTPUT_ROOT = "videos" os.makedirs(PROMPT_DIR, exist_ok=True) os.makedirs(OUTPUT_ROOT, exist_ok=True) def reward_forcing_inference( prompt_txt_path: str, num_output_frames: int, use_ema: bool, output_root: str, progress: gr.Progress, ): """ Inline / simplified version of inference.py: - single GPU - text-to-video only - one .txt file = N prompts, but returns only the first generated video """ logs = "" # --------------------- Device & randomness --------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") set_seed(0) free_vram = get_cuda_free_memory_gb(device) logs += f"Free VRAM {free_vram} GB\n" low_memory = free_vram < 40 torch.set_grad_enabled(False) # --------------------- Phase 1: model & config init --------------------- progress(0.05, desc="Init: loading config") logs += "Loading config...\n" config = OmegaConf.load(CONFIG_PATH) default_config = OmegaConf.load("configs/default_config.yaml") config = OmegaConf.merge(default_config, config) progress(0.15, desc="Init: creating pipeline") logs += "Creating pipeline...\n" if hasattr(config, "denoising_step_list"): pipeline = CausalInferencePipeline(config, device=device) else: pipeline = CausalDiffusionInferencePipeline(config, device=device) progress(0.35, desc="Init: loading checkpoint") logs += "Loading checkpoint weights...\n" state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu") pipeline.generator.load_state_dict(state_dict) checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH)) checkpoint_step = checkpoint_step.split("_")[-1] progress(0.55, desc="Init: moving model to device") logs += "Moving model to device...\n" pipeline = pipeline.to(dtype=torch.bfloat16) if low_memory: DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device) else: pipeline.text_encoder.to(device=device) pipeline.generator.to(device=device) pipeline.vae.to(device=device) # --------------------- Dataset setup --------------------- progress(0.65, desc="Preparing dataset") logs += "Preparing dataset (TextDataset)...\n" dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None) num_prompts = len(dataset) logs += f"Number of prompts: {num_prompts}\n" from torch.utils.data import DataLoader, SequentialSampler sampler = SequentialSampler(dataset) dataloader = DataLoader( dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False ) # --------------------- Clean output folder --------------------- progress(0.7, desc="Cleaning output folder") output_folder = os.path.join( output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step ) shutil.rmtree(output_folder, ignore_errors=True) os.makedirs(output_folder, exist_ok=True) logs += f"Output directory: {output_folder}\n" # --------------------- Phase 2: inference loop --------------------- for i, batch_data in progress.tqdm( enumerate(dataloader), total=num_prompts, desc="Video generation", unit="prompt", ): idx = batch_data["idx"].item() # Unpack dataset batch if isinstance(batch_data, dict): batch = batch_data elif isinstance(batch_data, list): batch = batch_data[0] else: batch = batch_data all_video = [] # TEXT-TO-VIDEO only prompt = batch["prompts"][0] extended_prompt = batch.get("extended_prompts", [None])[0] prompts = [extended_prompt] if extended_prompt else [prompt] initial_latent = None sampled_noise = torch.randn( [1, num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16, ) logs += f"Generating for prompt: {prompt[:80]}...\n" # WAN2 inference video, latents = pipeline.inference( noise=sampled_noise, text_prompts=prompts, return_latents=True, initial_latent=initial_latent, low_memory=low_memory, ) current_video = rearrange(video, "b t c h w -> b t h w c").cpu() all_video.append(current_video) video = 255.0 * torch.cat(all_video, dim=1) pipeline.vae.model.clear_cache() if idx < num_prompts: model = "regular" if not use_ema else "ema" safe_name = prompt[:50].replace("/", "_").replace("\\", "_") output_path = os.path.join(output_folder, f"{safe_name}.mp4") write_video(output_path, video[0], fps=16) logs += f"Saved video: {output_path}\n" progress(1.0, desc="Done") return output_path, logs logs += "[WARN] No video generated.\n" return None, logs @spaces.GPU(duration=200) def gradio_generate( prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True) ): """ Triggered by Gradio: - writes prompt to a .txt file - performs inference - returns video + logs """ if not prompt or not prompt.strip(): raise gr.Error("Please enter a text prompt 🙂") # Duration → number of frames num_output_frames = 21 if duration == "5s (21 frames)" else 120 os.makedirs(PROMPT_DIR, exist_ok=True) prompt_id = uuid.uuid4().hex[:8] prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt") with open(prompt_path, "w", encoding="utf-8") as f: f.write(prompt.strip() + "\n") video_path, logs = reward_forcing_inference( prompt_txt_path=prompt_path, num_output_frames=num_output_frames, use_ema=use_ema, output_root=OUTPUT_ROOT, progress=progress, ) if video_path is None or not os.path.exists(video_path): raise gr.Error("No video generated. Check logs for details.") return video_path, logs # ------------------------------------------------------------------- # Gradio UI — updated title + intro text # ------------------------------------------------------------------- with gr.Blocks(title="Reward Forcing — Text-to-Video Demo") as demo: gr.Markdown( """ # 🎬 Reward Forcing — Text-to-Video Demo Generate short videos from text prompts using a model trained with the **Reward Forcing** method. Reward Forcing is a recent research technique that improves how well a video model follows a written description by guiding training with learned reward signals. You can learn more here: https://reward-forcing.github.io 👉 Type a prompt, click **Generate**, and the video will appear below. Longer and more detailed prompts usually produce better results. > ⏳ The first run may take a little longer while the model loads — generation is faster afterwards. """ ) with gr.Row(): prompt_in = gr.Textbox( label="Prompt", placeholder="A cinematic shot of late-summer wheat fields moving in the wind...", lines=4, ) with gr.Row(): duration = gr.Radio( ["5s (21 frames)", "30s (120 frames)"], value="5s (21 frames)", label="Duration", ) use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)") generate_btn = gr.Button("🚀 Generate Video", variant="primary") with gr.Row(): video_out = gr.Video(label="Generated Video") logs_out = gr.Textbox(label="Logs", lines=12, interactive=False) generate_btn.click( fn=gradio_generate, inputs=[prompt_in, duration, use_ema], outputs=[video_out, logs_out], ) demo.queue() if __name__ == "__main__": demo.launch()