fffiloni commited on
Commit
8f3f345
·
verified ·
1 Parent(s): 24616c9

Update app_wip.py

Browse files
Files changed (1) hide show
  1. app_wip.py +113 -60
app_wip.py CHANGED
@@ -1,63 +1,68 @@
1
  import sys
 
 
2
  import subprocess
 
3
 
4
- def ensure_flash_attn():
5
- try:
6
- import flash_attn # noqa: F401
7
- print("[init] flash-attn déjà installé")
8
- except Exception as e:
9
- print("[init] Installation de flash-attn (build from source)...", e, flush=True)
10
- subprocess.run(
11
- [
12
- sys.executable,
13
- "-m",
14
- "pip",
15
- "install",
16
- "flash-attn==2.7.4.post1",
17
- "--no-build-isolation",
18
- ],
19
- check=True,
20
- )
21
- import flash_attn # noqa: F401
22
- print("[init] flash-attn OK")
23
-
24
- #ensure_flash_attn()
25
-
26
  from huggingface_hub import snapshot_download
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  snapshot_download(
29
- repo_id='Wan-AI/Wan2.1-T2V-1.3B',
30
- local_dir='./checkpoints/Wan2.1-T2V-1.3B'
31
  )
32
 
33
  snapshot_download(
34
- repo_id='KlingTeam/VideoReward',
35
- local_dir='./checkpoints/Videoreward'
36
  )
37
 
38
  snapshot_download(
39
- repo_id='gdhe17/Self-Forcing',
40
- local_dir='./checkpoints/ode_init.pt'
41
  )
42
 
43
  snapshot_download(
44
- repo_id='JaydenLu666/Reward-Forcing-T2V-1.3B',
45
- local_dir='./checkpoints/Reward-Forcing-T2V-1.3B'
46
  )
47
 
48
- import os
49
- import uuid
50
- import subprocess
51
- from datetime import datetime
52
-
53
- import gradio as gr
54
-
55
- # === Chemins à adapter si besoin ===
56
  CONFIG_PATH = "configs/reward_forcing.yaml"
57
  CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
58
 
59
  PROMPT_DIR = "prompts/gradio_inputs"
60
- OUTPUT_ROOT = "videos/gradio_outputs"
 
61
 
62
  os.makedirs(PROMPT_DIR, exist_ok=True)
63
  os.makedirs(OUTPUT_ROOT, exist_ok=True)
@@ -67,40 +72,54 @@ def run_inference(prompt: str, duration: str, use_ema: bool):
67
  """
68
  1. Écrit le prompt dans un fichier .txt
69
  2. Lance inference.py avec ce fichier comme --data_path
70
- 3. Retourne le chemin de la première vidéo .mp4 générée + les logs
71
  """
 
 
72
  if not prompt or not prompt.strip():
73
  raise gr.Error("Veuillez entrer un prompt texte 🙂")
74
 
75
- # 1) On mappe la durée choisie num_output_frames
76
  if duration == "5s (21 frames)":
77
  num_output_frames = 21
 
78
  else: # "30s (120 frames)"
79
  num_output_frames = 120
 
 
 
80
 
81
  # 2) Fichier .txt temporaire pour le prompt
82
  prompt_id = uuid.uuid4().hex[:8]
83
  prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
84
 
85
  with open(prompt_path, "w", encoding="utf-8") as f:
86
- # TextDataset lit juste chaque ligne comme un prompt
87
  f.write(prompt.strip() + "\n")
88
 
89
- # 3) Dossier de sortie unique pour cette génération
90
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
91
- output_folder = os.path.join(OUTPUT_ROOT, f"{ts}_{prompt_id}")
92
- os.makedirs(output_folder, exist_ok=True)
 
 
93
 
94
  # 4) Commande inference.py
95
  cmd = [
96
- "python",
97
  "inference.py",
98
- "--num_output_frames", str(num_output_frames),
99
- "--config_path", CONFIG_PATH,
100
- "--checkpoint_path", CHECKPOINT_PATH,
101
- "--output_folder", output_folder,
102
- "--data_path", prompt_path,
103
- "--num_samples", "1",
 
 
 
 
 
 
104
  ]
105
  if use_ema:
106
  cmd.append("--use_ema")
@@ -110,21 +129,55 @@ def run_inference(prompt: str, duration: str, use_ema: bool):
110
  stdout=subprocess.PIPE,
