fix: progress updater
Browse files
app.py
CHANGED
@@ -67,6 +67,7 @@ def generate(prompt: str, progress=gr.Progress()):
|
|
67 |
),
|
68 |
).images[0]
|
69 |
image = to_tensor(image)
|
|
|
70 |
frames: list[Image.Image] = pipeline(
|
71 |
prompt=prompt,
|
72 |
image=image,
|
@@ -74,17 +75,13 @@ def generate(prompt: str, progress=gr.Progress()):
|
|
74 |
negative_prompt=negative_prompt,
|
75 |
guidance_scale=9.0,
|
76 |
generator=generator,
|
77 |
-
decode_chunk_size=
|
78 |
-
|
79 |
-
start=n_sdxl_steps,
|
80 |
-
total=total_steps,
|
81 |
-
desc="Generating video...",
|
82 |
-
progress=progress,
|
83 |
-
),
|
84 |
).frames[0]
|
|
|
85 |
frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
|
86 |
frames = torch.stack(frames)
|
87 |
-
torchvision.io.write_video("video.mp4", frames, fps=
|
88 |
return "video.mp4"
|
89 |
|
90 |
app = gr.Interface(
|
|
|
67 |
),
|
68 |
).images[0]
|
69 |
image = to_tensor(image)
|
70 |
+
progress((n_sdxl_steps + 1, total_steps), desc="Generating video...")
|
71 |
frames: list[Image.Image] = pipeline(
|
72 |
prompt=prompt,
|
73 |
image=image,
|
|
|
75 |
negative_prompt=negative_prompt,
|
76 |
guidance_scale=9.0,
|
77 |
generator=generator,
|
78 |
+
decode_chunk_size=8,
|
79 |
+
num_frames=64,
|
|
|
|
|
|
|
|
|
|
|
80 |
).frames[0]
|
81 |
+
progress((total_steps - 1, total_steps), desc="Finalizing...")
|
82 |
frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
|
83 |
frames = torch.stack(frames)
|
84 |
+
torchvision.io.write_video("video.mp4", frames, fps=16)
|
85 |
return "video.mp4"
|
86 |
|
87 |
app = gr.Interface(
|
utils.py
CHANGED
@@ -2,6 +2,6 @@ from gradio import Progress
|
|
2 |
|
3 |
def create_progress_updater(start: int, total: int, desc: str, progress: Progress):
|
4 |
def updater(pipe, step, timestep, callback_kwargs):
|
5 |
-
progress((step + start, total), desc=desc)
|
6 |
return callback_kwargs
|
7 |
return updater
|
|
|
2 |
|
3 |
def create_progress_updater(start: int, total: int, desc: str, progress: Progress):
|
4 |
def updater(pipe, step, timestep, callback_kwargs):
|
5 |
+
progress((step + start + 1, total), desc=desc)
|
6 |
return callback_kwargs
|
7 |
return updater
|