wsntxxn
Update requirements.txt
def3c02
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio import transforms
from utils.model_util import mean_with_lens, max_with_lens
from utils.train_util import merge_load_state_dict
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')
return x
class ConvBlock5x5(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock5x5, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(5, 5), stride=(1, 1),
padding=(2, 2), bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_bn(self.bn1)
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')
return x
class Cnn6Encoder(nn.Module):
def __init__(self, sample_rate=32000, freeze=False):
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
self.downsample_ratio = 16
self.fc1 = nn.Linear(512, 512, bias=True)
self.fc_emb_size = 512
self.init_weight()
self.freeze = freeze
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
def load_pretrained(self, pretrained, output_fn):
checkpoint = torch.load(pretrained, map_location="cpu")
if "model" in checkpoint:
state_dict = checkpoint["model"]
else:
raise Exception("Unkown checkpoint format")
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
if self.freeze:
for name, param in self.named_parameters():
if name in loaded_keys:
param.requires_grad = False
else:
param.requires_grad = True
def forward(self, input_dict):
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": feat_length
}
class Cnn10Encoder(nn.Module):
def __init__(self, sample_rate=32000, freeze=False):
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.downsample_ratio = 16
self.fc1 = nn.Linear(512, 512, bias=True)
self.fc_emb_size = 512
self.init_weight()
self.freeze = freeze
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
def load_pretrained(self, pretrained, output_fn):
checkpoint = torch.load(pretrained, map_location="cpu")
if "model" in checkpoint:
state_dict = checkpoint["model"]
else:
raise Exception("Unkown checkpoint format")
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
if self.freeze:
for name, param in self.named_parameters():
if name in loaded_keys:
param.requires_grad = False
else:
param.requires_grad = True
def forward(self, input_dict):
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
return {
"attn_emb": attn_emb,
"fc_emb": fc_emb,
"attn_emb_len": feat_length
}
class Cnn14Encoder(nn.Module):
def __init__(self, sample_rate=32000, freeze=False):
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.downsample_ratio = 32
self.fc1 = nn.Linear(2048, 2048, bias=True)
self.fc_emb_size = 2048
self.init_weight()
self.freeze = freeze
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
def load_pretrained(self, pretrained, output_fn):
checkpoint = torch.load(pretrained, map_location="cpu")
if "model" in checkpoint:
state_keys = checkpoint["model"].keys()
backbone = False
for key in state_keys:
if key.startswith("backbone."):
backbone = True
break
if backbone: # COLA
state_dict = {}
for key, value in checkpoint["model"].items():
if key.startswith("backbone."):
model_key = key.replace("backbone.", "")
state_dict[model_key] = value
else: # PANNs
state_dict = checkpoint["model"]
elif "state_dict" in checkpoint: # BLAT
state_dict = checkpoint["state_dict"]
state_dict_keys = list(filter(
lambda x: "audio_encoder" in x, state_dict.keys()))
state_dict = {
key.replace('audio_encoder.', ''): state_dict[key]
for key in state_dict_keys
}
else:
raise Exception("Unkown checkpoint format")
loaded_keys = merge_load_state_dict(state_dict, self, output_fn)
if self.freeze:
for name, param in self.named_parameters():
if name in loaded_keys:
param.requires_grad = False
else:
param.requires_grad = True
def forward(self, input_dict):
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super().__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expand_ratio == 1:
_layers = [
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
nn.AvgPool2d(stride),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup)
]
_layers = nn.Sequential(*_layers)
init_layer(_layers[0])
init_bn(_layers[2])
init_layer(_layers[4])
init_bn(_layers[5])
self.conv = _layers
else:
_layers = [
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False),
nn.AvgPool2d(stride),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup)
]
_layers = nn.Sequential(*_layers)
init_layer(_layers[0])
init_bn(_layers[1])
init_layer(_layers[3])
init_bn(_layers[5])
init_layer(_layers[7])
init_bn(_layers[8])
self.conv = _layers
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, sample_rate):
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=32 * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
self.bn0 = nn.BatchNorm2d(64)
width_mult=1.
block = InvertedResidual
input_channel = 32
last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 2],
[6, 160, 3, 1],
[6, 320, 1, 1],
]
self.downsample_ratio = 32
def conv_bn(inp, oup, stride):
_layers = [
nn.Conv2d(inp, oup, 3, 1, 1, bias=False),
nn.AvgPool2d(stride),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
]
_layers = nn.Sequential(*_layers)
init_layer(_layers[0])
init_bn(_layers[2])
return _layers
def conv_1x1_bn(inp, oup):
_layers = nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
init_layer(_layers[0])
init_bn(_layers[1])
return _layers
# building first layer
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
self.features = [conv_bn(1, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
if i == 0:
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
else:
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
input_channel = output_channel
# building last several layers
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it nn.Sequential
self.features = nn.Sequential(*self.features)
self.fc1 = nn.Linear(1280, 1024, bias=True)
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
def forward(self, input_dict):
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.features(x)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
# TODO: the original PANNs code does not have dropout here, why?
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
class MobileNetV3(nn.Module):
def __init__(self,
sample_rate,
model_name,
n_mels=64,
win_length=32,
pretrained=True,
freeze=False,
pooling="mean_max_fc"):
from captioning.models.eff_at_encoder import get_model, NAME_TO_WIDTH
super().__init__()
sr_to_fmax = {
32000: 14000,
16000: 8000
}
self.n_mels = n_mels
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=32 * sample_rate // 1000,
win_length=win_length * sample_rate // 1000,
hop_length=10 * sample_rate // 1000,
f_min=50,
f_max=sr_to_fmax[sample_rate],
n_mels=n_mels,
norm="slaney",
mel_scale="slaney"
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB()
self.bn0 = nn.BatchNorm2d(n_mels)
width_mult = NAME_TO_WIDTH(model_name)
self.features = get_model(model_name=model_name,
pretrained=pretrained,
width_mult=width_mult).features
self.downsample_ratio = 32
if pooling == "mean_max_fc":
self.fc_emb_size = 512
self.fc1 = nn.Linear(self.features[-1].out_channels, 512, bias=True)
elif pooling == "mean":
self.fc_emb_size = self.features[-1].out_channels
self.init_weight()
if freeze:
for param in self.parameters():
param.requires_grad = False
self.pooling = pooling
def init_weight(self):
init_bn(self.bn0)
if hasattr(self, "fc1"):
init_layer(self.fc1)
def forward(self, input_dict):
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.features(x)
x = torch.mean(x, dim=3)
attn_emb = x.transpose(1, 2)
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
if self.pooling == "mean_max_fc":
x_max = max_with_lens(attn_emb, feat_length)
x_mean = mean_with_lens(attn_emb, feat_length)
x = x_max + x_mean
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
fc_emb = F.dropout(x, p=0.5, training=self.training)
elif self.pooling == "mean":
fc_emb = mean_with_lens(attn_emb, feat_length)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
class EfficientNetB2(nn.Module):
def __init__(self,
n_mels: int = 64,
win_length: int = 32,
hop_length: int = 10,
f_min: int = 0,
pretrained: bool = False,
prune_ratio: float = 0.0,
prune_se: bool = True,
prune_start_layer: int = 0,
prune_method: str = "operator_norm",
freeze: bool = False,):
from models.eff_latent_encoder import get_model, get_pruned_model
super().__init__()
sample_rate = 16000
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=win_length * sample_rate // 1000,
win_length=win_length * sample_rate // 1000,
hop_length=hop_length * sample_rate // 1000,
f_min=f_min,
n_mels=n_mels,
)
self.hop_length = 10 * sample_rate // 1000
self.db_transform = transforms.AmplitudeToDB(top_db=120)
if prune_ratio > 0:
self.backbone = get_pruned_model(pretrained=pretrained,
prune_ratio=prune_ratio,
prune_start_layer=prune_start_layer,
prune_se=prune_se,
prune_method=prune_method)
else:
self.backbone = get_model(pretrained=pretrained)
self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels
self.downsample_ratio = 32
if freeze:
for param in self.parameters():
param.requires_grad = False
def forward(self, input_dict):
waveform = input_dict["wav"]
wave_length = input_dict["wav_len"]
specaug = input_dict["specaug"]
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = self.backbone(x)
attn_emb = x
wave_length = torch.as_tensor(wave_length)
feat_length = torch.div(wave_length, self.hop_length,
rounding_mode="floor") + 1
feat_length = torch.div(feat_length, self.downsample_ratio,
rounding_mode="floor")
fc_emb = mean_with_lens(attn_emb, feat_length)
output_dict = {
'fc_emb': fc_emb,
'attn_emb': attn_emb,
'attn_emb_len': feat_length
}
return output_dict
if __name__ == "__main__":
encoder = MobileNetV3(32000, "mn10_as")
print(encoder)
input_dict = {
"wav": torch.randn(4, 320000),
"wav_len": torch.tensor([320000, 280000, 160000, 300000]),
"specaug": True
}
output_dict = encoder(input_dict)
print("attn embed: ", output_dict["attn_emb"].shape)
print("fc embed: ", output_dict["fc_emb"].shape)
print("attn embed length: ", output_dict["attn_emb_len"])