File size: 9,213 Bytes
adc21bf
 
8f3f345
c0f475e
8f3f345
c0f475e
fffb44f
8f3f345
c0f475e
 
 
 
dc57498
 
c0f475e
 
 
 
 
5dc6179
c0f475e
8f3f345
2c2e958
8f3f345
fa10854
8f3f345
dc57498
8f3f345
 
dc57498
 
 
8f3f345
 
dc57498
 
 
8f3f345
 
dc57498
 
 
8f3f345
 
dc57498
a7f1cad
2c2e958
fa10854
a7f1cad
 
 
 
8f3f345
a7f1cad
 
 
 
 
c0f475e
 
 
 
 
 
 
a7f1cad
fa10854
5a46abc
fa10854
2c2e958
a7f1cad
c0f475e
 
fa10854
5a46abc
 
 
 
 
 
 
 
 
2c2e958
fa10854
 
5a46abc
 
 
 
fa10854
 
5a46abc
 
 
 
 
fa10854
 
5a46abc
 
 
 
 
fa10854
 
5a46abc
 
 
 
 
 
 
c0f475e
fa10854
 
 
c0f475e
 
 
 
 
 
 
d45d065
 
 
c0f475e
2c2e958
fa10854
d45d065
 
 
c0f475e
 
fa10854
c0f475e
2c2e958
c0f475e
 
 
fa10854
c0f475e
 
 
 
fa10854
c0f475e
 
 
 
 
 
 
 
 
2c2e958
c0f475e
 
2c2e958
c0f475e
 
 
 
 
5a46abc
c0f475e
 
 
fa10854
c0f475e
2c2e958
c0f475e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa10854
d45d065
2c2e958
c0f475e
8f3f345
2c2e958
c0f475e
 
adc21bf
d45d065
 
 
c0f475e
fa10854
2c2e958
 
fa10854
c0f475e
a7f1cad
2c2e958
a7f1cad
2c2e958
 
8f3f345
c0f475e
a7f1cad
 
 
 
 
 
c0f475e
 
 
 
 
 
a7f1cad
 
c0f475e
2c2e958
a7f1cad
 
 
 
c0f475e
2c2e958
c0f475e
 
2c2e958
a7f1cad
 
2c2e958
a7f1cad
2c2e958
 
 
 
 
 
 
 
 
 
a7f1cad
 
 
 
 
 
2c2e958
a7f1cad
 
 
 
 
 
 
fa10854
a7f1cad
fa10854
a7f1cad
fa10854
a7f1cad
 
fa10854
2c2e958
a7f1cad
 
c0f475e
a7f1cad
 
 
 
c0f475e
2c2e958
a7f1cad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

import spaces
import os
import sys
import uuid
import shutil

import gradio as gr
import torch
from omegaconf import OmegaConf
from torchvision.io import write_video
from einops import rearrange
from huggingface_hub import snapshot_download

from pipeline import (
    CausalDiffusionInferencePipeline,
    CausalInferencePipeline,
)
from utils.dataset import TextDataset
from utils.misc import set_seed
from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller


# -------------------------------------------------------------------
# Download checkpoints once when the Space starts
# -------------------------------------------------------------------
snapshot_download(
    repo_id="Wan-AI/Wan2.1-T2V-1.3B",
    local_dir="./checkpoints/Wan2.1-T2V-1.3B",
)

snapshot_download(
    repo_id="KlingTeam/VideoReward",
    local_dir="./checkpoints/Videoreward",
)

snapshot_download(
    repo_id="gdhe17/Self-Forcing",
    local_dir="./checkpoints/ode_init.pt",
)

snapshot_download(
    repo_id="JaydenLu666/Reward-Forcing-T2V-1.3B",
    local_dir="./checkpoints/Reward-Forcing-T2V-1.3B",
)


# === Paths ===
CONFIG_PATH = "configs/reward_forcing.yaml"
CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"

PROMPT_DIR = "prompts/gradio_inputs"
OUTPUT_ROOT = "videos"

os.makedirs(PROMPT_DIR, exist_ok=True)
os.makedirs(OUTPUT_ROOT, exist_ok=True)


