Spaces:
Runtime error
Runtime error
RamAnanth1
commited on
Commit
β’
b6320af
1
Parent(s):
4997010
Update app.py
Browse files
app.py
CHANGED
@@ -29,8 +29,54 @@ model, _, _ = load_model(config, ckpt_path,
|
|
29 |
)
|
30 |
ddim_sampler = DDIMSampler(model)
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
return "Hello " + name + "!!"
|
34 |
|
35 |
-
|
|
|
36 |
iface.launch()
|
|
|
29 |
)
|
30 |
ddim_sampler = DDIMSampler(model)
|
31 |
|
32 |
+
@torch.no_grad()
|
33 |
+
def sample_text2video(model, prompt, n_samples, batch_size,
|
34 |
+
sample_type="ddim", sampler=None,
|
35 |
+
ddim_steps=50, eta=1.0, cfg_scale=15.0,
|
36 |
+
decode_frame_bs=1,
|
37 |
+
ddp=False, all_gather=True,
|
38 |
+
batch_progress=True, show_denoising_progress=False,
|
39 |
+
):
|
40 |
+
# get cond vector
|
41 |
+
assert(model.cond_stage_model is not None)
|
42 |
+
cond_embd = get_conditions(prompt, model, batch_size)
|
43 |
+
uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None
|
44 |
+
|
45 |
+
# sample batches
|
46 |
+
all_videos = []
|
47 |
+
n_iter = math.ceil(n_samples / batch_size)
|
48 |
+
iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter)
|
49 |
+
for _ in iterator:
|
50 |
+
noise_shape = make_model_input_shape(model, batch_size)
|
51 |
+
samples_latent = sample_denoising_batch(model, noise_shape, cond_embd,
|
52 |
+
sample_type=sample_type,
|
53 |
+
sampler=sampler,
|
54 |
+
ddim_steps=ddim_steps,
|
55 |
+
eta=eta,
|
56 |
+
unconditional_guidance_scale=cfg_scale,
|
57 |
+
uc=uncond_embd,
|
58 |
+
denoising_progress=show_denoising_progress,
|
59 |
+
)
|
60 |
+
samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False)
|
61 |
+
|
62 |
+
# gather samples from multiple gpus
|
63 |
+
if ddp and all_gather:
|
64 |
+
data_list = gather_data(samples, return_np=False)
|
65 |
+
all_videos.extend([torch_to_np(data) for data in data_list])
|
66 |
+
else:
|
67 |
+
all_videos.append(torch_to_np(samples))
|
68 |
+
|
69 |
+
all_videos = np.concatenate(all_videos, axis=0)
|
70 |
+
assert(all_videos.shape[0] >= n_samples)
|
71 |
+
return all_videos
|
72 |
+
|
73 |
+
|
74 |
+
def get_video(prompt):
|
75 |
+
samples = sample_text2video(model, prompt, n_samples = 2, batch_size = 1,
|
76 |
+
sampler=ddim_sampler,
|
77 |
+
)
|
78 |
return "Hello " + name + "!!"
|
79 |
|
80 |
+
prompt_inp = gr.Textbox(label = "Prompt")
|
81 |
+
iface = gr.Interface(fn=get_video, [prompt_inp], outputs="text")
|
82 |
iface.launch()
|