seungheondoh
add model
e48ca55
raw
history blame
3.3 kB
### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
import os
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Dict, Iterable, Optional
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 1024
N_MELS = 128
HOP_LENGTH = int(0.01 * SAMPLE_RATE)
DURATION = 10
N_SAMPLES = int(DURATION * SAMPLE_RATE)
N_FRAMES = N_SAMPLES // HOP_LENGTH + 1
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MelEncoder(nn.Module):
"""
time-frequency represntation
"""
def __init__(self,
sample_rate= 16000,
f_min=0,
f_max=8000,
n_fft=1024,
win_length=1024,
hop_length = int(0.01 * 16000),
n_mels = 128,
power = None,
pad= 0,
normalized= False,
center= True,
pad_mode= "reflect"
):
super(MelEncoder, self).__init__()
self.window = torch.hann_window(win_length)
self.spec_fn = torchaudio.transforms.Spectrogram(
n_fft = n_fft,
win_length = win_length,
hop_length = hop_length,
power = power
)
self.mel_scale = torchaudio.transforms.MelScale(
n_mels,
sample_rate,
f_min,
f_max,
n_fft // 2 + 1)
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
def forward(self, wav):
spec = self.spec_fn(wav)
power_spec = spec.real.abs().pow(2)
mel_spec = self.mel_scale(power_spec)
mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
return mel_spec
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
):
super().__init__()
self.mel_encoder = MelEncoder(n_mels=n_mels)
self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
self.conv_stack = nn.ModuleList([])
for _ in range(num_of_stride_conv):
self.conv_stack.append(
nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
)
# self.proj = nn.Linear(audio_dim, text_dim, bias=False)
self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, waveform)
single channel wavform
"""
x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
x = F.gelu(self.conv1(x))
for conv in self.conv_stack:
x = F.gelu(conv(x))
x = x.permute(0, 2, 1)
x = (x + self.positional_embedding).to(x.dtype)
return x