def reward_forcing_inference(
    prompt_txt_path: str,
    num_output_frames: int,
    use_ema: bool,
    output_root: str,
    progress: gr.Progress,
):
    """
    Inline / simplified version of inference.py:
    - single GPU
    - text-to-video only
    - one .txt file = N prompts, but returns only the first generated video
    """
    logs = ""

    # --------------------- Device & randomness ---------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(0)

    free_vram = get_cuda_free_memory_gb(device)
    logs += f"Free VRAM {free_vram} GB\n"
    low_memory = free_vram < 40

    torch.set_grad_enabled(False)

    # --------------------- Phase 1: model & config init ---------------------
    progress(0.05, desc="Init: loading config")
    logs += "Loading config...\n"
    config = OmegaConf.load(CONFIG_PATH)
    default_config = OmegaConf.load("configs/default_config.yaml")
    config = OmegaConf.merge(default_config, config)

    progress(0.15, desc="Init: creating pipeline")
    logs += "Creating pipeline...\n"
    if hasattr(config, "denoising_step_list"):
        pipeline = CausalInferencePipeline(config, device=device)
    else:
        pipeline = CausalDiffusionInferencePipeline(config, device=device)

    progress(0.35, desc="Init: loading checkpoint")
    logs += "Loading checkpoint weights...\n"
    state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
    pipeline.generator.load_state_dict(state_dict)
    checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
    checkpoint_step = checkpoint_step.split("_")[-1]

    progress(0.55, desc="Init: moving model to device")
    logs += "Moving model to device...\n"
    pipeline = pipeline.to(dtype=torch.bfloat16)
    if low_memory:
        DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
    else:
        pipeline.text_encoder.to(device=device)
    pipeline.generator.to(device=device)
    pipeline.vae.to(device=device)

    # --------------------- Dataset setup ---------------------
    progress(0.65, desc="Preparing dataset")
    logs += "Preparing dataset (TextDataset)...\n"
    dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
    num_prompts = len(dataset)
    logs += f"Number of prompts: {num_prompts}\n"

    from torch.utils.data import DataLoader, SequentialSampler

    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(
        dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False
    )

    # --------------------- Clean output folder ---------------------
    progress(0.7, desc="Cleaning output folder")
    output_folder = os.path.join(
        output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
    )
    shutil.rmtree(output_folder, ignore_errors=True)
    os.makedirs(output_folder, exist_ok=True)
    logs += f"Output directory: {output_folder}\n"

    # --------------------- Phase 2: inference loop ---------------------
    for i, batch_data in progress.tqdm(
        enumerate(dataloader),
        total=num_prompts,
        desc="Video generation",
        unit="prompt",
    ):
        idx = batch_data["idx"].item()

        # Unpack dataset batch
        if isinstance(batch_data, dict):
            batch = batch_data
        elif isinstance(batch_data, list):
            batch = batch_data[0]
        else:
            batch = batch_data

        all_video = []

        # TEXT-TO-VIDEO only
        prompt = batch["prompts"][0]
        extended_prompt = batch.get("extended_prompts", [None])[0]
        prompts = [extended_prompt] if extended_prompt else [prompt]

        initial_latent = None

        sampled_noise = torch.randn(
            [1, num_output_frames, 16, 60, 104],
            device=device,
            dtype=torch.bfloat16,
        )

        logs += f"Generating for prompt: {prompt[:80]}...\n"

        # WAN2 inference
        video, latents = pipeline.inference(
            noise=sampled_noise,
            text_prompts=prompts,
            return_latents=True,
            initial_latent=initial_latent,
            low_memory=low_memory,
        )

        current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
        all_video.append(current_video)
        video = 255.0 * torch.cat(all_video, dim=1)

        pipeline.vae.model.clear_cache()

        if idx < num_prompts:
            model = "regular" if not use_ema else "ema"
            safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
            output_path = os.path.join(output_folder, f"{safe_name}.mp4")
            write_video(output_path, video[0], fps=16)
            logs += f"Saved video: {output_path}\n"

            progress(1.0, desc="Done")
            return output_path, logs

    logs += "[WARN] No video generated.\n"
    return None, logs

@spaces.GPU(duration=200)
def gradio_generate(
    prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)
):
    """
    Triggered by Gradio:
    - writes prompt to a .txt file
    - performs inference
    - returns video + logs
    """
    if not prompt or not prompt.strip():
        raise gr.Error("Please enter a text prompt πŸ™‚")

    # Duration β†’ number of frames
    num_output_frames = 21 if duration == "5s (21 frames)" else 120

    os.makedirs(PROMPT_DIR, exist_ok=True)

    prompt_id = uuid.uuid4().hex[:8]
    prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
    with open(prompt_path, "w", encoding="utf-8") as f:
        f.write(prompt.strip() + "\n")

    video_path, logs = reward_forcing_inference(
        prompt_txt_path=prompt_path,
        num_output_frames=num_output_frames,
        use_ema=use_ema,
        output_root=OUTPUT_ROOT,
        progress=progress,
    )

    if video_path is None or not os.path.exists(video_path):
        raise gr.Error("No video generated. Check logs for details.")

    return video_path, logs


# -------------------------------------------------------------------
# Gradio UI β€” updated title + intro text
# -------------------------------------------------------------------

with gr.Blocks(title="Reward Forcing β€” Text-to-Video Demo") as demo:
    gr.Markdown(
        """
        # 🎬 Reward Forcing β€” Text-to-Video Demo

        Generate short videos from text prompts using a model trained with the **Reward Forcing** method.

        Reward Forcing is a recent research technique that improves how well a video model follows a written description  
        by guiding training with learned reward signals. You can learn more here:  
        https://reward-forcing.github.io

        πŸ‘‰ Type a prompt, click **Generate**, and the video will appear below.  
        Longer and more detailed prompts usually produce better results.

        > ⏳ The first run may take a little longer while the model loads β€” generation is faster afterwards.
        """
    )

    with gr.Row():
        prompt_in = gr.Textbox(
            label="Prompt",
            placeholder="A cinematic shot of late-summer wheat fields moving in the wind...",
            lines=4,
        )

    with gr.Row():
        duration = gr.Radio(
            ["5s (21 frames)", "30s (120 frames)"],
            value="5s (21 frames)",
            label="Duration",
        )
        use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)")

    generate_btn = gr.Button("πŸš€ Generate Video", variant="primary")

    with gr.Row():
        video_out = gr.Video(label="Generated Video")
    logs_out = gr.Textbox(label="Logs", lines=12, interactive=False)

    generate_btn.click(
        fn=gradio_generate,
        inputs=[prompt_in, duration, use_ema],
        outputs=[video_out, logs_out],
    )

demo.queue()

if __name__ == "__main__":
    demo.launch()