Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
•
7b88c07
1
Parent(s):
bf35d45
better onset detection!!!!!
Browse filesremove annealing from sampling temperature
add pitch shifting w/ torch pitch shift
improvements to lora config
(TODO: fix a lora bug where the lora weights won't load correctly)
add helper scripts for collecting xeno canto data
- app.py +60 -11
- conf/lora/lora.yml +4 -2
- scripts/exp/fine_tune.py +2 -2
- scripts/utils/augment.py +38 -24
- scripts/utils/remove_quiet_files.py +29 -0
- scripts/xeno-canto-dl.py +234 -0
- vampnet/mask.py +38 -20
- vampnet/modules/transformer.py +9 -14
app.py
CHANGED
@@ -18,6 +18,16 @@ Interface = argbind.bind(Interface)
|
|
18 |
|
19 |
conf = argbind.parse_args()
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def load_interface():
|
22 |
with argbind.scope(conf):
|
23 |
interface = Interface()
|
@@ -95,6 +105,10 @@ def _vamp(data, return_mask=False):
|
|
95 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
96 |
out_dir.mkdir()
|
97 |
sig = at.AudioSignal(data[input_audio])
|
|
|
|
|
|
|
|
|
98 |
|
99 |
z = interface.encode(sig)
|
100 |
|
@@ -134,7 +148,27 @@ def _vamp(data, return_mask=False):
|
|
134 |
mask = pmask.codebook_unmask(mask, ncc)
|
135 |
|
136 |
|
137 |
-
print(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
_top_p = data[top_p] if data[top_p] > 0 else None
|
139 |
# save the mask as a txt file
|
140 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
@@ -153,6 +187,7 @@ def _vamp(data, return_mask=False):
|
|
153 |
top_p=_top_p,
|
154 |
gen_fn=interface.coarse.generate,
|
155 |
seed=_seed,
|
|
|
156 |
)
|
157 |
|
158 |
if use_coarse2fine:
|
@@ -356,7 +391,7 @@ with gr.Blocks() as demo:
|
|
356 |
onset_mask_width = gr.Slider(
|
357 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
358 |
minimum=0,
|
359 |
-
maximum=
|
360 |
step=1,
|
361 |
value=5,
|
362 |
)
|
@@ -374,6 +409,14 @@ with gr.Blocks() as demo:
|
|
374 |
|
375 |
|
376 |
with gr.Accordion("extras ", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
rand_mask_intensity = gr.Slider(
|
378 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
379 |
minimum=0.0,
|
@@ -436,14 +479,15 @@ with gr.Blocks() as demo:
|
|
436 |
masktemp = gr.Slider(
|
437 |
label="mask temperature",
|
438 |
minimum=0.0,
|
439 |
-
maximum=
|
440 |
value=1.5
|
441 |
)
|
442 |
sampletemp = gr.Slider(
|
443 |
label="sample temperature",
|
444 |
minimum=0.1,
|
445 |
-
maximum=
|
446 |
-
value=1.0
|
|
|
447 |
)
|
448 |
|
449 |
|
@@ -459,7 +503,7 @@ with gr.Blocks() as demo:
|
|
459 |
label="typical filtering ",
|
460 |
value=False
|
461 |
)
|
462 |
-
typical_mass = gr.Slider(
|
463 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
464 |
minimum=0.01,
|
465 |
maximum=0.99,
|
@@ -472,6 +516,13 @@ with gr.Blocks() as demo:
|
|
472 |
step=1,
|
473 |
value=64
|
474 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
|
476 |
use_coarse2fine = gr.Checkbox(
|
477 |
label="use coarse2fine",
|
@@ -495,10 +546,6 @@ with gr.Blocks() as demo:
|
|
495 |
value=0.0
|
496 |
)
|
497 |
|
498 |
-
use_new_trick = gr.Checkbox(
|
499 |
-
label="new trick",
|
500 |
-
value=False
|
501 |
-
)
|
502 |
|
503 |
seed = gr.Number(
|
504 |
label="seed (0 for random)",
|
@@ -560,6 +607,8 @@ with gr.Blocks() as demo:
|
|
560 |
beat_mask_downbeats,
|
561 |
seed,
|
562 |
lora_choice,
|
|
|
|
|
563 |
}
|
564 |
|
565 |
# connect widgets
|
@@ -589,4 +638,4 @@ with gr.Blocks() as demo:
|
|
589 |
outputs=[thank_you, download_file]
|
590 |
)
|
591 |
|
592 |
-
demo.launch(share=True, enable_queue=
|
|
|
18 |
|
19 |
conf = argbind.parse_args()
|
20 |
|
21 |
+
|
22 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
23 |
+
def shift_pitch(signal, interval: int):
|
24 |
+
signal.samples = pitch_shift(
|
25 |
+
signal.samples,
|
26 |
+
shift=interval,
|
27 |
+
sample_rate=signal.sample_rate
|
28 |
+
)
|
29 |
+
return signal
|
30 |
+
|
31 |
def load_interface():
|
32 |
with argbind.scope(conf):
|
33 |
interface = Interface()
|
|
|
105 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
106 |
out_dir.mkdir()
|
107 |
sig = at.AudioSignal(data[input_audio])
|
108 |
+
sig = interface.preprocess(sig)
|
109 |
+
|
110 |
+
if data[pitch_shift_amt] != 0:
|
111 |
+
sig = shift_pitch(sig, data[pitch_shift_amt])
|
112 |
|
113 |
z = interface.encode(sig)
|
114 |
|
|
|
148 |
mask = pmask.codebook_unmask(mask, ncc)
|
149 |
|
150 |
|
151 |
+
print(f"dropout {data[dropout]}")
|
152 |
+
print(f"masktemp {data[masktemp]}")
|
153 |
+
print(f"sampletemp {data[sampletemp]}")
|
154 |
+
print(f"top_p {data[top_p]}")
|
155 |
+
print(f"prefix_s {data[prefix_s]}")
|
156 |
+
print(f"suffix_s {data[suffix_s]}")
|
157 |
+
print(f"rand_mask_intensity {data[rand_mask_intensity]}")
|
158 |
+
print(f"num_steps {data[num_steps]}")
|
159 |
+
print(f"periodic_p {data[periodic_p]}")
|
160 |
+
print(f"periodic_w {data[periodic_w]}")
|
161 |
+
print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
|
162 |
+
print(f"use_coarse2fine {data[use_coarse2fine]}")
|
163 |
+
print(f"onset_mask_width {data[onset_mask_width]}")
|
164 |
+
print(f"beat_mask_width {data[beat_mask_width]}")
|
165 |
+
print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
|
166 |
+
print(f"stretch_factor {data[stretch_factor]}")
|
167 |
+
print(f"seed {data[seed]}")
|
168 |
+
print(f"pitch_shift_amt {data[pitch_shift_amt]}")
|
169 |
+
print(f"sample_cutoff {data[sample_cutoff]}")
|
170 |
+
|
171 |
+
|
172 |
_top_p = data[top_p] if data[top_p] > 0 else None
|
173 |
# save the mask as a txt file
|
174 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
|
|
187 |
top_p=_top_p,
|
188 |
gen_fn=interface.coarse.generate,
|
189 |
seed=_seed,
|
190 |
+
sample_cutoff=data[sample_cutoff],
|
191 |
)
|
192 |
|
193 |
if use_coarse2fine:
|
|
|
391 |
onset_mask_width = gr.Slider(
|
392 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
393 |
minimum=0,
|
394 |
+
maximum=100,
|
395 |
step=1,
|
396 |
value=5,
|
397 |
)
|
|
|
409 |
|
410 |
|
411 |
with gr.Accordion("extras ", open=False):
|
412 |
+
pitch_shift_amt = gr.Slider(
|
413 |
+
label="pitch shift amount (semitones)",
|
414 |
+
minimum=-12,
|
415 |
+
maximum=12,
|
416 |
+
step=1,
|
417 |
+
value=0,
|
418 |
+
)
|
419 |
+
|
420 |
rand_mask_intensity = gr.Slider(
|
421 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
422 |
minimum=0.0,
|
|
|
479 |
masktemp = gr.Slider(
|
480 |
label="mask temperature",
|
481 |
minimum=0.0,
|
482 |
+
maximum=100.0,
|
483 |
value=1.5
|
484 |
)
|
485 |
sampletemp = gr.Slider(
|
486 |
label="sample temperature",
|
487 |
minimum=0.1,
|
488 |
+
maximum=10.0,
|
489 |
+
value=1.0,
|
490 |
+
step=0.001
|
491 |
)
|
492 |
|
493 |
|
|
|
503 |
label="typical filtering ",
|
504 |
value=False
|
505 |
)
|
506 |
+
typical_mass = gr.Slider(
|
507 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
508 |
minimum=0.01,
|
509 |
maximum=0.99,
|
|
|
516 |
step=1,
|
517 |
value=64
|
518 |
)
|
519 |
+
sample_cutoff = gr.Slider(
|
520 |
+
label="sample cutoff",
|
521 |
+
minimum=0.0,
|
522 |
+
maximum=1.0,
|
523 |
+
value=0.5,
|
524 |
+
step=0.01
|
525 |
+
)
|
526 |
|
527 |
use_coarse2fine = gr.Checkbox(
|
528 |
label="use coarse2fine",
|
|
|
546 |
value=0.0
|
547 |
)
|
548 |
|
|
|
|
|
|
|
|
|
549 |
|
550 |
seed = gr.Number(
|
551 |
label="seed (0 for random)",
|
|
|
607 |
beat_mask_downbeats,
|
608 |
seed,
|
609 |
lora_choice,
|
610 |
+
pitch_shift_amt,
|
611 |
+
sample_cutoff
|
612 |
}
|
613 |
|
614 |
# connect widgets
|
|
|
638 |
outputs=[thank_you, download_file]
|
639 |
)
|
640 |
|
641 |
+
demo.launch(share=True, enable_queue=True, debug=True)
|
conf/lora/lora.yml
CHANGED
@@ -4,14 +4,16 @@ $include:
|
|
4 |
fine_tune: True
|
5 |
|
6 |
train/AudioDataset.n_examples: 100000000
|
7 |
-
val/AudioDataset.n_examples:
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
batch_size: 7
|
13 |
num_workers: 7
|
14 |
-
save_iters: [
|
|
|
|
|
15 |
|
16 |
AdamW.lr: 0.0001
|
17 |
|
|
|
4 |
fine_tune: True
|
5 |
|
6 |
train/AudioDataset.n_examples: 100000000
|
7 |
+
val/AudioDataset.n_examples: 500
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
batch_size: 7
|
13 |
num_workers: 7
|
14 |
+
save_iters: [10000, 20000, 30000, 40000, 50000]
|
15 |
+
sample_freq: 1000
|
16 |
+
val_freq: 500
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
scripts/exp/fine_tune.py
CHANGED
@@ -48,10 +48,10 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
-
"Interface.coarse_ckpt": f"./
|
52 |
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
|
54 |
-
"Interface.coarse2fine_ckpt": f"./
|
55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
57 |
|
|
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
|
52 |
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
|
54 |
+
"Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
|
55 |
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
57 |
|
scripts/utils/augment.py
CHANGED
@@ -5,34 +5,19 @@ from audiotools import AudioSignal
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
|
|
8 |
|
9 |
|
10 |
-
from
|
11 |
-
|
12 |
-
)
|
13 |
-
from pedalboard.io import AudioFile
|
14 |
|
15 |
-
|
16 |
-
samplerate = 44100.0
|
17 |
-
with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
|
18 |
-
audio = f.read(f.frames)
|
19 |
-
|
20 |
-
# Make a pretty interesting sounding guitar pedalboard:
|
21 |
-
board = Pedalboard([
|
22 |
-
Compressor(threshold_db=-50, ratio=25),
|
23 |
-
Gain(gain_db=30),
|
24 |
-
Chorus(),
|
25 |
-
LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
|
26 |
-
Phaser(),
|
27 |
-
Convolution("./guitar_amp.wav", 1.0),
|
28 |
-
Reverb(room_size=0.25),
|
29 |
-
])
|
30 |
|
31 |
|
32 |
@argbind.bind(without_prefix=True)
|
33 |
def augment(
|
34 |
-
audio_folder: Path,
|
35 |
-
dest_folder: Path,
|
36 |
n_augmentations: int = 10,
|
37 |
):
|
38 |
"""
|
@@ -41,7 +26,8 @@ def augment(
|
|
41 |
The dest foler will contain a folder for each of the clean dataset's files.
|
42 |
Under each of these folders, there will be a clean file and many augmented files.
|
43 |
"""
|
44 |
-
|
|
|
45 |
audio_files = at.util.find_audio(audio_folder)
|
46 |
|
47 |
for audio_file in tqdm.tqdm(audio_files):
|
@@ -49,5 +35,33 @@ def augment(
|
|
49 |
subdir = subtree / audio_file.stem
|
50 |
subdir.mkdir(parents=True, exist_ok=True)
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
8 |
+
import torch
|
9 |
|
10 |
|
11 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
12 |
+
from torch_time_stretch import time_stretch, get_fast_stretches
|
|
|
|
|
13 |
|
14 |
+
from audiotools.core.util import sample_from_dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@argbind.bind(without_prefix=True)
|
18 |
def augment(
|
19 |
+
audio_folder: Path = None,
|
20 |
+
dest_folder: Path = None,
|
21 |
n_augmentations: int = 10,
|
22 |
):
|
23 |
"""
|
|
|
26 |
The dest foler will contain a folder for each of the clean dataset's files.
|
27 |
Under each of these folders, there will be a clean file and many augmented files.
|
28 |
"""
|
29 |
+
assert audio_folder is not None
|
30 |
+
assert dest_folder is not None
|
31 |
audio_files = at.util.find_audio(audio_folder)
|
32 |
|
33 |
for audio_file in tqdm.tqdm(audio_files):
|
|
|
35 |
subdir = subtree / audio_file.stem
|
36 |
subdir.mkdir(parents=True, exist_ok=True)
|
37 |
|
38 |
+
src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
|
40 |
+
|
41 |
+
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
42 |
+
# apply pedalboard transforms
|
43 |
+
for j in range(n_augmentations):
|
44 |
+
# pitch shift between -7 and 7 semitones
|
45 |
+
import random
|
46 |
+
dst = chunk.clone()
|
47 |
+
dst.samples = pitch_shift(
|
48 |
+
dst.samples,
|
49 |
+
shift=random.choice(get_fast_shifts(src.sample_rate,
|
50 |
+
condition=lambda x: x >= 0.25 and x <= 1.0)),
|
51 |
+
sample_rate=src.sample_rate
|
52 |
+
)
|
53 |
+
dst.samples = time_stretch(
|
54 |
+
dst.samples,
|
55 |
+
stretch=random.choice(get_fast_stretches(src.sample_rate,
|
56 |
+
condition=lambda x: x >= 0.667 and x <= 1.5, )),
|
57 |
+
sample_rate=src.sample_rate,
|
58 |
+
)
|
59 |
+
|
60 |
+
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
args = argbind.parse_args()
|
65 |
+
|
66 |
+
with argbind.scope(args):
|
67 |
+
augment()
|
scripts/utils/remove_quiet_files.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# removes files with loudness below 24db
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
import audiotools as at
|
6 |
+
import argbind
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def remove_quiet_files(
|
10 |
+
src_dir: Path = None,
|
11 |
+
dest_dir: Path = None,
|
12 |
+
min_loudness: float = -30,
|
13 |
+
):
|
14 |
+
# copy src to dest
|
15 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
16 |
+
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
17 |
+
|
18 |
+
audio_files = at.util.find_audio(dest_dir)
|
19 |
+
for audio_file in audio_files:
|
20 |
+
sig = at.AudioSignal(audio_file)
|
21 |
+
if sig.loudness() < min_loudness:
|
22 |
+
audio_file.unlink()
|
23 |
+
print(f"removed {audio_file}")
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
args = argbind.parse_args()
|
27 |
+
|
28 |
+
with argbind.scope(args):
|
29 |
+
remove_quiet_files()
|
scripts/xeno-canto-dl.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from xenopy import Query
|
2 |
+
|
3 |
+
|
4 |
+
SPECIES = [
|
5 |
+
"American Robin",
|
6 |
+
"Northern Cardinal",
|
7 |
+
"Mourning Dove",
|
8 |
+
"American Crow",
|
9 |
+
"Baltimore Oriole",
|
10 |
+
"Blue Jay",
|
11 |
+
"Eastern Bluebird",
|
12 |
+
"House Finch",
|
13 |
+
"American Goldfinch",
|
14 |
+
"House Sparrow",
|
15 |
+
"Song Sparrow",
|
16 |
+
"Tufted Titmouse",
|
17 |
+
"White-breasted Nuthatch",
|
18 |
+
"European Starling",
|
19 |
+
"American Redstart",
|
20 |
+
"Red-winged Blackbird",
|
21 |
+
"Brown-headed Cowbird",
|
22 |
+
"Common Grackle",
|
23 |
+
"Boat-tailed Grackle",
|
24 |
+
"Common Yellowthroat",
|
25 |
+
"Northern Mockingbird",
|
26 |
+
"Carolina Wren",
|
27 |
+
"Eastern Meadowlark",
|
28 |
+
"Chipping Sparrow",
|
29 |
+
"Tree Swallow",
|
30 |
+
"Barn Swallow",
|
31 |
+
"Cliff Swallow",
|
32 |
+
"Pine Siskin",
|
33 |
+
"Indigo Bunting",
|
34 |
+
"Eastern Towhee",
|
35 |
+
"Carolina Chickadee",
|
36 |
+
"Great Crested Flycatcher",
|
37 |
+
"Eastern Wood-Pewee",
|
38 |
+
"Ovenbird",
|
39 |
+
"Northern Flicker",
|
40 |
+
"Red-eyed Vireo",
|
41 |
+
"American Woodcock",
|
42 |
+
"Eastern Phoebe",
|
43 |
+
"Downy Woodpecker",
|
44 |
+
"Scarlet Tanager",
|
45 |
+
"Yellow Warbler",
|
46 |
+
"White-eyed Vireo",
|
47 |
+
"Common Loon",
|
48 |
+
"White-throated Sparrow",
|
49 |
+
"Yellow-throated Vireo",
|
50 |
+
"Great Blue Heron",
|
51 |
+
"Belted Kingfisher",
|
52 |
+
"Pied-billed Grebe",
|
53 |
+
"Wild Turkey",
|
54 |
+
"Wood Thrush",
|
55 |
+
"Rose-breasted Grosbeak",
|
56 |
+
"Field Sparrow",
|
57 |
+
"Hooded Warbler",
|
58 |
+
"Northern Parula",
|
59 |
+
"Chestnut-sided Warbler",
|
60 |
+
"Blue-winged Warbler",
|
61 |
+
"Red-bellied Woodpecker",
|
62 |
+
"Yellow-billed Cuckoo",
|
63 |
+
"Gray Catbird",
|
64 |
+
"Northern Saw-whet Owl",
|
65 |
+
"Osprey",
|
66 |
+
"Common Nighthawk",
|
67 |
+
"Broad-winged Hawk",
|
68 |
+
"Black-throated Green Warbler",
|
69 |
+
"Great Horned Owl",
|
70 |
+
"Common Raven",
|
71 |
+
"Barred Owl",
|
72 |
+
"Canada Warbler",
|
73 |
+
"Magnolia Warbler",
|
74 |
+
"Black-and-white Warbler",
|
75 |
+
"Eastern Kingbird",
|
76 |
+
"Swainson's Thrush",
|
77 |
+
"Worm-eating Warbler",
|
78 |
+
"Prairie Warbler",
|
79 |
+
"Baltimore Oriole",
|
80 |
+
"Black-throated Blue Warbler",
|
81 |
+
"Louisiana Waterthrush",
|
82 |
+
"Blackburnian Warbler",
|
83 |
+
"Black-capped Chickadee",
|
84 |
+
"Cerulean Warbler",
|
85 |
+
"Red-shouldered Hawk",
|
86 |
+
"Cooper's Hawk",
|
87 |
+
"Yellow-throated Warbler",
|
88 |
+
"Blue-headed Vireo",
|
89 |
+
"Blackpoll Warbler",
|
90 |
+
"Ruffed Grouse",
|
91 |
+
"Kentucky Warbler",
|
92 |
+
"Hermit Thrush",
|
93 |
+
"Cedar Waxwing",
|
94 |
+
"Eastern Screech-Owl",
|
95 |
+
"Northern Goshawk",
|
96 |
+
"Green Heron",
|
97 |
+
"Red-tailed Hawk",
|
98 |
+
"Black Vulture",
|
99 |
+
"Hairy Woodpecker",
|
100 |
+
"Golden-crowned Kinglet",
|
101 |
+
"Ruby-crowned Kinglet",
|
102 |
+
"Bicknell's Thrush",
|
103 |
+
"Blue-gray Gnatcatcher",
|
104 |
+
"Veery",
|
105 |
+
"Pileated Woodpecker",
|
106 |
+
"Purple Finch",
|
107 |
+
"White-crowned Sparrow",
|
108 |
+
"Snow Bunting",
|
109 |
+
"Pine Grosbeak",
|
110 |
+
"American Tree Sparrow",
|
111 |
+
"Dark-eyed Junco",
|
112 |
+
"Snowy Owl",
|
113 |
+
"White-winged Crossbill",
|
114 |
+
"Red Crossbill",
|
115 |
+
"Common Redpoll",
|
116 |
+
"Northern Shrike",
|
117 |
+
"Northern Harrier",
|
118 |
+
"Rough-legged Hawk",
|
119 |
+
"Long-eared Owl",
|
120 |
+
"Evening Grosbeak",
|
121 |
+
"Northern Pintail",
|
122 |
+
"American Black Duck",
|
123 |
+
"Mallard",
|
124 |
+
"Canvasback",
|
125 |
+
"Redhead",
|
126 |
+
"Ring-necked Duck",
|
127 |
+
"Greater Scaup",
|
128 |
+
"Lesser Scaup",
|
129 |
+
"Bufflehead",
|
130 |
+
"Common Goldeneye",
|
131 |
+
"Hooded Merganser",
|
132 |
+
"Common Merganser",
|
133 |
+
"Red-breasted Merganser",
|
134 |
+
"Ruddy Duck",
|
135 |
+
"Wood Duck",
|
136 |
+
"Gadwall",
|
137 |
+
"American Wigeon",
|
138 |
+
"Northern Shoveler",
|
139 |
+
"Green-winged Teal",
|
140 |
+
"Blue-winged Teal",
|
141 |
+
"Cinnamon Teal",
|
142 |
+
"Ringed Teal",
|
143 |
+
"Cape Teal",
|
144 |
+
"Northern Fulmar",
|
145 |
+
"Yellow-billed Loon",
|
146 |
+
"Red-throated Loon",
|
147 |
+
"Arctic Loon",
|
148 |
+
"Pacific Loon",
|
149 |
+
"Horned Grebe",
|
150 |
+
"Red-necked Grebe",
|
151 |
+
"Eared Grebe",
|
152 |
+
"Western Grebe",
|
153 |
+
"Clark's Grebe",
|
154 |
+
"Double-crested Cormorant",
|
155 |
+
"Pelagic Cormorant",
|
156 |
+
"Great Cormorant",
|
157 |
+
"American White Pelican",
|
158 |
+
"Brown Pelican",
|
159 |
+
"Brandt's Cormorant",
|
160 |
+
"Least Bittern",
|
161 |
+
"Great Egret",
|
162 |
+
"Snowy Egret",
|
163 |
+
"Little Blue Heron",
|
164 |
+
"Tricolored Heron",
|
165 |
+
"Reddish Egret",
|
166 |
+
"Black-crowned Night-Heron",
|
167 |
+
"Yellow-crowned Night-Heron",
|
168 |
+
"White Ibis",
|
169 |
+
"Glossy Ibis",
|
170 |
+
"Roseate Spoonbill",
|
171 |
+
"Wood Stork",
|
172 |
+
"Black-bellied Whistling-Duck",
|
173 |
+
"Fulvous Whistling-Duck",
|
174 |
+
"Greater White-fronted Goose",
|
175 |
+
"Snow Goose",
|
176 |
+
"Ross's Goose",
|
177 |
+
"Canada Goose",
|
178 |
+
"Brant",
|
179 |
+
"Mute Swan",
|
180 |
+
"Tundra Swan",
|
181 |
+
"Whooper Swan",
|
182 |
+
"Sandhill Crane",
|
183 |
+
"Black-necked Stilt",
|
184 |
+
"American Avocet",
|
185 |
+
"Northern Jacana",
|
186 |
+
"Greater Yellowlegs",
|
187 |
+
"Lesser Yellowlegs",
|
188 |
+
"Willet",
|
189 |
+
"Spotted Sandpiper",
|
190 |
+
"Upland Sandpiper",
|
191 |
+
"Whimbrel",
|
192 |
+
"Long-billed Curlew",
|
193 |
+
"Marbled Godwit",
|
194 |
+
"Ruddy Turnstone",
|
195 |
+
"Red Knot",
|
196 |
+
"Sanderling",
|
197 |
+
"Semipalmated Sandpiper",
|
198 |
+
"Western Sandpiper",
|
199 |
+
"Least Sandpiper",
|
200 |
+
"White-rumped Sandpiper",
|
201 |
+
"Baird's Sandpiper",
|
202 |
+
"Pectoral Sandpiper",
|
203 |
+
"Dunlin",
|
204 |
+
"Buff-breasted Sandpiper",
|
205 |
+
"Short-billed Dowitcher",
|
206 |
+
"Long-billed Dowitcher",
|
207 |
+
"Common Snipe",
|
208 |
+
"American Woodcock",
|
209 |
+
"Wilson's Phalarope",
|
210 |
+
"Red-necked Phalarope",
|
211 |
+
"Red Phalarope"
|
212 |
+
]
|
213 |
+
|
214 |
+
from pathlib import Path
|
215 |
+
|
216 |
+
def remove_spaces(s):
|
217 |
+
return s.replace(" ", "")
|
218 |
+
|
219 |
+
for species in SPECIES:
|
220 |
+
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
221 |
+
continue
|
222 |
+
try:
|
223 |
+
q = Query(
|
224 |
+
name=species, q="A", length="10-30",
|
225 |
+
)
|
226 |
+
|
227 |
+
# retrieve metadata
|
228 |
+
metafiles = q.retrieve_meta(verbose=True)
|
229 |
+
# retrieve recordings
|
230 |
+
q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
|
231 |
+
|
232 |
+
except:
|
233 |
+
print("Failed to download " + species)
|
234 |
+
continue
|
vampnet/mask.py
CHANGED
@@ -191,29 +191,47 @@ def onset_mask(
|
|
191 |
width: int = 1
|
192 |
):
|
193 |
import librosa
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
return mask
|
214 |
|
215 |
|
216 |
|
217 |
if __name__ == "__main__":
|
218 |
-
|
219 |
-
|
|
|
191 |
width: int = 1
|
192 |
):
|
193 |
import librosa
|
194 |
+
import madmom
|
195 |
+
from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
|
196 |
+
import tempfile
|
197 |
+
import numpy as np
|
198 |
+
|
199 |
+
with tempfile.NamedTemporaryFile(suffix='.wav') as f:
|
200 |
+
sig = sig.clone()
|
201 |
+
sig.write(f.name)
|
202 |
+
|
203 |
+
proc = RNNOnsetProcessor(online=False)
|
204 |
+
onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
|
205 |
+
fps=sig.sample_rate/interface.codec.hop_length)
|
206 |
+
|
207 |
+
act = proc(f.name)
|
208 |
+
onset_times = onsetproc(act)
|
209 |
+
|
210 |
+
# convert to indices for z array
|
211 |
+
onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
|
212 |
+
|
213 |
+
if onset_indices.shape[0] == 0:
|
214 |
+
mask = empty_mask(z)
|
215 |
+
print(f"no onsets found, returning empty mask")
|
216 |
+
else:
|
217 |
+
torch.set_printoptions(threshold=1000)
|
218 |
+
print("onset indices: ", onset_indices)
|
219 |
+
print("onset times: ", onset_times)
|
220 |
+
|
221 |
+
# create a mask, set onset
|
222 |
+
mask = torch.ones_like(z)
|
223 |
+
n_timesteps = z.shape[-1]
|
224 |
+
|
225 |
+
for onset_index in onset_indices:
|
226 |
+
onset_index = min(onset_index, n_timesteps - 1)
|
227 |
+
onset_index = max(onset_index, 0)
|
228 |
+
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
229 |
+
|
230 |
+
print(mask)
|
231 |
|
232 |
return mask
|
233 |
|
234 |
|
235 |
|
236 |
if __name__ == "__main__":
|
237 |
+
pass
|
|
vampnet/modules/transformer.py
CHANGED
@@ -367,15 +367,6 @@ class TransformerLayer(nn.Module):
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
370 |
-
def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
|
371 |
-
x = np.linspace(0, 1, n_steps)
|
372 |
-
a = (0.5 - min_temp) / (max_temp - min_temp)
|
373 |
-
|
374 |
-
x = (x * 12) - 6
|
375 |
-
x0 = np.log((1 / a - 1) + 1e-5) / k
|
376 |
-
y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
|
377 |
-
|
378 |
-
return y
|
379 |
|
380 |
class TransformerStack(nn.Module):
|
381 |
def __init__(
|
@@ -598,7 +589,7 @@ class VampNet(at.ml.BaseModel):
|
|
598 |
top_p=None,
|
599 |
return_signal=True,
|
600 |
seed: int = None,
|
601 |
-
sample_cutoff: float = 0.5
|
602 |
):
|
603 |
if seed is not None:
|
604 |
at.util.seed(seed)
|
@@ -651,7 +642,6 @@ class VampNet(at.ml.BaseModel):
|
|
651 |
#################
|
652 |
# begin sampling #
|
653 |
#################
|
654 |
-
t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
|
655 |
|
656 |
for i in range(sampling_steps):
|
657 |
logging.debug(f"step {i} of {sampling_steps}")
|
@@ -680,7 +670,7 @@ class VampNet(at.ml.BaseModel):
|
|
680 |
logits, sample=(
|
681 |
(i / sampling_steps) <= sample_cutoff
|
682 |
),
|
683 |
-
temperature=
|
684 |
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
685 |
typical_min_tokens=typical_min_tokens,
|
686 |
top_k=None, top_p=top_p, return_probs=True,
|
@@ -843,7 +833,11 @@ def sample_from_logits(
|
|
843 |
|
844 |
|
845 |
|
846 |
-
def mask_by_random_topk(
|
|
|
|
|
|
|
|
|
847 |
"""
|
848 |
Args:
|
849 |
num_to_mask (int): number of tokens to mask
|
@@ -856,7 +850,8 @@ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: floa
|
|
856 |
logging.debug(f"temperature: {temperature}")
|
857 |
logging.debug("")
|
858 |
|
859 |
-
|
|
|
860 |
logging.debug(f"confidence shape: {confidence.shape}")
|
861 |
|
862 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
class TransformerStack(nn.Module):
|
372 |
def __init__(
|
|
|
589 |
top_p=None,
|
590 |
return_signal=True,
|
591 |
seed: int = None,
|
592 |
+
sample_cutoff: float = 0.5,
|
593 |
):
|
594 |
if seed is not None:
|
595 |
at.util.seed(seed)
|
|
|
642 |
#################
|
643 |
# begin sampling #
|
644 |
#################
|
|
|
645 |
|
646 |
for i in range(sampling_steps):
|
647 |
logging.debug(f"step {i} of {sampling_steps}")
|
|
|
670 |
logits, sample=(
|
671 |
(i / sampling_steps) <= sample_cutoff
|
672 |
),
|
673 |
+
temperature=sampling_temperature,
|
674 |
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
675 |
typical_min_tokens=typical_min_tokens,
|
676 |
top_k=None, top_p=top_p, return_probs=True,
|
|
|
833 |
|
834 |
|
835 |
|
836 |
+
def mask_by_random_topk(
|
837 |
+
num_to_mask: int,
|
838 |
+
probs: torch.Tensor,
|
839 |
+
temperature: float = 1.0,
|
840 |
+
):
|
841 |
"""
|
842 |
Args:
|
843 |
num_to_mask (int): number of tokens to mask
|
|
|
850 |
logging.debug(f"temperature: {temperature}")
|
851 |
logging.debug("")
|
852 |
|
853 |
+
noise = gumbel_noise_like(probs)
|
854 |
+
confidence = torch.log(probs) + temperature * noise
|
855 |
logging.debug(f"confidence shape: {confidence.shape}")
|
856 |
|
857 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|