Spaces:
Running
on
T4
Running
on
T4
Hugo Flores Garcia
commited on
Commit
•
93b48cb
1
Parent(s):
128981d
more tweaks
Browse files- demo.py +22 -16
- scripts/exp/eval.py +17 -12
- scripts/utils/vamp_folder.py +116 -22
- vampnet/interface.py +0 -2
- vampnet/modules/base.py +10 -3
demo.py
CHANGED
@@ -210,25 +210,30 @@ with gr.Blocks() as demo:
|
|
210 |
|
211 |
""")
|
212 |
gr.Markdown("## Input Audio")
|
213 |
-
with gr.Column():
|
214 |
-
gr.Markdown("""
|
215 |
-
## Mask Hints
|
216 |
-
- most of the original audio will be masked and replaced with audio generated by vampnet
|
217 |
-
- mask hints are used to guide vampnet to generate audio that sounds like the original
|
218 |
-
- the more hints you give, the more the generated audio will sound like the original
|
219 |
|
220 |
-
""")
|
221 |
with gr.Column():
|
222 |
gr.Markdown("""
|
223 |
### Tips
|
224 |
- use the beat hint button so the output audio has the same beat structure as the input audio
|
225 |
-
- if you want
|
226 |
-
-
|
227 |
-
- decrease the periodic unmasking to anywhere from 2 to 8
|
228 |
- if you want a more "random" generation:
|
229 |
-
-
|
230 |
-
- increase the periodic unmasking to 16 or more
|
231 |
- increase the temperatures!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
""")
|
234 |
|
@@ -243,7 +248,8 @@ with gr.Blocks() as demo:
|
|
243 |
num_vamps = gr.Number(
|
244 |
label="number of vamps. more vamps = longer generated audio",
|
245 |
value=1,
|
246 |
-
precision=0
|
|
|
247 |
)
|
248 |
|
249 |
manual_audio_upload = gr.File(
|
@@ -286,7 +292,7 @@ with gr.Blocks() as demo:
|
|
286 |
minimum=0,
|
287 |
maximum=64,
|
288 |
step=1,
|
289 |
-
value=
|
290 |
)
|
291 |
|
292 |
|
@@ -326,8 +332,8 @@ with gr.Blocks() as demo:
|
|
326 |
)
|
327 |
|
328 |
use_beats = gr.Checkbox(
|
329 |
-
label="use beat hints",
|
330 |
-
value=
|
331 |
)
|
332 |
|
333 |
snap_to_beats = gr.Checkbox(
|
|
|
210 |
|
211 |
""")
|
212 |
gr.Markdown("## Input Audio")
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
|
|
214 |
with gr.Column():
|
215 |
gr.Markdown("""
|
216 |
### Tips
|
217 |
- use the beat hint button so the output audio has the same beat structure as the input audio
|
218 |
+
- if you want more beat structure:
|
219 |
+
- enable beat hints
|
|
|
220 |
- if you want a more "random" generation:
|
221 |
+
- increase the periodic unmasking to 12 or more
|
|
|
222 |
- increase the temperatures!
|
223 |
+
- uncheck the beat hint button (or reduce the beat unmask duration)
|
224 |
+
- if you want the generated audio to sound like the original, but with a different beat structure:
|
225 |
+
- uncheck the beat hint button
|
226 |
+
- decrease the periodic unmasking to anywhere from 2 to 20
|
227 |
+
- slightly decrease the random intensity, to like .95
|
228 |
+
|
229 |
+
|
230 |
+
""")
|
231 |
+
with gr.Column():
|
232 |
+
gr.Markdown("""
|
233 |
+
## Mask Hints
|
234 |
+
- most of the original audio will be masked and replaced with audio generated by vampnet
|
235 |
+
- mask hints are used to guide vampnet to generate audio that sounds like the original
|
236 |
+
- the more hints you give, the more the generated audio will sound like the original
|
237 |
|
238 |
""")
|
239 |
|
|
|
248 |
num_vamps = gr.Number(
|
249 |
label="number of vamps. more vamps = longer generated audio",
|
250 |
value=1,
|
251 |
+
precision=0,
|
252 |
+
visible=False
|
253 |
)
|
254 |
|
255 |
manual_audio_upload = gr.File(
|
|
|
292 |
minimum=0,
|
293 |
maximum=64,
|
294 |
step=1,
|
295 |
+
value=9,
|
296 |
)
|
297 |
|
298 |
|
|
|
332 |
)
|
333 |
|
334 |
use_beats = gr.Checkbox(
|
335 |
+
label="use beat hints (helps the output stick to the beat structure of the input)",
|
336 |
+
value=False
|
337 |
)
|
338 |
|
339 |
snap_to_beats = gr.Checkbox(
|
scripts/exp/eval.py
CHANGED
@@ -5,6 +5,7 @@ from functools import partial
|
|
5 |
from frechet_audio_distance import FrechetAudioDistance
|
6 |
import pandas
|
7 |
import argbind
|
|
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
import audiotools
|
@@ -21,15 +22,16 @@ def eval(
|
|
21 |
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
22 |
|
23 |
# set up our metrics
|
24 |
-
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
25 |
-
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
26 |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
27 |
frechet = FrechetAudioDistance(
|
28 |
use_pca=False,
|
29 |
use_activation=False,
|
30 |
-
verbose=True
|
|
|
31 |
)
|
32 |
-
|
33 |
|
34 |
# figure out what conditions we have
|
35 |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
@@ -44,7 +46,7 @@ def eval(
|
|
44 |
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
45 |
|
46 |
metrics = []
|
47 |
-
for condition in conditions:
|
48 |
cond_dir = exp_dir / condition
|
49 |
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
50 |
|
@@ -68,14 +70,17 @@ def eval(
|
|
68 |
cond_sig.resample(baseline_sig.sample_rate)
|
69 |
cond_sig.truncate_samples(baseline_sig.length)
|
70 |
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
76 |
return {
|
77 |
-
"sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
78 |
-
"stft": stft_loss(baseline_sig, cond_sig).item(),
|
79 |
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
80 |
"frechet": frechet_score,
|
81 |
# "visqol": vsq,
|
|
|
5 |
from frechet_audio_distance import FrechetAudioDistance
|
6 |
import pandas
|
7 |
import argbind
|
8 |
+
import torch
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
import audiotools
|
|
|
22 |
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
23 |
|
24 |
# set up our metrics
|
25 |
+
# sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
26 |
+
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
27 |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
28 |
frechet = FrechetAudioDistance(
|
29 |
use_pca=False,
|
30 |
use_activation=False,
|
31 |
+
verbose=True,
|
32 |
+
audio_load_worker=4,
|
33 |
)
|
34 |
+
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
|
36 |
# figure out what conditions we have
|
37 |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
|
|
46 |
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
47 |
|
48 |
metrics = []
|
49 |
+
for condition in tqdm(conditions):
|
50 |
cond_dir = exp_dir / condition
|
51 |
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
52 |
|
|
|
70 |
cond_sig.resample(baseline_sig.sample_rate)
|
71 |
cond_sig.truncate_samples(baseline_sig.length)
|
72 |
|
73 |
+
# if our condition is inpainting, we need to trim the conditioning off
|
74 |
+
if "inpaint" in condition:
|
75 |
+
ctx_amt = float(condition.split("_")[-1])
|
76 |
+
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
|
77 |
+
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
|
78 |
+
cond_sig.trim(ctx_samples, ctx_samples)
|
79 |
+
baseline_sig.trim(ctx_samples, ctx_samples)
|
80 |
+
|
81 |
return {
|
82 |
+
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
83 |
+
# "stft": stft_loss(baseline_sig, cond_sig).item(),
|
84 |
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
85 |
"frechet": frechet_score,
|
86 |
# "visqol": vsq,
|
scripts/utils/vamp_folder.py
CHANGED
@@ -6,7 +6,7 @@ import subprocess
|
|
6 |
|
7 |
import argbind
|
8 |
from tqdm import tqdm
|
9 |
-
import
|
10 |
|
11 |
from vampnet.interface import Interface
|
12 |
import audiotools as at
|
@@ -48,7 +48,6 @@ def coarse2fine_argmax(sig, interface):
|
|
48 |
)
|
49 |
return interface.to_signal(z)
|
50 |
|
51 |
-
|
52 |
class CoarseCond:
|
53 |
|
54 |
def __init__(self, num_codebooks, downsample_factor):
|
@@ -59,13 +58,12 @@ class CoarseCond:
|
|
59 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
60 |
zv = interface.coarse_vamp_v2(sig,
|
61 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
62 |
-
downsample_factor=self.downsample_factor
|
63 |
)
|
64 |
|
65 |
zv = interface.coarse_to_fine(zv)
|
66 |
return interface.to_signal(zv)
|
67 |
|
68 |
-
|
69 |
def opus(sig, interface, bitrate=128):
|
70 |
sig = interface.preprocess(sig)
|
71 |
|
@@ -97,8 +95,78 @@ def opus(sig, interface, bitrate=128):
|
|
97 |
)
|
98 |
return sig
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
"baseline": baseline,
|
103 |
"reconstructed": reconstructed,
|
104 |
"coarse2fine": coarse2fine,
|
@@ -119,23 +187,55 @@ COARSE_SAMPLE_CONDS ={
|
|
119 |
|
120 |
}
|
121 |
|
122 |
-
|
123 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
124 |
for bitrate in [5620, 1875, 1250, 625]
|
125 |
}
|
126 |
|
127 |
-
|
128 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
129 |
for bitrate in [8036, 2296, 1148, 574]
|
130 |
}
|
131 |
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
133 |
"baseline": baseline,
|
134 |
"reconstructed": reconstructed,
|
135 |
"coarse2fine": coarse2fine,
|
136 |
"coarse2fine_argmax": coarse2fine_argmax,
|
137 |
}
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
@argbind.bind(without_prefix=True)
|
140 |
def main(
|
141 |
sources=[
|
@@ -162,14 +262,8 @@ def main(
|
|
162 |
without_replacement=True,
|
163 |
)
|
164 |
|
165 |
-
if exp_type
|
166 |
-
SAMPLE_CONDS =
|
167 |
-
elif exp_type == "opus-spotdl":
|
168 |
-
SAMPLE_CONDS = OPUS_SPOTDL_SAMPLE_CONDS
|
169 |
-
elif exp_type == "coarse":
|
170 |
-
SAMPLE_CONDS = COARSE_SAMPLE_CONDS
|
171 |
-
elif exp_type == "c2f":
|
172 |
-
SAMPLE_CONDS = C2F_SAMPLE_CONDS
|
173 |
else:
|
174 |
raise ValueError(f"Unknown exp_type {exp_type}")
|
175 |
|
@@ -178,12 +272,12 @@ def main(
|
|
178 |
random.shuffle(indices)
|
179 |
for i in tqdm(indices):
|
180 |
# if all our files are already there, skip
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
|
188 |
sig = dataset[i]["signal"]
|
189 |
results = {
|
|
|
6 |
|
7 |
import argbind
|
8 |
from tqdm import tqdm
|
9 |
+
import torch
|
10 |
|
11 |
from vampnet.interface import Interface
|
12 |
import audiotools as at
|
|
|
48 |
)
|
49 |
return interface.to_signal(z)
|
50 |
|
|
|
51 |
class CoarseCond:
|
52 |
|
53 |
def __init__(self, num_codebooks, downsample_factor):
|
|
|
58 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
59 |
zv = interface.coarse_vamp_v2(sig,
|
60 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
61 |
+
downsample_factor=self.downsample_factor,
|
62 |
)
|
63 |
|
64 |
zv = interface.coarse_to_fine(zv)
|
65 |
return interface.to_signal(zv)
|
66 |
|
|
|
67 |
def opus(sig, interface, bitrate=128):
|
68 |
sig = interface.preprocess(sig)
|
69 |
|
|
|
95 |
)
|
96 |
return sig
|
97 |
|
98 |
+
def token_noise(ratio=1.0):
|
99 |
+
def wrapper(sig, interface):
|
100 |
+
z = interface.encode(sig)
|
101 |
+
r = interface.coarse.invgamma(ratio).to(interface.device)
|
102 |
+
print(f'adding noise with ratio {ratio}')
|
103 |
+
z, mask = interface.coarse.add_noise(
|
104 |
+
z,
|
105 |
+
r,
|
106 |
+
noise_mode="random"
|
107 |
+
)
|
108 |
+
return interface.to_signal(z)
|
109 |
+
return wrapper
|
110 |
+
|
111 |
+
def mask_ratio_1_step(ratio=1.0):
|
112 |
+
def wrapper(sig, interface):
|
113 |
+
r = interface.coarse.invgamma(ratio).to(interface.device)
|
114 |
+
intensity = 1-r
|
115 |
+
|
116 |
+
zv = interface.coarse_vamp_v2(
|
117 |
+
sig,
|
118 |
+
sample='argmax',
|
119 |
+
sampling_steps=1,
|
120 |
+
intensity=intensity
|
121 |
+
)
|
122 |
+
|
123 |
+
return interface.to_signal(zv)
|
124 |
+
return wrapper
|
125 |
+
|
126 |
+
def num_sampling_steps(num_steps=1):
|
127 |
+
def wrapper(sig, interface):
|
128 |
+
zv = interface.coarse_vamp_v2(
|
129 |
+
sig,
|
130 |
+
downsample_factor=16,
|
131 |
+
sampling_steps=num_steps,
|
132 |
+
)
|
133 |
|
134 |
+
zv = interface.coarse_to_fine(zv)
|
135 |
+
return interface.to_signal(zv)
|
136 |
+
return wrapper
|
137 |
+
|
138 |
+
def beat_mask(ctx_time):
|
139 |
+
def wrapper(sig, interface):
|
140 |
+
beat_mask = interface.make_beat_mask(
|
141 |
+
sig,
|
142 |
+
before_beat_s=0.0,
|
143 |
+
after_beat_s=ctx_time,
|
144 |
+
invert=True
|
145 |
+
)
|
146 |
+
zv = interface.coarse_vamp_v2(
|
147 |
+
sig,
|
148 |
+
ext_mask=beat_mask,
|
149 |
+
)
|
150 |
+
|
151 |
+
zv = interface.coarse_to_fine(zv)
|
152 |
+
return interface.to_signal(zv)
|
153 |
+
return wrapper
|
154 |
+
|
155 |
+
def inpaint(ctx_time):
|
156 |
+
def wrapper(sig, interface):
|
157 |
+
zv = interface.coarse_vamp_v2(
|
158 |
+
sig,
|
159 |
+
prefix_dur_s=ctx_time,
|
160 |
+
suffix_dur_s=ctx_time,
|
161 |
+
)
|
162 |
+
|
163 |
+
zv = interface.coarse_to_fine(zv)
|
164 |
+
return interface.to_signal(zv)
|
165 |
+
return wrapper
|
166 |
+
|
167 |
+
EXP_REGISTRY = {}
|
168 |
+
|
169 |
+
EXP_REGISTRY["gen-compression"] = {
|
170 |
"baseline": baseline,
|
171 |
"reconstructed": reconstructed,
|
172 |
"coarse2fine": coarse2fine,
|
|
|
187 |
|
188 |
}
|
189 |
|
190 |
+
EXP_REGISTRY["opus-jazzpop"] = {
|
191 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
192 |
for bitrate in [5620, 1875, 1250, 625]
|
193 |
}
|
194 |
|
195 |
+
EXP_REGISTRY["opus-spotdl"] = {
|
196 |
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
197 |
for bitrate in [8036, 2296, 1148, 574]
|
198 |
}
|
199 |
|
200 |
+
EXP_REGISTRY["opus-baseline"] = {
|
201 |
+
f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
|
202 |
+
for bitrate in [8000, 12000, 16000]
|
203 |
+
}
|
204 |
+
|
205 |
+
EXP_REGISTRY["c2f"] = {
|
206 |
"baseline": baseline,
|
207 |
"reconstructed": reconstructed,
|
208 |
"coarse2fine": coarse2fine,
|
209 |
"coarse2fine_argmax": coarse2fine_argmax,
|
210 |
}
|
211 |
|
212 |
+
EXP_REGISTRY["token-noise"] = {
|
213 |
+
f"token_noise_{r}": token_noise(r) for r in [0.25, 0.5, 0.75, 1.0]
|
214 |
+
}
|
215 |
+
|
216 |
+
EXP_REGISTRY["mask-ratio"] = {
|
217 |
+
"codec": reconstructed,
|
218 |
+
**{f"mask_ratio_{r}": mask_ratio_1_step(r) for r in [0.25, 0.5, 0.75, 0.9]}
|
219 |
+
}
|
220 |
+
|
221 |
+
EXP_REGISTRY["sampling-steps"] = {
|
222 |
+
"codec": reconstructed,
|
223 |
+
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 24, 36, 64, 72, 128]},
|
224 |
+
}
|
225 |
+
|
226 |
+
EXP_REGISTRY["baseline"] = {
|
227 |
+
"baseline": baseline,
|
228 |
+
"codec": reconstructed,
|
229 |
+
}
|
230 |
+
|
231 |
+
EXP_REGISTRY["musical-sampling"] = {
|
232 |
+
"baseline": baseline,
|
233 |
+
"codec": reconstructed,
|
234 |
+
**{f"downsample_{x}x": CoarseCond(4, downsample_factor=x) for x in [16, 32]},
|
235 |
+
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
236 |
+
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
237 |
+
}
|
238 |
+
|
239 |
@argbind.bind(without_prefix=True)
|
240 |
def main(
|
241 |
sources=[
|
|
|
262 |
without_replacement=True,
|
263 |
)
|
264 |
|
265 |
+
if exp_type in EXP_REGISTRY:
|
266 |
+
SAMPLE_CONDS = EXP_REGISTRY[exp_type]
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
else:
|
268 |
raise ValueError(f"Unknown exp_type {exp_type}")
|
269 |
|
|
|
272 |
random.shuffle(indices)
|
273 |
for i in tqdm(indices):
|
274 |
# if all our files are already there, skip
|
275 |
+
done = []
|
276 |
+
for name in SAMPLE_CONDS:
|
277 |
+
o_dir = Path(output_dir) / name
|
278 |
+
done.append((o_dir / f"{i}.wav").exists())
|
279 |
+
if all(done):
|
280 |
+
continue
|
281 |
|
282 |
sig = dataset[i]["signal"]
|
283 |
results = {
|
vampnet/interface.py
CHANGED
@@ -183,10 +183,8 @@ class Interface:
|
|
183 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
184 |
_m = torch.ones(num_steps, device=self.device)
|
185 |
_m = torch.nn.functional.dropout(_m, p=dropout)
|
186 |
-
print(_m)
|
187 |
|
188 |
mask[_slice[0]:_slice[1]] = _m
|
189 |
-
print(mask)
|
190 |
|
191 |
if mask_downbeats:
|
192 |
for downbeat_idx in downbeats_z:
|
|
|
183 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
184 |
_m = torch.ones(num_steps, device=self.device)
|
185 |
_m = torch.nn.functional.dropout(_m, p=dropout)
|
|
|
186 |
|
187 |
mask[_slice[0]:_slice[1]] = _m
|
|
|
188 |
|
189 |
if mask_downbeats:
|
190 |
for downbeat_idx in downbeats_z:
|
vampnet/modules/base.py
CHANGED
@@ -42,6 +42,7 @@ class VampBase(at.ml.BaseModel):
|
|
42 |
n_suffix: Optional[torch.Tensor] = None,
|
43 |
downsample_factor: Optional[int] = None,
|
44 |
n_conditioning_codebooks: Optional[int] = None,
|
|
|
45 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
46 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
47 |
|
@@ -89,13 +90,14 @@ class VampBase(at.ml.BaseModel):
|
|
89 |
if random_x is None:
|
90 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
91 |
|
92 |
-
if self.noise_mode
|
|
|
93 |
random_x = torch.full_like(x, self.mask_token)
|
94 |
-
elif
|
95 |
if random_x is None:
|
96 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
97 |
else:
|
98 |
-
raise ValueError(f"invalid noise mode {
|
99 |
|
100 |
# add the external mask if we were given one
|
101 |
if ext_mask is not None:
|
@@ -132,6 +134,11 @@ class VampBase(at.ml.BaseModel):
|
|
132 |
def gamma(self, r):
|
133 |
return (r * torch.pi / 2).cos()
|
134 |
|
|
|
|
|
|
|
|
|
|
|
135 |
def r_embed(self, r, max_positions=10000):
|
136 |
""" """
|
137 |
assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
|
|
|
42 |
n_suffix: Optional[torch.Tensor] = None,
|
43 |
downsample_factor: Optional[int] = None,
|
44 |
n_conditioning_codebooks: Optional[int] = None,
|
45 |
+
noise_mode: str = None,
|
46 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
47 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
48 |
|
|
|
90 |
if random_x is None:
|
91 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
92 |
|
93 |
+
noise_mode = noise_mode if noise_mode is not None else self.noise_mode
|
94 |
+
if noise_mode == "mask":
|
95 |
random_x = torch.full_like(x, self.mask_token)
|
96 |
+
elif noise_mode == "random":
|
97 |
if random_x is None:
|
98 |
random_x = torch.randint_like(x, 0, self.vocab_size)
|
99 |
else:
|
100 |
+
raise ValueError(f"invalid noise mode {noise_mode}")
|
101 |
|
102 |
# add the external mask if we were given one
|
103 |
if ext_mask is not None:
|
|
|
134 |
def gamma(self, r):
|
135 |
return (r * torch.pi / 2).cos()
|
136 |
|
137 |
+
def invgamma(self, y):
|
138 |
+
if not torch.is_tensor(y):
|
139 |
+
y = torch.tensor(y)[None]
|
140 |
+
return 2 * y.acos() / torch.pi
|
141 |
+
|
142 |
def r_embed(self, r, max_positions=10000):
|
143 |
""" """
|
144 |
assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
|