.gitignore CHANGED
@@ -182,7 +182,3 @@ models.zip
182
  .git-old
183
  conf/generated/*
184
  runs*/
185
-
186
-
187
- gtzan.zip
188
- .gtzan_emb_cache
 
182
  .git-old
183
  conf/generated/*
184
  runs*/
 
 
 
 
app.py CHANGED
@@ -1,9 +1,6 @@
1
  # huggingface space exclusive
2
  import os
3
 
4
- # print("installing pyharp")
5
- # os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
6
- # print("installing madmom")
7
  os.system('pip install cython')
8
  os.system('pip install madmom')
9
 
@@ -24,7 +21,8 @@ import gradio as gr
24
  from vampnet.interface import Interface
25
  from vampnet import mask as pmask
26
 
27
- from pyharp import ModelCard, build_endpoint
 
28
 
29
 
30
 
@@ -56,6 +54,13 @@ def load_interface():
56
 
57
  interface = load_interface()
58
 
 
 
 
 
 
 
 
59
 
60
  OUT_DIR = Path("gradio-outputs")
61
  OUT_DIR.mkdir(exist_ok=True, parents=True)
@@ -179,7 +184,7 @@ def _vamp(data, return_mask=False):
179
  mask_temperature=data[masktemp]*10,
180
  sampling_temperature=data[sampletemp],
181
  mask=mask,
182
- sampling_steps=data[num_steps] // 2,
183
  sample_cutoff=data[sample_cutoff],
184
  seed=_seed,
185
  )
@@ -245,46 +250,6 @@ def save_vamp(data):
245
  return f"saved! your save code is {out_dir.stem}", zip_path
246
 
247
 
248
- def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
249
-
250
- out_dir = OUT_DIR / str(uuid.uuid4())
251
- out_dir.mkdir()
252
- sig = at.AudioSignal(_input_audio)
253
- sig = interface.preprocess(sig)
254
-
255
- z = interface.encode(sig)
256
-
257
- # build the mask
258
- mask = pmask.linear_random(z, 1.0)
259
- if _beat_mask_width > 0:
260
- beat_mask = interface.make_beat_mask(
261
- sig,
262
- after_beat_s=(_beat_mask_width/1000),
263
- )
264
- mask = pmask.mask_and(mask, beat_mask)
265
-
266
- # save the mask as a txt file
267
- zv, mask_z = interface.coarse_vamp(
268
- z,
269
- mask=mask,
270
- sampling_temperature=_sampletemp,
271
- return_mask=True,
272
- gen_fn=interface.coarse.generate,
273
- )
274
-
275
-
276
- zv = interface.coarse_to_fine(
277
- zv,
278
- sampling_temperature=_sampletemp,
279
- mask=mask,
280
- )
281
-
282
- sig = interface.to_signal(zv).cpu()
283
- print("done")
284
-
285
- sig.write(out_dir / "output.wav")
286
-
287
- return sig.path_to_file
288
 
289
  with gr.Blocks() as demo:
290
 
@@ -408,7 +373,7 @@ with gr.Blocks() as demo:
408
  minimum=0,
409
  maximum=128,
410
  step=1,
411
- value=3,
412
  )
413
 
414
 
@@ -421,7 +386,7 @@ with gr.Blocks() as demo:
421
  )
422
 
423
  beat_mask_width = gr.Slider(
424
- label="beat prompt (ms)",
425
  minimum=0,
426
  maximum=200,
427
  value=0,
@@ -521,7 +486,7 @@ with gr.Blocks() as demo:
521
  label="top p (0.0 = off)",
522
  minimum=0.0,
523
  maximum=1.0,
524
- value=0.9
525
  )
