zhzluke96
update
d2b7e94
raw
history blame
2.1 kB
import numpy as np
import torch
from torch import nn
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
from .hparams import HParams
class MelSpectrogram(nn.Module):
def __init__(self, hp: HParams):
"""
Torch implementation of Resemble's mel extraction.
Note that the values are NOT identical to librosa's implementation
due to floating point precisions.
"""
super().__init__()
self.hp = hp
self.melspec = TorchMelSpectrogram(
hp.wav_rate,
n_fft=hp.n_fft,
win_length=hp.win_size,
hop_length=hp.hop_size,
f_min=0,
f_max=hp.wav_rate // 2,
n_mels=hp.num_mels,
power=1,
normalized=False,
# NOTE: Folowing librosa's default.
pad_mode="constant",
norm="slaney",
mel_scale="slaney",
)
self.register_buffer(
"stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min])
)
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
self.preemphasis = hp.preemphasis
self.hop_size = hp.hop_size
def forward(self, wav, pad=True):
"""
Args:
wav: [B, T]
"""
device = wav.device
if wav.is_mps:
wav = wav.cpu()
self.to(wav.device)
if self.preemphasis > 0:
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
mel = self.melspec(wav)
mel = self._amp_to_db(mel)
mel_normed = self._normalize(mel)
assert (
not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size
) # Sanity check
mel_normed = mel_normed.to(device)
return mel_normed # (M, T)
def _normalize(self, s, headroom_db=15):
return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
def _amp_to_db(self, x):
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20