alex commited on
Commit
5128853
·
1 Parent(s): 44b7571

reverting back

Browse files
Files changed (2) hide show
  1. app.py +6 -11
  2. src/eval/generate_samples.py +2 -2
app.py CHANGED
@@ -210,7 +210,6 @@ def run_one(
210
  action_id: int,
211
  out_dir: str,
212
  guidance: Tuple[float,float] = (1.0, 3.0),
213
- seed = 0
214
  ) -> str:
215
  os.makedirs(out_dir, exist_ok=True)
216
  sample = make_sample(init_image_path, keyframes_boxes, action_id)
@@ -233,7 +232,6 @@ def run_one(
233
  use_factor_guidance=g_use_factor_guidance,
234
  guidance=list(guidance),
235
  video_path2=None,
236
- seed=seed
237
  )
238
  return f"{video_prefix}.mp4"
239
 
@@ -315,10 +313,9 @@ def ctrl_generate_from_video(
315
  vid_path,
316
  action_label: str,
317
  num_frames: int = 3,
318
- seed = 0,
319
- session_id = None,
320
  guidance_min: float = 1.0,
321
- guidance_max: float = 3.0,
 
322
  ):
323
 
324
  if session_id is None:
@@ -358,7 +355,6 @@ def ctrl_generate_from_video(
358
  action_id=action_id,
359
  out_dir=out_dir,
360
  guidance=(guidance_min, guidance_max),
361
- seed = seed
362
  )
363
 
364
  if not os.path.exists(mp4_path):
@@ -370,10 +366,10 @@ def ctrl_generate_from_video(
370
  def ctrl_generate_from_image(
371
  img_path,
372
  action_label: str,
373
- seed = 0,
374
- session_id = None,
375
  guidance_min: float = 1.0,
376
  guidance_max: float = 3.0,
 
377
  ):
378
 
379
  if session_id is None:
@@ -421,7 +417,6 @@ def ctrl_generate_from_image(
421
  action_id=action_id,
422
  out_dir=out_dir,
423
  guidance=(guidance_min, guidance_max),
424
- seed = seed
425
  )
426
 
427
  if not os.path.exists(mp4_path):
@@ -615,7 +610,7 @@ with gr.Blocks(css=css) as demo:
615
  run_video_btn.click(
616
  fn=ctrl_generate_from_video,
617
  inputs=[
618
- video_in, action_dropdown, bbx_frames, 2, session_state
619
  ],
620
  outputs=[video_out, boxes_gif],
621
  api_name="generate"
@@ -624,7 +619,7 @@ with gr.Blocks(css=css) as demo:
624
  run_image_btn.click(
625
  fn=ctrl_generate_from_image,
626
  inputs=[
627
- image_in, action_dropdown, 2, session_state
628
  ],
629
  outputs=[video_out],
630
  api_name="generate"
 
210
  action_id: int,
211
  out_dir: str,
212
  guidance: Tuple[float,float] = (1.0, 3.0),
 
213
  ) -> str:
214
  os.makedirs(out_dir, exist_ok=True)
215
  sample = make_sample(init_image_path, keyframes_boxes, action_id)
 
232
  use_factor_guidance=g_use_factor_guidance,
233
  guidance=list(guidance),
234
  video_path2=None,
 
235
  )
236
  return f"{video_prefix}.mp4"
237
 
 
313
  vid_path,
314
  action_label: str,
315
  num_frames: int = 3,
 
 
316
  guidance_min: float = 1.0,
317
+ guidance_max: float = 3.0,
318
+ session_id = None
319
  ):
320
 
321
  if session_id is None:
 
355
  action_id=action_id,
356
  out_dir=out_dir,
357
  guidance=(guidance_min, guidance_max),
 
358
  )
359
 
360
  if not os.path.exists(mp4_path):
 
366
  def ctrl_generate_from_image(
367
  img_path,
368
  action_label: str,
369
+ num_frames: int = 3,
 
370
  guidance_min: float = 1.0,
371
  guidance_max: float = 3.0,
372
+ session_id = None
373
  ):
374
 
375
  if session_id is None:
 
417
  action_id=action_id,
418
  out_dir=out_dir,
419
  guidance=(guidance_min, guidance_max),
 
420
  )
421
 
422
  if not os.path.exists(mp4_path):
 
610
  run_video_btn.click(
611
  fn=ctrl_generate_from_video,
612
  inputs=[
613
+ video_in, action_dropdown, bbx_frames
614
  ],
615
  outputs=[video_out, boxes_gif],
616
  api_name="generate"
 
619
  run_image_btn.click(
620
  fn=ctrl_generate_from_image,
621
  inputs=[
622
+ image_in, action_dropdown, bbx_frames
623
  ],
624
  outputs=[video_out],
625
  api_name="generate"
src/eval/generate_samples.py CHANGED
@@ -144,7 +144,7 @@ def load_ctrlv_pipelines(model_dir, use_null_model=False, use_factor_guidance=Fa
144
  return pipeline
145
 
146
 
147
- def generate_video_ctrlv(sample, pipeline, video_path="video_out/genvid", json_path="video_out/gt_frames", bbox_mask_frames=None, action_type=None, use_factor_guidance=False, guidance=[1.0, 3.0], video_path2=None, seed=0):
148
  frame_size = (512, 320)
149
  FPS = 6
150
  CLIP_LENGTH = sample['bbox_images'].shape[0]
@@ -161,7 +161,7 @@ def generate_video_ctrlv(sample, pipeline, video_path="video_out/genvid", json_p
161
  json.dump(gt_frame_paths, file, indent=1)
162
  print("Saved GT frames json file:", json_path)
163
 
164
- generator = torch.Generator(device=device).manual_seed(seed)
165
 
166
  if not use_factor_guidance:
167
  frames = pipeline(init_image,
 
144
  return pipeline
145
 
146
 
147
+ def generate_video_ctrlv(sample, pipeline, video_path="video_out/genvid", json_path="video_out/gt_frames", bbox_mask_frames=None, action_type=None, use_factor_guidance=False, guidance=[1.0, 3.0], video_path2=None):
148
  frame_size = (512, 320)
149
  FPS = 6
150
  CLIP_LENGTH = sample['bbox_images'].shape[0]
 
161
  json.dump(gt_frame_paths, file, indent=1)
162
  print("Saved GT frames json file:", json_path)
163
 
164
+ generator = torch.Generator(device=device).manual_seed(2)
165
 
166
  if not use_factor_guidance:
167
  frames = pipeline(init_image,