Fabrice-TIERCELIN commited on
Commit
897013f
1 Parent(s): 5d17ee3

Randomize the seed

Browse files

...and fix some typo

Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -5,6 +5,7 @@ import tempfile
5
  import gradio as gr
6
  import numpy as np
7
  import torch
 
8
  from glob import glob
9
  from torchvision.transforms import CenterCrop, Compose, Resize
10
 
@@ -175,6 +176,7 @@ def fn_vis_camera(camera_args):
175
  gr.update(visible=vis_prompt), \
176
  gr.update(visible=vis_num_samples), \
177
  gr.update(visible=vis_seed), \
 
178
  gr.update(visible=vis_start), \
179
  gr.update(visible=vis_gen_video, value=None), \
180
  gr.update(visible=vis_repeat_highlight)
@@ -388,6 +390,7 @@ def visualized_camera_poses(step2_camera_motion):
388
  gr.update(visible=vis_prompt), \
389
  gr.update(visible=vis_num_samples), \
390
  gr.update(visible=vis_seed), \
 
391
  gr.update(visible=vis_start), \
392
  gr.update(visible=vis_gen_video), \
393
  gr.update(visible=vis_repeat_highlight)
@@ -503,12 +506,15 @@ def process_input_image(input_image, resize_mode):
503
  gr.update(visible=vis_prompt), \
504
  gr.update(visible=vis_num_samples), \
505
  gr.update(visible=vis_seed), \
 
506
  gr.update(visible=vis_start), \
507
  gr.update(visible=vis_gen_video), \
508
  gr.update(visible=vis_repeat_highlight)
509
 
510
- def model_run(input_image, fps_id, seed, n_samples, camera_args):
511
  global model, device, camera_dict, num_frames, num_steps, width, height
 
 
512
  RT = process_camera(camera_dict, camera_args, num_frames=num_frames, width=width, height=height).reshape(-1,12)
513
 
