File size: 3,303 Bytes
2c1f0c3 95cc45b 35920a6 998bf52 2c1f0c3 95cc45b 998bf52 95cc45b 0593b2c 998bf52 0593b2c c5c043a 95cc45b 448a859 95cc45b 448a859 95cc45b 0593b2c 95cc45b 998bf52 448a859 998bf52 448a859 ce3d7bb 448a859 998bf52 448a859 998bf52 95cc45b 3df5e24 95cc45b 35920a6 ce3d7bb 3df5e24 35920a6 3df5e24 95cc45b 3df5e24 254aacb 2c1f0c3 35920a6 2c1f0c3 95cc45b 2c1f0c3 35920a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import torch
import torchvision
from diffusers import I2VGenXLPipeline, DiffusionPipeline
from torchvision.transforms.functional import to_tensor
from PIL import Image
from utils import create_progress_updater
if gr.NO_RELOAD:
n_sdxl_steps = 50
n_i2v_steps = 50
high_noise_frac = 0.8
negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
generator = torch.manual_seed(8888)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
total_steps = n_sdxl_steps + n_i2v_steps
print("Device:", device)
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
base.to("cuda")
refiner.to("cuda")
pipeline.to("cuda")
base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
def generate(prompt: str, progress=gr.Progress()):
progress((0, 100), desc="Starting..")
image = base(
prompt=prompt,
num_inference_steps=n_sdxl_steps,
denoising_end=high_noise_frac,
output_type="latent",
callback_on_step_end=create_progress_updater(
start=0,
total=total_steps,
desc="Generating first frame...",
progress=progress,
),
).images[0]
progress((n_sdxl_steps * high_noise_frac, total_steps), desc="Refining first frame...")
image = refiner(
prompt=prompt,
num_inference_steps=n_sdxl_steps,
denoising_start=high_noise_frac,
image=image,
callback_on_step_end=create_progress_updater(
start=n_sdxl_steps * high_noise_frac,
total=total_steps,
desc="Refining first frame...",
progress=progress,
),
).images[0]
image = to_tensor(image)
progress((n_sdxl_steps + 1, total_steps), desc="Generating video...")
frames: list[Image.Image] = pipeline(
prompt=prompt,
image=image,
num_inference_steps=50,
negative_prompt=negative_prompt,
guidance_scale=9.0,
generator=generator,
decode_chunk_size=4,
num_frames=64,
).frames[0]
progress((total_steps - 1, total_steps), desc="Finalizing...")
frames = [to_tensor(frame.convert("RGB")).mul(255).byte().permute(1, 2, 0) for frame in frames]
frames = torch.stack(frames)
torchvision.io.write_video("video.mp4", frames, fps=16)
return "video.mp4"
app = gr.Interface(
fn=generate,
inputs=["text"],
outputs=gr.Video()
)
if __name__ == "__main__":
app.launch()
|