526
  typical_filtering = gr.Checkbox(
527
  label="typical filtering ",
@@ -581,14 +546,6 @@ with gr.Blocks() as demo:
581
 
582
  # mask settings
583
  with gr.Column():
584
-
585
- # lora_choice = gr.Dropdown(
586
- # label="lora choice",
587
- # choices=list(loras.keys()),
588
- # value=LORA_NONE,
589
- # visible=False
590
- # )
591
-
592
  vamp_button = gr.Button("generate (vamp)!!!")
593
  output_audio = gr.Audio(
594
  label="output audio",
@@ -663,24 +620,4 @@ with gr.Blocks() as demo:
663
  outputs=[thank_you, download_file]
664
  )
665
 
666
- # harp stuff
667
- harp_inputs = [
668
- input_audio,
669
- beat_mask_width,
670
- sampletemp,
671
- ]
672
-
673
- build_endpoint(
674
- inputs=harp_inputs,
675
- output=output_audio,
676
- process_fn=harp_vamp,
677
- card=ModelCard(
678
- name="vampnet",
679
- description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
680
- author="Hugo Flores García",
681
- tags=["music", "generative"]
682
- ),
683
- visible=False
684
- )
685
-
686
  demo.launch()
 
1
  # huggingface space exclusive
2
  import os
3
 
 
 
 
4
  os.system('pip install cython')
5
  os.system('pip install madmom')
6
 
 
21
  from vampnet.interface import Interface
22
  from vampnet import mask as pmask
23
 
24
+ # Interface = argbind.bind(Interface)
25
+ # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
26
 
27
 
28
 
 
54
 
55
  interface = load_interface()
56
 
57
+ # dataset = at.data.datasets.AudioDataset(
58
+ # loader,
59
+ # sample_rate=interface.codec.sample_rate,
60
+ # duration=interface.coarse.chunk_size_s,
61
+ # n_examples=5000,
62
+ # without_replacement=True,
63
+ # )
64
 
65
  OUT_DIR = Path("gradio-outputs")
66
  OUT_DIR.mkdir(exist_ok=True, parents=True)
 
184
  mask_temperature=data[masktemp]*10,
185
  sampling_temperature=data[sampletemp],
186
  mask=mask,
187
+ sampling_steps=data[num_steps],
188
  sample_cutoff=data[sample_cutoff],
189
  seed=_seed,
190
  )
 
250
  return f"saved! your save code is {out_dir.stem}", zip_path
251
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  with gr.Blocks() as demo:
255
 
 
373
  minimum=0,
374
  maximum=128,
375
  step=1,
376
+ value=5,
377
  )
378
 
379
 
 
386
  )
387
 
388
  beat_mask_width = gr.Slider(
389
+ label="beat mask width (in milliseconds)",
390
  minimum=0,
391
  maximum=200,
392
  value=0,
 
486
  label="top p (0.0 = off)",
487
  minimum=0.0,
488
  maximum=1.0,
489
+ value=0.0
490
  )
