Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
•
ac059f4
1
Parent(s):
bcc3305
changes
Browse files- conf/{interface.yml → interface-jazzpop-exp.yml} +5 -1
- demo.py +220 -0
- scripts/exp/eval.py +11 -10
- scripts/utils/vamp_folder.py +30 -34
- setup.py +3 -4
- vampnet/gradio.py +0 -4
- vampnet/interface.py +3 -1
- vampnet/modules/base.py +3 -1
- vampnet/util.py +63 -0
conf/{interface.yml → interface-jazzpop-exp.yml}
RENAMED
@@ -2,4 +2,8 @@ Interface.coarse_ckpt: /runs/jazzpop-coarse-1m-steps.pth
|
|
2 |
Interface.coarse2fine_ckpt: /runs/jazzpop-c2f.pth
|
3 |
Interface.codec_ckpt: /runs/codec-ckpt/codec.pth
|
4 |
Interface.coarse_chunk_size_s: 5
|
5 |
-
Interface.coarse2fine_chunk_size_s: 3
|
|
|
|
|
|
|
|
|
|
2 |
Interface.coarse2fine_ckpt: /runs/jazzpop-c2f.pth
|
3 |
Interface.codec_ckpt: /runs/codec-ckpt/codec.pth
|
4 |
Interface.coarse_chunk_size_s: 5
|
5 |
+
Interface.coarse2fine_chunk_size_s: 3
|
6 |
+
|
7 |
+
AudioLoader.sources:
|
8 |
+
- /data/spotdl/audio/val
|
9 |
+
- /data/spotdl/audio/test
|
demo.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Tuple
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import audiotools as at
|
7 |
+
import argbind
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from vampnet.interface import Interface
|
11 |
+
|
12 |
+
conf = yaml.safe_load(Path("conf/interface-jazzpop-exp.yml").read_text())
|
13 |
+
|
14 |
+
Interface = argbind.bind(Interface)
|
15 |
+
AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
16 |
+
with argbind.scope(conf):
|
17 |
+
interface = Interface()
|
18 |
+
loader = AudioLoader()
|
19 |
+
|
20 |
+
dataset = at.data.datasets.AudioDataset(
|
21 |
+
loader,
|
22 |
+
sample_rate=interface.codec.sample_rate,
|
23 |
+
duration=interface.coarse.chunk_size_s,
|
24 |
+
n_examples=5000,
|
25 |
+
without_replacement=True,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def load_audio(file):
|
30 |
+
print(file)
|
31 |
+
filepath = file.name
|
32 |
+
sig = at.AudioSignal.salient_excerpt(
|
33 |
+
filepath,
|
34 |
+
duration=interface.coarse.chunk_size_s
|
35 |
+
)
|
36 |
+
sig = interface.preprocess(sig)
|
37 |
+
|
38 |
+
audio = sig.samples.numpy()[0]
|
39 |
+
sr = sig.sample_rate
|
40 |
+
return sr, audio.T
|
41 |
+
|
42 |
+
def load_random_audio():
|
43 |
+
index = np.random.randint(0, len(dataset))
|
44 |
+
sig = dataset[index]["signal"]
|
45 |
+
sig = interface.preprocess(sig)
|
46 |
+
|
47 |
+
audio = sig.samples.numpy()[0]
|
48 |
+
sr = sig.sample_rate
|
49 |
+
return sr, audio.T
|
50 |
+
|
51 |
+
def mask_audio(
|
52 |
+
prefix_s, suffix_s, rand_mask_intensity,
|
53 |
+
mask_periodic_amt, beat_unmask_dur,
|
54 |
+
mask_dwn_chk, dwn_factor,
|
55 |
+
mask_up_chk, up_factor
|
56 |
+
):
|
57 |
+
pass
|
58 |
+
|
59 |
+
def vamp(
|
60 |
+
input_audio, prefix_s, suffix_s, rand_mask_intensity,
|
61 |
+
mask_periodic_amt, beat_unmask_dur,
|
62 |
+
mask_dwn_chk, dwn_factor,
|
63 |
+
mask_up_chk, up_factor
|
64 |
+
):
|
65 |
+
print(input_audio)
|
66 |
+
|
67 |
+
|
68 |
+
with gr.Blocks() as demo:
|
69 |
+
|
70 |
+
gr.Markdown('# Vampnet')
|
71 |
+
|
72 |
+
with gr.Row():
|
73 |
+
# input audio
|
74 |
+
with gr.Column():
|
75 |
+
gr.Markdown("## Input Audio")
|
76 |
+
|
77 |
+
manual_audio_upload = gr.File(
|
78 |
+
label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
|
79 |
+
file_types=["audio"]
|
80 |
+
)
|
81 |
+
load_random_audio_button = gr.Button("or load random audio")
|
82 |
+
|
83 |
+
input_audio = gr.Audio(
|
84 |
+
label="input audio",
|
85 |
+
interactive=False,
|
86 |
+
)
|
87 |
+
input_audio_viz = gr.HTML(
|
88 |
+
label="input audio",
|
89 |
+
)
|
90 |
+
|
91 |
+
# connect widgets
|
92 |
+
load_random_audio_button.click(
|
93 |
+
fn=load_random_audio,
|
94 |
+
inputs=[],
|
95 |
+
outputs=[ input_audio]
|
96 |
+
)
|
97 |
+
|
98 |
+
manual_audio_upload.change(
|
99 |
+
fn=load_audio,
|
100 |
+
inputs=[manual_audio_upload],
|
101 |
+
outputs=[ input_audio]
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
# mask settings
|
106 |
+
with gr.Column():
|
107 |
+
gr.Markdown("## Mask Settings")
|
108 |
+
prefix_s = gr.Slider(
|
109 |
+
label="prefix length (seconds)",
|
110 |
+
minimum=0.0,
|
111 |
+
maximum=10.0,
|
112 |
+
value=0.0
|
113 |
+
)
|
114 |
+
suffix_s = gr.Slider(
|
115 |
+
label="suffix length (seconds)",
|
116 |
+
minimum=0.0,
|
117 |
+
maximum=10.0,
|
118 |
+
value=0.0
|
119 |
+
)
|
120 |
+
|
121 |
+
rand_mask_intensity = gr.Slider(
|
122 |
+
label="random mask intensity (lower means more freedom)",
|
123 |
+
minimum=0.0,
|
124 |
+
maximum=1.0,
|
125 |
+
value=1.0
|
126 |
+
)
|
127 |
+
|
128 |
+
mask_periodic_amt = gr.Slider(
|
129 |
+
label="periodic unmasking factor (higher means more freedom)",
|
130 |
+
minimum=0,
|
131 |
+
maximum=32,
|
132 |
+
step=1,
|
133 |
+
value=2,
|
134 |
+
)
|
135 |
+
compute_mask_button = gr.Button("compute mask")
|
136 |
+
mask_output = gr.Audio(
|
137 |
+
label="masked audio",
|
138 |
+
interactive=False,
|
139 |
+
visible=False
|
140 |
+
)
|
141 |
+
mask_output_viz = gr.Video(
|
142 |
+
label="masked audio",
|
143 |
+
interactive=False
|
144 |
+
)
|
145 |
+
|
146 |
+
with gr.Column():
|
147 |
+
gr.Markdown("## Beat Unmasking")
|
148 |
+
with gr.Accordion(label="beat unmask"):
|
149 |
+
beat_unmask_dur = gr.Slider(
|
150 |
+
label="duration",
|
151 |
+
minimum=0.0,
|
152 |
+
maximum=3.0,
|
153 |
+
value=0.1
|
154 |
+
)
|
155 |
+
with gr.Accordion("downbeat settings"):
|
156 |
+
mask_dwn_chk = gr.Checkbox(
|
157 |
+
label="unmask downbeats",
|
158 |
+
value=True
|
159 |
+
)
|
160 |
+
dwn_factor = gr.Slider(
|
161 |
+
label="downbeat downsample factor (unmask every Nth downbeat)",
|
162 |
+
value=1,
|
163 |
+
minimum=1,
|
164 |
+
maximum=16,
|
165 |
+
step=1
|
166 |
+
)
|
167 |
+
with gr.Accordion("upbeat settings"):
|
168 |
+
mask_up_chk = gr.Checkbox(
|
169 |
+
label="unmask upbeats",
|
170 |
+
value=True
|
171 |
+
)
|
172 |
+
up_factor = gr.Slider(
|
173 |
+
label="upbeat downsample factor (unmask every Nth upbeat)",
|
174 |
+
value=1,
|
175 |
+
minimum=1,
|
176 |
+
maximum=16,
|
177 |
+
step=1
|
178 |
+
)
|
179 |
+
|
180 |
+
# process and output
|
181 |
+
with gr.Row():
|
182 |
+
with gr.Column():
|
183 |
+
vamp_button = gr.Button("vamp")
|
184 |
+
|
185 |
+
output_audio = gr.Audio(
|
186 |
+
label="output audio",
|
187 |
+
interactive=False,
|
188 |
+
visible=False
|
189 |
+
)
|
190 |
+
output_audio_viz = gr.Video(
|
191 |
+
label="output audio",
|
192 |
+
interactive=False
|
193 |
+
)
|
194 |
+
|
195 |
+
# connect widgets
|
196 |
+
compute_mask_button.click(
|
197 |
+
fn=mask_audio,
|
198 |
+
inputs=[
|
199 |
+
prefix_s, suffix_s, rand_mask_intensity,
|
200 |
+
mask_periodic_amt, beat_unmask_dur,
|
201 |
+
mask_dwn_chk, dwn_factor,
|
202 |
+
mask_up_chk, up_factor
|
203 |
+
],
|
204 |
+
outputs=[mask_output, mask_output_viz]
|
205 |
+
)
|
206 |
+
|
207 |
+
# connect widgets
|
208 |
+
vamp_button.click(
|
209 |
+
fn=vamp,
|
210 |
+
inputs=[input_audio,
|
211 |
+
prefix_s, suffix_s, rand_mask_intensity,
|
212 |
+
mask_periodic_amt, beat_unmask_dur,
|
213 |
+
mask_dwn_chk, dwn_factor,
|
214 |
+
mask_up_chk, up_factor
|
215 |
+
],
|
216 |
+
outputs=[output_audio, output_audio_viz]
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
demo.launch(share=True)
|
scripts/exp/eval.py
CHANGED
@@ -57,30 +57,31 @@ def eval(
|
|
57 |
cond_files = cond_files[:num_files]
|
58 |
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
59 |
|
60 |
-
|
61 |
-
for baseline_file, cond_file in pbar:
|
62 |
# make sure the files match (same name)
|
63 |
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
64 |
-
pbar.set_description(baseline_file.stem)
|
65 |
|
66 |
# load the files
|
67 |
baseline_sig = AudioSignal(str(baseline_file))
|
68 |
cond_sig = AudioSignal(str(cond_file))
|
69 |
|
70 |
# compute the metrics
|
71 |
-
try:
|
72 |
-
|
73 |
-
except:
|
74 |
-
|
75 |
-
|
76 |
"sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
77 |
"stft": stft_loss(baseline_sig, cond_sig).item(),
|
78 |
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
79 |
"frechet": frechet_score,
|
80 |
-
"visqol": vsq,
|
81 |
"condition": condition,
|
82 |
"file": baseline_file.stem,
|
83 |
-
}
|
|
|
|
|
|
|
84 |
|
85 |
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
86 |
|
|
|
57 |
cond_files = cond_files[:num_files]
|
58 |
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
59 |
|
60 |
+
def process(baseline_file, cond_file):
|
|
|
61 |
# make sure the files match (same name)
|
62 |
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
|
|
63 |
|
64 |
# load the files
|
65 |
baseline_sig = AudioSignal(str(baseline_file))
|
66 |
cond_sig = AudioSignal(str(cond_file))
|
67 |
|
68 |
# compute the metrics
|
69 |
+
# try:
|
70 |
+
# vsq = visqol(baseline_sig, cond_sig)
|
71 |
+
# except:
|
72 |
+
# vsq = 0.0
|
73 |
+
return {
|
74 |
"sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
75 |
"stft": stft_loss(baseline_sig, cond_sig).item(),
|
76 |
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
77 |
"frechet": frechet_score,
|
78 |
+
# "visqol": vsq,
|
79 |
"condition": condition,
|
80 |
"file": baseline_file.stem,
|
81 |
+
}
|
82 |
+
|
83 |
+
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
|
84 |
+
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
|
85 |
|
86 |
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
87 |
|
scripts/utils/vamp_folder.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
from pathlib import Path
|
|
|
|
|
2 |
|
3 |
import argbind
|
4 |
from tqdm import tqdm
|
5 |
-
import
|
6 |
|
7 |
from vampnet.interface import Interface
|
8 |
import audiotools as at
|
@@ -12,9 +14,9 @@ Interface = argbind.bind(Interface)
|
|
12 |
# condition wrapper for printing
|
13 |
def condition(cond):
|
14 |
def wrapper(sig, interface):
|
15 |
-
print(f"Condition: {cond.__name__}")
|
16 |
sig = cond(sig, interface)
|
17 |
-
print(f"Condition: {cond.__name__} (done)\n")
|
18 |
return sig
|
19 |
return wrapper
|
20 |
|
@@ -49,48 +51,27 @@ def coarse2fine_argmax(sig, interface):
|
|
49 |
|
50 |
@condition
|
51 |
def one_codebook(sig, interface):
|
52 |
-
z = interface.encode(sig)
|
53 |
-
|
54 |
-
nb, _, nt = z.shape
|
55 |
-
nc = interface.coarse.n_codebooks
|
56 |
-
mask = torch.zeros(nb, nc, nt).to(interface.device)
|
57 |
-
mask[:, 1:, :] = 1
|
58 |
-
|
59 |
zv = interface.coarse_vamp_v2(
|
60 |
-
sig,
|
61 |
)
|
62 |
zv = interface.coarse_to_fine(zv)
|
63 |
|
64 |
return interface.to_signal(zv)
|
65 |
|
66 |
-
@condition
|
67 |
-
def four_codebooks_downsampled_4x(sig, interface):
|
68 |
-
zv = interface.coarse_vamp_v2(
|
69 |
-
sig, downsample_factor=4
|
70 |
-
)
|
71 |
-
zv = interface.coarse_to_fine(zv)
|
72 |
-
return interface.to_signal(zv)
|
73 |
-
|
74 |
@condition
|
75 |
def two_codebooks_downsampled_4x(sig, interface):
|
76 |
-
z = interface.encode(sig)
|
77 |
-
|
78 |
-
nb, _, nt = z.shape
|
79 |
-
nc = interface.coarse.n_codebooks
|
80 |
-
mask = torch.zeros(nb, nc, nt).to(interface.device)
|
81 |
-
mask[:, 2:, :] = 1
|
82 |
-
|
83 |
zv = interface.coarse_vamp_v2(
|
84 |
-
sig,
|
|
|
85 |
)
|
86 |
zv = interface.coarse_to_fine(zv)
|
87 |
|
88 |
return interface.to_signal(zv)
|
89 |
|
90 |
-
|
91 |
-
def
|
92 |
zv = interface.coarse_vamp_v2(
|
93 |
-
sig, downsample_factor=
|
94 |
)
|
95 |
zv = interface.coarse_to_fine(zv)
|
96 |
return interface.to_signal(zv)
|
@@ -101,9 +82,13 @@ COARSE_SAMPLE_CONDS ={
|
|
101 |
"reconstructed": reconstructed,
|
102 |
"coarse2fine": coarse2fine,
|
103 |
"one_codebook": one_codebook,
|
104 |
-
"four_codebooks_downsampled_4x": four_codebooks_downsampled_4x,
|
105 |
"two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
107 |
}
|
108 |
|
109 |
C2F_SAMPLE_CONDS = {
|
@@ -131,7 +116,7 @@ def main(
|
|
131 |
|
132 |
from audiotools.data.datasets import AudioLoader, AudioDataset
|
133 |
|
134 |
-
loader = AudioLoader(sources=sources)
|
135 |
dataset = AudioDataset(loader,
|
136 |
sample_rate=interface.codec.sample_rate,
|
137 |
duration=interface.coarse.chunk_size_s,
|
@@ -141,7 +126,18 @@ def main(
|
|
141 |
|
142 |
SAMPLE_CONDS = COARSE_SAMPLE_CONDS if exp_type == "coarse" else C2F_SAMPLE_CONDS
|
143 |
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
sig = dataset[i]["signal"]
|
146 |
|
147 |
results = {
|
|
|
1 |
from pathlib import Path
|
2 |
+
import random
|
3 |
+
from typing import List
|
4 |
|
5 |
import argbind
|
6 |
from tqdm import tqdm
|
7 |
+
import argbind
|
8 |
|
9 |
from vampnet.interface import Interface
|
10 |
import audiotools as at
|
|
|
14 |
# condition wrapper for printing
|
15 |
def condition(cond):
|
16 |
def wrapper(sig, interface):
|
17 |
+
# print(f"Condition: {cond.__name__}")
|
18 |
sig = cond(sig, interface)
|
19 |
+
# print(f"Condition: {cond.__name__} (done)\n")
|
20 |
return sig
|
21 |
return wrapper
|
22 |
|
|
|
51 |
|
52 |
@condition
|
53 |
def one_codebook(sig, interface):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
zv = interface.coarse_vamp_v2(
|
55 |
+
sig, n_conditioning_codebooks=1
|
56 |
)
|
57 |
zv = interface.coarse_to_fine(zv)
|
58 |
|
59 |
return interface.to_signal(zv)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
@condition
|
62 |
def two_codebooks_downsampled_4x(sig, interface):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
zv = interface.coarse_vamp_v2(
|
64 |
+
sig, n_conditioning_codebooks=2,
|
65 |
+
downsample_factor=4
|
66 |
)
|
67 |
zv = interface.coarse_to_fine(zv)
|
68 |
|
69 |
return interface.to_signal(zv)
|
70 |
|
71 |
+
|
72 |
+
def four_codebooks_downsampled(sig, interface, x=12):
|
73 |
zv = interface.coarse_vamp_v2(
|
74 |
+
sig, downsample_factor=12
|
75 |
)
|
76 |
zv = interface.coarse_to_fine(zv)
|
77 |
return interface.to_signal(zv)
|
|
|
82 |
"reconstructed": reconstructed,
|
83 |
"coarse2fine": coarse2fine,
|
84 |
"one_codebook": one_codebook,
|
|
|
85 |
"two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
|
86 |
+
# four codebooks at different downsample factors
|
87 |
+
**{
|
88 |
+
f"four_codebooks_downsampled_{x}x": lambda sig, interface: four_codebooks_downsampled(sig, interface, x=x)
|
89 |
+
for x in [4, 8, 12, 16, 20, 24]
|
90 |
+
}
|
91 |
+
|
92 |
}
|
93 |
|
94 |
C2F_SAMPLE_CONDS = {
|
|
|
116 |
|
117 |
from audiotools.data.datasets import AudioLoader, AudioDataset
|
118 |
|
119 |
+
loader = AudioLoader(sources=sources, shuffle_state=seed)
|
120 |
dataset = AudioDataset(loader,
|
121 |
sample_rate=interface.codec.sample_rate,
|
122 |
duration=interface.coarse.chunk_size_s,
|
|
|
126 |
|
127 |
SAMPLE_CONDS = COARSE_SAMPLE_CONDS if exp_type == "coarse" else C2F_SAMPLE_CONDS
|
128 |
|
129 |
+
|
130 |
+
indices = list(range(max_excerpts))
|
131 |
+
random.shuffle(indices)
|
132 |
+
for i in tqdm(indices):
|
133 |
+
# if all our files are already there, skip
|
134 |
+
# done = []
|
135 |
+
# for name in SAMPLE_CONDS:
|
136 |
+
# o_dir = Path(output_dir) / name
|
137 |
+
# done.append((o_dir / f"{i}.wav").exists())
|
138 |
+
# if all(done):
|
139 |
+
# continue
|
140 |
+
|
141 |
sig = dataset[i]["signal"]
|
142 |
|
143 |
results = {
|
setup.py
CHANGED
@@ -26,16 +26,15 @@ setup(
|
|
26 |
license="MIT",
|
27 |
packages=find_packages(),
|
28 |
install_requires=[
|
29 |
-
"torch
|
30 |
"argbind>=0.3.2",
|
31 |
"pytorch-ignite",
|
32 |
"rich",
|
33 |
-
"audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@
|
34 |
-
"lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git
|
35 |
"tqdm",
|
36 |
"tensorboard",
|
37 |
"google-cloud-logging==2.2.0",
|
38 |
-
"torchmetrics>=0.7.3",
|
39 |
"einops",
|
40 |
"frechet_audio_distance"
|
41 |
],
|
|
|
26 |
license="MIT",
|
27 |
packages=find_packages(),
|
28 |
install_requires=[
|
29 |
+
"torch",
|
30 |
"argbind>=0.3.2",
|
31 |
"pytorch-ignite",
|
32 |
"rich",
|
33 |
+
"audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@hf/backup-info",
|
34 |
+
"lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git",
|
35 |
"tqdm",
|
36 |
"tensorboard",
|
37 |
"google-cloud-logging==2.2.0",
|
|
|
38 |
"einops",
|
39 |
"frechet_audio_distance"
|
40 |
],
|
vampnet/gradio.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
|
2 |
-
import gradio as gr
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
vampnet/interface.py
CHANGED
@@ -315,6 +315,7 @@ class Interface:
|
|
315 |
debug=False,
|
316 |
swap_prefix_suffix=False,
|
317 |
ext_mask=None,
|
|
|
318 |
verbose=False,
|
319 |
**kwargs
|
320 |
):
|
@@ -351,7 +352,8 @@ class Interface:
|
|
351 |
n_suffix=n_suffix,
|
352 |
downsample_factor=downsample_factor,
|
353 |
mask=cz_mask,
|
354 |
-
ext_mask=ext_mask
|
|
|
355 |
)
|
356 |
if debug:
|
357 |
print("tokens to infer")
|
|
|
315 |
debug=False,
|
316 |
swap_prefix_suffix=False,
|
317 |
ext_mask=None,
|
318 |
+
n_conditioning_codebooks=None,
|
319 |
verbose=False,
|
320 |
**kwargs
|
321 |
):
|
|
|
352 |
n_suffix=n_suffix,
|
353 |
downsample_factor=downsample_factor,
|
354 |
mask=cz_mask,
|
355 |
+
ext_mask=ext_mask,
|
356 |
+
n_conditioning_codebooks=n_conditioning_codebooks
|
357 |
)
|
358 |
if debug:
|
359 |
print("tokens to infer")
|
vampnet/modules/base.py
CHANGED
@@ -41,6 +41,7 @@ class VampBase(at.ml.BaseModel):
|
|
41 |
n_prefix: Optional[torch.Tensor] = None,
|
42 |
n_suffix: Optional[torch.Tensor] = None,
|
43 |
downsample_factor: Optional[int] = None,
|
|
|
44 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
45 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
46 |
|
@@ -79,7 +80,8 @@ class VampBase(at.ml.BaseModel):
|
|
79 |
mask = mask.round().long()
|
80 |
|
81 |
# if we have any conditioning codebooks, set their mask to 0
|
82 |
-
|
|
|
83 |
else:
|
84 |
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
|
85 |
assert mask.shape == x.shape, "mask must be same shape as x"
|
|
|
41 |
n_prefix: Optional[torch.Tensor] = None,
|
42 |
n_suffix: Optional[torch.Tensor] = None,
|
43 |
downsample_factor: Optional[int] = None,
|
44 |
+
n_conditioning_codebooks: Optional[int] = None,
|
45 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
46 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
47 |
|
|
|
80 |
mask = mask.round().long()
|
81 |
|
82 |
# if we have any conditioning codebooks, set their mask to 0
|
83 |
+
n_conditioning_codebooks = n_conditioning_codebooks or self.n_conditioning_codebooks
|
84 |
+
mask[:, :n_conditioning_codebooks, :] = 0
|
85 |
else:
|
86 |
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
|
87 |
assert mask.shape == x.shape, "mask must be same shape as x"
|
vampnet/util.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
# import pathos
|
3 |
+
|
4 |
+
def process_map(fn, *iterables, **tqdm_kwargs):
|
5 |
+
"""
|
6 |
+
Equivalent of `list(map(fn, *iterables))`
|
7 |
+
driven by `concurrent.futures.ProcessPoolExecutor`.
|
8 |
+
|
9 |
+
Parameters
|
10 |
+
----------
|
11 |
+
tqdm_class : optional
|
12 |
+
`tqdm` class to use for bars [default: tqdm.auto.tqdm].
|
13 |
+
max_workers : int, optional
|
14 |
+
Maximum number of workers to spawn; passed to
|
15 |
+
`concurrent.futures.ProcessPoolExecutor.__init__`.
|
16 |
+
[default: min(32, cpu_count() + 4)].
|
17 |
+
chunksize : int, optional
|
18 |
+
Size of chunks sent to worker processes; passed to
|
19 |
+
`concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
|
20 |
+
lock_name : str, optional
|
21 |
+
Member of `tqdm_class.get_lock()` to use [default: mp_lock].
|
22 |
+
"""
|
23 |
+
from concurrent.futures import ProcessPoolExecutor
|
24 |
+
if iterables and "chunksize" not in tqdm_kwargs:
|
25 |
+
# default `chunksize=1` has poor performance for large iterables
|
26 |
+
# (most time spent dispatching items to workers).
|
27 |
+
longest_iterable_len = max(map(length_hint, iterables))
|
28 |
+
if longest_iterable_len > 1000:
|
29 |
+
from warnings import warn
|
30 |
+
warn("Iterable length %d > 1000 but `chunksize` is not set."
|
31 |
+
" This may seriously degrade multiprocess performance."
|
32 |
+
" Set `chunksize=1` or more." % longest_iterable_len,
|
33 |
+
TqdmWarning, stacklevel=2)
|
34 |
+
if "lock_name" not in tqdm_kwargs:
|
35 |
+
tqdm_kwargs = tqdm_kwargs.copy()
|
36 |
+
tqdm_kwargs["lock_name"] = "mp_lock"
|
37 |
+
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
38 |
+
|
39 |
+
|
40 |
+
def parallelize(
|
41 |
+
fn,
|
42 |
+
*iterables,
|
43 |
+
parallel: str = "thread_map",
|
44 |
+
**kwargs
|
45 |
+
):
|
46 |
+
if parallel == "thread_map":
|
47 |
+
from tqdm.contrib.concurrent import thread_map
|
48 |
+
return thread_map(
|
49 |
+
fn,
|
50 |
+
*iterables,
|
51 |
+
**kwargs
|
52 |
+
)
|
53 |
+
elif parallel == "process_map":
|
54 |
+
from tqdm.contrib.concurrent import process_map
|
55 |
+
return process_map(
|
56 |
+
fn,
|
57 |
+
*iterables,
|
58 |
+
**kwargs
|
59 |
+
)
|
60 |
+
elif parallel == "single":
|
61 |
+
return [fn(x) for x in tqdm.tqdm(*iterables)]
|
62 |
+
else:
|
63 |
+
raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
|