FreeNoise / app.py
Anonymous
add interface
93f8cdd
import gradio as gr
import os
import sys
import argparse
import random
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling_freenoise,
load_model_checkpoint,
)
from utils.utils import instantiate_from_config
def infer(prompt, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps):
window_size = 16
window_stride = 4
if output_size == "320x512":
width = 512
height = 320
ckpt_dir_512 = "checkpoints/base_512_v2"
ckpt_path_512 = "checkpoints/base_512_v2/model.ckpt"
config_512 = "configs/inference_t2v_tconv512_v2.0_freenoise.yaml"
config_512 = OmegaConf.load(config_512)
model_config_512 = config_512.pop("model", OmegaConf.create())
model_512 = instantiate_from_config(model_config_512)
model_512 = model_512.cuda()
if not os.path.exists(ckpt_path_512):
os.makedirs(ckpt_dir_512, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512)
try:
model_512 = load_model_checkpoint(model_512, ckpt_path_512)
except:
hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512, force_download=True)
model_512 = load_model_checkpoint(model_512, ckpt_path_512)
model_512.eval()
model = model_512
fps = 16
if output_size == "576x1024":
width = 1024
height = 576
ckpt_dir_1024 = "checkpoints/base_1024_v1"
ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt"
config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml"
config_1024 = OmegaConf.load(config_1024)
model_config_1024 = config_1024.pop("model", OmegaConf.create())
model_1024 = instantiate_from_config(model_config_1024)
model_1024 = model_1024.cuda()
if not os.path.exists(ckpt_path_1024):
os.makedirs(ckpt_dir_1024, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024)
try:
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024)
except:
hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024, force_download=True)
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024)
model_1024.eval()
model = model_1024
fps = 28
num_frames = min(num_frames, 36)
elif output_size == "256x256":
width = 256
height = 256
ckpt_dir_256 = "checkpoints/base_256_v1"
ckpt_path_256 = "checkpoints/base_256_v1/model.ckpt"
config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml"
config_256 = OmegaConf.load(config_256)
model_config_256 = config_256.pop("model", OmegaConf.create())
model_256 = instantiate_from_config(model_config_256)
model_256 = model_256.cuda()
if not os.path.exists(ckpt_path_256):
os.makedirs(ckpt_dir_256, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256)
try:
model_256 = load_model_checkpoint(model_256, ckpt_path_256)
except:
hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256, force_download=True)
model_256 = load_model_checkpoint(model_256, ckpt_path_256)
model_256.eval()
model = model_256
fps = 8
print('Model Loaded.')
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
args = argparse.Namespace(
mode="base",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=0.0,
bs=1,
height=height,
width=width,
frames=num_frames,
fps=fps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
cond_input=None,
window_size=window_size,
window_stride=window_stride,
)
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
x_T_total = torch.randn(
[args.n_samples, 1, channels, frames, h, w], device=model.device
).repeat(1, args.bs, 1, 1, 1, 1)
for frame_index in range(args.window_size, args.frames, args.window_stride):
list_index = list(
range(
frame_index - args.window_size,
frame_index + args.window_stride - args.window_size,
)
)
random.shuffle(list_index)
x_T_total[
:, :, :, frame_index : frame_index + args.window_stride
] = x_T_total[:, :, :, list_index]
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [prompt]
text_emb = model.get_learned_conditioning(prompts)
cond = {"c_crossattn": [text_emb], "fps": fps}
## inference
batch_samples = batch_ddim_sampling_freenoise(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
args=args,
x_T_total=x_T_total,
)
video_path = "output.mp4"
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
video_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
print(video_path)
return video_path
examples = [
["A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect",],
["A corgi is swimming quickly",],
["A bigfoot walking in the snowstorm",],
["Campfire at night in a snowy forest with starry sky in the background",],
["A panda is surfing in the universe",],
]
css = """
#col-container {max-width: 640px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
max-width: 15rem;
height: 36px;
}
div#share-btn-container > div {
flex-direction: row;
background: black;
align-items: center;
}
#share-btn-container:hover {
background-color: #060606;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor:pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.5rem !important;
padding-bottom: 0.5rem !important;
right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
#share-btn-container.hidden {
display: none!important;
}
img[src*='#center'] {
display: inline-block;
margin: unset;
}
.footer {
margin-bottom: 45px;
margin-top: 10px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center;">FreeNoise (Longer Text-to-Video)</h1>
<p style="text-align: center;">
FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling (ICLR 2024)
</p>
<p style="text-align: center;">
<a href="https://arxiv.org/abs/2310.15169" target="_blank"><b>[arXiv]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="http://haonanqiu.com/projects/FreeNoise.html" target="_blank"><b>[Project Page]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/AILab-CVC/FreeNoise" target="_blank"><b>[Code]</b></a>
</p>
"""
)
prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect")
with gr.Row():
with gr.Accordion('FreeNoise Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
with gr.Row():
output_size = gr.Dropdown(["320x512", "576x1024", "256x256"], value="320x512", label="Output Size", info="250s for 512 model, 900s for 1024 model (32 frames). Recovering from sleeping will take more time to download ckpt")
with gr.Row():
num_frames = gr.Slider(label='Frames (a multiple of 4), max 36 for 1024 model',
minimum=16,
maximum=64,
step=4,
value=32)
ddim_steps = gr.Slider(label='DDIM Steps',
minimum=5,
maximum=200,
step=1,
value=50)
with gr.Row():
unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale',
minimum=1.0,
maximum=20.0,
step=0.1,
value=12.0)
save_fps = gr.Slider(label='Save FPS',
minimum=1,
maximum=30,
step=1,
value=10)
with gr.Row():
seed = gr.Slider(label='Random Seed',
minimum=0,
maximum=10000,
step=1,
value=123)
submit_btn = gr.Button("Generate", variant='primary')
video_result = gr.Video(label="Video Output")
gr.Examples(examples=examples, inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps])
submit_btn.click(fn=infer,
inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps],
outputs=[video_result],
api_name="zrscp")
demo.queue(max_size=12).launch(show_api=True)