Hugo Flores Garcia commited on
Commit
a66dc9c
2 Parent(s): 62f49b0 de03185

Merge branch 'lora-app' into ismir

Browse files
app.py CHANGED
@@ -18,10 +18,55 @@ Interface = argbind.bind(Interface)
18
 
19
  conf = argbind.parse_args()
20
 
21
- with argbind.scope(conf):
22
- interface = Interface()
23
- # loader = AudioLoader()
24
- print(f"interface device is {interface.device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # dataset = at.data.datasets.AudioDataset(
27
  # loader,
@@ -55,9 +100,15 @@ def load_example_audio():
55
 
56
 
57
  def _vamp(data, return_mask=False):
 
 
58
  out_dir = OUT_DIR / str(uuid.uuid4())
59
  out_dir.mkdir()
60
  sig = at.AudioSignal(data[input_audio])
 
 
 
 
61
 
62
  z = interface.encode(sig)
63
 
@@ -97,7 +148,27 @@ def _vamp(data, return_mask=False):
97
  mask = pmask.codebook_unmask(mask, ncc)
98
 
99
 
100
- print(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  _top_p = data[top_p] if data[top_p] > 0 else None
102
  # save the mask as a txt file
103
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
@@ -117,7 +188,6 @@ def _vamp(data, return_mask=False):
117
  gen_fn=interface.coarse.generate,
118
  seed=_seed,
119
  sample_cutoff=data[sample_cutoff],
120
- classes=_classes,
121
  )
122
 
123
  if use_coarse2fine:
@@ -177,6 +247,7 @@ def save_vamp(data):
177
  "stretch_factor": data[stretch_factor],
178
  "seed": data[seed],
179
  "samplecutoff": data[sample_cutoff],
 
180
  }
181
 
182
  # save with yaml
@@ -322,7 +393,7 @@ with gr.Blocks() as demo:
322
  onset_mask_width = gr.Slider(
323
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
324
  minimum=0,
325
- maximum=20,
326
  step=1,
327
  value=5,
328
  )
@@ -340,6 +411,14 @@ with gr.Blocks() as demo:
340
 
341
 
342
  with gr.Accordion("extras ", open=False):
 
 
 
 
 
 
 
 
343
  rand_mask_intensity = gr.Slider(
344
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
345
  minimum=0.0,
@@ -402,14 +481,15 @@ with gr.Blocks() as demo:
402
  masktemp = gr.Slider(
403
  label="mask temperature",
404
  minimum=0.0,
405
- maximum=10.0,
406
  value=1.5
407
  )
408
  sampletemp = gr.Slider(
409
  label="sample temperature",
410
  minimum=0.1,
411
- maximum=2.0,
412
- value=1.0
 
413
  )
414
 
415
 
@@ -425,7 +505,7 @@ with gr.Blocks() as demo:
425
  label="typical filtering ",
426
  value=False
427
  )
428
- typical_mass = gr.Slider(
429
  label="typical mass (should probably stay between 0.1 and 0.5)",
430
  minimum=0.01,
431
  maximum=0.99,
@@ -438,6 +518,13 @@ with gr.Blocks() as demo:
438
  step=1,
439
  value=64
440
  )
 
 
 
 
 
 
 
441
 
442
  use_coarse2fine = gr.Checkbox(
443
  label="use coarse2fine",
@@ -461,10 +548,6 @@ with gr.Blocks() as demo:
461
  value=0.0
462
  )
463
 
464
- use_new_trick = gr.Checkbox(
465
- label="new trick",
466
- value=False
467
- )
468
 
469
  seed = gr.Number(
470
  label="seed (0 for random)",
@@ -476,6 +559,14 @@ with gr.Blocks() as demo:
476
 
477
  # mask settings
478
  with gr.Column():
 
 
 
 
 
 
 
 
479
  vamp_button = gr.Button("generate (vamp)!!!")
480
  output_audio = gr.Audio(
481
  label="output audio",
@@ -518,7 +609,9 @@ with gr.Blocks() as demo:
518
  beat_mask_width,
519
  beat_mask_downbeats,
520
  seed,
521
- sample_cutoff,
 
 
522
  }
523
 
524
  # connect widgets
@@ -548,4 +641,4 @@ with gr.Blocks() as demo:
548
  outputs=[thank_you, download_file]
549
  )
