import torch import julius import torchopenl3 import torchmetrics import pytorch_lightning as pl from typing import Tuple, List, Dict from argparse import ArgumentParser from deepafx_st.probes.cdpam_encoder import CDPAMEncoder from deepafx_st.probes.random_mel import RandomMelProjection import deepafx_st.utils as utils from deepafx_st.utils import DSPMode from deepafx_st.system import System from deepafx_st.data.style import StyleDataset class ProbeSystem(pl.LightningModule): def __init__( self, audio_dir=None, num_classes=5, task="style", encoder_type="deepafx_st_autodiff", deepafx_st_autodiff_ckpt=None, deepafx_st_spsa_ckpt=None, deepafx_st_proxy0_ckpt=None, probe_type="linear", batch_size=32, lr=3e-4, lr_patience=20, patience=10, preload=False, sample_rate=24000, shuffle=True, num_workers=16, **kwargs, ): super().__init__() self.save_hyperparameters() if "deepafx_st" in self.hparams.encoder_type: if "autodiff" in self.hparams.encoder_type: self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_autodiff_ckpt elif "spsa" in self.hparams.encoder_type: self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_spsa_ckpt elif "proxy0" in self.hparams.encoder_type: self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_proxy0_ckpt else: raise RuntimeError(f"Invalid encoder_type: {self.hparams.encoder_type}") if self.hparams.deepafx_st_ckpt is None: raise RuntimeError( f"Must supply {self.hparams.encoder_type}_ckpt checkpoint." ) use_dsp = DSPMode.NONE system = System.load_from_checkpoint( self.hparams.deepafx_st_ckpt, use_dsp=use_dsp, batch_size=self.hparams.batch_size, spsa_parallel=False, proxy_ckpts=[], strict=False, ) system.eval() self.encoder = system.encoder self.hparams.embed_dim = self.encoder.embed_dim # freeze weights for name, param in self.encoder.named_parameters(): param.requires_grad = False elif self.hparams.encoder_type == "openl3": self.encoder = torchopenl3.models.load_audio_embedding_model( input_repr=self.hparams.openl3_input_repr, embedding_size=self.hparams.openl3_embedding_size, content_type=self.hparams.openl3_content_type, ) self.hparams.embed_dim = 6144 elif self.hparams.encoder_type == "random_mel": self.encoder = RandomMelProjection( self.hparams.sample_rate, self.hparams.random_mel_embedding_size, self.hparams.random_mel_n_mels, self.hparams.random_mel_n_fft, self.hparams.random_mel_hop_size, ) self.hparams.embed_dim = self.hparams.random_mel_embedding_size elif self.hparams.encoder_type == "cdpam": self.encoder = CDPAMEncoder(self.hparams.cdpam_ckpt) self.encoder.eval() self.hparams.embed_dim = self.encoder.embed_dim else: raise ValueError(f"Invalid encoder_type: {self.hparams.encoder_type}") if self.hparams.probe_type == "linear": if self.hparams.task == "style": self.probe = torch.nn.Sequential( torch.nn.Linear(self.hparams.embed_dim, self.hparams.num_classes), # torch.nn.Softmax(-1), ) elif self.hparams.probe_type == "mlp": if self.hparams.task == "style": self.probe = torch.nn.Sequential( torch.nn.Linear(self.hparams.embed_dim, 512), torch.nn.ReLU(), torch.nn.Linear(512, 512), torch.nn.ReLU(), torch.nn.Linear(512, self.hparams.num_classes), ) self.accuracy = torchmetrics.Accuracy() self.f1_score = torchmetrics.F1Score(self.hparams.num_classes) def forward(self, x): bs, chs, samp = x.size() with torch.no_grad(): if "deepafx_st" in self.hparams.encoder_type: x /= x.abs().max() x *= 10 ** (-12.0 / 20) # with min 12 dBFS headroom e = self.encoder(x) norm = torch.norm(e, p=2, dim=-1, keepdim=True) e = e / norm elif self.hparams.encoder_type == "openl3": # x = julius.resample_frac(x, self.hparams.sample_rate, 48000) e, ts = torchopenl3.get_audio_embedding( x, 48000, model=self.encoder, input_repr="mel128", content_type="music", ) e = e.permute(0, 2, 1) e = e.mean(dim=-1) # normalize by L2 norm norm = torch.norm(e, p=2, dim=-1, keepdim=True) e = e / norm elif self.hparams.encoder_type == "random_mel": e = self.encoder(x) norm = torch.norm(e, p=2, dim=-1, keepdim=True) e = e / norm elif self.hparams.encoder_type == "cdpam": # x = julius.resample_frac(x, self.hparams.sample_rate, 22050) x = torch.round(x * 32768) e = self.encoder(x) return self.probe(e) def common_step( self, batch: Tuple, batch_idx: int, optimizer_idx: int = 0, train: bool = True, ): loss = 0 x, y = batch y_hat = self(x) # compute CE if self.hparams.task == "style": loss = torch.nn.functional.cross_entropy(y_hat, y) if not train: # store audio data data_dict = {"x": x.float().cpu()} else: data_dict = {} self.log( "train_loss" if train else "val_loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True, sync_dist=True, ) if not train and self.hparams.task == "style": self.log("val_acc_step", self.accuracy(y_hat, y)) self.log("val_f1_step", self.f1_score(y_hat, y)) return loss, data_dict def training_step(self, batch, batch_idx, optimizer_idx=0): loss, _ = self.common_step(batch, batch_idx) return loss def validation_step(self, batch, batch_idx): loss, data_dict = self.common_step(batch, batch_idx, train=False) if batch_idx == 0: return data_dict def validation_epoch_end(self, outputs) -> None: if self.hparams.task == "style": self.log("val_acc_epoch", self.accuracy.compute()) self.log("val_f1_epoch", self.f1_score.compute()) return super().validation_epoch_end(outputs) def configure_optimizers(self): optimizer = torch.optim.AdamW( self.probe.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999), ) ms1 = int(self.hparams.max_epochs * 0.8) ms2 = int(self.hparams.max_epochs * 0.95) print( "Learning rate schedule:", f"0 {self.hparams.lr:0.2e} -> ", f"{ms1} {self.hparams.lr*0.1:0.2e} -> ", f"{ms2} {self.hparams.lr*0.01:0.2e}", ) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ms1, ms2], gamma=0.1, ) return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"} def train_dataloader(self): if self.hparams.task == "style": train_dataset = StyleDataset( self.hparams.audio_dir, "train", sample_rate=self.hparams.encoder_sample_rate, ) g = torch.Generator() g.manual_seed(0) return torch.utils.data.DataLoader( train_dataset, num_workers=self.hparams.num_workers, batch_size=self.hparams.batch_size, shuffle=True, worker_init_fn=utils.seed_worker, generator=g, pin_memory=True, ) def val_dataloader(self): if self.hparams.task == "style": val_dataset = StyleDataset( self.hparams.audio_dir, subset="val", sample_rate=self.hparams.encoder_sample_rate, ) g = torch.Generator() g.manual_seed(0) return torch.utils.data.DataLoader( val_dataset, num_workers=self.hparams.num_workers, batch_size=self.hparams.batch_size, worker_init_fn=utils.seed_worker, generator=g, pin_memory=True, ) # add any model hyperparameters here @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) # --- Model --- parser.add_argument("--encoder_type", type=str, default="deeapfx2") parser.add_argument("--probe_type", type=str, default="linear") parser.add_argument("--task", type=str, default="style") parser.add_argument("--encoder_sample_rate", type=int, default=24000) # --- deeapfx2 --- parser.add_argument("--deepafx_st_autodiff_ckpt", type=str) parser.add_argument("--deepafx_st_spsa_ckpt", type=str) parser.add_argument("--deepafx_st_proxy0_ckpt", type=str) # --- cdpam --- parser.add_argument("--cdpam_ckpt", type=str) # --- openl3 --- parser.add_argument("--openl3_input_repr", type=str, default="mel128") parser.add_argument("--openl3_content_type", type=str, default="env") parser.add_argument("--openl3_embedding_size", type=int, default=6144) # --- random_mel --- parser.add_argument("--random_mel_embedding_size", type=str, default=4096) parser.add_argument("--random_mel_n_fft", type=str, default=4096) parser.add_argument("--random_mel_hop_size", type=str, default=1024) parser.add_argument("--random_mel_n_mels", type=str, default=128) # --- Training --- parser.add_argument("--audio_dir", type=str) parser.add_argument("--num_classes", type=int, default=5) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--lr_patience", type=int, default=20) parser.add_argument("--patience", type=int, default=10) parser.add_argument("--preload", action="store_true") parser.add_argument("--sample_rate", type=int, default=24000) parser.add_argument("--num_workers", type=int, default=8) return parser