Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class PreEmphasis(torch.nn.Module): | |
def __init__(self, coef: float = 0.97) -> None: | |
super().__init__() | |
self.coef = coef | |
# make kernel | |
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. | |
self.register_buffer( | |
"flipped_filter", | |
torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0), | |
) | |
def forward(self, input: torch.tensor) -> torch.tensor: | |
assert ( | |
len(input.size()) == 2 | |
), "The number of dimensions of input tensor must be 2!" | |
# reflect padding to match lengths of in/out | |
input = input.unsqueeze(1) | |
input = F.pad(input, (1, 0), "reflect") | |
return F.conv1d(input, self.flipped_filter) | |
class AFMS(nn.Module): | |
""" | |
Alpha-Feature map scaling, added to the output of each residual block[1,2]. | |
Reference: | |
[1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf | |
[2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page | |
""" | |
def __init__(self, nb_dim: int) -> None: | |
super().__init__() | |
self.alpha = nn.Parameter(torch.ones((nb_dim, 1))) | |
self.fc = nn.Linear(nb_dim, nb_dim) | |
self.sig = nn.Sigmoid() | |
def forward(self, x): | |
y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1) | |
y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1) | |
x = x + self.alpha | |
x = x * y | |
return x | |
class Bottle2neck(nn.Module): | |
def __init__( | |
self, | |
inplanes, | |
planes, | |
kernel_size=None, | |
dilation=None, | |
scale=4, | |
pool=False, | |
): | |
super().__init__() | |
width = int(math.floor(planes / scale)) | |
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) | |
self.bn1 = nn.BatchNorm1d(width * scale) | |
self.nums = scale - 1 | |
convs = [] | |
bns = [] | |
num_pad = math.floor(kernel_size / 2) * dilation | |
for i in range(self.nums): | |
convs.append( | |
nn.Conv1d( | |
width, | |
width, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
padding=num_pad, | |
) | |
) | |
bns.append(nn.BatchNorm1d(width)) | |
self.convs = nn.ModuleList(convs) | |
self.bns = nn.ModuleList(bns) | |
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) | |
self.bn3 = nn.BatchNorm1d(planes) | |
self.relu = nn.ReLU() | |
self.width = width | |
self.mp = nn.MaxPool1d(pool) if pool else False | |
self.afms = AFMS(planes) | |
if inplanes != planes: # if change in number of filters | |
self.residual = nn.Sequential( | |
nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False) | |
) | |
else: | |
self.residual = nn.Identity() | |
def forward(self, x): | |
residual = self.residual(x) | |
out = self.conv1(x) | |
out = self.relu(out) | |
out = self.bn1(out) | |
spx = torch.split(out, self.width, 1) | |
for i in range(self.nums): | |
if i == 0: | |
sp = spx[i] | |
else: | |
sp = sp + spx[i] | |
sp = self.convs[i](sp) | |
sp = self.relu(sp) | |
sp = self.bns[i](sp) | |
if i == 0: | |
out = sp | |
else: | |
out = torch.cat((out, sp), 1) | |
out = torch.cat((out, spx[self.nums]), 1) | |
out = self.conv3(out) | |
out = self.relu(out) | |
out = self.bn3(out) | |
out += residual | |
if self.mp: | |
out = self.mp(out) | |
out = self.afms(out) | |
return out | |