File size: 5,614 Bytes
5c4a11c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5f9b65
54ad002
5c4a11c
 
 
 
54ad002
 
5c4a11c
54ad002
4997010
 
54ad002
 
5c4a11c
d29f4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6320af
 
 
a92b0d1
b6320af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bfbba7
 
 
 
 
 
 
 
 
 
 
21c2481
 
b6320af
 
a92b0d1
 
5d778f1
b6320af
 
1bfbba7
5c4a11c
a92b0d1
 
 
b6320af
21c2481
a92b0d1
 
 
 
 
 
5c4a11c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import os
import time
import argparse
import yaml, math
from tqdm import trange
import torch
import numpy as np
from omegaconf import OmegaConf
import torch.distributed as dist
from pytorch_lightning import seed_everything

from lvdm.samplers.ddim import DDIMSampler
from lvdm.utils.common_utils import str2bool
from lvdm.utils.dist_utils import setup_dist, gather_data
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
from utils import load_model, get_conditions, make_model_input_shape, torch_to_np
from huggingface_hub import hf_hub_url, cached_download

config_path = "model_config.yaml"
config = OmegaConf.load(config_path)

REPO_ID = "RamAnanth1/videocrafter-text2video"
ckpt_path = cached_download(hf_hub_url(REPO_ID, 'model.ckpt'))
# # get model & sampler
model, _, _ = load_model(config, ckpt_path, 
                         inject_lora=False, 
                         lora_scale=None, 
                         )
ddim_sampler = DDIMSampler(model)

def sample_denoising_batch(model, noise_shape, condition, *args,
                           sample_type="ddim", sampler=None, 
                           ddim_steps=None, eta=None,
                           unconditional_guidance_scale=1.0, uc=None,
                           denoising_progress=False,
                           **kwargs,
                           ):

    assert(sampler is not None)
    assert(ddim_steps is not None)
    assert(eta is not None)
    ddim_sampler = sampler
    samples, _ = ddim_sampler.sample(S=ddim_steps,
                                     conditioning=condition,
                                     batch_size=noise_shape[0],
                                     shape=noise_shape[1:],
                                     verbose=denoising_progress,
                                     unconditional_guidance_scale=unconditional_guidance_scale,
                                     unconditional_conditioning=uc,
                                     eta=eta,
                                     **kwargs,
                                    )
    return samples
                               
@torch.no_grad()
def sample_text2video(model, prompt, n_samples, batch_size,
                      sample_type="ddim", sampler=None, 
                      ddim_steps=50, eta=1.0, cfg_scale=7.5, 
                      decode_frame_bs=1,
                      ddp=False, all_gather=True, 
                      batch_progress=True, show_denoising_progress=False,
                      ):
    # get cond vector
    assert(model.cond_stage_model is not None)
    cond_embd = get_conditions(prompt, model, batch_size)
    uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None

    # sample batches
    all_videos = []
    n_iter = math.ceil(n_samples / batch_size)
    iterator  = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter)
    for _ in iterator:
        noise_shape = make_model_input_shape(model, batch_size)
        samples_latent = sample_denoising_batch(model, noise_shape, cond_embd,
                                            sample_type=sample_type,
                                            sampler=sampler,
                                            ddim_steps=ddim_steps,
                                            eta=eta,
                                            unconditional_guidance_scale=cfg_scale, 
                                            uc=uncond_embd,
                                            denoising_progress=show_denoising_progress,
                                            )
        samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False)
        
        # gather samples from multiple gpus
        if ddp and all_gather:
            data_list = gather_data(samples, return_np=False)
            all_videos.extend([torch_to_np(data) for data in data_list])
        else:
            all_videos.append(torch_to_np(samples))
    
    all_videos = np.concatenate(all_videos, axis=0)
    assert(all_videos.shape[0] >= n_samples)
    return all_videos

def save_results(videos, 
                 save_name="results", save_fps=8, save_mp4=True, 
                 save_npz=False, save_mp4_sheet=False, save_jpg=False
                 ):
    
    save_subdir = os.path.join("videos")
    os.makedirs(save_subdir, exist_ok=True)
    for i in range(videos.shape[0]):
        npz_to_video_grid(videos[i:i+1,...], 
                          os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4"), 
                          fps=save_fps)
        
    return os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4")

def get_video(prompt):
    seed = 1000
    seed_everything(seed)
    samples = sample_text2video(model, prompt, n_samples = 1, batch_size = 1,
                          sampler=ddim_sampler,
                          )
    return save_results(samples)

title = 'Latent Video Diffusion Models'
DESCRIPTION  = '<p>This model can only be used for non-commercial purposes. To learn more about the model, take a look at the <a href="https://github.com/VideoCrafter/VideoCrafter" style="text-decoration: underline;" target="_blank">model card</a>.</p>'

prompt_inp = gr.Textbox(label = "Prompt")
result = gr.Video(label='Result')
iface = gr.Interface(fn=get_video, 
                     inputs=[prompt_inp], 
                     outputs=[result], 
                     title = title, 
                     description = DESCRIPTION,
                     examples = [["An astronaut riding a horse"]])
iface.launch()