Spaces:
Build error
Build error
import torch | |
import auraloss | |
import resampy | |
import torchaudio | |
from pesq import pesq | |
import pyloudnorm as pyln | |
def crest_factor(x): | |
"""Compute the crest factor of waveform.""" | |
peak, _ = x.abs().max(dim=-1) | |
rms = torch.sqrt((x ** 2).mean(dim=-1)) | |
return 20 * torch.log(peak / rms.clamp(1e-8)) | |
def rms_energy(x): | |
rms = torch.sqrt((x ** 2).mean(dim=-1)) | |
return 20 * torch.log(rms.clamp(1e-8)) | |
def spectral_centroid(x): | |
"""Compute the crest factor of waveform. | |
See: https://gist.github.com/endolith/359724 | |
""" | |
spectrum = torch.fft.rfft(x).abs() | |
normalized_spectrum = spectrum / spectrum.sum() | |
normalized_frequencies = torch.linspace(0, 1, spectrum.shape[-1]) | |
spectral_centroid = torch.sum(normalized_frequencies * normalized_spectrum) | |
return spectral_centroid | |
def loudness(x, sample_rate): | |
"""Compute the loudness in dB LUFS of waveform.""" | |
meter = pyln.Meter(sample_rate) | |
# add stereo dim if needed | |
if x.shape[0] < 2: | |
x = x.repeat(2, 1) | |
return torch.tensor(meter.integrated_loudness(x.permute(1, 0).numpy())) | |
class MelSpectralDistance(torch.nn.Module): | |
def __init__(self, sample_rate, length=65536): | |
super().__init__() | |
self.error = auraloss.freq.MelSTFTLoss( | |
sample_rate, | |
fft_size=length, | |
hop_size=length, | |
win_length=length, | |
w_sc=0, | |
w_log_mag=1, | |
w_lin_mag=1, | |
n_mels=128, | |
scale_invariance=False, | |
) | |
# I think scale invariance may not work well, | |
# since aspects of the phase may be considered? | |
def forward(self, input, target): | |
return self.error(input, target) | |
class PESQ(torch.nn.Module): | |
def __init__(self, sample_rate): | |
super().__init__() | |
self.sample_rate = sample_rate | |
def forward(self, input, target): | |
if self.sample_rate != 16000: | |
target = resampy.resample( | |
target.view(-1).numpy(), | |
self.sample_rate, | |
16000, | |
) | |
input = resampy.resample( | |
input.view(-1).numpy(), | |
self.sample_rate, | |
16000, | |
) | |
return pesq( | |
16000, | |
target, | |
input, | |
"wb", | |
) | |
class CrestFactorError(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input, target): | |
return torch.nn.functional.l1_loss( | |
crest_factor(input), | |
crest_factor(target), | |
).item() | |
class RMSEnergyError(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input, target): | |
return torch.nn.functional.l1_loss( | |
rms_energy(input), | |
rms_energy(target), | |
).item() | |
class SpectralCentroidError(torch.nn.Module): | |
def __init__(self, sample_rate, n_fft=2048, hop_length=512): | |
super().__init__() | |
self.spectral_centroid = torchaudio.transforms.SpectralCentroid( | |
sample_rate, | |
n_fft=n_fft, | |
hop_length=hop_length, | |
) | |
def forward(self, input, target): | |
return torch.nn.functional.l1_loss( | |
self.spectral_centroid(input + 1e-16).mean(), | |
self.spectral_centroid(target + 1e-16).mean(), | |
).item() | |
class LoudnessError(torch.nn.Module): | |
def __init__(self, sample_rate: int, peak_normalize: bool = False): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.peak_normalize = peak_normalize | |
def forward(self, input, target): | |
if self.peak_normalize: | |
# peak normalize | |
x = input / input.abs().max() | |
y = target / target.abs().max() | |
else: | |
x = input | |
y = target | |
return torch.nn.functional.l1_loss( | |
loudness(x.view(1, -1), self.sample_rate), | |
loudness(y.view(1, -1), self.sample_rate), | |
).item() | |