fffiloni commited on
Commit
fa10854
·
verified ·
1 Parent(s): 5a46abc

English version

Browse files
Files changed (1) hide show
  1. app_wip.py +55 -51
app_wip.py CHANGED
@@ -15,11 +15,11 @@ from pipeline import (
15
  CausalInferencePipeline,
16
  )
17
  from utils.dataset import TextDataset
18
- from utils.misc import set_seed
19
  from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller
20
 
21
  # -------------------------------------------------------------------
22
- # Téléchargement des checkpoints (une fois au démarrage du Space)
23
  # -------------------------------------------------------------------
24
  snapshot_download(
25
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
@@ -41,7 +41,7 @@ snapshot_download(
41
  local_dir="./checkpoints/Reward-Forcing-T2V-1.3B",
42
  )
43
 
44
- # === Chemins ===
45
  CONFIG_PATH = "configs/reward_forcing.yaml"
46
  CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
47
 
@@ -60,14 +60,14 @@ def reward_forcing_inference(
60
  progress: gr.Progress,
61
  ):
62
  """
63
- Version inline / simplifiée de inference.py :
64
  - single GPU
65
- - T2V uniquement
66
- - 1 fichier .txt = n prompts (mais on retourne la 1ère vidéo)
67
  """
68
  logs = ""
69
 
70
- # --------------------- Device & seed ---------------------
71
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
  set_seed(0)
73
 
@@ -77,29 +77,31 @@ def reward_forcing_inference(
77
 
78
  torch.set_grad_enabled(False)
79
 
80
- # --------------------- Phase 1 : init modèle / config ---------------------
81
- progress(0.05, desc="Initialisation : chargement de la config")
82
- logs += "Chargement de la config...\n"
83
  config = OmegaConf.load(CONFIG_PATH)
84
  default_config = OmegaConf.load("configs/default_config.yaml")
85
  config = OmegaConf.merge(default_config, config)
86
 
87
- progress(0.15, desc="Initialisation : création de la pipeline")
88
- logs += "Initialisation de la pipeline...\n"
89
  if hasattr(config, "denoising_step_list"):
 
90
  pipeline = CausalInferencePipeline(config, device=device)
91
  else:
 
92
  pipeline = CausalDiffusionInferencePipeline(config, device=device)
93
 
94
- progress(0.35, desc="Initialisation : chargement du checkpoint")
95
- logs += "Chargement des poids du checkpoint...\n"
96
  state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
97
  pipeline.generator.load_state_dict(state_dict)
98
  checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
99
  checkpoint_step = checkpoint_step.split("_")[-1]
100
 
101
- progress(0.55, desc="Initialisation : placement sur le device")
102
- logs += "Placement du modèle sur le device...\n"
103
  pipeline = pipeline.to(dtype=torch.bfloat16)
104
  if low_memory:
105
  DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
@@ -108,9 +110,9 @@ def reward_forcing_inference(
108
  pipeline.generator.to(device=device)
109
  pipeline.vae.to(device=device)
110
 
111
- # --------------------- Dataset / DataLoader ---------------------
112
- progress(0.65, desc="Préparation du dataset")
113
- logs += "Préparation du dataset (TextDataset)...\n"
114
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
115
  num_prompts = len(dataset)
116
  logs += f"Number of prompts: {num_prompts}\n"
@@ -122,26 +124,26 @@ def reward_forcing_inference(
122
  dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False
123
  )
124
 
125
- # --------------------- Output folder (on le vide) ---------------------
126
- progress(0.7, desc="Nettoyage du dossier de sortie")
127
  output_folder = os.path.join(
128
  output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
129
  )
130
  shutil.rmtree(output_folder, ignore_errors=True)
131
  os.makedirs(output_folder, exist_ok=True)
132
- logs += f"Dossier de sortie: {output_folder}\n"
133
 
134
- # --------------------- Phase 2 : boucle d'inférence ---------------------
135
- # Ici on peut utiliser progress.tqdm sur la boucle dataloader
136
  for i, batch_data in progress.tqdm(
137
  enumerate(dataloader),
138
  total=num_prompts,
139
- desc="Génération vidéo",
140
  unit="prompt",
141
  ):
142
  idx = batch_data["idx"].item()
143
 
144
- # Unpack batch
145
  if isinstance(batch_data, dict):
146
  batch = batch_data
147
  elif isinstance(batch_data, list):
@@ -151,7 +153,7 @@ def reward_forcing_inference(
151
 
152
  all_video = []
153
 
154
- # TEXT-TO-VIDEO uniquement (pas d'I2V ici)
155
  prompt = batch["prompts"][0]
156
  extended_prompt = batch.get("extended_prompts", [None])[0]
157
  if extended_prompt is not None:
@@ -161,15 +163,16 @@ def reward_forcing_inference(
161
 
162
  initial_latent = None
163
 
 
164
  sampled_noise = torch.randn(
165
  [1, num_output_frames, 16, 60, 104],
166
  device=device,
167
  dtype=torch.bfloat16,
168
  )
169
 
170
- logs += f"Génération pour le prompt: {prompt[:80]}...\n"
171
 
172
- # Appel au pipeline
173
  video, latents = pipeline.inference(
174
  noise=sampled_noise,
175
  text_prompts=prompts,
@@ -181,23 +184,24 @@ def reward_forcing_inference(
181
  current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
182
  all_video.append(current_video)
183
 
 
184
  video = 255.0 * torch.cat(all_video, dim=1)
185
 
186
- # Clear VAE cache
187
  pipeline.vae.model.clear_cache()
188
 
189
- # Sauvegarde vidéo (on retourne la 1ère vidéo)
190
  if idx < num_prompts:
191
  model = "regular" if not use_ema else "ema"
192
  safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
193
  output_path = os.path.join(output_folder, f"{safe_name}.mp4")
194
  write_video(output_path, video[0], fps=16)
195
- logs += f"Vidéo enregistrée: {output_path}\n"
196
 
197
- progress(1.0, desc="Terminé ✅")
198
  return output_path, logs
199
 
200
- logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
201
  return None, logs
202
 
203
 
@@ -205,15 +209,15 @@ def gradio_generate(
205
  prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)
206
  ):
207
  """
208
- Fonction appelée par Gradio :
209
- - écrit le prompt dans un .txt
210
- - appelle reward_forcing_inference
211
- - retourne (video_path, logs)
212
  """
213
  if not prompt or not prompt.strip():
214
- raise gr.Error("Veuillez entrer un prompt texte 🙂")
215
 
216
- # Durée -> frames
217
  if duration == "5s (21 frames)":
218
  num_output_frames = 21
219
  else:
@@ -236,15 +240,15 @@ def gradio_generate(
236
 
237
  if video_path is None or not os.path.exists(video_path):
238
  raise gr.Error(
239
- "Aucune vidéo trouvée après l'inférence.\n"
240
- "Regarde les logs ci-dessous pour voir ce qui a coincé."
241
  )
242
 
243
  return video_path, logs
244
 
245
 
246
  # -------------------------------------------------------------------
247
- # UI Gradio
248
  # -------------------------------------------------------------------
249
 
250
  with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
@@ -252,10 +256,10 @@ with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
252
  """
253
  # 🎬 Reward Forcing – Text-to-Video (inline)
254
 
255
- Cette version appelle directement la logique d'inférence en Python,
256
- ce qui permet à Gradio de suivre :
257
- - l'initialisation du modèle (via `progress(...)`)
258
- - la boucle de génération (via `progress.tqdm(...)`)
259
  """
260
  )
261
 
@@ -270,14 +274,14 @@ with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
270
  duration = gr.Radio(
271
  ["5s (21 frames)", "30s (120 frames)"],
272
  value="5s (21 frames)",
273
- label="Durée",
274
  )
275
- use_ema = gr.Checkbox(value=True, label="Utiliser les poids EMA (--use_ema)")
276
 
277
- generate_btn = gr.Button("🚀 Générer la vidéo", variant="primary")
278
 
279
  with gr.Row():
280
- video_out = gr.Video(label="Vidéo générée")
281
  logs_out = gr.Textbox(
282
  label="Logs",
283
  lines=12,
 
15
  CausalInferencePipeline,
16
  )
17
  from utils.dataset import TextDataset
18
+ from utils.misc import set_seed
19
  from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller
20
 
21
  # -------------------------------------------------------------------
22
+ # Download checkpoints once when the Space starts
23
  # -------------------------------------------------------------------
24
  snapshot_download(
25
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
 
41
  local_dir="./checkpoints/Reward-Forcing-T2V-1.3B",
42
  )
43
 
44
+ # === Paths ===
45
  CONFIG_PATH = "configs/reward_forcing.yaml"
46
  CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
47
 
 
60
  progress: gr.Progress,
61
  ):
62
  """
63
+ Inline / simplified version of inference.py:
64
  - single GPU
65
+ - text-to-video only
66
+ - one .txt file = N prompts, but we return only the first generated video
67
  """
68
  logs = ""
69
 
70
+ # --------------------- Device & randomness ---------------------
71
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
  set_seed(0)
73
 
 
77
 
78
  torch.set_grad_enabled(False)
79
 
80
+ # --------------------- Stage 1: model & config init ---------------------
81
+ progress(0.05, desc="Init: loading config")
82
+ logs += "Loading config...\n"
83
  config = OmegaConf.load(CONFIG_PATH)
84
  default_config = OmegaConf.load("configs/default_config.yaml")
85
  config = OmegaConf.merge(default_config, config)
86
 
87
+ progress(0.15, desc="Init: creating pipeline")
88
+ logs += "Creating pipeline...\n"
89
  if hasattr(config, "denoising_step_list"):
90
+ # few-step sampling pipeline
91
  pipeline = CausalInferencePipeline(config, device=device)
92
  else:
93
+ # full diffusion pipeline
94
  pipeline = CausalDiffusionInferencePipeline(config, device=device)
95
 
96
+ progress(0.35, desc="Init: loading checkpoint")
97
+ logs += "Loading checkpoint weights...\n"
98
  state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
99
  pipeline.generator.load_state_dict(state_dict)
100
  checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
101
  checkpoint_step = checkpoint_step.split("_")[-1]
102
 
103
+ progress(0.55, desc="Init: moving model to device")
104
+ logs += "Moving model to device...\n"
105
  pipeline = pipeline.to(dtype=torch.bfloat16)
106
  if low_memory:
107
  DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
 
110
  pipeline.generator.to(device=device)
111
  pipeline.vae.to(device=device)
112
 
113
+ # --------------------- Dataset setup ---------------------
114
+ progress(0.65, desc="Preparing dataset")
115
+ logs += "Preparing dataset (TextDataset)...\n"
116
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
117
  num_prompts = len(dataset)
118
  logs += f"Number of prompts: {num_prompts}\n"
 
124
  dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False
125
  )
126
 
127
+ # --------------------- Make a clean output directory ---------------------
128
+ progress(0.7, desc="Cleaning output folder")
129
  output_folder = os.path.join(
130
  output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
131
  )
132
  shutil.rmtree(output_folder, ignore_errors=True)
133
  os.makedirs(output_folder, exist_ok=True)
134
+ logs += f"Output directory: {output_folder}\n"
135
 
136
+ # --------------------- Stage 2: inference loop ---------------------
137
+ # Gradio can track tqdm progress on iterable loops
138
  for i, batch_data in progress.tqdm(
139
  enumerate(dataloader),
140
  total=num_prompts,
141
+ desc="Video generation",
142
  unit="prompt",
143
  ):
144
  idx = batch_data["idx"].item()
145
 
146
+ # Unpack dataset batch
147
  if isinstance(batch_data, dict):
148
  batch = batch_data
149
  elif isinstance(batch_data, list):
 
153
 
154
  all_video = []
155
 
156
+ # TEXT-TO-VIDEO only (no I2V here)
157
  prompt = batch["prompts"][0]
158
  extended_prompt = batch.get("extended_prompts", [None])[0]
159
  if extended_prompt is not None:
 
163
 
164
  initial_latent = None
165
 
166
+ # Noise tensor shape matches WAN2 expected latent dims
167
  sampled_noise = torch.randn(
168
  [1, num_output_frames, 16, 60, 104],
169
  device=device,
170
  dtype=torch.bfloat16,
171
  )
172
 
173
+ logs += f"Generating for prompt: {prompt[:80]}...\n"
174
 
175
+ # Run WAN inference
176
  video, latents = pipeline.inference(
177
  noise=sampled_noise,
178
  text_prompts=prompts,
 
184
  current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
185
  all_video.append(current_video)
186
 
187
+ # convert to uint8 *after* concatenation
188
  video = 255.0 * torch.cat(all_video, dim=1)
189
 
190
+ # free VAE cache between clips
191
  pipeline.vae.model.clear_cache()
192
 
193
+ # Save only the first video
194
  if idx < num_prompts:
195
  model = "regular" if not use_ema else "ema"
196
  safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
197
  output_path = os.path.join(output_folder, f"{safe_name}.mp4")
198
  write_video(output_path, video[0], fps=16)
199
+ logs += f"Saved video: {output_path}\n"
200
 
201
+ progress(1.0, desc="Done ✅")
202
  return output_path, logs
203
 
204
+ logs += "[WARN] No video generated in loop.\n"
205
  return None, logs
206
 
207
 
 
209
  prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)
210
  ):
211
  """
212
+ Triggered by Gradio:
213
+ - writes prompt to a temporary .txt file
214
+ - runs reward_forcing_inference
215
+ - returns video + logs
216
  """
217
  if not prompt or not prompt.strip():
218
+ raise gr.Error("Please type a text prompt 🙂")
219
 
220
+ # Duration -> number of latent timesteps
221
  if duration == "5s (21 frames)":
222
  num_output_frames = 21
223
  else:
 
240
 
241
  if video_path is None or not os.path.exists(video_path):
242
  raise gr.Error(
243
+ "No video generated.\n"
244
+ "Check the logs below for errors."
245
  )
246
 
247
  return video_path, logs
248
 
249
 
250
  # -------------------------------------------------------------------
251
+ # Gradio UI
252
  # -------------------------------------------------------------------
253
 
254
  with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
 
256
  """
257
  # 🎬 Reward Forcing – Text-to-Video (inline)
258
 
259
+ This version directly calls the inference logic in Python,
260
+ allowing Gradio to track:
261
+ - model initialization via `progress(...)`
262
+ - video generation progress via `progress.tqdm(...)`
263
  """
264
  )
265
 
 
274
  duration = gr.Radio(
275
  ["5s (21 frames)", "30s (120 frames)"],
276
  value="5s (21 frames)",
277
+ label="Duration",
278
  )
279
+ use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)")
280
 
281
+ generate_btn = gr.Button("🚀 Generate Video", variant="primary")
282
 
283
  with gr.Row():
284
+ video_out = gr.Video(label="Generated Video")
285
  logs_out = gr.Textbox(
286
  label="Logs",
287
  lines=12,