RamAnanth1 commited on
Commit
d8df719
β€’
1 Parent(s): 0e8aba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py CHANGED
@@ -16,6 +16,7 @@ from lvdm.utils.dist_utils import setup_dist, gather_data
16
  from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
17
  from utils import load_model, get_conditions, make_model_input_shape, torch_to_np
18
  from lvdm.models.modules.lora import change_lora
 
19
 
20
  from huggingface_hub import hf_hub_download
21
 
@@ -110,6 +111,50 @@ def sample_text2video(model, prompt, n_samples, batch_size,
110
  assert(all_videos.shape[0] >= n_samples)
111
  return all_videos
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def save_results(videos,
114
  save_name="results", save_fps=8, save_mp4=True,
115
  save_npz=False, save_mp4_sheet=False, save_jpg=False
@@ -124,6 +169,9 @@ def save_results(videos,
124
 
125
  return os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4")
126
 
 
 
 
127
  def get_video(prompt, seed, ddim_steps):
128
  seed_everything(seed)
129
  samples = sample_text2video(model, prompt, n_samples = 1, batch_size = 1,
@@ -156,6 +204,14 @@ def get_video_lora(prompt, seed, ddim_steps, model_choice):
156
  )
157
  return save_results(samples)
158
 
 
 
 
 
 
 
 
 
159
 
160
  from gradio_t2v import create_demo as create_demo_basic
161
  from gradio_videolora import create_demo as create_demo_videolora
@@ -170,6 +226,8 @@ with gr.Blocks(css='style.css') as demo:
170
  create_demo_basic(get_video)
171
  with gr.TabItem('VideoLoRA'):
172
  create_demo_videolora(get_video_lora)
 
 
173
 
174
  demo.queue(api_open=False).launch()
175
 
 
16
  from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
17
  from utils import load_model, get_conditions, make_model_input_shape, torch_to_np
18
  from lvdm.models.modules.lora import change_lora
19
+ from lvdm.utils.saving_utils import tensor_to_mp4
20
 
21
  from huggingface_hub import hf_hub_download
22
 
 
111
  assert(all_videos.shape[0] >= n_samples)
112
  return all_videos
113
 
114
+ def adapter_guided_synthesis(model, prompts, videos, noise_shape, sampler, n_samples=1, ddim_steps=50, ddim_eta=1., \
115
+ unconditional_guidance_scale=1.0, unconditional_guidance_scale_temporal=None, **kwargs):
116
+ ddim_sampler = sampler
117
+
118
+ batch_size = noise_shape[0]
119
+ ## get condition embeddings (support single prompt only)
120
+ if isinstance(prompts, str):
121
+ prompts = [prompts]
122
+ cond = model.get_learned_conditioning(prompts)
123
+ if unconditional_guidance_scale != 1.0:
124
+ prompts = batch_size * [""]
125
+ uc = model.get_learned_conditioning(prompts)
126
+ else:
127
+ uc = None
128
+
129
+ ## adapter features: process in 2D manner
130
+ b, c, t, h, w = videos.shape
131
+ extra_cond = model.get_batch_depth(videos, (h,w))
132
+ features_adapter = model.get_adapter_features(extra_cond)
133
+
134
+ batch_variants = []
135
+ for _ in range(n_samples):
136
+ if ddim_sampler is not None:
137
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
138
+ conditioning=cond,
139
+ batch_size=noise_shape[0],
140
+ shape=noise_shape[1:],
141
+ verbose=False,
142
+ unconditional_guidance_scale=unconditional_guidance_scale,
143
+ unconditional_conditioning=uc,
144
+ eta=ddim_eta,
145
+ temporal_length=noise_shape[2],
146
+ conditional_guidance_scale_temporal=unconditional_guidance_scale_temporal,
147
+ features_adapter=features_adapter,
148
+ **kwargs
149
+ )
150
+ ## reconstruct from latent to pixel space
151
+ batch_images = model.decode_first_stage(samples, decode_bs=1, return_cpu=False)
152
+ batch_variants.append(batch_images)
153
+ ## variants, batch, c, t, h, w
154
+ batch_variants = torch.stack(batch_variants)
155
+ return batch_variants.permute(1, 0, 2, 3, 4, 5), extra_cond
156
+
157
+
158
  def save_results(videos,
159
  save_name="results", save_fps=8, save_mp4=True,
160
  save_npz=False, save_mp4_sheet=False, save_jpg=False
 
169
 
170
  return os.path.join(save_subdir, f"{save_name}_{i:03d}.mp4")
171
 
172
+ def save_results_control(batch_samples, batch_conds):
173
+ return
174
+
175
  def get_video(prompt, seed, ddim_steps):
176
  seed_everything(seed)
177
  samples = sample_text2video(model, prompt, n_samples = 1, batch_size = 1,
 
204
  )
205
  return save_results(samples)
206
 
207
+ def get_video_control(prompt, input_video, seed, ddim_steps):
208
+ seed_everything(seed)
209
+ h,w = 512//8, 512//8
210
+ noise_shape = [1, model.channels, model.temporal_length,h,w]
211
+ batch_samples, batch_conds = adapter_guided_synthesis(model, prompt,input_video,noise_shape, sampler=ddim_sampler, n_samples = 1,
212
+ ddim_steps=ddim_steps
213
+ )
214
+ return save_results_control(batch_samples, batch_conds)
215
 
216
  from gradio_t2v import create_demo as create_demo_basic
217
  from gradio_videolora import create_demo as create_demo_videolora
 
226
  create_demo_basic(get_video)
227
  with gr.TabItem('VideoLoRA'):
228
  create_demo_videolora(get_video_lora)
229
+ with gr.TabItem('VideoControl'):
230
+ create_demo_videolora(get_video_control)
231
 
232
  demo.queue(api_open=False).launch()
233