Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import time | |
import argparse | |
import yaml, math | |
from tqdm import trange | |
import torch | |
import numpy as np | |
from omegaconf import OmegaConf | |
import torch.distributed as dist | |
from pytorch_lightning import seed_everything | |
from lvdm.samplers.ddim import DDIMSampler | |
from lvdm.utils.common_utils import str2bool | |
from lvdm.utils.dist_utils import setup_dist, gather_data | |
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d | |
from utils import load_model, get_conditions, make_model_input_shape, torch_to_np | |
from huggingface_hub import hf_hub_url, cached_download | |
config_path = "model_config.yaml" | |
config = OmegaConf.load(config_path) | |
REPO_ID = "RamAnanth1/videocrafter-text2video" | |
ckpt_path = cached_download(hf_hub_url(REPO_ID, 'model.ckpt')) | |
# # get model & sampler | |
model, _, _ = load_model(config, ckpt_path, | |
inject_lora=False, | |
lora_scale=None, | |
) | |
ddim_sampler = DDIMSampler(model) | |
def sample_text2video(model, prompt, n_samples, batch_size, | |
sample_type="ddim", sampler=None, | |
ddim_steps=50, eta=1.0, cfg_scale=15.0, | |
decode_frame_bs=1, | |
ddp=False, all_gather=True, | |
batch_progress=True, show_denoising_progress=False, | |
): | |
# get cond vector | |
assert(model.cond_stage_model is not None) | |
cond_embd = get_conditions(prompt, model, batch_size) | |
uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None | |
# sample batches | |
all_videos = [] | |
n_iter = math.ceil(n_samples / batch_size) | |
iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter) | |
for _ in iterator: | |
noise_shape = make_model_input_shape(model, batch_size) | |
samples_latent = sample_denoising_batch(model, noise_shape, cond_embd, | |
sample_type=sample_type, | |
sampler=sampler, | |
ddim_steps=ddim_steps, | |
eta=eta, | |
unconditional_guidance_scale=cfg_scale, | |
uc=uncond_embd, | |
denoising_progress=show_denoising_progress, | |
) | |
samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False) | |
# gather samples from multiple gpus | |
if ddp and all_gather: | |
data_list = gather_data(samples, return_np=False) | |
all_videos.extend([torch_to_np(data) for data in data_list]) | |
else: | |
all_videos.append(torch_to_np(samples)) | |
all_videos = np.concatenate(all_videos, axis=0) | |
assert(all_videos.shape[0] >= n_samples) | |
return all_videos | |
def get_video(prompt): | |
samples = sample_text2video(model, prompt, n_samples = 2, batch_size = 1, | |
sampler=ddim_sampler, | |
) | |
return "Hello " + name + "!!" | |
prompt_inp = gr.Textbox(label = "Prompt") | |
iface = gr.Interface(fn=get_video, inputs=[prompt_inp], outputs="text") | |
iface.launch() | |