File size: 5,906 Bytes
45916af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import librosa
import pytorch_lightning as pl
import torch
from auraloss.freq import STFTLoss, MultiResolutionSTFTLoss, apply_reduction, SpectralConvergenceLoss, STFTMagnitudeLoss
from config import CONFIG
class STFTLossDDP(STFTLoss):
def __init__(self,
fft_size=1024,
hop_size=256,
win_length=1024,
window="hann_window",
w_sc=1.0,
w_log_mag=1.0,
w_lin_mag=0.0,
w_phs=0.0,
sample_rate=None,
scale=None,
n_bins=None,
scale_invariance=False,
eps=1e-8,
output="loss",
reduction="mean",
device=None):
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.w_sc = w_sc
self.w_log_mag = w_log_mag
self.w_lin_mag = w_lin_mag
self.w_phs = w_phs
self.sample_rate = sample_rate
self.scale = scale
self.n_bins = n_bins
self.scale_invariance = scale_invariance
self.eps = eps
self.output = output
self.reduction = reduction
self.device = device
self.spectralconv = SpectralConvergenceLoss()
self.logstft = STFTMagnitudeLoss(log=True, reduction=reduction)
self.linstft = STFTMagnitudeLoss(log=False, reduction=reduction)
# setup mel filterbank
if self.scale == "mel":
assert (sample_rate is not None) # Must set sample rate to use mel scale
assert (n_bins <= fft_size) # Must be more FFT bins than Mel bins
fb = librosa.filters.mel(sample_rate, fft_size, n_mels=n_bins)
self.fb = torch.tensor(fb).unsqueeze(0)
elif self.scale == "chroma":
assert (sample_rate is not None) # Must set sample rate to use chroma scale
assert (n_bins <= fft_size) # Must be more FFT bins than chroma bins
fb = librosa.filters.chroma(sample_rate, fft_size, n_chroma=n_bins)
self.fb = torch.tensor(fb).unsqueeze(0)
if scale is not None and device is not None:
self.fb = self.fb.to(self.device) # move filterbank to device
def compressed_loss(self, x, y, alpha=None):
self.window = self.window.to(x.device)
x_mag, x_phs = self.stft(x.view(-1, x.size(-1)))
y_mag, y_phs = self.stft(y.view(-1, y.size(-1)))
if alpha is not None:
x_mag = x_mag ** alpha
y_mag = y_mag ** alpha
# apply relevant transforms
if self.scale is not None:
x_mag = torch.matmul(self.fb.to(x_mag.device), x_mag)
y_mag = torch.matmul(self.fb.to(y_mag.device), y_mag)
# normalize scales
if self.scale_invariance:
alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag ** 2).sum([-2, -1]))
y_mag = y_mag * alpha.unsqueeze(-1)
# compute loss terms
sc_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
lin_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
# combine loss terms
loss = (self.w_sc * sc_loss) + (self.w_log_mag * mag_loss) + (self.w_lin_mag * lin_loss)
loss = apply_reduction(loss, reduction=self.reduction)
return loss
def forward(self, x, y):
return self.compressed_loss(x, y, 0.3)
class MRSTFTLossDDP(MultiResolutionSTFTLoss):
def __init__(self,
fft_sizes=(1024, 2048, 512),
hop_sizes=(120, 240, 50),
win_lengths=(600, 1200, 240),
window="hann_window",
w_sc=1.0,
w_log_mag=1.0,
w_lin_mag=0.0,
w_phs=0.0,
sample_rate=None,
scale=None,
n_bins=None,
scale_invariance=False,
**kwargs):
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLossDDP(fs,
ss,
wl,
window,
w_sc,
w_log_mag,
w_lin_mag,
w_phs,
sample_rate,
scale,
n_bins,
scale_invariance,
**kwargs)]
class Loss(pl.LightningModule):
def __init__(self):
super(Loss, self).__init__()
self.stft_loss = MRSTFTLossDDP(sample_rate=CONFIG.DATA.sr, device="cpu", w_log_mag=0.0, w_lin_mag=1.0)
self.window = torch.sqrt(torch.hann_window(CONFIG.DATA.window_size))
def forward(self, x, y):
x = x.permute(0, 2, 3, 1)
y = y.permute(0, 2, 3, 1)
wave_x = torch.istft(torch.view_as_complex(x.contiguous()), CONFIG.DATA.window_size, CONFIG.DATA.stride,
window=self.window.to(x.device))
wave_y = torch.istft(torch.view_as_complex(y.contiguous()), CONFIG.DATA.window_size, CONFIG.DATA.stride,
window=self.window.to(y.device))
loss = self.stft_loss(wave_x, wave_y)
return loss
|