fix: out of memory error
Browse files
app.py
CHANGED
@@ -32,13 +32,14 @@ if gr.NO_RELOAD:
|
|
32 |
# )
|
33 |
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
|
34 |
|
35 |
-
base.to("cuda")
|
36 |
# refiner.to("cuda")
|
37 |
# pipeline.to("cuda")
|
38 |
|
39 |
-
base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
|
40 |
# refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
|
41 |
# pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
|
|
42 |
pipeline.enable_model_cpu_offload()
|
43 |
pipeline.unet.enable_forward_chunking()
|
44 |
|
@@ -77,12 +78,12 @@ def generate(prompt: str, progress=gr.Progress()):
|
|
77 |
guidance_scale=9.0,
|
78 |
generator=generator,
|
79 |
decode_chunk_size=2,
|
80 |
-
num_frames=
|
81 |
).frames[0]
|
82 |
progress((total_steps - 1, total_steps), desc="Finalizing...")
|
83 |
frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
|
84 |
frames = torch.stack(frames)
|
85 |
-
torchvision.io.write_video("video.mp4", frames, fps=
|
86 |
return "video.mp4"
|
87 |
|
88 |
app = gr.Interface(
|
|
|
32 |
# )
|
33 |
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
|
34 |
|
35 |
+
# base.to("cuda")
|
36 |
# refiner.to("cuda")
|
37 |
# pipeline.to("cuda")
|
38 |
|
39 |
+
# base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
|
40 |
# refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
|
41 |
# pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
42 |
+
base.enable_model_cpu_offload()
|
43 |
pipeline.enable_model_cpu_offload()
|
44 |
pipeline.unet.enable_forward_chunking()
|
45 |
|
|
|
78 |
guidance_scale=9.0,
|
79 |
generator=generator,
|
80 |
decode_chunk_size=2,
|
81 |
+
num_frames=16,
|
82 |
).frames[0]
|
83 |
progress((total_steps - 1, total_steps), desc="Finalizing...")
|
84 |
frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
|
85 |
frames = torch.stack(frames)
|
86 |
+
torchvision.io.write_video("video.mp4", frames, fps=8)
|
87 |
return "video.mp4"
|
88 |
|
89 |
app = gr.Interface(
|