import torch.nn as nn import torchaudio import torch class AugmentMelSTFT(nn.Module): def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192, fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000): torch.nn.Module.__init__(self) # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e self.win_length = win_length self.n_mels = n_mels self.n_fft = n_fft self.sr = sr self.fmin = fmin if fmax is None: fmax = sr // 2 - fmax_aug_range // 2 print(f"Warning: FMAX is None setting to {fmax} ") self.fmax = fmax self.hopsize = hopsize self.register_buffer('window', torch.hann_window(win_length, periodic=False), persistent=False) assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation" assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation" self.fmin_aug_range = fmin_aug_range self.fmax_aug_range = fmax_aug_range self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False) if freqm == 0: self.freqm = torch.nn.Identity() else: self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True) if timem == 0: self.timem = torch.nn.Identity() else: self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True) def forward(self, x): x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1) x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length, center=True, normalized=False, window=self.window, return_complex=False) x = (x ** 2).sum(dim=-1) # power mag fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item() fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item() # don't augment eval data if not self.training: fmin = self.fmin fmax = self.fmax mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr, fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0) mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0), device=x.device) with torch.cuda.amp.autocast(enabled=False): melspec = torch.matmul(mel_basis, x) melspec = (melspec + 0.00001).log() if self.training: melspec = self.freqm(melspec) melspec = self.timem(melspec) melspec = (melspec + 4.5) / 5. # fast normalization return melspec