# -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F from torchaudio import transforms from utils.model_util import mean_with_lens, max_with_lens from utils.train_util import merge_load_state_dict def init_layer(layer): """Initialize a Linear or Convolutional layer. """ nn.init.xavier_uniform_(layer.weight) if hasattr(layer, 'bias'): if layer.bias is not None: layer.bias.data.fill_(0.) def init_bn(bn): """Initialize a Batchnorm layer. """ bn.bias.data.fill_(0.) bn.weight.data.fill_(1.) class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.bn2 = nn.BatchNorm2d(out_channels) self.init_weight() def init_weight(self): init_layer(self.conv1) init_layer(self.conv2) init_bn(self.bn1) init_bn(self.bn2) def forward(self, input, pool_size=(2, 2), pool_type='avg'): x = input x = F.relu_(self.bn1(self.conv1(x))) x = F.relu_(self.bn2(self.conv2(x))) if pool_type == 'max': x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg': x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg+max': x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 else: raise Exception('Incorrect argument!') return x class ConvBlock5x5(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock5x5, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.init_weight() def init_weight(self): init_layer(self.conv1) init_bn(self.bn1) def forward(self, input, pool_size=(2, 2), pool_type='avg'): x = input x = F.relu_(self.bn1(self.conv1(x))) if pool_type == 'max': x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg': x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg+max': x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 else: raise Exception('Incorrect argument!') return x class Cnn6Encoder(nn.Module): def __init__(self, sample_rate=32000, freeze=False): super().__init__() sr_to_fmax = { 32000: 14000, 16000: 8000 } # Logmel spectrogram extractor self.melspec_extractor = transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=32 * sample_rate // 1000, win_length=32 * sample_rate // 1000, hop_length=10 * sample_rate // 1000, f_min=50, f_max=sr_to_fmax[sample_rate], n_mels=64, norm="slaney", mel_scale="slaney" ) self.hop_length = 10 * sample_rate // 1000 self.db_transform = transforms.AmplitudeToDB() self.bn0 = nn.BatchNorm2d(64) self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) self.downsample_ratio = 16 self.fc1 = nn.Linear(512, 512, bias=True) self.fc_emb_size = 512 self.init_weight() self.freeze = freeze def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) def load_pretrained(self, pretrained, output_fn): checkpoint = torch.load(pretrained, map_location="cpu") if "model" in checkpoint: state_dict = checkpoint["model"] else: raise Exception("Unkown checkpoint format") loaded_keys = merge_load_state_dict(state_dict, self, output_fn) if self.freeze: for name, param in self.named_parameters(): if name in loaded_keys: param.requires_grad = False else: param.requires_grad = True def forward(self, input_dict): waveform = input_dict["wav"] wave_length = input_dict["wav_len"] specaug = input_dict["specaug"] x = self.melspec_extractor(waveform) x = self.db_transform(x) # (batch_size, mel_bins, time_steps) x = x.transpose(1, 2) x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=3) attn_emb = x.transpose(1, 2) wave_length = torch.as_tensor(wave_length) feat_length = torch.div(wave_length, self.hop_length, rounding_mode="floor") + 1 feat_length = torch.div(feat_length, self.downsample_ratio, rounding_mode="floor") x_max = max_with_lens(attn_emb, feat_length) x_mean = mean_with_lens(attn_emb, feat_length) x = x_max + x_mean x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) fc_emb = F.dropout(x, p=0.5, training=self.training) return { "attn_emb": attn_emb, "fc_emb": fc_emb, "attn_emb_len": feat_length } class Cnn10Encoder(nn.Module): def __init__(self, sample_rate=32000, freeze=False): super().__init__() sr_to_fmax = { 32000: 14000, 16000: 8000 } # Logmel spectrogram extractor self.melspec_extractor = transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=32 * sample_rate // 1000, win_length=32 * sample_rate // 1000, hop_length=10 * sample_rate // 1000, f_min=50, f_max=sr_to_fmax[sample_rate], n_mels=64, norm="slaney", mel_scale="slaney" ) self.hop_length = 10 * sample_rate // 1000 self.db_transform = transforms.AmplitudeToDB() self.bn0 = nn.BatchNorm2d(64) self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.downsample_ratio = 16 self.fc1 = nn.Linear(512, 512, bias=True) self.fc_emb_size = 512 self.init_weight() self.freeze = freeze def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) def load_pretrained(self, pretrained, output_fn): checkpoint = torch.load(pretrained, map_location="cpu") if "model" in checkpoint: state_dict = checkpoint["model"] else: raise Exception("Unkown checkpoint format") loaded_keys = merge_load_state_dict(state_dict, self, output_fn) if self.freeze: for name, param in self.named_parameters(): if name in loaded_keys: param.requires_grad = False else: param.requires_grad = True def forward(self, input_dict): waveform = input_dict["wav"] wave_length = input_dict["wav_len"] specaug = input_dict["specaug"] x = self.melspec_extractor(waveform) x = self.db_transform(x) # (batch_size, mel_bins, time_steps) x = x.transpose(1, 2) x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=3) attn_emb = x.transpose(1, 2) wave_length = torch.as_tensor(wave_length) feat_length = torch.div(wave_length, self.hop_length, rounding_mode="floor") + 1 feat_length = torch.div(feat_length, self.downsample_ratio, rounding_mode="floor") x_max = max_with_lens(attn_emb, feat_length) x_mean = mean_with_lens(attn_emb, feat_length) x = x_max + x_mean x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) fc_emb = F.dropout(x, p=0.5, training=self.training) return { "attn_emb": attn_emb, "fc_emb": fc_emb, "attn_emb_len": feat_length } class Cnn14Encoder(nn.Module): def __init__(self, sample_rate=32000, freeze=False): super().__init__() sr_to_fmax = { 32000: 14000, 16000: 8000 } # Logmel spectrogram extractor self.melspec_extractor = transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=32 * sample_rate // 1000, win_length=32 * sample_rate // 1000, hop_length=10 * sample_rate // 1000, f_min=50, f_max=sr_to_fmax[sample_rate], n_mels=64, norm="slaney", mel_scale="slaney" ) self.hop_length = 10 * sample_rate // 1000 self.db_transform = transforms.AmplitudeToDB() self.bn0 = nn.BatchNorm2d(64) self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) self.downsample_ratio = 32 self.fc1 = nn.Linear(2048, 2048, bias=True) self.fc_emb_size = 2048 self.init_weight() self.freeze = freeze def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) def load_pretrained(self, pretrained, output_fn): checkpoint = torch.load(pretrained, map_location="cpu") if "model" in checkpoint: state_keys = checkpoint["model"].keys() backbone = False for key in state_keys: if key.startswith("backbone."): backbone = True break if backbone: # COLA state_dict = {} for key, value in checkpoint["model"].items(): if key.startswith("backbone."): model_key = key.replace("backbone.", "") state_dict[model_key] = value else: # PANNs state_dict = checkpoint["model"] elif "state_dict" in checkpoint: # BLAT state_dict = checkpoint["state_dict"] state_dict_keys = list(filter( lambda x: "audio_encoder" in x, state_dict.keys())) state_dict = { key.replace('audio_encoder.', ''): state_dict[key] for key in state_dict_keys } else: raise Exception("Unkown checkpoint format") loaded_keys = merge_load_state_dict(state_dict, self, output_fn) if self.freeze: for name, param in self.named_parameters(): if name in loaded_keys: param.requires_grad = False else: param.requires_grad = True def forward(self, input_dict): waveform = input_dict["wav"] wave_length = input_dict["wav_len"] specaug = input_dict["specaug"] x = self.melspec_extractor(waveform) x = self.db_transform(x) # (batch_size, mel_bins, time_steps) x = x.transpose(1, 2) x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=3) attn_emb = x.transpose(1, 2) wave_length = torch.as_tensor(wave_length) feat_length = torch.div(wave_length, self.hop_length, rounding_mode="floor") + 1 feat_length = torch.div(feat_length, self.downsample_ratio, rounding_mode="floor") x_max = max_with_lens(attn_emb, feat_length) x_mean = mean_with_lens(attn_emb, feat_length) x = x_max + x_mean x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) fc_emb = F.dropout(x, p=0.5, training=self.training) output_dict = { 'fc_emb': fc_emb, 'attn_emb': attn_emb, 'attn_emb_len': feat_length } return output_dict class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio): super().__init__() self.stride = stride assert stride in [1, 2] hidden_dim = round(inp * expand_ratio) self.use_res_connect = self.stride == 1 and inp == oup if expand_ratio == 1: _layers = [ nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False), nn.AvgPool2d(stride), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup) ] _layers = nn.Sequential(*_layers) init_layer(_layers[0]) init_bn(_layers[2]) init_layer(_layers[4]) init_bn(_layers[5]) self.conv = _layers else: _layers = [ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1, groups=hidden_dim, bias=False), nn.AvgPool2d(stride), nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup) ] _layers = nn.Sequential(*_layers) init_layer(_layers[0]) init_bn(_layers[1]) init_layer(_layers[3]) init_bn(_layers[5]) init_layer(_layers[7]) init_bn(_layers[8]) self.conv = _layers def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileNetV2(nn.Module): def __init__(self, sample_rate): super().__init__() sr_to_fmax = { 32000: 14000, 16000: 8000 } # Logmel spectrogram extractor self.melspec_extractor = transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=32 * sample_rate // 1000, win_length=32 * sample_rate // 1000, hop_length=10 * sample_rate // 1000, f_min=50, f_max=sr_to_fmax[sample_rate], n_mels=64, norm="slaney", mel_scale="slaney" ) self.hop_length = 10 * sample_rate // 1000 self.db_transform = transforms.AmplitudeToDB() self.bn0 = nn.BatchNorm2d(64) width_mult=1. block = InvertedResidual input_channel = 32 last_channel = 1280 interverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 2], [6, 160, 3, 1], [6, 320, 1, 1], ] self.downsample_ratio = 32 def conv_bn(inp, oup, stride): _layers = [ nn.Conv2d(inp, oup, 3, 1, 1, bias=False), nn.AvgPool2d(stride), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True) ] _layers = nn.Sequential(*_layers) init_layer(_layers[0]) init_bn(_layers[2]) return _layers def conv_1x1_bn(inp, oup): _layers = nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True) ) init_layer(_layers[0]) init_bn(_layers[1]) return _layers # building first layer input_channel = int(input_channel * width_mult) self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel self.features = [conv_bn(1, input_channel, 2)] # building inverted residual blocks for t, c, n, s in interverted_residual_setting: output_channel = int(c * width_mult) for i in range(n): if i == 0: self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) else: self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) input_channel = output_channel # building last several layers self.features.append(conv_1x1_bn(input_channel, self.last_channel)) # make it nn.Sequential self.features = nn.Sequential(*self.features) self.fc1 = nn.Linear(1280, 1024, bias=True) self.init_weight() def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) def forward(self, input_dict): waveform = input_dict["wav"] wave_length = input_dict["wav_len"] specaug = input_dict["specaug"] x = self.melspec_extractor(waveform) x = self.db_transform(x) # (batch_size, mel_bins, time_steps) x = x.transpose(1, 2) x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) x = self.features(x) x = torch.mean(x, dim=3) attn_emb = x.transpose(1, 2) wave_length = torch.as_tensor(wave_length) feat_length = torch.div(wave_length, self.hop_length, rounding_mode="floor") + 1 feat_length = torch.div(feat_length, self.downsample_ratio, rounding_mode="floor") x_max = max_with_lens(attn_emb, feat_length) x_mean = mean_with_lens(attn_emb, feat_length) x = x_max + x_mean # TODO: the original PANNs code does not have dropout here, why? x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) fc_emb = F.dropout(x, p=0.5, training=self.training) output_dict = { 'fc_emb': fc_emb, 'attn_emb': attn_emb, 'attn_emb_len': feat_length } return output_dict class MobileNetV3(nn.Module): def __init__(self, sample_rate, model_name, n_mels=64, win_length=32, pretrained=True, freeze=False, pooling="mean_max_fc"): from captioning.models.eff_at_encoder import get_model, NAME_TO_WIDTH super().__init__() sr_to_fmax = { 32000: 14000, 16000: 8000 } self.n_mels = n_mels # Logmel spectrogram extractor self.melspec_extractor = transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=32 * sample_rate // 1000, win_length=win_length * sample_rate // 1000, hop_length=10 * sample_rate // 1000, f_min=50, f_max=sr_to_fmax[sample_rate], n_mels=n_mels, norm="slaney", mel_scale="slaney" ) self.hop_length = 10 * sample_rate // 1000 self.db_transform = transforms.AmplitudeToDB() self.bn0 = nn.BatchNorm2d(n_mels) width_mult = NAME_TO_WIDTH(model_name) self.features = get_model(model_name=model_name, pretrained=pretrained, width_mult=width_mult).features self.downsample_ratio = 32 if pooling == "mean_max_fc": self.fc_emb_size = 512 self.fc1 = nn.Linear(self.features[-1].out_channels, 512, bias=True) elif pooling == "mean": self.fc_emb_size = self.features[-1].out_channels self.init_weight() if freeze: for param in self.parameters(): param.requires_grad = False self.pooling = pooling def init_weight(self): init_bn(self.bn0) if hasattr(self, "fc1"): init_layer(self.fc1) def forward(self, input_dict): waveform = input_dict["wav"] wave_length = input_dict["wav_len"] specaug = input_dict["specaug"] x = self.melspec_extractor(waveform) x = self.db_transform(x) # (batch_size, mel_bins, time_steps) x = x.transpose(1, 2) x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) x = self.features(x) x = torch.mean(x, dim=3) attn_emb = x.transpose(1, 2) wave_length = torch.as_tensor(wave_length) feat_length = torch.div(wave_length, self.hop_length, rounding_mode="floor") + 1 feat_length = torch.div(feat_length, self.downsample_ratio, rounding_mode="floor") if self.pooling == "mean_max_fc": x_max = max_with_lens(attn_emb, feat_length) x_mean = mean_with_lens(attn_emb, feat_length) x = x_max + x_mean x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) fc_emb = F.dropout(x, p=0.5, training=self.training) elif self.pooling == "mean": fc_emb = mean_with_lens(attn_emb, feat_length) output_dict = { 'fc_emb': fc_emb, 'attn_emb': attn_emb, 'attn_emb_len': feat_length } return output_dict class EfficientNetB2(nn.Module): def __init__(self, n_mels: int = 64, win_length: int = 32, hop_length: int = 10, f_min: int = 0, pretrained: bool = False, prune_ratio: float = 0.0, prune_se: bool = True, prune_start_layer: int = 0, prune_method: str = "operator_norm", freeze: bool = False,): from models.eff_latent_encoder import get_model, get_pruned_model super().__init__() sample_rate = 16000 self.melspec_extractor = transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=win_length * sample_rate // 1000, win_length=win_length * sample_rate // 1000, hop_length=hop_length * sample_rate // 1000, f_min=f_min, n_mels=n_mels, ) self.hop_length = 10 * sample_rate // 1000 self.db_transform = transforms.AmplitudeToDB(top_db=120) if prune_ratio > 0: self.backbone = get_pruned_model(pretrained=pretrained, prune_ratio=prune_ratio, prune_start_layer=prune_start_layer, prune_se=prune_se, prune_method=prune_method) else: self.backbone = get_model(pretrained=pretrained) self.fc_emb_size = self.backbone.eff_net._conv_head.out_channels self.downsample_ratio = 32 if freeze: for param in self.parameters(): param.requires_grad = False def forward(self, input_dict): waveform = input_dict["wav"] wave_length = input_dict["wav_len"] specaug = input_dict["specaug"] x = self.melspec_extractor(waveform) x = self.db_transform(x) # (batch_size, mel_bins, time_steps) x = self.backbone(x) attn_emb = x wave_length = torch.as_tensor(wave_length) feat_length = torch.div(wave_length, self.hop_length, rounding_mode="floor") + 1 feat_length = torch.div(feat_length, self.downsample_ratio, rounding_mode="floor") fc_emb = mean_with_lens(attn_emb, feat_length) output_dict = { 'fc_emb': fc_emb, 'attn_emb': attn_emb, 'attn_emb_len': feat_length } return output_dict if __name__ == "__main__": encoder = MobileNetV3(32000, "mn10_as") print(encoder) input_dict = { "wav": torch.randn(4, 320000), "wav_len": torch.tensor([320000, 280000, 160000, 300000]), "specaug": True } output_dict = encoder(input_dict) print("attn embed: ", output_dict["attn_emb"].shape) print("fc embed: ", output_dict["fc_emb"].shape) print("attn embed length: ", output_dict["attn_emb_len"])