550
 
551
- demo.launch(share=True, enable_queue=False, debug=True)
 
18
 
19
  conf = argbind.parse_args()
20
 
21
+
22
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
23
+ def shift_pitch(signal, interval: int):
24
+ signal.samples = pitch_shift(
25
+ signal.samples,
26
+ shift=interval,
27
+ sample_rate=signal.sample_rate
28
+ )
29
+ return signal
30
+
31
+ def load_interface():
32
+ with argbind.scope(conf):
33
+ interface = Interface()
34
+ # loader = AudioLoader()
35
+ print(f"interface device is {interface.device}")
36
+ return interface
37
+
38
+
39
+ LORA_NONE = "None"
40
+ def load_loras():
41
+ loras = {}
42
+ # find confs under conf/generated
43
+ for conf_file in Path("conf/generated").glob("**/interface.yml"):
44
+ name = conf_file.parent.name
45
+ with open(conf_file) as f:
46
+ loras[name] = yaml.safe_load(f)
47
+ loras[LORA_NONE] = None
48
+ return loras
49
+
50
+ interface = load_interface()
51
+ loras = load_loras()
52
+ cur_lora = LORA_NONE
53
+
54
+ def load_lora(name):
55
+ global interface
56
+ global cur_lora
57
+ if name == cur_lora:
58
+ return
59
+ if name != LORA_NONE:
60
+ interface.lora_load(
61
+ coarse_ckpt=loras[name]["Interface.coarse_lora_ckpt"],
62
+ c2f_ckpt=loras[name]["Interface.coarse2fine_lora_ckpt"],
63
+ full_ckpts=False
64
+ )
65
+ cur_lora = name
66
+
67
+ else:
68
+ interface = load_interface()
69
+ cur_lora = LORA_NONE
70
 
71
  # dataset = at.data.datasets.AudioDataset(
72
  # loader,
 
100
 
101
 
102
  def _vamp(data, return_mask=False):
103
+ load_lora(data[lora_choice])
104
+
105
  out_dir = OUT_DIR / str(uuid.uuid4())
106
  out_dir.mkdir()
107
  sig = at.AudioSignal(data[input_audio])
108
+ sig = interface.preprocess(sig)
109
+
110
+ if data[pitch_shift_amt] != 0:
111
+ sig = shift_pitch(sig, data[pitch_shift_amt])
112
 
113
  z = interface.encode(sig)
114
 
 
148
  mask = pmask.codebook_unmask(mask, ncc)
149
 
150
 
151
+ print(f"dropout {data[dropout]}")
152
+ print(f"masktemp {data[masktemp]}")
153
+ print(f"sampletemp {data[sampletemp]}")
154
+ print(f"top_p {data[top_p]}")
155
+ print(f"prefix_s {data[prefix_s]}")
156
+ print(f"suffix_s {data[suffix_s]}")
157
+ print(f"rand_mask_intensity {data[rand_mask_intensity]}")
158
+ print(f"num_steps {data[num_steps]}")
159
+ print(f"periodic_p {data[periodic_p]}")
160
+ print(f"periodic_w {data[periodic_w]}")
161
+ print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
162
+ print(f"use_coarse2fine {data[use_coarse2fine]}")
163
+ print(f"onset_mask_width {data[onset_mask_width]}")
164
+ print(f"beat_mask_width {data[beat_mask_width]}")
165
+ print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
166
+ print(f"stretch_factor {data[stretch_factor]}")
167
+ print(f"seed {data[seed]}")
168
+ print(f"pitch_shift_amt {data[pitch_shift_amt]}")
169
+ print(f"sample_cutoff {data[sample_cutoff]}")
170
+
171
+
172
  _top_p = data[top_p] if data[top_p] > 0 else None