111
  stderr=subprocess.STDOUT,
112
  text=True,
 
113
  )
114
 
115
  logs = result.stdout
116
  print(logs)
117
 
118
- # 5) On récupère la première vidéo produite
119
- mp4s = [f for f in os.listdir(output_folder) if f.lower().endswith(".mp4")]
120
- if not mp4s:
121
  raise gr.Error(
122
- "Aucune vidéo trouvée dans le dossier de sortie.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  "Regarde les logs ci-dessous pour voir ce qui a coincé."
124
  )
125
 
126
- mp4s.sort()
127
- video_path = os.path.join(output_folder, mp4s[0])
 
 
 
128
  return video_path, logs
129
 
130
 
@@ -159,7 +212,7 @@ with gr.Blocks(title="Reward Forcing T2V Demo") as demo:
159
  video_out = gr.Video(label="Vidéo générée")
160
  logs_out = gr.Textbox(
161
  label="Logs de inference.py",
162
- lines=10,
163
  interactive=False,
164
  )
165
 
 
1
  import sys
2
+ import os
3
+ import uuid
4
  import subprocess
5
+ from datetime import datetime
6
 
7
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from huggingface_hub import snapshot_download
9
 
10
+ # -------------------------------------------------------------------
11
+ # (Optionnel) flash-attn : comme tu as déjà la bonne wheel dans
12
+ # requirements.txt, on laisse commenté pour éviter des builds lents.
13
+ # -------------------------------------------------------------------
14
+ # def ensure_flash_attn():
15
+ # try:
16
+ # import flash_attn # noqa: F401
17
+ # print("[init] flash-attn déjà installé")
18
+ # except Exception as e:
19
+ # print("[init] Installation de flash-attn (build from source)...", e, flush=True)
20
+ # subprocess.run(
21
+ # [
22
+ # sys.executable,
23
+ # "-m",
24
+ # "pip",
25
+ # "install",
26
+ # "flash-attn==2.7.4.post1",
27
+ # "--no-build-isolation",
28
+ # ],
29
+ # check=True,
30
+ # )
31
+ # import flash_attn # noqa: F401
32
+ # print("[init] flash-attn OK")
33
+
34
+ # ensure_flash_attn()
35
+
36
+ # -------------------------------------------------------------------
37
+ # Téléchargement des checkpoints (fait une fois au démarrage du Space)
38
+ # -------------------------------------------------------------------
39
  snapshot_download(
40
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
41
+ local_dir="./checkpoints/Wan2.1-T2V-1.3B",
42
  )
43
 
44
  snapshot_download(
45
+ repo_id="KlingTeam/VideoReward",
46
+ local_dir="./checkpoints/Videoreward",
47
  )
48
 
49
  snapshot_download(
50
+ repo_id="gdhe17/Self-Forcing",
51
+ local_dir="./checkpoints/ode_init.pt",
52
  )
53
 
54
  snapshot_download(
55
+ repo_id="JaydenLu666/Reward-Forcing-T2V-1.3B",
56
+ local_dir="./checkpoints/Reward-Forcing-T2V-1.3B",
57
  )
58
 
59
+ # === Chemins ===
 
 
 
 
 
 
 
60
  CONFIG_PATH = "configs/reward_forcing.yaml"
61
  CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
62
 
63
  PROMPT_DIR = "prompts/gradio_inputs"
64
+ # on garde OUTPUT_ROOT mais on va aussi coller au README pour l'output
65
+ OUTPUT_ROOT = "videos"
66
 
67
  os.makedirs(PROMPT_DIR, exist_ok=True)
68
  os.makedirs(OUTPUT_ROOT, exist_ok=True)
 
72
  """
73
  1. Écrit le prompt dans un fichier .txt
74
  2. Lance inference.py avec ce fichier comme --data_path
75
+ 3. Retourne le chemin de la vidéo .mp4 générée + les logs
76
  """
77
+ import glob
78
+
79
  if not prompt or not prompt.strip():
80
  raise gr.Error("Veuillez entrer un prompt texte 🙂")
81
 
82
+ # 1) Durée -> num_output_frames + dossier conforme au README
83
  if duration == "5s (21 frames)":
84
  num_output_frames = 21
85
+ output_folder = os.path.join(OUTPUT_ROOT, "rewardforcing-5s")
86
  else: # "30s (120 frames)"
87
  num_output_frames = 120
