music_tagging / models /modules.py
cchaun's picture
add project files
0d6426a
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchaudio
import sys
from torch.autograd import Variable
import math
import librosa
class Conv_1d(nn.Module):
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
super(Conv_1d, self).__init__()
self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
self.bn = nn.BatchNorm1d(output_channels)
self.relu = nn.ReLU()
self.mp = nn.MaxPool1d(pooling)
def forward(self, x):
out = self.mp(self.relu(self.bn(self.conv(x))))
return out
class Conv_2d(nn.Module):
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
super(Conv_2d, self).__init__()
self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
self.bn = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU()
self.mp = nn.MaxPool2d(pooling)
def forward(self, x):
out = self.mp(self.relu(self.bn(self.conv(x))))
return out
class Res_2d(nn.Module):
def __init__(self, input_channels, output_channels, shape=3, stride=2):
super(Res_2d, self).__init__()
# convolution
self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
self.bn_1 = nn.BatchNorm2d(output_channels)
self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
self.bn_2 = nn.BatchNorm2d(output_channels)
# residual
self.diff = False
if (stride != 1) or (input_channels != output_channels):
self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
self.bn_3 = nn.BatchNorm2d(output_channels)
self.diff = True
self.relu = nn.ReLU()
def forward(self, x):
# convolution
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
# residual
if self.diff:
x = self.bn_3(self.conv_3(x))
out = x + out
out = self.relu(out)
return out
class Res_2d_mp(nn.Module):
def __init__(self, input_channels, output_channels, pooling=2):
super(Res_2d_mp, self).__init__()
self.conv_1 = nn.Conv2d(input_channels, output_channels, 3, padding=1)
self.bn_1 = nn.BatchNorm2d(output_channels)
self.conv_2 = nn.Conv2d(output_channels, output_channels, 3, padding=1)
self.bn_2 = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU()
self.mp = nn.MaxPool2d(pooling)
def forward(self, x):
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
out = x + out
out = self.mp(self.relu(out))
return out
class ResSE_1d(nn.Module):
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=3):
super(ResSE_1d, self).__init__()
# convolution
self.conv_1 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
self.bn_1 = nn.BatchNorm1d(output_channels)
self.conv_2 = nn.Conv1d(output_channels, output_channels, shape, padding=shape//2)
self.bn_2 = nn.BatchNorm1d(output_channels)
# squeeze & excitation
self.dense1 = nn.Linear(output_channels, output_channels)
self.dense2 = nn.Linear(output_channels, output_channels)
# residual
self.diff = False
if (stride != 1) or (input_channels != output_channels):
self.conv_3 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
self.bn_3 = nn.BatchNorm1d(output_channels)
self.diff = True
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.mp = nn.MaxPool1d(pooling)
def forward(self, x):
# convolution
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
# squeeze & excitation
se_out = nn.AvgPool1d(out.size(-1))(out)
se_out = se_out.squeeze(-1)
se_out = self.relu(self.dense1(se_out))
se_out = self.sigmoid(self.dense2(se_out))
se_out = se_out.unsqueeze(-1)
out = torch.mul(out, se_out)
# residual
if self.diff:
x = self.bn_3(self.conv_3(x))
out = x + out
out = self.mp(self.relu(out))
return out
class Conv_V(nn.Module):
# vertical convolution
def __init__(self, input_channels, output_channels, filter_shape):
super(Conv_V, self).__init__()
self.conv = nn.Conv2d(input_channels, output_channels, filter_shape,
padding=(0, filter_shape[1]//2))
self.bn = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.bn(self.conv(x)))
freq = x.size(2)
out = nn.MaxPool2d((freq, 1), stride=(freq, 1))(x)
out = out.squeeze(2)
return out
class Conv_H(nn.Module):
# horizontal convolution
def __init__(self, input_channels, output_channels, filter_length):
super(Conv_H, self).__init__()
self.conv = nn.Conv1d(input_channels, output_channels, filter_length,
padding=filter_length//2)
self.bn = nn.BatchNorm1d(output_channels)
self.relu = nn.ReLU()
def forward(self, x):
freq = x.size(2)
out = nn.AvgPool2d((freq, 1), stride=(freq, 1))(x)
out = out.squeeze(2)
out = self.relu(self.bn(self.conv(out)))
return out
# Modules for harmonic filters
def hz_to_midi(hz):
return 12 * (torch.log2(hz) - np.log2(440.0)) + 69
def midi_to_hz(midi):
return 440.0 * (2.0 ** ((midi - 69.0)/12.0))
def note_to_midi(note):
return librosa.core.note_to_midi(note)
def hz_to_note(hz):
return librosa.core.hz_to_note(hz)
def initialize_filterbank(sample_rate, n_harmonic, semitone_scale):
# MIDI
# lowest note
low_midi = note_to_midi('C1')
# highest note
high_note = hz_to_note(sample_rate / (2 * n_harmonic))
high_midi = note_to_midi(high_note)
# number of scales
level = (high_midi - low_midi) * semitone_scale
midi = np.linspace(low_midi, high_midi, level + 1)
hz = midi_to_hz(midi[:-1])
# stack harmonics
harmonic_hz = []
for i in range(n_harmonic):
harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1)))
return harmonic_hz, level
class HarmonicSTFT(nn.Module):
def __init__(self,
sample_rate=16000,
n_fft=513,
win_length=None,
hop_length=None,
pad=0,
power=2,
normalized=False,
n_harmonic=6,
semitone_scale=2,
bw_Q=1.0,
learn_bw=None):
super(HarmonicSTFT, self).__init__()
# Parameters
self.sample_rate = sample_rate
self.n_harmonic = n_harmonic
self.bw_alpha = 0.1079
self.bw_beta = 24.7
# Spectrogram
self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length,
hop_length=None, pad=0,
window_fn=torch.hann_window,
power=power, normalized=normalized, wkwargs=None)
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
# Initialize the filterbank. Equally spaced in MIDI scale.
harmonic_hz, self.level = initialize_filterbank(sample_rate, n_harmonic, semitone_scale)
# Center frequncies to tensor
self.f0 = torch.tensor(harmonic_hz.astype('float32'))
# Bandwidth parameters
if learn_bw == 'only_Q':
self.bw_Q = nn.Parameter(torch.tensor(np.array([bw_Q]).astype('float32')))
elif learn_bw == 'fix':
self.bw_Q = torch.tensor(np.array([bw_Q]).astype('float32'))
def get_harmonic_fb(self):
# bandwidth
bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q
bw = bw.unsqueeze(0) # (1, n_band)
f0 = self.f0.unsqueeze(0) # (1, n_band)
fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1)
up_slope = torch.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw)
down_slope = torch.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw)
fb = torch.max(self.zero, torch.min(down_slope, up_slope))
return fb
def to_device(self, device, n_bins):
self.f0 = self.f0.to(device)
self.bw_Q = self.bw_Q.to(device)
# fft bins
self.fft_bins = torch.linspace(0, self.sample_rate//2, n_bins)
self.fft_bins = self.fft_bins.to(device)
self.zero = torch.zeros(1)
self.zero = self.zero.to(device)
def forward(self, waveform):
# stft
spectrogram = self.spec(waveform)
# to device
self.to_device(waveform.device, spectrogram.size(1))
# triangle filter
harmonic_fb = self.get_harmonic_fb()
harmonic_spec = torch.matmul(spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2)
# (batch, channel, length) -> (batch, harmonic, f0, length)
b, c, l = harmonic_spec.size()
harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l)
# amplitude to db
harmonic_spec = self.amplitude_to_db(harmonic_spec)
return harmonic_spec