Spaces:
Runtime error
Runtime error
added: basic commands
Browse files- app.py +219 -4
- requirements.txt +0 -0
app.py
CHANGED
@@ -1,8 +1,223 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
iface = gr.Interface(
|
7 |
iface.launch()
|
8 |
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
os.system("git clone https://github.com/v-iashin/SpecVQGAN")
|
4 |
+
os.system("pip install pytorch-lightning==1.2.10 omegaconf==2.0.6 streamlit==0.80 matplotlib==3.4.1 albumentations==0.5.2 SoundFile torch torchvision librosa gdown")
|
5 |
+
|
6 |
+
|
7 |
+
# %%
|
8 |
+
|
9 |
+
import sys
|
10 |
+
sys.path.append('./SpecVQGAN')
|
11 |
+
import time
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import IPython.display as display_audio
|
15 |
+
import soundfile
|
16 |
+
import torch
|
17 |
+
from IPython import display
|
18 |
+
from matplotlib import pyplot as plt
|
19 |
+
from torch.utils.data.dataloader import default_collate
|
20 |
+
from torchvision.utils import make_grid
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from feature_extraction.demo_utils import (ExtractResNet50, check_video_for_audio,
|
24 |
+
extract_melspectrogram, load_model,
|
25 |
+
show_grid, trim_video)
|
26 |
+
from sample_visualization import (all_attention_to_st, get_class_preditions,
|
27 |
+
last_attention_to_st, spec_to_audio_to_st,
|
28 |
+
tensor_to_plt)
|
29 |
+
from specvqgan.data.vggsound import CropImage
|
30 |
+
|
31 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
+
|
33 |
+
# load model
|
34 |
+
model_name = '2021-07-30T21-34-25_vggsound_transformer'
|
35 |
+
log_dir = './logs'
|
36 |
+
os.chdir("./SpecVQGAN/")
|
37 |
+
config, sampler, melgan, melception = load_model(model_name, log_dir, device)
|
38 |
+
# %%
|
39 |
+
|
40 |
+
def extract_thumbnails(video_path):
|
41 |
+
# Trim the video
|
42 |
+
start_sec = 0 # to start with 01:35 use 95 seconds
|
43 |
+
video_path = trim_video(video_path, start_sec, trim_duration=10)
|
44 |
+
|
45 |
+
# Extract Features
|
46 |
+
extraction_fps = 21.5
|
47 |
+
feature_extractor = ExtractResNet50(extraction_fps, config.data.params, device)
|
48 |
+
visual_features, resampled_frames = feature_extractor(video_path)
|
49 |
+
|
50 |
+
# Show the selected frames to extract features for
|
51 |
+
if not config.data.params.replace_feats_with_random:
|
52 |
+
fig = show_grid(make_grid(resampled_frames))
|
53 |
+
fig.show()
|
54 |
+
|
55 |
+
# Prepare Input
|
56 |
+
batch = default_collate([visual_features])
|
57 |
+
batch['feature'] = batch['feature'].to(device)
|
58 |
+
c = sampler.get_input(sampler.cond_stage_key, batch)
|
59 |
+
return c, video_path
|
60 |
+
|
61 |
+
# %%
|
62 |
+
import numpy as np
|
63 |
+
|
64 |
+
def generate_audio(video_path, temperature = 1.0):
|
65 |
+
# Define Sampling Parameters
|
66 |
+
W_scale = 1
|
67 |
+
mode = 'full'
|
68 |
+
top_x = sampler.first_stage_model.quantize.n_e // 2
|
69 |
+
update_every = 0 # use > 0 value, e.g. 15, to see the progress of generation (slows down the sampling speed)
|
70 |
+
full_att_mat = True
|
71 |
+
|
72 |
+
c, video_path = extract_thumbnails(video_path)
|
73 |
+
|
74 |
+
# Start sampling
|
75 |
+
with torch.no_grad():
|
76 |
+
start_t = time.time()
|
77 |
+
|
78 |
+
quant_c, c_indices = sampler.encode_to_c(c)
|
79 |
+
# crec = sampler.cond_stage_model.decode(quant_c)
|
80 |
+
|
81 |
+
patch_size_i = 5
|
82 |
+
patch_size_j = 53
|
83 |
+
|
84 |
+
B, D, hr_h, hr_w = sampling_shape = (1, 256, 5, 53*W_scale)
|
85 |
+
|
86 |
+
z_pred_indices = torch.zeros((B, hr_h*hr_w)).long().to(device)
|
87 |
+
|
88 |
+
if mode == 'full':
|
89 |
+
start_step = 0
|
90 |
+
else:
|
91 |
+
start_step = (patch_size_j // 2) * patch_size_i
|
92 |
+
z_pred_indices[:, :start_step] = z_indices[:, :start_step]
|
93 |
+
|
94 |
+
pbar = tqdm(range(start_step, hr_w * hr_h), desc='Sampling Codebook Indices')
|
95 |
+
for step in pbar:
|
96 |
+
i = step % hr_h
|
97 |
+
j = step // hr_h
|
98 |
|
99 |
+
i_start = min(max(0, i - (patch_size_i // 2)), hr_h - patch_size_i)
|
100 |
+
j_start = min(max(0, j - (patch_size_j // 2)), hr_w - patch_size_j)
|
101 |
+
i_end = i_start + patch_size_i
|
102 |
+
j_end = j_start + patch_size_j
|
103 |
+
|
104 |
+
local_i = i - i_start
|
105 |
+
local_j = j - j_start
|
106 |
+
|
107 |
+
patch_2d_shape = (B, D, patch_size_i, patch_size_j)
|
108 |
+
|
109 |
+
pbar.set_postfix(
|
110 |
+
Step=f'({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})'
|
111 |
+
)
|
112 |
+
|
113 |
+
patch = z_pred_indices \
|
114 |
+
.reshape(B, hr_w, hr_h) \
|
115 |
+
.permute(0, 2, 1)[:, i_start:i_end, j_start:j_end].permute(0, 2, 1) \
|
116 |
+
.reshape(B, patch_size_i * patch_size_j)
|
117 |
+
|
118 |
+
# assuming we don't crop the conditioning and just use the whole c, if not desired uncomment the above
|
119 |
+
cpatch = c_indices
|
120 |
+
logits, _, attention = sampler.transformer(patch[:, :-1], cpatch)
|
121 |
+
# remove conditioning
|
122 |
+
logits = logits[:, -patch_size_j*patch_size_i:, :]
|
123 |
+
|
124 |
+
local_pos_in_flat = local_j * patch_size_i + local_i
|
125 |
+
logits = logits[:, local_pos_in_flat, :]
|
126 |
+
|
127 |
+
logits = logits / temperature
|
128 |
+
logits = sampler.top_k_logits(logits, top_x)
|
129 |
+
|
130 |
+
# apply softmax to convert to probabilities
|
131 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
132 |
+
|
133 |
+
# sample from the distribution
|
134 |
+
ix = torch.multinomial(probs, num_samples=1)
|
135 |
+
z_pred_indices[:, j * hr_h + i] = ix
|
136 |
+
|
137 |
+
if update_every > 0 and step % update_every == 0:
|
138 |
+
z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape)
|
139 |
+
# fliping the spectrogram just for illustration purposes (low freqs to bottom, high - top)
|
140 |
+
z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,))
|
141 |
+
display.clear_output(wait=True)
|
142 |
+
display.display(z_pred_img_st)
|
143 |
+
|
144 |
+
if full_att_mat:
|
145 |
+
att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True)
|
146 |
+
display.display(att_plot)
|
147 |
+
plt.close()
|
148 |
+
else:
|
149 |
+
quant_z_shape = sampling_shape
|
150 |
+
c_length = cpatch.shape[-1]
|
151 |
+
quant_c_shape = quant_c.shape
|
152 |
+
c_att_plot, z_att_plot = last_attention_to_st(
|
153 |
+
attention, local_pos_in_flat, c_length, sampler.first_stage_permuter,
|
154 |
+
sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape,
|
155 |
+
placeholders=None, flip_c_dims=None, flip_z_dims=(2,))
|
156 |
+
display.display(c_att_plot)
|
157 |
+
display.display(z_att_plot)
|
158 |
+
plt.close()
|
159 |
+
plt.close()
|
160 |
+
plt.close()
|
161 |
+
|
162 |
+
# quant_z_shape = sampling_shape
|
163 |
+
z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape)
|
164 |
+
|
165 |
+
# showing the final image
|
166 |
+
z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,))
|
167 |
+
display.clear_output(wait=True)
|
168 |
+
display.display(z_pred_img_st)
|
169 |
+
|
170 |
+
if full_att_mat:
|
171 |
+
att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True)
|
172 |
+
display.display(att_plot)
|
173 |
+
plt.close()
|
174 |
+
else:
|
175 |
+
quant_z_shape = sampling_shape
|
176 |
+
c_length = cpatch.shape[-1]
|
177 |
+
quant_c_shape = quant_c.shape
|
178 |
+
c_att_plot, z_att_plot = last_attention_to_st(
|
179 |
+
attention, local_pos_in_flat, c_length, sampler.first_stage_permuter,
|
180 |
+
sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape,
|
181 |
+
placeholders=None, flip_c_dims=None, flip_z_dims=(2,)
|
182 |
+
)
|
183 |
+
display.display(c_att_plot)
|
184 |
+
display.display(z_att_plot)
|
185 |
+
plt.close()
|
186 |
+
plt.close()
|
187 |
+
plt.close()
|
188 |
+
|
189 |
+
print(f'Sampling Time: {time.time() - start_t:3.2f} seconds')
|
190 |
+
waves = spec_to_audio_to_st(z_pred_img, config.data.params.spec_dir_path,
|
191 |
+
config.data.params.sample_rate, show_griffin_lim=False,
|
192 |
+
vocoder=melgan, show_in_st=False)
|
193 |
+
print(f'Sampling Time (with vocoder): {time.time() - start_t:3.2f} seconds')
|
194 |
+
print(f'Generated: {len(waves["vocoder"]) / config.data.params.sample_rate:.2f} seconds')
|
195 |
+
|
196 |
+
# Melception opinion on the class distribution of the generated sample
|
197 |
+
topk_preds = get_class_preditions(z_pred_img, melception)
|
198 |
+
print(topk_preds)
|
199 |
+
|
200 |
+
audio_path = os.path.join(log_dir, Path(video_path).stem + '.wav')
|
201 |
+
audio = waves['vocoder']
|
202 |
+
audio = np.repeat([audio], 2, axis=0).T
|
203 |
+
print(audio.shape)
|
204 |
+
soundfile.write(audio_path, audio, config.data.params.sample_rate, 'PCM_24')
|
205 |
+
print(f'The sample has been saved @ {audio_path}')
|
206 |
+
|
207 |
+
|
208 |
+
video_out_path = os.path.join(log_dir, Path(video_path).stem + '_audio.mp4')
|
209 |
+
print(video_path, audio_path, video_out_path)
|
210 |
+
os.system("ffmpeg -i %s -i %s -map 0:v -map 1:a -c:v copy -shortest %s" % (video_path, audio_path, video_out_path))
|
211 |
+
|
212 |
+
return video_out_path
|
213 |
+
# return config.data.params.sample_rate, audio
|
214 |
+
|
215 |
+
# %%
|
216 |
+
generate_audio("../kiss.avi")
|
217 |
+
#%%
|
218 |
+
import gradio as gr
|
219 |
|
220 |
+
iface = gr.Interface(generate_audio, "video", "playable_video")
|
221 |
iface.launch()
|
222 |
|
223 |
+
# %%
|
requirements.txt
ADDED
File without changes
|