Spaces:
Build error
Build error
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torchaudio | |
import sys | |
from torch.autograd import Variable | |
import math | |
import librosa | |
class Conv_1d(nn.Module): | |
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): | |
super(Conv_1d, self).__init__() | |
self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn = nn.BatchNorm1d(output_channels) | |
self.relu = nn.ReLU() | |
self.mp = nn.MaxPool1d(pooling) | |
def forward(self, x): | |
out = self.mp(self.relu(self.bn(self.conv(x)))) | |
return out | |
class Conv_2d(nn.Module): | |
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): | |
super(Conv_2d, self).__init__() | |
self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn = nn.BatchNorm2d(output_channels) | |
self.relu = nn.ReLU() | |
self.mp = nn.MaxPool2d(pooling) | |
def forward(self, x): | |
out = self.mp(self.relu(self.bn(self.conv(x)))) | |
return out | |
class Res_2d(nn.Module): | |
def __init__(self, input_channels, output_channels, shape=3, stride=2): | |
super(Res_2d, self).__init__() | |
# convolution | |
self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn_1 = nn.BatchNorm2d(output_channels) | |
self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2) | |
self.bn_2 = nn.BatchNorm2d(output_channels) | |
# residual | |
self.diff = False | |
if (stride != 1) or (input_channels != output_channels): | |
self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn_3 = nn.BatchNorm2d(output_channels) | |
self.diff = True | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
# convolution | |
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) | |
# residual | |
if self.diff: | |
x = self.bn_3(self.conv_3(x)) | |
out = x + out | |
out = self.relu(out) | |
return out | |
class Res_2d_mp(nn.Module): | |
def __init__(self, input_channels, output_channels, pooling=2): | |
super(Res_2d_mp, self).__init__() | |
self.conv_1 = nn.Conv2d(input_channels, output_channels, 3, padding=1) | |
self.bn_1 = nn.BatchNorm2d(output_channels) | |
self.conv_2 = nn.Conv2d(output_channels, output_channels, 3, padding=1) | |
self.bn_2 = nn.BatchNorm2d(output_channels) | |
self.relu = nn.ReLU() | |
self.mp = nn.MaxPool2d(pooling) | |
def forward(self, x): | |
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) | |
out = x + out | |
out = self.mp(self.relu(out)) | |
return out | |
class ResSE_1d(nn.Module): | |
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=3): | |
super(ResSE_1d, self).__init__() | |
# convolution | |
self.conv_1 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn_1 = nn.BatchNorm1d(output_channels) | |
self.conv_2 = nn.Conv1d(output_channels, output_channels, shape, padding=shape//2) | |
self.bn_2 = nn.BatchNorm1d(output_channels) | |
# squeeze & excitation | |
self.dense1 = nn.Linear(output_channels, output_channels) | |
self.dense2 = nn.Linear(output_channels, output_channels) | |
# residual | |
self.diff = False | |
if (stride != 1) or (input_channels != output_channels): | |
self.conv_3 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) | |
self.bn_3 = nn.BatchNorm1d(output_channels) | |
self.diff = True | |
self.relu = nn.ReLU() | |
self.sigmoid = nn.Sigmoid() | |
self.mp = nn.MaxPool1d(pooling) | |
def forward(self, x): | |
# convolution | |
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) | |
# squeeze & excitation | |
se_out = nn.AvgPool1d(out.size(-1))(out) | |
se_out = se_out.squeeze(-1) | |
se_out = self.relu(self.dense1(se_out)) | |
se_out = self.sigmoid(self.dense2(se_out)) | |
se_out = se_out.unsqueeze(-1) | |
out = torch.mul(out, se_out) | |
# residual | |
if self.diff: | |
x = self.bn_3(self.conv_3(x)) | |
out = x + out | |
out = self.mp(self.relu(out)) | |
return out | |
class Conv_V(nn.Module): | |
# vertical convolution | |
def __init__(self, input_channels, output_channels, filter_shape): | |
super(Conv_V, self).__init__() | |
self.conv = nn.Conv2d(input_channels, output_channels, filter_shape, | |
padding=(0, filter_shape[1]//2)) | |
self.bn = nn.BatchNorm2d(output_channels) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
x = self.relu(self.bn(self.conv(x))) | |
freq = x.size(2) | |
out = nn.MaxPool2d((freq, 1), stride=(freq, 1))(x) | |
out = out.squeeze(2) | |
return out | |
class Conv_H(nn.Module): | |
# horizontal convolution | |
def __init__(self, input_channels, output_channels, filter_length): | |
super(Conv_H, self).__init__() | |
self.conv = nn.Conv1d(input_channels, output_channels, filter_length, | |
padding=filter_length//2) | |
self.bn = nn.BatchNorm1d(output_channels) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
freq = x.size(2) | |
out = nn.AvgPool2d((freq, 1), stride=(freq, 1))(x) | |
out = out.squeeze(2) | |
out = self.relu(self.bn(self.conv(out))) | |
return out | |
# Modules for harmonic filters | |
def hz_to_midi(hz): | |
return 12 * (torch.log2(hz) - np.log2(440.0)) + 69 | |
def midi_to_hz(midi): | |
return 440.0 * (2.0 ** ((midi - 69.0)/12.0)) | |
def note_to_midi(note): | |
return librosa.core.note_to_midi(note) | |
def hz_to_note(hz): | |
return librosa.core.hz_to_note(hz) | |
def initialize_filterbank(sample_rate, n_harmonic, semitone_scale): | |
# MIDI | |
# lowest note | |
low_midi = note_to_midi('C1') | |
# highest note | |
high_note = hz_to_note(sample_rate / (2 * n_harmonic)) | |
high_midi = note_to_midi(high_note) | |
# number of scales | |
level = (high_midi - low_midi) * semitone_scale | |
midi = np.linspace(low_midi, high_midi, level + 1) | |
hz = midi_to_hz(midi[:-1]) | |
# stack harmonics | |
harmonic_hz = [] | |
for i in range(n_harmonic): | |
harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1))) | |
return harmonic_hz, level | |
class HarmonicSTFT(nn.Module): | |
def __init__(self, | |
sample_rate=16000, | |
n_fft=513, | |
win_length=None, | |
hop_length=None, | |
pad=0, | |
power=2, | |
normalized=False, | |
n_harmonic=6, | |
semitone_scale=2, | |
bw_Q=1.0, | |
learn_bw=None): | |
super(HarmonicSTFT, self).__init__() | |
# Parameters | |
self.sample_rate = sample_rate | |
self.n_harmonic = n_harmonic | |
self.bw_alpha = 0.1079 | |
self.bw_beta = 24.7 | |
# Spectrogram | |
self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length, | |
hop_length=None, pad=0, | |
window_fn=torch.hann_window, | |
power=power, normalized=normalized, wkwargs=None) | |
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() | |
# Initialize the filterbank. Equally spaced in MIDI scale. | |
harmonic_hz, self.level = initialize_filterbank(sample_rate, n_harmonic, semitone_scale) | |
# Center frequncies to tensor | |
self.f0 = torch.tensor(harmonic_hz.astype('float32')) | |
# Bandwidth parameters | |
if learn_bw == 'only_Q': | |
self.bw_Q = nn.Parameter(torch.tensor(np.array([bw_Q]).astype('float32'))) | |
elif learn_bw == 'fix': | |
self.bw_Q = torch.tensor(np.array([bw_Q]).astype('float32')) | |
def get_harmonic_fb(self): | |
# bandwidth | |
bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q | |
bw = bw.unsqueeze(0) # (1, n_band) | |
f0 = self.f0.unsqueeze(0) # (1, n_band) | |
fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1) | |
up_slope = torch.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw) | |
down_slope = torch.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw) | |
fb = torch.max(self.zero, torch.min(down_slope, up_slope)) | |
return fb | |
def to_device(self, device, n_bins): | |
self.f0 = self.f0.to(device) | |
self.bw_Q = self.bw_Q.to(device) | |
# fft bins | |
self.fft_bins = torch.linspace(0, self.sample_rate//2, n_bins) | |
self.fft_bins = self.fft_bins.to(device) | |
self.zero = torch.zeros(1) | |
self.zero = self.zero.to(device) | |
def forward(self, waveform): | |
# stft | |
spectrogram = self.spec(waveform) | |
# to device | |
self.to_device(waveform.device, spectrogram.size(1)) | |
# triangle filter | |
harmonic_fb = self.get_harmonic_fb() | |
harmonic_spec = torch.matmul(spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2) | |
# (batch, channel, length) -> (batch, harmonic, f0, length) | |
b, c, l = harmonic_spec.size() | |
harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l) | |
# amplitude to db | |
harmonic_spec = self.amplitude_to_db(harmonic_spec) | |
return harmonic_spec | |