173
  # save the mask as a txt file
174
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
 
188
  gen_fn=interface.coarse.generate,
189
  seed=_seed,
190
  sample_cutoff=data[sample_cutoff],
 
191
  )
192
 
193
  if use_coarse2fine:
 
247
  "stretch_factor": data[stretch_factor],
248
  "seed": data[seed],
249
  "samplecutoff": data[sample_cutoff],
250
+ "lora": data[lora_choice],
251
  }
252
 
253
  # save with yaml
 
393
  onset_mask_width = gr.Slider(
394
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
395
  minimum=0,
396
+ maximum=100,
397
  step=1,
398
  value=5,
399
  )
 
411
 
412
 
413
  with gr.Accordion("extras ", open=False):
414
+ pitch_shift_amt = gr.Slider(
415
+ label="pitch shift amount (semitones)",
416
+ minimum=-12,
417
+ maximum=12,
418
+ step=1,
419
+ value=0,
420
+ )
421
+
422
  rand_mask_intensity = gr.Slider(
423
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
424
  minimum=0.0,
 
481
  masktemp = gr.Slider(
482
  label="mask temperature",
483
  minimum=0.0,
484
+ maximum=100.0,
485
  value=1.5
486
  )
487
  sampletemp = gr.Slider(
488
  label="sample temperature",
489
  minimum=0.1,
490
+ maximum=10.0,
491
+ value=1.0,
492
+ step=0.001
493
  )
494
 
495
 
 
505
  label="typical filtering ",
506
  value=False
507
  )
508
+ typical_mass = gr.Slider(
509
  label="typical mass (should probably stay between 0.1 and 0.5)",
510
  minimum=0.01,
511
  maximum=0.99,
 
518
  step=1,
519
  value=64
520
  )
521
+ sample_cutoff = gr.Slider(
522
+ label="sample cutoff",
523
+ minimum=0.0,
524
+ maximum=1.0,
525
+ value=0.5,
526
+ step=0.01
527
+ )
528
 
529
  use_coarse2fine = gr.Checkbox(
530
  label="use coarse2fine",
 
548
  value=0.0
549
  )
550
 
 
 
 
 
551
 
552
  seed = gr.Number(
553
  label="seed (0 for random)",
 
559
 
560
  # mask settings
561
  with gr.Column():
562
+
563
+ lora_choice = gr.Dropdown(
564
+ label="lora choice",
565
+ choices=list(loras.keys()),
566
+ value=LORA_NONE,
567
+ visible=False
568
+ )
569
+
570
  vamp_button = gr.Button("generate (vamp)!!!")
571
  output_audio = gr.Audio(
572
  label="output audio",
 
609
  beat_mask_width,
610
  beat_mask_downbeats,
611
  seed,
612
+ lora_choice,
613
+ pitch_shift_amt,
614
+ sample_cutoff
615
  }
616
 
617
  # connect widgets
 
641
  outputs=[thank_you, download_file]
642
  )
643
 
644
+ demo.launch(share=True, enable_queue=True, debug=True)
conf/lora/lora.yml CHANGED
@@ -4,14 +4,16 @@ $include:
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
- val/AudioDataset.n_examples: 100
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
  batch_size: 7
13
  num_workers: 7
14
- save_iters: [100000, 200000, 300000, 4000000, 500000]
 
 
15
 
16
  AdamW.lr: 0.0001
17
 
 
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 500
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
  batch_size: 7
13
  num_workers: 7
14
+ save_iters: [10000, 20000, 30000, 40000, 50000]
15
+ sample_freq: 1000
16
+ val_freq: 500
17
 
18
  AdamW.lr: 0.0001
19
 
scripts/exp/fine_tune.py CHANGED
@@ -48,10 +48,10 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
48
  }
49
 
50
  interface_conf = {
51
- "Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
52
  "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
- "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
57
 
 
48
  }
