Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torch import nn | |
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram | |
from .hparams import HParams | |
class MelSpectrogram(nn.Module): | |
def __init__(self, hp: HParams): | |
""" | |
Torch implementation of Resemble's mel extraction. | |
Note that the values are NOT identical to librosa's implementation | |
due to floating point precisions. | |
""" | |
super().__init__() | |
self.hp = hp | |
self.melspec = TorchMelSpectrogram( | |
hp.wav_rate, | |
n_fft=hp.n_fft, | |
win_length=hp.win_size, | |
hop_length=hp.hop_size, | |
f_min=0, | |
f_max=hp.wav_rate // 2, | |
n_mels=hp.num_mels, | |
power=1, | |
normalized=False, | |
# NOTE: Folowing librosa's default. | |
pad_mode="constant", | |
norm="slaney", | |
mel_scale="slaney", | |
) | |
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min])) | |
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min) | |
self.preemphasis = hp.preemphasis | |
self.hop_size = hp.hop_size | |
def forward(self, wav, pad=True): | |
""" | |
Args: | |
wav: [B, T] | |
""" | |
device = wav.device | |
if wav.is_mps: | |
wav = wav.cpu() | |
self.to(wav.device) | |
if self.preemphasis > 0: | |
wav = torch.nn.functional.pad(wav, [1, 0], value=0) | |
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1] | |
mel = self.melspec(wav) | |
mel = self._amp_to_db(mel) | |
mel_normed = self._normalize(mel) | |
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check | |
mel_normed = mel_normed.to(device) | |
return mel_normed # (M, T) | |
def _normalize(self, s, headroom_db=15): | |
return (s - self.min_level_db) / (-self.min_level_db + headroom_db) | |
def _amp_to_db(self, x): | |
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20 | |