Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import sys | |
import argparse | |
import random | |
from omegaconf import OmegaConf | |
import torch | |
import torchvision | |
from pytorch_lightning import seed_everything | |
from huggingface_hub import hf_hub_download | |
sys.path.insert(0, "scripts/evaluation") | |
from funcs import ( | |
batch_ddim_sampling_freenoise, | |
load_model_checkpoint, | |
) | |
from utils.utils import instantiate_from_config | |
ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt" | |
config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml" | |
# hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_path_1024) | |
ckpt_path_256 = "checkpoints/base_256_v1/model.pth" | |
config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml" | |
hf_hub_download(repo_id="MoonQiu/LongerCrafter", filename="model.pth", local_dir=ckpt_path_256) | |
def infer(prompt): | |
output_size = "256x256" | |
num_frames = 32 | |
ddim_steps = 50 | |
unconditional_guidance_scale = 12.0 | |
seed = 123 | |
save_fps = 10 | |
window_size = 16 | |
window_stride = 4 | |
if output_size == "576x1024": | |
width = 1024 | |
height = 576 | |
config_1024 = OmegaConf.load(config_1024) | |
model_config_1024 = config_1024.pop("model", OmegaConf.create()) | |
model_1024 = instantiate_from_config(model_config_1024) | |
# model_1024 = model_1024.cuda() | |
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024) | |
model_1024.eval() | |
model = model_1024 | |
fps = 24 | |
elif output_size == "256x256": | |
width = 256 | |
height = 256 | |
config_256 = OmegaConf.load(config_256) | |
model_config_256 = config_256.pop("model", OmegaConf.create()) | |
model_256 = instantiate_from_config(model_config_256) | |
# model_256 = model_256.cuda() | |
model_256 = load_model_checkpoint(model_256, ckpt_path_256) | |
model_256.eval() | |
model = model_256 | |
fps = 8 | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
print(f"Using seed: {seed}") | |
seed_everything(seed) | |
args = argparse.Namespace( | |
mode="base", | |
savefps=save_fps, | |
n_samples=1, | |
ddim_steps=ddim_steps, | |
ddim_eta=0.0, | |
bs=1, | |
height=height, | |
width=width, | |
frames=num_frames, | |
fps=fps, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_guidance_scale_temporal=None, | |
cond_input=None, | |
window_size=window_size, | |
window_stride=window_stride, | |
) | |
## latent noise shape | |
h, w = args.height // 8, args.width // 8 | |
frames = model.temporal_length if args.frames < 0 else args.frames | |
channels = model.channels | |
x_T_total = torch.randn( | |
[args.n_samples, 1, channels, frames, h, w], device=model.device | |
).repeat(1, args.bs, 1, 1, 1, 1) | |
for frame_index in range(args.window_size, args.frames, args.window_stride): | |
list_index = list( | |
range( | |
frame_index - args.window_size, | |
frame_index + args.window_stride - args.window_size, | |
) | |
) | |
random.shuffle(list_index) | |
x_T_total[ | |
:, :, :, frame_index : frame_index + args.window_stride | |
] = x_T_total[:, :, :, list_index] | |
batch_size = 1 | |
noise_shape = [batch_size, channels, frames, h, w] | |
fps = torch.tensor([args.fps] * batch_size).to(model.device).long() | |
prompts = [prompt] | |
text_emb = model.get_learned_conditioning(prompts) | |
cond = {"c_crossattn": [text_emb], "fps": fps} | |
## inference | |
batch_samples = batch_ddim_sampling_freenoise( | |
model, | |
cond, | |
noise_shape, | |
args.n_samples, | |
args.ddim_steps, | |
args.ddim_eta, | |
args.unconditional_guidance_scale, | |
args=args, | |
x_T_total=x_T_total, | |
) | |
video_path = "/tmp/output.mp4" | |
vid_tensor = batch_samples[0] | |
video = vid_tensor.detach().cpu() | |
video = torch.clamp(video.float(), -1.0, 1.0) | |
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w | |
frame_grids = [ | |
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) | |
for framesheet in video | |
] # [3, 1*h, n*w] | |
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] | |
grid = (grid + 1.0) / 2.0 | |
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) | |
torchvision.io.write_video( | |
video_path, | |
grid, | |
fps=args.savefps, | |
video_codec="h264", | |
options={"crf": "10"}, | |
) | |
print(video_path) | |
return video_path, gr.Group.update(visible=True) | |
css = """ | |
#col-container {max-width: 510px; margin-left: auto; margin-right: auto;} | |
a {text-decoration-line: underline; font-weight: 600;} | |
.animate-spin { | |
animation: spin 1s linear infinite; | |
} | |
@keyframes spin { | |
from { | |
transform: rotate(0deg); | |
} | |
to { | |
transform: rotate(360deg); | |
} | |
} | |
#share-btn-container { | |
display: flex; | |
padding-left: 0.5rem !important; | |
padding-right: 0.5rem !important; | |
background-color: #000000; | |
justify-content: center; | |
align-items: center; | |
border-radius: 9999px !important; | |
max-width: 15rem; | |
height: 36px; | |
} | |
div#share-btn-container > div { | |
flex-direction: row; | |
background: black; | |
align-items: center; | |
} | |
#share-btn-container:hover { | |
background-color: #060606; | |
} | |
#share-btn { | |
all: initial; | |
color: #ffffff; | |
font-weight: 600; | |
cursor:pointer; | |
font-family: 'IBM Plex Sans', sans-serif; | |
margin-left: 0.5rem !important; | |
padding-top: 0.5rem !important; | |
padding-bottom: 0.5rem !important; | |
right:0; | |
} | |
#share-btn * { | |
all: unset; | |
} | |
#share-btn-container div:nth-child(-n+2){ | |
width: auto !important; | |
min-height: 0px !important; | |
} | |
#share-btn-container .wrap { | |
display: none !important; | |
} | |
#share-btn-container.hidden { | |
display: none!important; | |
} | |
img[src*='#center'] { | |
display: inline-block; | |
margin: unset; | |
} | |
.footer { | |
margin-bottom: 45px; | |
margin-top: 10px; | |
text-align: center; | |
border-bottom: 1px solid #e5e5e5; | |
} | |
.footer>p { | |
font-size: .8rem; | |
display: inline-block; | |
padding: 0 10px; | |
transform: translateY(10px); | |
background: white; | |
} | |
.dark .footer { | |
border-color: #303030; | |
} | |
.dark .footer>p { | |
background: #0b0f19; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
""" | |
<h1 style="text-align: center;">LongerCrafter(FreeNoise) Text-to-Video</h1> | |
<p style="text-align: center;"> | |
Tuning-Free Longer Video Diffusion via Noise Rescheduling <br /> | |
</p> | |
""" | |
) | |
prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect", elem_id="prompt-in") | |
#neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in") | |
#inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False) | |
submit_btn = gr.Button("Submit") | |
video_result = gr.Video(label="Video Output", elem_id="video-output") | |
submit_btn.click(fn=infer, | |
inputs=[prompt_in], | |
outputs=[video_result], | |
api_name="zrscp") | |
demo.queue(max_size=12).launch(show_api=True) |