514
  video_path = motionctrl_sample(
@@ -657,14 +663,15 @@ def main(args):
657
  generation_dec = gr.Markdown(f"\n 1. Set `FPS`.; \
658
  \n 2. Set `n_samples`; \
659
  \n 3. Set `seed`; \
660
- \n 4. Click `Start generation !` to generate videos; ", visible=False)
661
  # prompt = gr.Textbox(value="a dog sitting on grass", label="Prompt", interactive=True, visible=False)
662
  prompt = gr.Slider(minimum=5, maximum=30, step=1, label="FPS", value=10, visible=False)
663
  n_samples = gr.Number(value=1, precision=0, interactive=True, label="n_samples", visible=False)
664
- seed = gr.Number(value=1234, precision=0, interactive=True, label="Seed", visible=False)
665
- start = gr.Button(value="Start generation !", visible=False)
 
666
  with gr.Column():
667
- gen_video = gr.Video(value=None, label="Generate Video", visible=False)
668
  repeat_highlight=gr.HighlightedText(value=[("",""), (f"1. If the motion control is not obvious, try to increase the `Motion Speed`. \
669
  \n 2. If the generated videos are distored severely, try to descrease the `Motion Speed` \
670
  or increase `FPS`.", "Normal")],
@@ -699,7 +706,7 @@ def main(args):
699
  generation_dec,
700
  prompt,
701
  n_samples,
702
- seed, start, gen_video, repeat_highlight])
703
 
704
  keep_spatial_raition_botton.click(
705
  fn=process_input_image,
@@ -730,7 +737,7 @@ def main(args):
730
  generation_dec,
731
  prompt,
732
  n_samples,
733
- seed, start, gen_video, repeat_highlight])
734
 
735
 
736
  camera_info.click(
@@ -751,7 +758,7 @@ def main(args):
751
  camera_args,
752
  camera_reset, camera_vis,
753
  vis_camera,
754
- step3_prompt_generate, generation_dec, prompt, n_samples, seed, start, gen_video, repeat_highlight],
755
  )
756
 
757
 
@@ -793,6 +800,7 @@ def main(args):
793
  generation_dec,
794
  prompt,
795
  n_samples,
 
796
  seed,
797
  start,
798
  gen_video,
@@ -809,7 +817,7 @@ def main(args):
809
 
810
 
811
  start.click(fn=model_run,
812
- inputs=[process_image, prompt, seed, n_samples, camera_args],
813
  outputs=gen_video)
814
 
815
  # set example
 
5
  import gradio as gr
6
  import numpy as np
7
  import torch
8
+ import random
9
  from glob import glob
10
  from torchvision.transforms import CenterCrop, Compose, Resize
11
 
 
176
  gr.update(visible=vis_prompt), \
177
  gr.update(visible=vis_num_samples), \
178
  gr.update(visible=vis_seed), \
179
+ gr.update(visible=vis_seed), \
180
  gr.update(visible=vis_start), \
181
  gr.update(visible=vis_gen_video, value=None), \
182
  gr.update(visible=vis_repeat_highlight)
 
390
  gr.update(visible=vis_prompt), \
391
  gr.update(visible=vis_num_samples), \
392
  gr.update(visible=vis_seed), \
393
+ gr.update(visible=vis_seed), \
394
  gr.update(visible=vis_start), \
395
  gr.update(visible=vis_gen_video), \
396
  gr.update(visible=vis_repeat_highlight)
 
506
  gr.update(visible=vis_prompt), \
507
  gr.update(visible=vis_num_samples), \
508
  gr.update(visible=vis_seed), \
509
+ gr.update(visible=vis_seed), \
510
  gr.update(visible=vis_start), \
511
  gr.update(visible=vis_gen_video), \
512
  gr.update(visible=vis_repeat_highlight)
513
 
514
+ def model_run(input_image, fps_id, randomize_seed, seed, n_samples, camera_args):
515
  global model, device, camera_dict, num_frames, num_steps, width, height
516
+ if randomize_seed:
517
+ seed = random.randint(0, 2**63 - 1)
518
  RT = process_camera(camera_dict, camera_args, num_frames=num_frames, width=width, height=height).reshape(-1,12)
519
 
520
  video_path = motionctrl_sample(
 
663
  generation_dec = gr.Markdown(f"\n 1. Set `FPS`.; \
664
  \n 2. Set `n_samples`; \
665
  \n 3. Set `seed`; \
666
+ \n 4. Click `Start generation!` to generate videos; ", visible=False)
667
  # prompt = gr.Textbox(value="a dog sitting on grass", label="Prompt", interactive=True, visible=False)
668
  prompt = gr.Slider(minimum=5, maximum=30, step=1, label="FPS", value=10, visible=False)
669
  n_samples = gr.Number(value=1, precision=0, interactive=True, label="n_samples", visible=False)
670
+ randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different", interactive=True, visible=False)
671
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2**63 - 1, step=1, randomize=True, interactive=True, visible=False)
672
+ start = gr.Button(value="Start generation!", variant = "primary", visible=False)
673
  with gr.Column():
674
+ gen_video = gr.Video(value=None, label="Generated Video", visible=False)
675
  repeat_highlight=gr.HighlightedText(value=[("",""), (f"1. If the motion control is not obvious, try to increase the `Motion Speed`. \
676
  \n 2. If the generated videos are distored severely, try to descrease the `Motion Speed` \
677
  or increase `FPS`.", "Normal")],
 
706
  generation_dec,
707
  prompt,
708
  n_samples,
709
+ randomize_seed, seed, start, gen_video, repeat_highlight])
710
 
711
  keep_spatial_raition_botton.click(
712
  fn=process_input_image,
 
737
  generation_dec,
738
  prompt,
739
  n_samples,
740
+ randomize_seed, seed, start, gen_video, repeat_highlight])
741
 
742
 
743
  camera_info.click(
 
758
  camera_args,
759
  camera_reset, camera_vis,
760
  vis_camera,
761
+ step3_prompt_generate, generation_dec, prompt, n_samples, randomize_seed, seed, start, gen_video, repeat_highlight],
762
  )
763
 
764
 
 
800
  generation_dec,
801
  prompt,
802
  n_samples,
803
+ randomize_seed,
804
  seed,
805
  start,
806
  gen_video,
 
817
 
818
 
819
  start.click(fn=model_run,
820
+ inputs=[process_image, prompt, randomize_seed, seed, n_samples, camera_args],
821
  outputs=gen_video)
822
 
823
  # set example