Hugo Flores commited on
Commit
534a89c
β€’
1 Parent(s): fc839a6

refactor bugfixes

Browse files
conf/vampnet.yml CHANGED
@@ -1,5 +1,5 @@
1
 
2
- wav2wav_ckpt: /u/home/src/runs/codec-ckpt/codec.pth
3
  save_path: ckpt
4
  max_epochs: 1000000
5
  epoch_length: 1000
 
1
 
2
+ codec_ckpt: /u/home/src/runs/codec-ckpt/codec.pth
3
  save_path: ckpt
4
  max_epochs: 1000000
5
  epoch_length: 1000
lyrebird-audiotools CHANGED
@@ -1 +1 @@
1
- Subproject commit 018a055ff7406c7bcb3b175551356ec18ba895b7
 
1
+ Subproject commit 3b1abbe27a846f3e2330cacc3ddf70a280b08e98
scripts/{generative β†’ exp}/eval.py RENAMED
File without changes
scripts/{generative β†’ exp}/train.py RENAMED
@@ -114,8 +114,8 @@ def load(
114
  "map_location": "cpu",
115
  "package": not load_weights,
116
  }
117
- if (Path(kwargs["folder"]) / "model").exists():
118
- model, v_extra = model.load_from_folder(**kwargs)
119
 
120
  codec = LAC.load(args["codec_ckpt"], map_location="cpu")
121
  codec.eval()