88
+ output_folder = os.path.join(OUTPUT_ROOT, "rewardforcing-30s")
89
+
90
+ os.makedirs(output_folder, exist_ok=True)
91
 
92
  # 2) Fichier .txt temporaire pour le prompt
93
  prompt_id = uuid.uuid4().hex[:8]
94
  prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
95
 
96
  with open(prompt_path, "w", encoding="utf-8") as f:
97
+ # TextDataset lit chaque ligne comme un prompt
98
  f.write(prompt.strip() + "\n")
99
 
100
+ # 3) On sauve la liste des vidéos AVANT l'inférence
101
+ cwd = os.path.dirname(os.path.abspath(__file__))
102
+ before_mp4s = set(
103
+ os.path.relpath(p, cwd)
104
+ for p in glob.glob(os.path.join(cwd, "videos", "**", "*.mp4"), recursive=True)
105
+ )
106
 
107
  # 4) Commande inference.py
108
  cmd = [
109
+ sys.executable,
110
  "inference.py",
111
+ "--num_output_frames",
112
+ str(num_output_frames),
113
+ "--config_path",
114
+ CONFIG_PATH,
115
+ "--checkpoint_path",
116
+ CHECKPOINT_PATH,
117
+ "--output_folder",
118
+ output_folder,
119
+ "--data_path",
120
+ prompt_path,
121
+ "--num_samples",
122
+ "1",
123
  ]
124
  if use_ema:
125
  cmd.append("--use_ema")
 
129
  stdout=subprocess.PIPE,
130
  stderr=subprocess.STDOUT,
131
  text=True,
132
+ cwd=cwd, # important sur les Spaces
133
  )
134
 
135
  logs = result.stdout
136
  print(logs)
137
 
138
+ # 5) Si inference.py a planté, on remonte l'erreur
139
+ if result.returncode != 0:
 
140
  raise gr.Error(
141
+ f"inference.py a retourné un code d'erreur ({result.returncode}).\n\n"
142
+ "Regarde les logs ci-dessous pour les détails."
143
+ )
144
+
145
+ # 6) On regarde les vidéos APRÈS l'inférence
146
+ after_mp4s_abs = glob.glob(os.path.join(cwd, "videos", "**", "*.mp4"), recursive=True)
147
+ after_mp4s = set(os.path.relpath(p, cwd) for p in after_mp4s_abs)
148
+
149
+ new_mp4s = list(after_mp4s - before_mp4s)
150
+
151
+ # Debug : log de tout ce qui a été trouvé
152
+ logs += "\n\n[DEBUG] Fichiers .mp4 AVANT:\n"
153
+ logs += "\n".join(sorted(before_mp4s)) if before_mp4s else "[aucun]\n"
154
+ logs += "\n\n[DEBUG] Fichiers .mp4 APRÈS:\n"
155
+ logs += "\n".join(sorted(after_mp4s)) if after_mp4s else "[aucun]\n"
156
+
157
+ if not new_mp4s:
158
+ # Pas de nouvelle vidéo détectée. En dernier recours,
159
+ # on prend la plus récente dans tout `videos/` si elle existe.
160
+ if after_mp4s_abs:
161
+ after_mp4s_abs.sort(key=os.path.getmtime, reverse=True)
162
+ fallback_video = after_mp4s_abs[0]
163
+ logs += (
164
+ "\n\n[WARN] Aucune nouvelle vidéo détectée, "
165
+ "on utilise la plus récente trouvée: "
166
+ f"{os.path.relpath(fallback_video, cwd)}"
167
+ )
168
+ return fallback_video, logs
169
+
170
+ # Vraiment aucune vidéo
171
+ raise gr.Error(
172
+ "Aucune vidéo .mp4 trouvée dans le dossier de sortie.\n"
173
  "Regarde les logs ci-dessous pour voir ce qui a coincé."
174
  )
175
 
176
+ # On prend la nouvelle vidéo la plus récente
177
+ new_mp4s_abs = [os.path.join(cwd, p) for p in new_mp4s]
178
+ new_mp4s_abs.sort(key=os.path.getmtime, reverse=True)
179
+ video_path = new_mp4s_abs[0]
180
+
181
  return video_path, logs
182
 
183
 
 
212
  video_out = gr.Video(label="Vidéo générée")
213
  logs_out = gr.Textbox(
214
  label="Logs de inference.py",
215
+ lines=12,
216
  interactive=False,
217
  )
218