ihsanvp commited on
Commit
3df5e24
1 Parent(s): 998bf52

fix: progress updater

Browse files
Files changed (2) hide show
  1. app.py +5 -8
  2. utils.py +1 -1
app.py CHANGED
@@ -67,6 +67,7 @@ def generate(prompt: str, progress=gr.Progress()):
67
  ),
68
  ).images[0]
69
  image = to_tensor(image)
 
70
  frames: list[Image.Image] = pipeline(
71
  prompt=prompt,
72
  image=image,
@@ -74,17 +75,13 @@ def generate(prompt: str, progress=gr.Progress()):
74
  negative_prompt=negative_prompt,
75
  guidance_scale=9.0,
76
  generator=generator,
77
- decode_chunk_size=10,
78
- callback_on_step_end=create_progress_updater(
79
- start=n_sdxl_steps,
80
- total=total_steps,
81
- desc="Generating video...",
82
- progress=progress,
83
- ),
84
  ).frames[0]
 
85
  frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
86
  frames = torch.stack(frames)
87
- torchvision.io.write_video("video.mp4", frames, fps=8)
88
  return "video.mp4"
89
 
90
  app = gr.Interface(
 
67
  ),
68
  ).images[0]
69
  image = to_tensor(image)
70
+ progress((n_sdxl_steps + 1, total_steps), desc="Generating video...")
71
  frames: list[Image.Image] = pipeline(
72
  prompt=prompt,
73
  image=image,
 
75
  negative_prompt=negative_prompt,
76
  guidance_scale=9.0,
77
  generator=generator,
78
+ decode_chunk_size=8,
79
+ num_frames=64,
 
 
 
 
 
80
  ).frames[0]
81
+ progress((total_steps - 1, total_steps), desc="Finalizing...")
82
  frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
83
  frames = torch.stack(frames)
84
+ torchvision.io.write_video("video.mp4", frames, fps=16)
85
  return "video.mp4"
86
 
87
  app = gr.Interface(
utils.py CHANGED
@@ -2,6 +2,6 @@ from gradio import Progress
2
 
3
  def create_progress_updater(start: int, total: int, desc: str, progress: Progress):
4
  def updater(pipe, step, timestep, callback_kwargs):
5
- progress((step + start, total), desc=desc)
6
  return callback_kwargs
7
  return updater
 
2
 
3
  def create_progress_updater(start: int, total: int, desc: str, progress: Progress):
4
  def updater(pipe, step, timestep, callback_kwargs):
5
+ progress((step + start + 1, total), desc=desc)
6
  return callback_kwargs
7
  return updater