491
  typical_filtering = gr.Checkbox(
492
  label="typical filtering ",
 
546
 
547
  # mask settings
548
  with gr.Column():
 
 
 
 
 
 
 
 
549
  vamp_button = gr.Button("generate (vamp)!!!")
550
  output_audio = gr.Audio(
551
  label="output audio",
 
620
  outputs=[thank_you, download_file]
621
  )
622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  demo.launch()
conf/lora/lora.yml CHANGED
@@ -9,9 +9,9 @@ val/AudioDataset.n_examples: 500
9
 
10
  NoamScheduler.warmup: 500
11
 
12
- batch_size: 6
13
  num_workers: 7
14
- save_iters: [10000, 20000, 30000, 40000, 50000, 100000]
15
  sample_freq: 1000
16
  val_freq: 500
17
 
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
+ batch_size: 7
13
  num_workers: 7
14
+ save_iters: [10000, 20000, 30000, 40000, 50000]
15
  sample_freq: 1000
16
  val_freq: 500
17
 
conf/vampnet.yml CHANGED
@@ -32,7 +32,7 @@ VampNet.n_heads: 20
32
  VampNet.flash_attn: false
33
  VampNet.dropout: 0.1
34
 
35
- AudioLoader.relative_path: ""
36
  AudioDataset.loudness_cutoff: -30.0
37
  AudioDataset.without_replacement: true
38
  AudioLoader.shuffle: true
 
32
  VampNet.flash_attn: false
33
  VampNet.dropout: 0.1
34
 
35
+ AudioLoader.relative_path: /data/
36
  AudioDataset.loudness_cutoff: -30.0
37
  AudioDataset.without_replacement: true
38
  AudioLoader.shuffle: true
requirements.txt CHANGED
@@ -6,5 +6,4 @@ loralib
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
  descript-audiotools @ git+https://github.com/descriptinc/[email protected]
9
- -e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
10
  torch_pitch_shift
 
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
  descript-audiotools @ git+https://github.com/descriptinc/[email protected]
 
9
  torch_pitch_shift
scripts/exp/train.py CHANGED
@@ -224,7 +224,7 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
224
 
225
  dtype = torch.bfloat16 if accel.amp else None
226
  with accel.autocast(dtype=dtype):
227
- z_hat = state.model(z_mask_latent)
228
 
229
  target = codebook_flatten(
230
  z[:, vn.n_conditioning_codebooks :, :],
@@ -289,7 +289,7 @@ def val_loop(state: State, batch: dict, accel: Accelerator):
289
 
290
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
291
 
292
- z_hat = state.model(z_mask_latent)
293
 
294
  target = codebook_flatten(
295
  z[:, vn.n_conditioning_codebooks :, :],
@@ -408,19 +408,19 @@ def save_imputation(state, z, val_idx, writer):
408
 
409
  for i in range(len(val_idx)):
410
  imputed_noisy[i].cpu().write_audio_to_tb(
411
- f"inpainted_prompt/{i}",
412
  writer,
413
  step=state.tracker.step,
414
  plot_fn=None,
415
  )
416
  imputed[i].cpu().write_audio_to_tb(
417
- f"inpainted_middle/{i}",
418
  writer,
419
  step=state.tracker.step,
420
  plot_fn=None,
421
  )
422
  imputed_true[i].cpu().write_audio_to_tb(
423
- f"reconstructed/{i}",
424
  writer,
425
  step=state.tracker.step,
426
  plot_fn=None,
@@ -450,7 +450,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
450
 
451
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
452
 
453
- z_hat = state.model(z_mask_latent)
454
 
455
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
456
  z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
@@ -469,7 +469,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
469
  }
470
  for k, v in audio_dict.items():
471
  v.cpu().write_audio_to_tb(
472
- f"onestep/_{i}.r={r[i]:0.2f}/{k}",
473
  writer,
474
  step=state.tracker.step,
475
  plot_fn=None,
 
224
 
225
  dtype = torch.bfloat16 if accel.amp else None
226
  with accel.autocast(dtype=dtype):
227
+ z_hat = state.model(z_mask_latent, r)
228
 
229
  target = codebook_flatten(
230
  z[:, vn.n_conditioning_codebooks :, :],
 
289
 
290
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
291
 
292
+ z_hat = state.model(z_mask_latent, r)
293
 
294
  target = codebook_flatten(
295
  z[:, vn.n_conditioning_codebooks :, :],
 
408
 
409
  for i in range(len(val_idx)):
410
  imputed_noisy[i].cpu().write_audio_to_tb(
411
+ f"imputed_noisy/{i}",
412
  writer,
413
  step=state.tracker.step,
414
  plot_fn=None,
415
  )
416
  imputed[i].cpu().write_audio_to_tb(
417
+ f"imputed/{i}",
418
  writer,
419
  step=state.tracker.step,
420
  plot_fn=None,
421
  )
422
  imputed_true[i].cpu().write_audio_to_tb(
423
+ f"imputed_true/{i}",
424
  writer,
425
  step=state.tracker.step,
426
  plot_fn=None,
 
450
 
451
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
452
 
453
+ z_hat = state.model(z_mask_latent, r)
454
 
455
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
456
  z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
 
469
  }
470
  for k, v in audio_dict.items():
471
  v.cpu().write_audio_to_tb(
472
+ f"samples/_{i}.r={r[i]:0.2f}/{k}",
473
  writer,
474
  step=state.tracker.step,
475
  plot_fn=None,
scripts/utils/{data/augment.py → augment.py} RENAMED
@@ -64,4 +64,4 @@ if __name__ == "__main__":
64
  args = argbind.parse_args()
65
 
66
  with argbind.scope(args):
67
- augment()
 
64
  args = argbind.parse_args()
65
 
66
  with argbind.scope(args):
67
+ augment()
scripts/utils/gtzan_embeddings.py DELETED
@@ -1,263 +0,0 @@
1
- """
2
- TODO: train a linear probe
3
- usage:
4
- python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
5
- """
6
- from pathlib import Path
7
- from typing import List
8
-
9
- import audiotools as at
10
- from audiotools import AudioSignal
11
- import argbind
12
- import torch
13
- import numpy as np
14
- import zipfile
15
- import json
16
-
17
- from vampnet.interface import Interface
18
- import tqdm
19
-
20
- # bind the Interface to argbind
21
- Interface = argbind.bind(Interface)
22
-
23
- DEBUG = False
24
-
25
- def smart_plotly_export(fig, save_path):
26
- img_format = save_path.split('.')[-1]
27
- if img_format == 'html':
28
- fig.write_html(save_path)
29
- elif img_format == 'bytes':
30
- return fig.to_image(format='png')
31
- #TODO: come back and make this prettier
32
- elif img_format == 'numpy':
33
- import io
34
- from PIL import Image
35
-
36
- def plotly_fig2array(fig):
37
- #convert Plotly fig to an array
38
- fig_bytes = fig.to_image(format="png", width=1200, height=700)
39
- buf = io.BytesIO(fig_bytes)
40
- img = Image.open(buf)
41
- return np.asarray(img)
42
-
43
- return plotly_fig2array(fig)
44
- elif img_format == 'jpeg' or 'png' or 'webp':
45
- fig.write_image(save_path)
46
- else:
47
- raise ValueError("invalid image format")
48
-
49
- def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
50
- """
51
- dimensionality reduction for visualization!
52
- saves an html plotly figure to save_path
53
- parameters:
54
- emb (np.ndarray): the samples to be reduces with shape (samples, features)
55
- labels (list): list of labels for embedding
56
- save_path (str): path where u wanna save ur figure
57
- method (str): umap, tsne, or pca
58
- title (str): title for ur figure
59
- returns:
60
- proj (np.ndarray): projection vector with shape (samples, dimensions)
61
- """
62
- import pandas as pd
63
- import plotly.express as px
64
- if method == 'umap':
65
- reducer = umap.UMAP(n_components=n_components)
66
- elif method == 'tsne':
67
- from sklearn.manifold import TSNE
68
- reducer = TSNE(n_components=n_components)
69
- elif method == 'pca':
70
- from sklearn.decomposition import PCA
71
- reducer = PCA(n_components=n_components)
72
- else:
73
- raise ValueError
74
-
75
- proj = reducer.fit_transform(emb)
76
-
77
- if n_components == 2:
78
- df = pd.DataFrame(dict(
79
- x=proj[:, 0],
80
- y=proj[:, 1],
81
- instrument=labels
82
- ))
83
- fig = px.scatter(df, x='x', y='y', color='instrument',
84
- title=title+f"_{method}")
85
-
86
- elif n_components == 3:
87
- df = pd.DataFrame(dict(
88
- x=proj[:, 0],
89
- y=proj[:, 1],
90
- z=proj[:, 2],
91
- instrument=labels
92
- ))
93
- fig = px.scatter_3d(df, x='x', y='y', z='z',
94
- color='instrument',
95
- title=title)
96
- else:
97
- raise ValueError("cant plot more than 3 components")
98
-
99
- fig.update_traces(marker=dict(size=6,
100
- line=dict(width=1,
101
- color='DarkSlateGrey')),
102
- selector=dict(mode='markers'))
103
-
104
- return smart_plotly_export(fig, save_path)
105
-
106
-
107
-
108
- # per JukeMIR, we want the emebddings from the middle layer?
109
- def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
110
- with torch.inference_mode():
111
- # preprocess the signal
112
- sig = interface.preprocess(sig)
113
-
114
- # get the coarse vampnet model
115
- vampnet = interface.coarse
116
-
117
- # get the tokens
118
- z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
119
- z_latents = vampnet.embedding.from_codes(z, interface.codec)
120
-
121
- # do a forward pass through the model, get the embeddings
122
- _z, embeddings = vampnet(z_latents, return_activations=True)
123
- # print(f"got embeddings with shape {embeddings.shape}")
124
- # [layer, batch, time, n_dims]
125
- # [20, 1, 600ish, 768]
126
-
127
-
128
- # squeeze batch dim (1 bc layer should be dim 0)
129
- assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
130
- embeddings = embeddings.squeeze(1)
131
-
132
- num_layers = embeddings.shape[0]
133
- assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
134
-
135
- # do meanpooling over the time dimension
136
- embeddings = embeddings.mean(dim=-2)
137
- # [20, 768]
138
-
139
- # return the embeddings
140
- return embeddings
141
-
142
- from dataclasses import dataclass, fields
143
- @dataclass
144
- class Embedding:
145
- genre: str
146
- filename: str
147
- embedding: np.ndarray
148
-
149
- def save(self, path):
150
- """Save the Embedding object to a given path as a zip file."""
151
- with zipfile.ZipFile(path, 'w') as archive:
152
-
153
- # Save numpy array
154
- with archive.open('embedding.npy', 'w') as f:
155
- np.save(f, self.embedding)
156
-
157
- # Save non-numpy data as json
158
- non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
159
- with archive.open('data.json', 'w') as f:
160
- f.write(json.dumps(non_numpy_data).encode('utf-8'))
161
-
162
- @classmethod
163
- def load(cls, path):
164
- """Load the Embedding object from a given zip path."""
165
- with zipfile.ZipFile(path, 'r') as archive:
166
-
167
- # Load numpy array
168
- with archive.open('embedding.npy') as f:
169
- embedding = np.load(f)
170
-
171
- # Load non-numpy data from json
172
- with archive.open('data.json') as f:
173
- data = json.loads(f.read().decode('utf-8'))
174
-
175
- return cls(embedding=embedding, **data)
176
-
177
-
178
- @argbind.bind(without_prefix=True)
179
- def main(
180
- path_to_gtzan: str = None,
181
- cache_dir: str = "./.gtzan_emb_cache",
182
- output_dir: str = "./gtzan_vampnet_embeddings",
183
- layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
184
- ):
185
- path_to_gtzan = Path(path_to_gtzan)
186
- assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
187
-
188
- cache_dir = Path(cache_dir)
189
- output_dir = Path(output_dir)
190
- output_dir.mkdir(exist_ok=True, parents=True)
191
-
192
- # load our interface
193
- # argbind will automatically load the default config,
194
- interface = Interface()
195
-
196
- # gtzan should have a folder for each genre, so let's get the list of genres
197
- genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
198
- print(f"Found {len(genres)} genres")
199
- print(f"genres: {genres}")
200
-
201
- # collect audio files, genres, and embeddings
202
- data = []
203
- for genre in genres:
204
- audio_files = list(at.util.find_audio(path_to_gtzan / genre))
205
- print(f"Found {len(audio_files)} audio files for genre {genre}")
206
-
207
- for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
208
- # check if we have a cached embedding for this file
209
- cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
210
- if cached_path.exists():
211
- # if so, load it
212
- if DEBUG:
213
- print(f"loading cached embedding for {cached_path.stem}")
214
- embedding = Embedding.load(cached_path)
215
- data.append(embedding)
216
- else:
217
- try:
218
- sig = AudioSignal(audio_file)
219
- except Exception as e:
220
- print(f"failed to load {audio_file.name} with error {e}")
221
- print(f"skipping {audio_file.name}")
222
- continue
223
-
224
- # gets the embedding
225
- emb = vampnet_embed(sig, interface).cpu().numpy()
226
-
227
- # create an embedding we can save/load
228
- embedding = Embedding(
229
- genre=genre,
230
- filename=audio_file.name,
231
- embedding=emb
232
- )
233
-
234
- # cache the embeddings
235
- cached_path.parent.mkdir(exist_ok=True, parents=True)
236
- embedding.save(cached_path)
237
-
238
- # now, let's do a dim reduction on the embeddings
239
- # and visualize them.
240
-
241
- # collect a list of embeddings and labels
242
- embeddings = [d.embedding for d in data]
243
- labels = [d.genre for d in data]
244
-
245
- # convert the embeddings to a numpy array
246
- embeddings = np.stack(embeddings)
247
-
248
- # do dimensionality reduction for each layer we're given
249
- for layer in tqdm.tqdm(layers, desc="dim reduction"):
250
- dim_reduce(
251
- embeddings[:, layer, :], labels,
252
- save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
253
- n_components=2, method='tsne',
254
- title=f'vampnet-gtzan-layer={layer}'
255
- )
256
-
257
-
258
-
259
-
260
- if __name__ == "__main__":
261
- args = argbind.parse_args()
262
- with argbind.scope(args):
263
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/{data/maestro-reorg.py → maestro-reorg.py} RENAMED
File without changes
vampnet/modules/transformer.py CHANGED
@@ -410,9 +410,7 @@ class TransformerStack(nn.Module):
410
  def subsequent_mask(self, size):
411
  return torch.ones(1, size, size).tril().bool()
412
 
413
- def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
414
- return_activations: bool = False
415
- ):
416
  """Computes a full transformer stack
417
  Parameters
418
  ----------
@@ -439,8 +437,6 @@ class TransformerStack(nn.Module):
439
  encoder_decoder_position_bias = None
440
 
441
  # Compute transformer layers
442
- if return_activations:
443
- activations = []
444
  for layer in self.layers:
445
  x, position_bias, encoder_decoder_position_bias = layer(
446
  x=x,
@@ -451,15 +447,8 @@ class TransformerStack(nn.Module):
451
  position_bias=position_bias,
452
  encoder_decoder_position_bias=encoder_decoder_position_bias,
453
  )
454
- if return_activations:
455
- activations.append(x.detach())
456
 
457
-
458
- out = self.norm(x) if self.norm is not None else x
459
- if return_activations:
460
- return out, torch.stack(activations)
461
- else:
462
- return out
463
 
464
 
465
  class VampNet(at.ml.BaseModel):
@@ -467,7 +456,7 @@ class VampNet(at.ml.BaseModel):
467
  self,
468
  n_heads: int = 20,
469
  n_layers: int = 16,
470
- r_cond_dim: int = 0,
471
  n_codebooks: int = 9,
472
  n_conditioning_codebooks: int = 0,
473
  latent_dim: int = 8,
@@ -478,7 +467,6 @@ class VampNet(at.ml.BaseModel):
478
  dropout: float = 0.1
479
  ):
480
  super().__init__()
481
- assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
482
  self.n_heads = n_heads
483
  self.n_layers = n_layers
484
  self.r_cond_dim = r_cond_dim
@@ -525,25 +513,21 @@ class VampNet(at.ml.BaseModel):
525
  ),
526
  )
527
 
528
- def forward(self, x, return_activations: bool = False):
529
  x = self.embedding(x)
530
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
531
 
532
- x = rearrange(x, "b d n -> b n d")
533
- out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
534
- if return_activations:
535
- out, activations = out
536
 
 
 
537
  out = rearrange(out, "b n d -> b d n")
538
 
539
- out = self.classifier(out, None) # no cond here!
540
 
541
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
542
 
543
- if return_activations:
544
- return out, activations
545
- else:
546
- return out
547
 
548
  def r_embed(self, r, max_positions=10000):
549
  if self.r_cond_dim > 0:
@@ -605,7 +589,7 @@ class VampNet(at.ml.BaseModel):
605
  top_p=None,
606
  return_signal=True,
607
  seed: int = None,
608
- sample_cutoff: float = 1.0,
609
  ):
610
  if seed is not None:
611
  at.util.seed(seed)
@@ -676,7 +660,7 @@ class VampNet(at.ml.BaseModel):
676
 
677
  # infer from latents
678
  # NOTE: this collapses the codebook dimension into the sequence dimension
679
- logits = self.forward(latents) # b, prob, seq
680
  logits = logits.permute(0, 2, 1) # b, seq, prob
681
  b = logits.shape[0]
682
 
@@ -937,7 +921,7 @@ if __name__ == "__main__":
937
  z_mask_latent = torch.rand(
938
  batch_size, model.latent_dim * model.n_codebooks, seq_len
939
  ).to(device)
940
- z_hat = model(z_mask_latent)
941
 
942
  pred = z_hat.argmax(dim=1)
943
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
 
410
  def subsequent_mask(self, size):
411
  return torch.ones(1, size, size).tril().bool()
412
 
413
+ def forward(self, x, x_mask, cond=None, src=None, src_mask=None):
 
 
414
  """Computes a full transformer stack
415
  Parameters
416
  ----------
 
437
  encoder_decoder_position_bias = None
438
 
439
  # Compute transformer layers
 
 
440
  for layer in self.layers:
441
  x, position_bias, encoder_decoder_position_bias = layer(
442
  x=x,
 
447
  position_bias=position_bias,
448
  encoder_decoder_position_bias=encoder_decoder_position_bias,
449
  )
 
 
450
 
451
+ return self.norm(x) if self.norm is not None else x
 
 
 
 
 
452
 
453
 
454
  class VampNet(at.ml.BaseModel):
 
456
  self,
457
  n_heads: int = 20,
458
  n_layers: int = 16,
459
+ r_cond_dim: int = 64,
460
  n_codebooks: int = 9,
461
  n_conditioning_codebooks: int = 0,
462
  latent_dim: int = 8,
 
467
  dropout: float = 0.1
468
  ):
469
  super().__init__()
 
470
  self.n_heads = n_heads
471
  self.n_layers = n_layers
472
  self.r_cond_dim = r_cond_dim
 
513
  ),
514
  )
515
 
516
+ def forward(self, x, cond):
517
  x = self.embedding(x)
518
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
519
 
520
+ cond = self.r_embed(cond)
 
 
 
521
 
522
+ x = rearrange(x, "b d n -> b n d")
523
+ out = self.transformer(x=x, x_mask=x_mask, cond=cond)
524
  out = rearrange(out, "b n d -> b d n")
525
 
526
+ out = self.classifier(out, cond)
527
 
528
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
529
 
530
+ return out
 
 
 
531
 
532
  def r_embed(self, r, max_positions=10000):
533
  if self.r_cond_dim > 0:
 
589
  top_p=None,
590
  return_signal=True,
591
  seed: int = None,
592
+ sample_cutoff: float = 0.5,
593
  ):
594
  if seed is not None:
595
  at.util.seed(seed)
 
660
 
661
  # infer from latents
662
  # NOTE: this collapses the codebook dimension into the sequence dimension
663
+ logits = self.forward(latents, r) # b, prob, seq
664
  logits = logits.permute(0, 2, 1) # b, seq, prob
665
  b = logits.shape[0]
666
 
 
921
  z_mask_latent = torch.rand(
922
  batch_size, model.latent_dim * model.n_codebooks, seq_len
923
  ).to(device)
924
+ z_hat = model(z_mask_latent, r)
925
 
926
  pred = z_hat.argmax(dim=1)
927
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)