""" Implementation of objective functions used in the task 'ITO-Master' """ import numpy as np import torch import torch.nn.functional as F import torch.nn as nn import auraloss import os import sys currentdir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.dirname(currentdir)) from modules.front_back_end import * # Root Mean Squared Loss # penalizes the volume factor with non-linearlity class RMSLoss(nn.Module): def __init__(self, reduce, loss_type="l2"): super(RMSLoss, self).__init__() self.weight_factor = 100. if loss_type=="l2": self.loss = nn.MSELoss(reduce=None) def forward(self, est_targets, targets): est_targets = est_targets.reshape(est_targets.shape[0]*est_targets.shape[1], est_targets.shape[2]) targets = targets.reshape(targets.shape[0]*targets.shape[1], targets.shape[2]) normalized_est = torch.sqrt(torch.mean(est_targets**2, dim=-1)) normalized_tgt = torch.sqrt(torch.mean(targets**2, dim=-1)) weight = torch.clamp(torch.abs(normalized_tgt-normalized_est), min=1/self.weight_factor) * self.weight_factor return torch.mean(weight**1.5 * self.loss(normalized_est, normalized_tgt)) # Multi-Scale Spectral Loss proposed at the paper "DDSP: DIFFERENTIABLE DIGITAL SIGNAL PROCESSING" (https://arxiv.org/abs/2001.04643) # we extend this loss by applying it to mid/side channels class MultiScale_Spectral_Loss_MidSide_DDSP(nn.Module): def __init__(self, mode='midside', \ reduce=True, \ n_filters=None, \ windows_size=None, \ hops_size=None, \ window="hann", \ eps=1e-7, \ device=torch.device("cpu")): super(MultiScale_Spectral_Loss_MidSide_DDSP, self).__init__() self.mode = mode self.eps = eps self.mid_weight = 0.5 # value in the range of 0.0 ~ 1.0 self.logmag_weight = 0.1 if n_filters is None: n_filters = [4096, 2048, 1024, 512] if windows_size is None: windows_size = [4096, 2048, 1024, 512] if hops_size is None: hops_size = [1024, 512, 256, 128] self.multiscales = [] for i in range(len(windows_size)): cur_scale = {'window_size' : float(windows_size[i])} if self.mode=='midside': cur_scale['front_end'] = FrontEnd(channel='mono', \ n_fft=n_filters[i], \ hop_length=hops_size[i], \ win_length=windows_size[i], \ window=window, \ device=device) elif self.mode=='ori': cur_scale['front_end'] = FrontEnd(channel='stereo', \ n_fft=n_filters[i], \ hop_length=hops_size[i], \ win_length=windows_size[i], \ window=window, \ device=device) self.multiscales.append(cur_scale) self.objective_l1 = nn.L1Loss(reduce=reduce) self.objective_l2 = nn.MSELoss(reduce=reduce) def forward(self, est_targets, targets): if self.mode=='midside': return self.forward_midside(est_targets, targets) elif self.mode=='ori': return self.forward_ori(est_targets, targets) def forward_ori(self, est_targets, targets): total_loss = 0.0 total_mag_loss = 0.0 total_logmag_loss = 0.0 for cur_scale in self.multiscales: est_mag = cur_scale['front_end'](est_targets, mode=["mag"]) tgt_mag = cur_scale['front_end'](targets, mode=["mag"]) mag_loss = self.magnitude_loss(est_mag, tgt_mag) logmag_loss = self.log_magnitude_loss(est_mag, tgt_mag) total_mag_loss += mag_loss total_logmag_loss += logmag_loss # return total_loss return (1-self.logmag_weight)*total_mag_loss + \ (self.logmag_weight)*total_logmag_loss def forward_midside(self, est_targets, targets): est_mid, est_side = self.to_mid_side(est_targets) tgt_mid, tgt_side = self.to_mid_side(targets) total_loss = 0.0 total_mag_loss = 0.0 total_logmag_loss = 0.0 for cur_scale in self.multiscales: est_mid_mag = cur_scale['front_end'](est_mid, mode=["mag"]) est_side_mag = cur_scale['front_end'](est_side, mode=["mag"]) tgt_mid_mag = cur_scale['front_end'](tgt_mid, mode=["mag"]) tgt_side_mag = cur_scale['front_end'](tgt_side, mode=["mag"]) mag_loss = self.mid_weight*self.magnitude_loss(est_mid_mag, tgt_mid_mag) + \ (1-self.mid_weight)*self.magnitude_loss(est_side_mag, tgt_side_mag) logmag_loss = self.mid_weight*self.log_magnitude_loss(est_mid_mag, tgt_mid_mag) + \ (1-self.mid_weight)*self.log_magnitude_loss(est_side_mag, tgt_side_mag) total_mag_loss += mag_loss total_logmag_loss += logmag_loss # return total_loss return (1-self.logmag_weight)*total_mag_loss + \ (self.logmag_weight)*total_logmag_loss def to_mid_side(self, stereo_in): mid = stereo_in[:,0] + stereo_in[:,1] side = stereo_in[:,0] - stereo_in[:,1] return mid, side def magnitude_loss(self, est_mag_spec, tgt_mag_spec): return torch.norm(self.objective_l1(est_mag_spec, tgt_mag_spec)) def log_magnitude_loss(self, est_mag_spec, tgt_mag_spec): est_log_mag_spec = torch.log10(est_mag_spec+self.eps) tgt_log_mag_spec = torch.log10(tgt_mag_spec+self.eps) return self.objective_l2(est_log_mag_spec, tgt_log_mag_spec) # Class of available loss functions class Loss: def __init__(self, args, reduce=True): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") self.l1 = nn.L1Loss(reduce=reduce) self.mse = nn.MSELoss(reduce=reduce) self.ce = nn.CrossEntropyLoss() self.triplet = nn.TripletMarginLoss(margin=1., p=2) self.cos = nn.CosineSimilarity(eps=args.eps) self.cosemb = nn.CosineEmbeddingLoss() self.multi_scale_spectral_midside = MultiScale_Spectral_Loss_MidSide_DDSP(mode='midside', eps=args.eps, device=device) self.multi_scale_spectral_ori = MultiScale_Spectral_Loss_MidSide_DDSP(mode='ori', eps=args.eps, device=device) self.gain = RMSLoss(reduce=reduce) self.infonce = infoNCE # perceptual weighting with mel scaled spectrograms self.mrs_mel_perceptual = auraloss.freq.MultiResolutionSTFTLoss( fft_sizes=[1024, 2048, 8192], hop_sizes=[256, 512, 2048], win_lengths=[1024, 2048, 8192], scale="mel", n_bins=128, sample_rate=args.sample_rate, perceptual_weighting=True, ) """ Audio Feature Loss implementation copied from https://github.com/sai-soum/Diff-MST/blob/main/mst/loss.py """ import librosa from typing import List from modules.filter import barkscale_fbanks def compute_mid_side(x: torch.Tensor): x_mid = x[:, 0, :] + x[:, 1, :] x_side = x[:, 0, :] - x[:, 1, :] return x_mid, x_side def compute_melspectrum( x: torch.Tensor, sample_rate: int = 44100, fft_size: int = 32768, n_bins: int = 128, **kwargs, ): """Compute mel-spectrogram. Args: x: (bs, 2, seq_len) sample_rate: sample rate of audio fft_size: size of fft n_bins: number of mel bins Returns: X: (bs, n_bins) """ fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) fb = torch.tensor(fb).unsqueeze(0).type_as(x) x = x.mean(dim=1, keepdim=True) X = torch.fft.rfft(x, n=fft_size, dim=-1) X = torch.abs(X) X = torch.mean(X, dim=1, keepdim=True) # take mean over time X = X.permute(0, 2, 1) # swap time and freq dims X = torch.matmul(fb, X) X = torch.log(X + 1e-8) return X def compute_barkspectrum( x: torch.Tensor, fft_size: int = 32768, n_bands: int = 24, sample_rate: int = 44100, f_min: float = 20.0, f_max: float = 20000.0, mode: str = "mid-side", **kwargs, ): """Compute bark-spectrogram. Args: x: (bs, 2, seq_len) fft_size: size of fft n_bands: number of bark bins sample_rate: sample rate of audio f_min: minimum frequency f_max: maximum frequency mode: "mono", "stereo", or "mid-side" Returns: X: (bs, 24) """ # compute filterbank fb = barkscale_fbanks((fft_size // 2) + 1, f_min, f_max, n_bands, sample_rate) fb = fb.unsqueeze(0).type_as(x) fb = fb.permute(0, 2, 1) if mode == "mono": x = x.mean(dim=1) # average over channels signals = [x] elif mode == "stereo": signals = [x[:, 0, :], x[:, 1, :]] elif mode == "mid-side": x_mid = x[:, 0, :] + x[:, 1, :] x_side = x[:, 0, :] - x[:, 1, :] signals = [x_mid, x_side] else: raise ValueError(f"Invalid mode {mode}") outputs = [] for signal in signals: X = torch.stft( signal, n_fft=fft_size, hop_length=fft_size // 4, return_complex=True, window=torch.hann_window(fft_size).to(x.device), ) # compute stft X = torch.abs(X) # take magnitude X = torch.mean(X, dim=-1, keepdim=True) # take mean over time # X = X.permute(0, 2, 1) # swap time and freq dims X = torch.matmul(fb, X) # apply filterbank X = torch.log(X + 1e-8) # X = torch.cat([X, X_log], dim=-1) outputs.append(X) # stack into tensor X = torch.cat(outputs, dim=-1) return X def compute_rms(x: torch.Tensor, **kwargs): """Compute root mean square energy. Args: x: (bs, 1, seq_len) Returns: rms: (bs, ) """ rms = torch.sqrt(torch.mean(x**2, dim=-1).clamp(min=1e-8)) return rms def compute_crest_factor(x: torch.Tensor, **kwargs): """Compute crest factor as ratio of peak to rms energy in dB. Args: x: (bs, 2, seq_len) """ num = torch.max(torch.abs(x), dim=-1)[0] den = compute_rms(x).clamp(min=1e-8) cf = 20 * torch.log10((num / den).clamp(min=1e-8)) return cf def compute_stereo_width(x: torch.Tensor, **kwargs): """Compute stereo width as ratio of energy in sum and difference signals. Args: x: (bs, 2, seq_len) """ bs, chs, seq_len = x.size() assert chs == 2, "Input must be stereo" # compute sum and diff of stereo channels x_sum = x[:, 0, :] + x[:, 1, :] x_diff = x[:, 0, :] - x[:, 1, :] # compute power of sum and diff sum_energy = torch.mean(x_sum**2, dim=-1) diff_energy = torch.mean(x_diff**2, dim=-1) # compute stereo width as ratio stereo_width = diff_energy / sum_energy.clamp(min=1e-8) return stereo_width def compute_stereo_imbalance(x: torch.Tensor, **kwargs): """Compute stereo imbalance as ratio of energy in left and right channels. Args: x: (bs, 2, seq_len) Returns: stereo_imbalance: (bs, ) """ left_energy = torch.mean(x[:, 0, :] ** 2, dim=-1) right_energy = torch.mean(x[:, 1, :] ** 2, dim=-1) stereo_imbalance = (right_energy - left_energy) / ( right_energy + left_energy ).clamp(min=1e-8) return stereo_imbalance class AudioFeatureLoss(torch.nn.Module): def __init__( self, weights: List[float], sample_rate: int, stem_separation: bool = False, use_clap: bool = False, ) -> None: """Compute loss using a set of differentiable audio features. Args: weights: weights for each feature sample_rate: sample rate of audio stem_separation: whether to compute loss on stems or mix Based on features proposed in: Man, B. D., et al. "An analysis and evaluation of audio features for multitrack music mixtures." (2014). """ super().__init__() self.weights = weights self.sample_rate = sample_rate self.stem_separation = stem_separation self.sources_list = ["mix"] self.source_weights = [1.0] self.use_clap = use_clap self.transforms = [ compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, compute_barkspectrum, ] assert len(self.transforms) == len(weights) def forward(self, input: torch.Tensor, target: torch.Tensor): losses = {} # reshape for example stem dim input_stems = input.unsqueeze(1) target_stems = target.unsqueeze(1) n_stems = input_stems.shape[1] # iterate over each stem compute loss for each transform for stem_idx in range(n_stems): input_stem = input_stems[:, stem_idx, ...] target_stem = target_stems[:, stem_idx, ...] for transform, weight in zip(self.transforms, self.weights): transform_name = "_".join(transform.__name__.split("_")[1:]) key = f"{self.sources_list[stem_idx]}-{transform_name}" input_transform = transform(input_stem, sample_rate=self.sample_rate) target_transform = transform(target_stem, sample_rate=self.sample_rate) val = torch.nn.functional.mse_loss(input_transform, target_transform) losses[key] = weight * val * self.source_weights[stem_idx] return losses