File size: 3,418 Bytes
2c1f0c3 95cc45b 35920a6 998bf52 2c1f0c3 95cc45b 998bf52 95cc45b 0593b2c 998bf52 0593b2c c5c043a 95cc45b c36adbc 95cc45b 448a859 3630e4c c36adbc 5ff8f5d 448a859 3630e4c c36adbc 5ff8f5d 3630e4c 5ff8f5d 95cc45b 0593b2c b8a7508 95cc45b 998bf52 448a859 310e3d0 c36adbc 95cc45b b8a7508 95cc45b 35920a6 344a67a ce46880 35920a6 3df5e24 95cc45b 3630e4c 254aacb 2c1f0c3 35920a6 2c1f0c3 95cc45b 2c1f0c3 35920a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
import torch
import torchvision
from diffusers import I2VGenXLPipeline, DiffusionPipeline
from torchvision.transforms.functional import to_tensor
from PIL import Image
from utils import create_progress_updater
if gr.NO_RELOAD:
n_sdxl_steps = 50
n_i2v_steps = 50
high_noise_frac = 0.8
negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
generator = torch.manual_seed(8888)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
total_steps = n_sdxl_steps + n_i2v_steps
print("Device:", device)
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
# refiner = DiffusionPipeline.from_pretrained(
# "stabilityai/stable-diffusion-xl-refiner-1.0",
# text_encoder_2=base.text_encoder_2,
# vae=base.vae,
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16",
# )
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
# base.to("cuda")
# refiner.to("cuda")
# pipeline.to("cuda")
# base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
# refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
# pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
base.enable_model_cpu_offload()
pipeline.enable_model_cpu_offload()
pipeline.unet.enable_forward_chunking()
def generate(prompt: str, progress=gr.Progress()):
progress((0, 100), desc="Generating first frame...")
image = base(
prompt=prompt,
num_inference_steps=n_sdxl_steps,
callback_on_step_end=create_progress_updater(
start=0,
total=total_steps,
desc="Generating first frame...",
progress=progress,
),
).images[0]
# progress((n_sdxl_steps * high_noise_frac, total_steps), desc="Refining first frame...")
# image = refiner(
# prompt=prompt,
# num_inference_steps=n_sdxl_steps,
# denoising_start=high_noise_frac,
# image=image,
# callback_on_step_end=create_progress_updater(
# start=n_sdxl_steps * high_noise_frac,
# total=total_steps,
# desc="Refining first frame...",
# progress=progress,
# ),
# ).images[0]
image = to_tensor(image)
progress((n_sdxl_steps, total_steps), desc="Generating video...")
frames: list[Image.Image] = pipeline(
prompt=prompt,
image=image,
num_inference_steps=50,
negative_prompt=negative_prompt,
guidance_scale=9.0,
generator=generator,
decode_chunk_size=2,
num_frames=32,
).frames[0]
progress((total_steps - 1, total_steps), desc="Finalizing...")
frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
frames = torch.stack(frames)
torchvision.io.write_video("video.mp4", frames, fps=8)
return "video.mp4"
app = gr.Interface(
fn=generate,
inputs=["text"],
outputs=gr.Video()
)
if __name__ == "__main__":
app.launch()
|