@@ -215,6 +215,29 @@ def accuracy(
215
 
216
  return accuracy
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  @argbind.bind(without_prefix=True)
220
  def train(
@@ -288,7 +311,7 @@ def train(
288
  class Trainer(at.ml.BaseTrainer):
289
  _last_grad_norm = 0.0
290
 
291
- def metrics(self, vn, z_hat, r, target, flat_mask, output):
292
  for r_range in [(0, 0.5), (0.5, 1.0)]:
293
  unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
294
  masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
@@ -324,7 +347,6 @@ def train(
324
  )
325
 
326
  def train_loop(self, engine, batch):
327
-
328
  model.train()
329
  batch = at.util.prepare_batch(batch, accel.device)
330
  signal = apply_transform(train_data.transform, batch)
@@ -333,22 +355,18 @@ def train(
333
  vn = accel.unwrap(model)
334
  with accel.autocast():
335
  with torch.inference_mode():
 
336
  z = codec.encode(signal.samples, signal.sample_rate)["codes"]
337
  z = z[:, : vn.n_codebooks, :]
338
 
339
  n_batch = z.shape[0]
340
  r = rng.draw(n_batch)[:, 0].to(accel.device)
341
 
342
- if prefix_amt > 0.0:
343
- prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
344
- n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
345
- else:
346
- n_prefix = None
347
- if suffix_amt > 0.0:
348
- suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
349
- n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
350
- else:
351
- n_suffix = None
352
 
353
  z_mask, mask = vn.add_noise(
354
  z, r, n_prefix=n_prefix, n_suffix=n_suffix
@@ -378,7 +396,7 @@ def train(
378
  else:
379
  output["loss"] = criterion(z_hat, target)
380
 
381
- self.metrics(
382
  vn=vn,
383
  r=r,
384
  z_hat=z_hat,
@@ -430,16 +448,11 @@ def train(
430
  n_batch = z.shape[0]
431
  r = rng.draw(n_batch)[:, 0].to(accel.device)
432
 
433
- if prefix_amt > 0.0:
434
- prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
435
- n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
436
- else:
437
- n_prefix = None
438
- if suffix_amt > 0.0:
439
- suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
440
- n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
441
- else:
442
- n_suffix = None
443
 
444
  z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
445
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
@@ -466,7 +479,7 @@ def train(
466
  else:
467
  output["loss"] = criterion(z_hat, target)
468
 
469
- self.metrics(
470
  vn=vn,
471
  r=r,
472
  z_hat=z_hat,
@@ -516,7 +529,7 @@ def train(
516
 
517
  for i in range(num_samples):
518
  sampled = accel.unwrap(model).sample(
519
- codec,
520
  time_steps=z.shape[-1],
521
  start_tokens=z[i : i + 1],
522
  )
@@ -547,7 +560,7 @@ def train(
547
  for i in range(len(z)):
548
  imputed.append(
549
  accel.unwrap(model).sample(
550
- codec,
551
  time_steps=z.shape[-1],
552
  start_tokens=z[i][None, ...],
553
  mask=imp_mask[i][None, ...],
@@ -593,16 +606,11 @@ def train(
593
 
594
  n_batch = z.shape[0]
595
 
596
- if prefix_amt > 0.0:
597
- prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
598
- n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
599
- else:
600
- n_prefix = None
601
- if suffix_amt > 0.0:
602
- suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
603
- n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
604
- else:
605
- n_suffix = None
606
 
607
  z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
608
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
 
114
  "map_location": "cpu",
115
  "package": not load_weights,
116
  }
117
+ if (Path(kwargs["folder"]) / "vampnet").exists():
118
+ model, v_extra = VampNet.load_from_folder(**kwargs)
119
 
120
  codec = LAC.load(args["codec_ckpt"], map_location="cpu")
121
  codec.eval()
 
215
 
216
  return accuracy
217
 
218
+ def sample_prefix_suffix_amt(
219
+ n_batch,
220
+ prefix_amt,
221
+ suffix_amt,
222
+ prefix_dropout,
223
+ suffix_dropout,
224
+ rng
225
+ ):
226
+ """
227
+ Sample the number of prefix and suffix tokens to drop.
228
+ """
229
+ if prefix_amt > 0.0:
230
+ prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
231
+ n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
232
+ else:
233
+ n_prefix = None
234
+ if suffix_amt > 0.0:
235
+ suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
236
+ n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
237
+ else:
238
+ n_suffix = None
239
+ return n_prefix, n_suffix
240
+
241
 
242
  @argbind.bind(without_prefix=True)
243
  def train(
 
311
  class Trainer(at.ml.BaseTrainer):
312
  _last_grad_norm = 0.0
313
 
314
+ def _metrics(self, vn, z_hat, r, target, flat_mask, output):
315
  for r_range in [(0, 0.5), (0.5, 1.0)]:
316
  unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
317
  masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
 
347
  )
348
 
349
  def train_loop(self, engine, batch):
 
350
  model.train()
351
  batch = at.util.prepare_batch(batch, accel.device)
352
  signal = apply_transform(train_data.transform, batch)
 
355
  vn = accel.unwrap(model)
356
  with accel.autocast():
357
  with torch.inference_mode():
358
+ codec.to(accel.device)
359
  z = codec.encode(signal.samples, signal.sample_rate)["codes"]
360
  z = z[:, : vn.n_codebooks, :]
361
 
362
  n_batch = z.shape[0]
363
  r = rng.draw(n_batch)[:, 0].to(accel.device)
364
 
365
+ n_prefix, n_suffix = sample_prefix_suffix_amt(
366
+ n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
367
+ prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
368
+ rng=rng
369
+ )
 
 
 
 
 
370
 
371
  z_mask, mask = vn.add_noise(
372
  z, r, n_prefix=n_prefix, n_suffix=n_suffix
 
396
  else:
397
  output["loss"] = criterion(z_hat, target)
398
 
399
+ self._metrics(
400
  vn=vn,
401
  r=r,
402
  z_hat=z_hat,
 
448
  n_batch = z.shape[0]
449
  r = rng.draw(n_batch)[:, 0].to(accel.device)
450
 
451
+ n_prefix, n_suffix = sample_prefix_suffix_amt(
452
+ n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
453
+ prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
454
+ rng=rng
455
+ )
 
 
 
 
 
456
 
457
  z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
458
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
 
479
  else:
480
  output["loss"] = criterion(z_hat, target)
481
 
482
+ self._metrics(
483
  vn=vn,
484
  r=r,
485
  z_hat=z_hat,
 
529
 
530
  for i in range(num_samples):
531
  sampled = accel.unwrap(model).sample(
532
+ codec=codec,
533
  time_steps=z.shape[-1],
534
  start_tokens=z[i : i + 1],
535
  )
 
560
  for i in range(len(z)):
561
  imputed.append(
562
  accel.unwrap(model).sample(
563
+ codec=codec,
564
  time_steps=z.shape[-1],
565
  start_tokens=z[i][None, ...],
566
  mask=imp_mask[i][None, ...],
 
606
 
607
  n_batch = z.shape[0]
608
 
609
+ n_prefix, n_suffix = sample_prefix_suffix_amt(
610
+ n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
611
+ prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
612
+ rng=rng
613
+ )
 
 
 
 
 
614
 
615
  z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
616
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
vampnet/modules/activations.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import torch
3
  import torch.nn as nn
@@ -5,7 +6,6 @@ import torch.nn.functional as F
5
  from einops import rearrange
6
 
7
 
8
-
9
  class NewGELU(nn.Module):
10
  """
11
  Implementation of the GELU activation function currently in Google BERT repo
 
1
+ import math
2
  import numpy as np
3
  import torch
4
  import torch.nn as nn
 
6
  from einops import rearrange
7
 
8
 
 
9
  class NewGELU(nn.Module):
10
  """
11
  Implementation of the GELU activation function currently in Google BERT repo
vampnet/modules/base.py CHANGED
@@ -85,8 +85,6 @@ class VampBase(at.ml.BaseModel):
85
  mask = mask[:, self.n_conditioning_codebooks :, :]
86
 
87
  truth = F.one_hot(z_true, self.vocab_size)
88
- print(truth.shape)
89
- # truth = rearrange(truth, "b c t p -> b p (t c)")
90
  mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
91
  z_hat = rearrange(
92
  z_hat,
@@ -127,16 +125,16 @@ class VampBase(at.ml.BaseModel):
127
  return r
128
 
129
  @torch.no_grad()
130
- def to_signal(self, z, vqvae):
131
  if z.ndim == 2:
132
  z = self.embedding.unflatten(z)
133
  assert z.ndim == 3
134
 
135
  signal = at.AudioSignal(
136
- vqvae.decode(
137
- vqvae.quantizer.from_latents(self.embedding.from_codes(z, vqvae))[0]
138
  )["audio"],
139
- vqvae.sample_rate,
140
  )
141
 
142
  return signal
@@ -150,7 +148,7 @@ class VampBase(at.ml.BaseModel):
150
 
151
  def paella_sample(
152
  self,
153
- vqvae,
154
  time_steps: int = 400,
155
  sampling_steps: int = 12,
156
  start_tokens: Optional[torch.Tensor] = None,
@@ -219,7 +217,7 @@ class VampBase(at.ml.BaseModel):
219
  if renoise_mode == "prev":
220
  z_prev = z.clone()
221
 
222
- latents = self.embedding.from_codes(z, vqvae)
223
  logits = self.forward(latents, r[i])
224
 
225
  # for mask mode
@@ -258,13 +256,13 @@ class VampBase(at.ml.BaseModel):
258
  z = start_tokens * (1 - mask) + z * mask
259
 
260
  if return_signal:
261
- return self.to_signal(z, vqvae)
262
  else:
263
  return z
264
 
265
  def maskgit_sample(
266
  self,
267
- vqvae,
268
  time_steps: int = 300,
269
  sampling_steps: int = 24,
270
  start_tokens: Optional[torch.Tensor] = None,
@@ -338,7 +336,7 @@ class VampBase(at.ml.BaseModel):
338
  z_masked = z.masked_fill(~keep_mask_unflat.bool(), self.mask_token)
339
 
340
  # get latents
341
- latents = self.embedding.from_codes(z_masked, vqvae)
342
 
343
  # infer from latents
344
  logits = self.forward(latents, r)
@@ -400,7 +398,7 @@ class VampBase(at.ml.BaseModel):
400
  # z = torch.cat([z[:, :self.n_conditioning_codebooks, :], z_inferred], dim=1)
401
 
402
  if return_signal:
403
- return self.to_signal(z, vqvae)
404
  else:
405
  return z
406
 
 
85
  mask = mask[:, self.n_conditioning_codebooks :, :]
86
 
87
  truth = F.one_hot(z_true, self.vocab_size)
 
 
88
  mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
89
  z_hat = rearrange(
90
  z_hat,
 
125
  return r
126
 
127
  @torch.no_grad()
128
+ def to_signal(self, z, codec):
129
  if z.ndim == 2:
130
  z = self.embedding.unflatten(z)
131
  assert z.ndim == 3
132
 
133
  signal = at.AudioSignal(
134
+ codec.decode(
135
+ codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
136
  )["audio"],
137
+ codec.sample_rate,
138
  )
139
 
140
  return signal
 
148
 
149
  def paella_sample(
150
  self,
151
+ codec,
152
  time_steps: int = 400,
153
  sampling_steps: int = 12,
154
  start_tokens: Optional[torch.Tensor] = None,
 
217
  if renoise_mode == "prev":
218
  z_prev = z.clone()
219
 
220
+ latents = self.embedding.from_codes(z, codec)
221
  logits = self.forward(latents, r[i])
222
 
223
  # for mask mode
 
256
  z = start_tokens * (1 - mask) + z * mask
257
 
258
  if return_signal:
259
+ return self.to_signal(z, codec)
260
  else:
261
  return z
262
 
263
  def maskgit_sample(
264
  self,
265
+ codec,
266
  time_steps: int = 300,
267
  sampling_steps: int = 24,
268
  start_tokens: Optional[torch.Tensor] = None,
 
336
  z_masked = z.masked_fill(~keep_mask_unflat.bool(), self.mask_token)
337
 
338
  # get latents
339
+ latents = self.embedding.from_codes(z_masked, codec)
340
 
341
  # infer from latents
342
  logits = self.forward(latents, r)
 
398
  # z = torch.cat([z[:, :self.n_conditioning_codebooks, :], z_inferred], dim=1)
399
 
400
  if return_signal:
401
+ return self.to_signal(z, codec)
402
  else:
403
  return z
404
 
vampnet/modules/layers.py CHANGED
@@ -113,13 +113,13 @@ class CodebookEmbedding(nn.Module):
113
 
114
  self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
115
 
116
- def from_codes(self, codes: torch.Tensor, vqvae):
117
  n_codebooks = codes.shape[1]
118
  latent = []
119
  for i in range(n_codebooks):
120
  c = codes[:, i, :]
121
 
122
- lookup_table = vqvae.quantizer.quantizers[i].codebook.weight
123
  if hasattr(self, "special"):
124
  special_lookup = torch.cat(
125
  [self.special[tkn][i : i + 1] for tkn in self.special], dim=0
 
113
 
114
  self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
115
 
116
+ def from_codes(self, codes: torch.Tensor, codec):
117
  n_codebooks = codes.shape[1]
118
  latent = []
119
  for i in range(n_codebooks):
120
  c = codes[:, i, :]
121
 
122
+ lookup_table = codec.quantizer.quantizers[i].codebook.weight
123
  if hasattr(self, "special"):
124
  special_lookup = torch.cat(
125
  [self.special[tkn][i : i + 1] for tkn in self.special], dim=0
vampnet/modules/wavenet.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from einops import rearrange
3
+
4
+ from voicegpt.nn import WaveNet
5
+
6
+ class AutoregMLP(nn.Module):
7
+ """Implements an autoregressive ConvNet decoder
8
+ Refer to SampleRNN (https://arxiv.org/abs/1612.07837) for motivation
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size: int,
14
+ d_model: int,
15
+ n_layers: int,
16
+ n_fine_tokens: int = 6,
17
+ n_tokens: int = 9,
18
+ dropout: float = 0.1,
19
+ activation: str = "gelu",
20
+ causal: bool = True,
21
+ ):
22
+ super().__init__()
23
+ self.n_fine = n_fine_tokens
24
+ self.n_layers = n_layers
25
+ self.upsampler = nn.Linear(d_model, d_model * n_fine_tokens)
26
+
27
+ self.wavenet = WaveNet(
28
+ d_model,
29
+ d_model,
30
+ d_model,
31
+ n_layers,
32
+ n_fine_tokens,
33
+ dropout=dropout,
34
+ activation=activation,
35
+ causal=causal,
36
+ )
37
+ self.ff_output = nn.Linear(d_model, vocab_size * n_tokens, bias=False)
38
+
39
+ def time_upsample(self, h_t_coarse):
40
+ """Upsamples the conditioning hidden states to match the time resolution
41
+ of output tokens
42
+ Parameters
43
+ ----------
44
+ h_t_coarse : Tensor[B x T_coarse x D]
45
+ Conditioning hidden states in coarse time-scale
46
+ Returns
47
+ -------
48
+ Tensor[B x T_fine x D]
49
+ Conditioning hidden states in fine time-scale
50
+ """
51
+ # Upsample the transformer hidden states to fine scale
52
+ h_t_fine = rearrange(
53
+ self.upsampler(h_t_coarse), "b t (n d) -> b (t n) d", n=self.n_fine
54
+ )
55
+ return h_t_fine
56
+
57
+ def decode_logits(self, x_tm1, h_t_fine):
58
+ """Decodes output logits conditioned on previous output
59
+ tokens (upto timestep t-1) and conditioning hidden states
60
+ using an autoregressive WaveNet
61
+ Parameters
62
+ ----------
63
+ x_tm1 : Tensor[B x T x D]
64
+ h_t_fine : Tensor[B x T x D]
65
+ Returns
66
+ -------
67
+ Tensor[B x T x vocab_size]
68
+ Predicted logits
69
+ """
70
+
71
+ # Compute wavenet layers and predict logits
72
+ o_t = self.wavenet(x_tm1, h_t_fine)
73
+ return self.ff_output(o_t)
74
+
75
+ def forward(self, x_tm1, h_t_coarse):
76
+ """Computes autoregressive conditional probability distribution
77
+ using a WaveNet decoder
78
+ Parameters
79
+ ----------
80
+ x_tm1 : Tensor[B x T_fine x D]
81
+ Embeddings of tokens at fine time-scale
82
+ h_t_coarse : Tensor[B x T_coarse x D]
83
+ Hidden states at coarse time scale
84
+ Returns
85
+ -------
86
+ Tensor[B x T_fine x vocab_size]
87
+ Predicted logits at fine time-scale
88
+ """
89
+ h_t_fine = self.time_upsample(h_t_coarse)
90
+ return self.decode_logits(x_tm1, h_t_fine)