RamAnanth1 commited on
Commit
b6320af
β€’
1 Parent(s): 4997010

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -2
app.py CHANGED
@@ -29,8 +29,54 @@ model, _, _ = load_model(config, ckpt_path,
29
  )
30
  ddim_sampler = DDIMSampler(model)
31
 
32
- def greet(name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return "Hello " + name + "!!"
34
 
35
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
36
  iface.launch()
 
29
  )
30
  ddim_sampler = DDIMSampler(model)
31
 
32
+ @torch.no_grad()
33
+ def sample_text2video(model, prompt, n_samples, batch_size,
34
+ sample_type="ddim", sampler=None,
35
+ ddim_steps=50, eta=1.0, cfg_scale=15.0,
36
+ decode_frame_bs=1,
37
+ ddp=False, all_gather=True,
38
+ batch_progress=True, show_denoising_progress=False,
39
+ ):
40
+ # get cond vector
41
+ assert(model.cond_stage_model is not None)
42
+ cond_embd = get_conditions(prompt, model, batch_size)
43
+ uncond_embd = get_conditions("", model, batch_size) if cfg_scale != 1.0 else None
44
+
45
+ # sample batches
46
+ all_videos = []
47
+ n_iter = math.ceil(n_samples / batch_size)
48
+ iterator = trange(n_iter, desc="Sampling Batches (text-to-video)") if batch_progress else range(n_iter)
49
+ for _ in iterator:
50
+ noise_shape = make_model_input_shape(model, batch_size)
51
+ samples_latent = sample_denoising_batch(model, noise_shape, cond_embd,
52
+ sample_type=sample_type,
53
+ sampler=sampler,
54
+ ddim_steps=ddim_steps,
55
+ eta=eta,
56
+ unconditional_guidance_scale=cfg_scale,
57
+ uc=uncond_embd,
58
+ denoising_progress=show_denoising_progress,
59
+ )
60
+ samples = model.decode_first_stage(samples_latent, decode_bs=decode_frame_bs, return_cpu=False)
61
+
62
+ # gather samples from multiple gpus
63
+ if ddp and all_gather:
64
+ data_list = gather_data(samples, return_np=False)
65
+ all_videos.extend([torch_to_np(data) for data in data_list])
66
+ else:
67
+ all_videos.append(torch_to_np(samples))
68
+
69
+ all_videos = np.concatenate(all_videos, axis=0)
70
+ assert(all_videos.shape[0] >= n_samples)
71
+ return all_videos
72
+
73
+
74
+ def get_video(prompt):
75
+ samples = sample_text2video(model, prompt, n_samples = 2, batch_size = 1,
76
+ sampler=ddim_sampler,
77
+ )
78
  return "Hello " + name + "!!"
79
 
80
+ prompt_inp = gr.Textbox(label = "Prompt")
81
+ iface = gr.Interface(fn=get_video, [prompt_inp], outputs="text")
82
  iface.launch()