videocrafter / app.py
RamAnanth1's picture
Update app.py
b6320af
raw
history blame
No virus
3.35 kB
import gradio as gr
import os
import time
import argparse
import yaml, math
from tqdm import trange
import torch
import numpy as np
from omegaconf import OmegaConf
import torch.distributed as dist
from pytorch_lightning import seed_everything
from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import str2bool
from lvdm.utils.dist_utils import setup_dist, gather_data
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
from utils import load_model, get_conditions, make_model_input_shape, torch_to_np
from huggingface_hub import hf_hub_url, cached_download
config_path = "model_config.yaml"
config = OmegaConf.load(config_path)
REPO_ID = "RamAnanth1/videocrafter-text2video"
ckpt_path = cached_download(hf_hub_url(REPO_ID, 'model.ckpt'))
# # get model & sampler
model, _, _ = load_model(config, ckpt_path,
inject_lora=False,
lora_scale=None,
)
ddim_sampler = DDIMSampler(model)
@torch.no_grad()
def sample_text2video(model, prompt, n_samples, batch_size,
sample_type="ddim", sampler=None,
ddim_steps=50, eta=1.0, cfg_scale=15.0,
decode_frame_bs=1,
ddp=False, all_gather=True,
batch_progress=True, show_denoising_progress=False,
):
# get cond vector
assert(model.cond_stage_model is not None)
cond_embd = get_conditions(prompt, model, batch_size)
uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None
# sample batches
all_videos = []
n_iter = math.ceil(n_samples / batch_size)
iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter)
for _ in iterator:
noise_shape = make_model_input_shape(model, batch_size)
samples_latent = sample_denoising_batch(model, noise_shape, cond_embd,
sample_type=sample_type,
sampler=sampler,
ddim_steps=ddim_steps,
eta=eta,
unconditional_guidance_scale=cfg_scale,
uc=uncond_embd,
denoising_progress=show_denoising_progress,
)
samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False)
# gather samples from multiple gpus
if ddp and all_gather:
data_list = gather_data(samples, return_np=False)
all_videos.extend([torch_to_np(data) for data in data_list])
else:
all_videos.append(torch_to_np(samples))
all_videos = np.concatenate(all_videos, axis=0)
assert(all_videos.shape[0] >= n_samples)
return all_videos
def get_video(prompt):
samples = sample_text2video(model, prompt, n_samples = 2, batch_size = 1,
sampler=ddim_sampler,
)
return "Hello " + name + "!!"
prompt_inp = gr.Textbox(label = "Prompt")
iface = gr.Interface(fn=get_video, [prompt_inp], outputs="text")
iface.launch()