49
 
50
  interface_conf = {
51
+ "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
52
  "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
+ "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
57
 
scripts/utils/augment.py CHANGED
@@ -5,34 +5,19 @@ from audiotools import AudioSignal
5
 
6
  import argbind
7
  import tqdm
 
8
 
9
 
10
- from pedalboard import (
11
- Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
12
- )
13
- from pedalboard.io import AudioFile
14
 
15
- # Read in a whole file, resampling to our desired sample rate:
16
- samplerate = 44100.0
17
- with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
18
- audio = f.read(f.frames)
19
-
20
- # Make a pretty interesting sounding guitar pedalboard:
21
- board = Pedalboard([
22
- Compressor(threshold_db=-50, ratio=25),
23
- Gain(gain_db=30),
24
- Chorus(),
25
- LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
26
- Phaser(),
27
- Convolution("./guitar_amp.wav", 1.0),
28
- Reverb(room_size=0.25),
29
- ])
30
 
31
 
32
  @argbind.bind(without_prefix=True)
33
  def augment(
34
- audio_folder: Path,
35
- dest_folder: Path,
36
  n_augmentations: int = 10,
37
  ):
38
  """
@@ -41,7 +26,8 @@ def augment(
41
  The dest foler will contain a folder for each of the clean dataset's files.
42
  Under each of these folders, there will be a clean file and many augmented files.
43
  """
44
-
 
45
  audio_files = at.util.find_audio(audio_folder)
46
 
47
  for audio_file in tqdm.tqdm(audio_files):
@@ -49,5 +35,33 @@ def augment(
49
  subdir = subtree / audio_file.stem
50
  subdir.mkdir(parents=True, exist_ok=True)
51
 
52
- # apply pedalboard transforms
53
- for i in range(n_augmentations):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import argbind
7
  import tqdm
8
+ import torch
9
 
10
 
11
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
12
+ from torch_time_stretch import time_stretch, get_fast_stretches
 
 
13
 
14
+ from audiotools.core.util import sample_from_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @argbind.bind(without_prefix=True)
18
  def augment(
19
+ audio_folder: Path = None,
20
+ dest_folder: Path = None,
21
  n_augmentations: int = 10,
22
  ):
23
  """
 
26
  The dest foler will contain a folder for each of the clean dataset's files.
27
  Under each of these folders, there will be a clean file and many augmented files.
28
  """
29
+ assert audio_folder is not None
30
+ assert dest_folder is not None
31
  audio_files = at.util.find_audio(audio_folder)
32
 
33
  for audio_file in tqdm.tqdm(audio_files):
 
35
  subdir = subtree / audio_file.stem
36
  subdir.mkdir(parents=True, exist_ok=True)
37
 
38
+ src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+
41
+ for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
42
+ # apply pedalboard transforms
43
+ for j in range(n_augmentations):
44
+ # pitch shift between -7 and 7 semitones
45
+ import random
46
+ dst = chunk.clone()
47
+ dst.samples = pitch_shift(
48
+ dst.samples,
49
+ shift=random.choice(get_fast_shifts(src.sample_rate,
50
+ condition=lambda x: x >= 0.25 and x <= 1.0)),
51
+ sample_rate=src.sample_rate
52
+ )
53
+ dst.samples = time_stretch(
54
+ dst.samples,
55
+ stretch=random.choice(get_fast_stretches(src.sample_rate,
56
+ condition=lambda x: x >= 0.667 and x <= 1.5, )),
57
+ sample_rate=src.sample_rate,
58
+ )
59
+
60
+ dst.cpu().write(subdir / f"{i}-{j}.wav")
61
+
62
+
63
+ if __name__ == "__main__":
64
+ args = argbind.parse_args()
65
+
66
+ with argbind.scope(args):
67
+ augment()
scripts/utils/remove_quiet_files.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # removes files with loudness below 24db
2
+
3
+ from pathlib import Path
4
+ import shutil
5
+ import audiotools as at
6
+ import argbind
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def remove_quiet_files(
10
+ src_dir: Path = None,
11
+ dest_dir: Path = None,
12
+ min_loudness: float = -30,
13
+ ):
14
+ # copy src to dest
15
+ dest_dir.mkdir(parents=True, exist_ok=True)
16
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
+
18
+ audio_files = at.util.find_audio(dest_dir)
19
+ for audio_file in audio_files:
20
+ sig = at.AudioSignal(audio_file)
21
+ if sig.loudness() < min_loudness:
22
+ audio_file.unlink()
23
+ print(f"removed {audio_file}")
24
+
25
+ if __name__ == "__main__":
26
+ args = argbind.parse_args()
27
+
28
+ with argbind.scope(args):
29
+ remove_quiet_files()
scripts/xeno-canto-dl.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xenopy import Query
2
+
3
+
4
+ SPECIES = [
5
+ "American Robin",
6
+ "Northern Cardinal",
7
+ "Mourning Dove",
8
+ "American Crow",
9
+ "Baltimore Oriole",
10
+ "Blue Jay",
11
+ "Eastern Bluebird",
12
+ "House Finch",
13
+ "American Goldfinch",
14
+ "House Sparrow",
15
+ "Song Sparrow",
16
+ "Tufted Titmouse",
17
+ "White-breasted Nuthatch",
18
+ "European Starling",
19
+ "American Redstart",
20
+ "Red-winged Blackbird",
21
+ "Brown-headed Cowbird",
22
+ "Common Grackle",
23
+ "Boat-tailed Grackle",
24
+ "Common Yellowthroat",
25
+ "Northern Mockingbird",
26
+ "Carolina Wren",
27
+ "Eastern Meadowlark",
28
+ "Chipping Sparrow",
29
+ "Tree Swallow",
30
+ "Barn Swallow",
31
+ "Cliff Swallow",
32
+ "Pine Siskin",
33
+ "Indigo Bunting",
34
+ "Eastern Towhee",
35
+ "Carolina Chickadee",
36
+ "Great Crested Flycatcher",
37
+ "Eastern Wood-Pewee",
38
+ "Ovenbird",
39
+ "Northern Flicker",
40
+ "Red-eyed Vireo",
41
+ "American Woodcock",
42
+ "Eastern Phoebe",
43
+ "Downy Woodpecker",
44
+ "Scarlet Tanager",
45
+ "Yellow Warbler",
46
+ "White-eyed Vireo",
47
+ "Common Loon",
48
+ "White-throated Sparrow",
49
+ "Yellow-throated Vireo",
50
+ "Great Blue Heron",
51
+ "Belted Kingfisher",
52
+ "Pied-billed Grebe",
53
+ "Wild Turkey",
54
+ "Wood Thrush",
55
+ "Rose-breasted Grosbeak",
56
+ "Field Sparrow",
57
+ "Hooded Warbler",
58
+ "Northern Parula",
59
+ "Chestnut-sided Warbler",
60
+ "Blue-winged Warbler",
61
+ "Red-bellied Woodpecker",
62
+ "Yellow-billed Cuckoo",
63
+ "Gray Catbird",
64
+ "Northern Saw-whet Owl",
65
+ "Osprey",
66
+ "Common Nighthawk",
67
+ "Broad-winged Hawk",
68
+ "Black-throated Green Warbler",
69
+ "Great Horned Owl",
70
+ "Common Raven",
71
+ "Barred Owl",
72
+ "Canada Warbler",
73
+ "Magnolia Warbler",
74
+ "Black-and-white Warbler",
75
+ "Eastern Kingbird",
76
+ "Swainson's Thrush",
77
+ "Worm-eating Warbler",
78
+ "Prairie Warbler",
79
+ "Baltimore Oriole",
80
+ "Black-throated Blue Warbler",
81
+ "Louisiana Waterthrush",
82
+ "Blackburnian Warbler",
83
+ "Black-capped Chickadee",
84
+ "Cerulean Warbler",
85
+ "Red-shouldered Hawk",
86
+ "Cooper's Hawk",
87
+ "Yellow-throated Warbler",
88
+ "Blue-headed Vireo",
89
+ "Blackpoll Warbler",
90
+ "Ruffed Grouse",
91
+ "Kentucky Warbler",
92
+ "Hermit Thrush",
93
+ "Cedar Waxwing",
94
+ "Eastern Screech-Owl",
95
+ "Northern Goshawk",
96
+ "Green Heron",
97
+ "Red-tailed Hawk",
98
+ "Black Vulture",
99
+ "Hairy Woodpecker",
100
+ "Golden-crowned Kinglet",
101
+ "Ruby-crowned Kinglet",
102
+ "Bicknell's Thrush",
103
+ "Blue-gray Gnatcatcher",
104
+ "Veery",
105
+ "Pileated Woodpecker",
106
+ "Purple Finch",
107
+ "White-crowned Sparrow",
108
+ "Snow Bunting",
109
+ "Pine Grosbeak",
110
+ "American Tree Sparrow",
111
+ "Dark-eyed Junco",
112
+ "Snowy Owl",
113
+ "White-winged Crossbill",
114
+ "Red Crossbill",
115
+ "Common Redpoll",
116
+ "Northern Shrike",
117
+ "Northern Harrier",
118
+ "Rough-legged Hawk",
119
+ "Long-eared Owl",
120
+ "Evening Grosbeak",
121
+ "Northern Pintail",
122
+ "American Black Duck",
123
+ "Mallard",
124
+ "Canvasback",
125
+ "Redhead",
126
+ "Ring-necked Duck",
127
+ "Greater Scaup",
128
+ "Lesser Scaup",
129
+ "Bufflehead",
130
+ "Common Goldeneye",
131
+ "Hooded Merganser",
132
+ "Common Merganser",
133
+ "Red-breasted Merganser",
134
+ "Ruddy Duck",
135
+ "Wood Duck",
136
+ "Gadwall",
137
+ "American Wigeon",
138
+ "Northern Shoveler",
139
+ "Green-winged Teal",
140
+ "Blue-winged Teal",
141
+ "Cinnamon Teal",
142
+ "Ringed Teal",
143
+ "Cape Teal",
144
+ "Northern Fulmar",
145
+ "Yellow-billed Loon",
146
+ "Red-throated Loon",
147
+ "Arctic Loon",
148
+ "Pacific Loon",
149
+ "Horned Grebe",
150
+ "Red-necked Grebe",
151
+ "Eared Grebe",
152
+ "Western Grebe",
153
+ "Clark's Grebe",
154
+ "Double-crested Cormorant",
155
+ "Pelagic Cormorant",
156
+ "Great Cormorant",
157
+ "American White Pelican",
158
+ "Brown Pelican",
159
+ "Brandt's Cormorant",
160
+ "Least Bittern",
161
+ "Great Egret",
162
+ "Snowy Egret",
163
+ "Little Blue Heron",
164
+ "Tricolored Heron",
165
+ "Reddish Egret",
166
+ "Black-crowned Night-Heron",
167
+ "Yellow-crowned Night-Heron",
168
+ "White Ibis",
169
+ "Glossy Ibis",
170
+ "Roseate Spoonbill",
171
+ "Wood Stork",
172
+ "Black-bellied Whistling-Duck",
173
+ "Fulvous Whistling-Duck",
174
+ "Greater White-fronted Goose",
175
+ "Snow Goose",
176
+ "Ross's Goose",
177
+ "Canada Goose",
178
+ "Brant",
179
+ "Mute Swan",
180
+ "Tundra Swan",
181
+ "Whooper Swan",
182
+ "Sandhill Crane",
183
+ "Black-necked Stilt",
184
+ "American Avocet",
185
+ "Northern Jacana",
186
+ "Greater Yellowlegs",
187
+ "Lesser Yellowlegs",
188
+ "Willet",
189
+ "Spotted Sandpiper",
190
+ "Upland Sandpiper",
191
+ "Whimbrel",
192
+ "Long-billed Curlew",
193
+ "Marbled Godwit",
194
+ "Ruddy Turnstone",
195
+ "Red Knot",
196
+ "Sanderling",
197
+ "Semipalmated Sandpiper",
198
+ "Western Sandpiper",
199
+ "Least Sandpiper",
200
+ "White-rumped Sandpiper",
201
+ "Baird's Sandpiper",
202
+ "Pectoral Sandpiper",
203
+ "Dunlin",
204
+ "Buff-breasted Sandpiper",
205
+ "Short-billed Dowitcher",
206
+ "Long-billed Dowitcher",
207
+ "Common Snipe",
208
+ "American Woodcock",
209
+ "Wilson's Phalarope",
210
+ "Red-necked Phalarope",
211
+ "Red Phalarope"
212
+ ]
213
+
214
+ from pathlib import Path
215
+
216
+ def remove_spaces(s):
217
+ return s.replace(" ", "")
218
+
219
+ for species in SPECIES:
220
+ if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
+ continue
222
+ try:
223
+ q = Query(
224
+ name=species, q="A", length="10-30",
225
+ )
226
+
227
+ # retrieve metadata
228
+ metafiles = q.retrieve_meta(verbose=True)
229
+ # retrieve recordings
230
+ q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
231
+
232
+ except:
233
+ print("Failed to download " + species)
234
+ continue
vampnet/interface.py CHANGED
@@ -120,17 +120,16 @@ class Interface(torch.nn.Module):
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
-
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
-
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
133
-
134
  def s2t(self, seconds: float):
135
  """seconds to tokens"""
136
  if isinstance(seconds, np.ndarray):
 
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
+ print(f"loading coarse from {coarse_ckpt}")
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
+ print(f"loading c2f from {c2f_ckpt}")
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
 
133
  def s2t(self, seconds: float):
134
  """seconds to tokens"""
135
  if isinstance(seconds, np.ndarray):
vampnet/mask.py CHANGED
@@ -191,29 +191,47 @@ def onset_mask(
191
  width: int = 1
192
  ):
193
  import librosa
194
-
195
- onset_indices = librosa.onset.onset_detect(
196
- y=sig.clone().to_mono().samples.cpu().numpy()[0, 0],
197
- sr=sig.sample_rate,
198
- hop_length=interface.codec.hop_length,
199
- backtrack=True,
200
- )
201
-
202
- # create a mask, set onset
203
- mask = torch.ones_like(z)
204
- n_timesteps = z.shape[-1]
205
-
206
- for onset_index in onset_indices:
207
- onset_index = min(onset_index, n_timesteps - 1)
208
- onset_index = max(onset_index, 0)
209
- mask[:, :, onset_index - width:onset_index + width] = 0.0
210
-
211
- print(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  return mask
214
 
215
 
216
 
217
  if __name__ == "__main__":
218
- torch.set_printoptions(threshold=10000)
219
-
 
191
  width: int = 1
192
  ):
193
  import librosa
194
+ import madmom
195
+ from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
196
+ import tempfile
197
+ import numpy as np
198
+
199
+ with tempfile.NamedTemporaryFile(suffix='.wav') as f:
200
+ sig = sig.clone()
201
+ sig.write(f.name)
202
+
203
+ proc = RNNOnsetProcessor(online=False)
204
+ onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
205
+ fps=sig.sample_rate/interface.codec.hop_length)
206
+
207
+ act = proc(f.name)
208
+ onset_times = onsetproc(act)
209
+
210
+ # convert to indices for z array
211
+ onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
212
+
213
+ if onset_indices.shape[0] == 0:
214
+ mask = empty_mask(z)
215
+ print(f"no onsets found, returning empty mask")
216
+ else:
217
+ torch.set_printoptions(threshold=1000)
218
+ print("onset indices: ", onset_indices)
219
+ print("onset times: ", onset_times)
220
+
221
+ # create a mask, set onset
222
+ mask = torch.ones_like(z)
223
+ n_timesteps = z.shape[-1]
224
+
225
+ for onset_index in onset_indices:
226
+ onset_index = min(onset_index, n_timesteps - 1)
227
+ onset_index = max(onset_index, 0)
228
+ mask[:, :, onset_index - width:onset_index + width] = 0.0
229
+
230
+ print(mask)
231
 
232
  return mask
233
 
234
 
235
 
236
  if __name__ == "__main__":
237
+ pass
 
vampnet/modules/transformer.py CHANGED
@@ -367,15 +367,6 @@ class TransformerLayer(nn.Module):
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
370
- def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
371
- x = np.linspace(0, 1, n_steps)
372
- a = (0.5 - min_temp) / (max_temp - min_temp)
373
-
374
- x = (x * 12) - 6
375
- x0 = np.log((1 / a - 1) + 1e-5) / k
376
- y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
377
-
378
- return y
379
 
380
  class TransformerStack(nn.Module):
381
  def __init__(
@@ -598,7 +589,7 @@ class VampNet(at.ml.BaseModel):
598
  top_p=None,
599
  return_signal=True,
600
  seed: int = None,
601
- sample_cutoff: float = 0.5
602
  ):
603
  if seed is not None:
604
  at.util.seed(seed)
@@ -651,7 +642,6 @@ class VampNet(at.ml.BaseModel):
651
  #################
652
  # begin sampling #
653
  #################
654
- t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
655
 
656
  for i in range(sampling_steps):
657
  logging.debug(f"step {i} of {sampling_steps}")
@@ -680,7 +670,7 @@ class VampNet(at.ml.BaseModel):
680
  logits, sample=(
681
  (i / sampling_steps) <= sample_cutoff
682
  ),
683
- temperature=t_sched[i],
684
  typical_filtering=typical_filtering, typical_mass=typical_mass,
685
  typical_min_tokens=typical_min_tokens,
686
  top_k=None, top_p=top_p, return_probs=True,
@@ -843,7 +833,11 @@ def sample_from_logits(
843
 
844
 
845
 
846
- def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
 
 
 
 
847
  """
848
  Args:
849
  num_to_mask (int): number of tokens to mask
@@ -856,7 +850,8 @@ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: floa
856
  logging.debug(f"temperature: {temperature}")
857
  logging.debug("")
858
 
859
- confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
 
860
  logging.debug(f"confidence shape: {confidence.shape}")
861
 
862
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)
 
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
 
 
 
 
 
 
 
 
 
370
 
371
  class TransformerStack(nn.Module):
372
  def __init__(
 
589
  top_p=None,
590
  return_signal=True,
591
  seed: int = None,
592
+ sample_cutoff: float = 0.5,
593
  ):
594
  if seed is not None:
595
  at.util.seed(seed)
 
642
  #################
643
  # begin sampling #
644
  #################
 
645
 
646
  for i in range(sampling_steps):
647
  logging.debug(f"step {i} of {sampling_steps}")
 
670
  logits, sample=(
671
  (i / sampling_steps) <= sample_cutoff
672
  ),
673
+ temperature=sampling_temperature,
674
  typical_filtering=typical_filtering, typical_mass=typical_mass,
675
  typical_min_tokens=typical_min_tokens,
676
  top_k=None, top_p=top_p, return_probs=True,
 
833
 
834
 
835
 
836
+ def mask_by_random_topk(
837
+ num_to_mask: int,
838
+ probs: torch.Tensor,
839
+ temperature: float = 1.0,
840
+ ):
841
  """
842
  Args:
843
  num_to_mask (int): number of tokens to mask
 
850
  logging.debug(f"temperature: {temperature}")
851
  logging.debug("")
852
 
853
+ noise = gumbel_noise_like(probs)
854
+ confidence = torch.log(probs) + temperature * noise
855
  logging.debug(f"confidence shape: {confidence.shape}")
856
 
857
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)