Spaces:
Running
on
Zero
Running
on
Zero
import typing | |
from typing import List | |
import numpy as np | |
from torch import nn | |
from .. import AudioSignal | |
from .. import STFTParams | |
class MultiScaleSTFTLoss(nn.Module): | |
"""Computes the multi-scale STFT loss from [1]. | |
Parameters | |
---------- | |
window_lengths : List[int], optional | |
Length of each window of each STFT, by default [2048, 512] | |
loss_fn : typing.Callable, optional | |
How to compare each loss, by default nn.L1Loss() | |
clamp_eps : float, optional | |
Clamp on the log magnitude, below, by default 1e-5 | |
mag_weight : float, optional | |
Weight of raw magnitude portion of loss, by default 1.0 | |
log_weight : float, optional | |
Weight of log magnitude portion of loss, by default 1.0 | |
pow : float, optional | |
Power to raise magnitude to before taking log, by default 2.0 | |
weight : float, optional | |
Weight of this loss, by default 1.0 | |
match_stride : bool, optional | |
Whether to match the stride of convolutional layers, by default False | |
References | |
---------- | |
1. Engel, Jesse, Chenjie Gu, and Adam Roberts. | |
"DDSP: Differentiable Digital Signal Processing." | |
International Conference on Learning Representations. 2019. | |
""" | |
def __init__( | |
self, | |
window_lengths: List[int] = [2048, 512], | |
loss_fn: typing.Callable = nn.L1Loss(), | |
clamp_eps: float = 1e-5, | |
mag_weight: float = 1.0, | |
log_weight: float = 1.0, | |
pow: float = 2.0, | |
weight: float = 1.0, | |
match_stride: bool = False, | |
window_type: str = None, | |
): | |
super().__init__() | |
self.stft_params = [ | |
STFTParams( | |
window_length=w, | |
hop_length=w // 4, | |
match_stride=match_stride, | |
window_type=window_type, | |
) | |
for w in window_lengths | |
] | |
self.loss_fn = loss_fn | |
self.log_weight = log_weight | |
self.mag_weight = mag_weight | |
self.clamp_eps = clamp_eps | |
self.weight = weight | |
self.pow = pow | |
def forward(self, x: AudioSignal, y: AudioSignal): | |
"""Computes multi-scale STFT between an estimate and a reference | |
signal. | |
Parameters | |
---------- | |
x : AudioSignal | |
Estimate signal | |
y : AudioSignal | |
Reference signal | |
Returns | |
------- | |
torch.Tensor | |
Multi-scale STFT loss. | |
""" | |
loss = 0.0 | |
for s in self.stft_params: | |
x.stft(s.window_length, s.hop_length, s.window_type) | |
y.stft(s.window_length, s.hop_length, s.window_type) | |
loss += self.log_weight * self.loss_fn( | |
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), | |
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), | |
) | |
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) | |
return loss | |
class MelSpectrogramLoss(nn.Module): | |
"""Compute distance between mel spectrograms. Can be used | |
in a multi-scale way. | |
Parameters | |
---------- | |
n_mels : List[int] | |
Number of mels per STFT, by default [150, 80], | |
window_lengths : List[int], optional | |
Length of each window of each STFT, by default [2048, 512] | |
loss_fn : typing.Callable, optional | |
How to compare each loss, by default nn.L1Loss() | |
clamp_eps : float, optional | |
Clamp on the log magnitude, below, by default 1e-5 | |
mag_weight : float, optional | |
Weight of raw magnitude portion of loss, by default 1.0 | |
log_weight : float, optional | |
Weight of log magnitude portion of loss, by default 1.0 | |
pow : float, optional | |
Power to raise magnitude to before taking log, by default 2.0 | |
weight : float, optional | |
Weight of this loss, by default 1.0 | |
match_stride : bool, optional | |
Whether to match the stride of convolutional layers, by default False | |
""" | |
def __init__( | |
self, | |
n_mels: List[int] = [150, 80], | |
window_lengths: List[int] = [2048, 512], | |
loss_fn: typing.Callable = nn.L1Loss(), | |
clamp_eps: float = 1e-5, | |
mag_weight: float = 1.0, | |
log_weight: float = 1.0, | |
pow: float = 2.0, | |
weight: float = 1.0, | |
match_stride: bool = False, | |
mel_fmin: List[float] = [0.0, 0.0], | |
mel_fmax: List[float] = [None, None], | |
window_type: str = None, | |
): | |
super().__init__() | |
self.stft_params = [ | |
STFTParams( | |
window_length=w, | |
hop_length=w // 4, | |
match_stride=match_stride, | |
window_type=window_type, | |
) | |
for w in window_lengths | |
] | |
self.n_mels = n_mels | |
self.loss_fn = loss_fn | |
self.clamp_eps = clamp_eps | |
self.log_weight = log_weight | |
self.mag_weight = mag_weight | |
self.weight = weight | |
self.mel_fmin = mel_fmin | |
self.mel_fmax = mel_fmax | |
self.pow = pow | |
def forward(self, x: AudioSignal, y: AudioSignal): | |
"""Computes mel loss between an estimate and a reference | |
signal. | |
Parameters | |
---------- | |
x : AudioSignal | |
Estimate signal | |
y : AudioSignal | |
Reference signal | |
Returns | |
------- | |
torch.Tensor | |
Mel loss. | |
""" | |
loss = 0.0 | |
for n_mels, fmin, fmax, s in zip( | |
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params | |
): | |
kwargs = { | |
"window_length": s.window_length, | |
"hop_length": s.hop_length, | |
"window_type": s.window_type, | |
} | |
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) | |
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) | |
loss += self.log_weight * self.loss_fn( | |
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), | |
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), | |
) | |
loss += self.mag_weight * self.loss_fn(x_mels, y_mels) | |
return loss | |
class PhaseLoss(nn.Module): | |
"""Difference between phase spectrograms. | |
Parameters | |
---------- | |
window_length : int, optional | |
Length of STFT window, by default 2048 | |
hop_length : int, optional | |
Hop length of STFT window, by default 512 | |
weight : float, optional | |
Weight of loss, by default 1.0 | |
""" | |
def __init__( | |
self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0 | |
): | |
super().__init__() | |
self.weight = weight | |
self.stft_params = STFTParams(window_length, hop_length) | |
def forward(self, x: AudioSignal, y: AudioSignal): | |
"""Computes phase loss between an estimate and a reference | |
signal. | |
Parameters | |
---------- | |
x : AudioSignal | |
Estimate signal | |
y : AudioSignal | |
Reference signal | |
Returns | |
------- | |
torch.Tensor | |
Phase loss. | |
""" | |
s = self.stft_params | |
x.stft(s.window_length, s.hop_length, s.window_type) | |
y.stft(s.window_length, s.hop_length, s.window_type) | |
# Take circular difference | |
diff = x.phase - y.phase | |
diff[diff < -np.pi] += 2 * np.pi | |
diff[diff > np.pi] -= -2 * np.pi | |
# Scale true magnitude to weights in [0, 1] | |
x_min, x_max = x.magnitude.min(), x.magnitude.max() | |
weights = (x.magnitude - x_min) / (x_max - x_min) | |
# Take weighted mean of all phase errors | |
loss = ((weights * diff) ** 2).mean() | |
return loss | |