|
import spaces |
|
import gradio as gr |
|
import torch |
|
import torchvision as tv |
|
import random, os |
|
from diffusers import StableVideoDiffusionPipeline |
|
from PIL import Image |
|
from glob import glob |
|
from typing import Optional |
|
|
|
from tdd_svd_scheduler import TDDSVDStochasticIterativeScheduler |
|
from utils import load_lora_weights, save_video |
|
|
|
|
|
LOCAL = False |
|
|
|
if LOCAL: |
|
svd_path = '/share2/duanyuxuan/diff_playground/diffusers_models/stable-video-diffusion-img2vid-xt-1-1' |
|
lora_file_path = '/share2/duanyuxuan/diff_playground/SVD-TDD/svd-xt-1-1_tdd_lora_weights.safetensors' |
|
else: |
|
svd_path = 'stabilityai/stable-video-diffusion-img2vid-xt-1-1' |
|
lora_repo_path = 'RED-AIGC/TDD' |
|
lora_weight_name = 'svd-xt-1-1_tdd_lora_weights.safetensors' |
|
|
|
if torch.cuda.is_available(): |
|
noise_scheduler = TDDSVDStochasticIterativeScheduler(num_train_timesteps = 250, sigma_min = 0.002, sigma_max = 700.0, sigma_data = 1.0, |
|
s_noise = 1.0, rho = 7, clip_denoised = False) |
|
|
|
pipeline = StableVideoDiffusionPipeline.from_pretrained(svd_path, scheduler = noise_scheduler, torch_dtype = torch.float16, variant = "fp16").to('cuda') |
|
if LOCAL: |
|
load_lora_weights(pipeline.unet, lora_file_path) |
|
else: |
|
load_lora_weights(pipeline.unet, lora_repo_path, weight_name = lora_weight_name) |
|
|
|
max_64_bit_int = 2**63 - 1 |
|
|
|
@spaces.GPU |
|
def sample( |
|
image: Image, |
|
seed: Optional[int] = 1, |
|
randomize_seed: bool = False, |
|
num_inference_steps: int = 4, |
|
eta: float = 0.3, |
|
min_guidance_scale: float = 1.0, |
|
max_guidance_scale: float = 1.0, |
|
|
|
fps: int = 7, |
|
width: int = 512, |
|
height: int = 512, |
|
num_frames: int = 25, |
|
motion_bucket_id: int = 127, |
|
output_folder: str = "outputs_gradio", |
|
): |
|
pipeline.scheduler.set_eta(eta) |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, max_64_bit_int) |
|
generator = torch.manual_seed(seed) |
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) |
|
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") |
|
|
|
with torch.autocast("cuda"): |
|
frames = pipeline( |
|
image, height = height, width = width, |
|
num_inference_steps = num_inference_steps, |
|
min_guidance_scale = min_guidance_scale, |
|
max_guidance_scale = max_guidance_scale, |
|
num_frames = num_frames, fps = fps, motion_bucket_id = motion_bucket_id, |
|
decode_chunk_size = 8, |
|
noise_aug_strength = 0.02, |
|
generator = generator, |
|
).frames[0] |
|
save_video(frames, video_path, fps = fps, quality = 5.0) |
|
torch.manual_seed(seed) |
|
|
|
return video_path, seed |
|
|
|
|
|
def preprocess_image(image, height = 512, width = 512): |
|
image = image.convert('RGB') |
|
if image.size[0] != image.size[1]: |
|
image = tv.transforms.functional.pil_to_tensor(image) |
|
image = tv.transforms.functional.center_crop(image, min(image.shape[-2:])) |
|
image = tv.transforms.functional.to_pil_image(image) |
|
image = image.resize((width, height)) |
|
return image |
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
.gradio-container { |
|
max-width: 70.5rem !important; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css = css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Stable Video Diffusion distilled by ✨Target-Driven Distillation✨ |
|
|
|
Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of *target timestep selection* and *decoupled guidance*, models distilled by TDD can generated highly detailed images with only a few steps. |
|
|
|
Besides, TDD is also available for distilling video generation models. This space presents TDD-distilled [SVD-xt 1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). |
|
|
|
[**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD) |
|
|
|
The codes of this space are built on [AnimateLCM-SVD](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD) and we acknowledge their contribution. |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(label="Upload your image", type="pil") |
|
generate_btn = gr.Button("Generate") |
|
video = gr.Video() |
|
with gr.Accordion("Options", open = True): |
|
seed = gr.Slider( |
|
label="Seed", |
|
value=1, |
|
randomize=False, |
|
minimum=0, |
|
maximum=max_64_bit_int, |
|
step=1, |
|
) |
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) |
|
min_guidance_scale = gr.Slider( |
|
label="Min guidance scale", |
|
info="min strength of classifier-free guidance", |
|
value=1.0, |
|
minimum=1.0, |
|
maximum=1.5, |
|
) |
|
max_guidance_scale = gr.Slider( |
|
label="Max guidance scale", |
|
info="max strength of classifier-free guidance, it should not be less than Min guidance scale", |
|
value=1.0, |
|
minimum=1.0, |
|
maximum=3.0, |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Num inference steps", |
|
info="steps for inference", |
|
value=4, |
|
minimum=4, |
|
maximum=8, |
|
step=1, |
|
) |
|
eta = gr.Slider( |
|
label = "Eta", |
|
info = "the value of gamma in gamma-sampling", |
|
value = 0.3, |
|
minimum = 0.0, |
|
maximum = 1.0, |
|
step = 0.1, |
|
) |
|
|
|
image.upload(fn = preprocess_image, inputs = image, outputs = image, queue = False) |
|
generate_btn.click( |
|
fn = sample, |
|
inputs = [ |
|
image, |
|
seed, |
|
randomize_seed, |
|
num_inference_steps, |
|
eta, |
|
min_guidance_scale, |
|
max_guidance_scale, |
|
], |
|
outputs = [video, seed], |
|
api_name = "video", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
if LOCAL: |
|
demo.queue().launch(share=True, server_name='0.0.0.0') |
|
else: |
|
demo.queue(api_open=False).launch(show_api=False) |