# -*- coding: utf-8 -*- import math import copy import torch import torch.nn as nn import torch.nn.functional as F from torchaudio import transforms from torchlibrosa.augmentation import SpecAugmentation from .utils import mean_with_lens, max_with_lens, \ init, pack_wrapper, generate_length_mask, PositionalEncoding def init_layer(layer): """Initialize a Linear or Convolutional layer. """ nn.init.xavier_uniform_(layer.weight) if hasattr(layer, 'bias'): if layer.bias is not None: layer.bias.data.fill_(0.) def init_bn(bn): """Initialize a Batchnorm layer. """ bn.bias.data.fill_(0.) bn.weight.data.fill_(1.) class BaseEncoder(nn.Module): """ Encode the given audio into embedding Base encoder class, cannot be called directly All encoders should inherit from this class """ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim): super(BaseEncoder, self).__init__() self.spec_dim = spec_dim self.fc_feat_dim = fc_feat_dim self.attn_feat_dim = attn_feat_dim def forward(self, x): ######################### # an encoder first encodes audio feature into embedding, obtaining # `encoded`: { # fc_embs: [N, fc_emb_dim], # attn_embs: [N, attn_max_len, attn_emb_dim], # attn_emb_lens: [N,] # } ######################### raise NotImplementedError class Block2D(nn.Module): def __init__(self, cin, cout, kernel_size=3, padding=1): super().__init__() self.block = nn.Sequential( nn.BatchNorm2d(cin), nn.Conv2d(cin, cout, kernel_size=kernel_size, padding=padding, bias=False), nn.LeakyReLU(inplace=True, negative_slope=0.1)) def forward(self, x): return self.block(x) class LinearSoftPool(nn.Module): """LinearSoftPool Linear softmax, takes logits and returns a probability, near to the actual maximum value. Taken from the paper: A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling https://arxiv.org/abs/1810.09050 """ def __init__(self, pooldim=1): super().__init__() self.pooldim = pooldim def forward(self, logits, time_decision): return (time_decision**2).sum(self.pooldim) / time_decision.sum( self.pooldim) class MeanPool(nn.Module): def __init__(self, pooldim=1): super().__init__() self.pooldim = pooldim def forward(self, logits, decision): return torch.mean(decision, dim=self.pooldim) class AttentionPool(nn.Module): """docstring for AttentionPool""" def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): super().__init__() self.inputdim = inputdim self.outputdim = outputdim self.pooldim = pooldim self.transform = nn.Linear(inputdim, outputdim) self.activ = nn.Softmax(dim=self.pooldim) self.eps = 1e-7 def forward(self, logits, decision): # Input is (B, T, D) # B, T, D w = self.activ(torch.clamp(self.transform(logits), -15, 15)) detect = (decision * w).sum( self.pooldim) / (w.sum(self.pooldim) + self.eps) # B, T, D return detect class MMPool(nn.Module): def __init__(self, dims): super().__init__() self.avgpool = nn.AvgPool2d(dims) self.maxpool = nn.MaxPool2d(dims) def forward(self, x): return self.avgpool(x) + self.maxpool(x) def parse_poolingfunction(poolingfunction_name='mean', **kwargs): """parse_poolingfunction A heler function to parse any temporal pooling Pooling is done on dimension 1 :param poolingfunction_name: :param **kwargs: """ poolingfunction_name = poolingfunction_name.lower() if poolingfunction_name == 'mean': return MeanPool(pooldim=1) elif poolingfunction_name == 'linear': return LinearSoftPool(pooldim=1) elif poolingfunction_name == 'attention': return AttentionPool(inputdim=kwargs['inputdim'], outputdim=kwargs['outputdim']) def embedding_pooling(x, lens, pooling="mean"): if pooling == "max": fc_embs = max_with_lens(x, lens) elif pooling == "mean": fc_embs = mean_with_lens(x, lens) elif pooling == "mean+max": x_mean = mean_with_lens(x, lens) x_max = max_with_lens(x, lens) fc_embs = x_mean + x_max elif pooling == "last": indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1)) # indices: [N, 1, hidden] fc_embs = torch.gather(x, 1, indices).squeeze(1) else: raise Exception(f"pooling method {pooling} not support") return fc_embs class Cdur5Encoder(BaseEncoder): def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"): super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) self.pooling = pooling self.features = nn.Sequential( Block2D(1, 32), nn.LPPool2d(4, (2, 4)), Block2D(32, 128), Block2D(128, 128), nn.LPPool2d(4, (2, 4)), Block2D(128, 128), Block2D(128, 128), nn.LPPool2d(4, (1, 4)), nn.Dropout(0.3), ) with torch.no_grad(): rnn_input_dim = self.features( torch.randn(1, 1, 500, spec_dim)).shape rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] self.gru = nn.GRU(rnn_input_dim, 128, bidirectional=True, batch_first=True) self.apply(init) def forward(self, input_dict): x = input_dict["spec"] lens = input_dict["spec_len"] if "upsample" not in input_dict: input_dict["upsample"] = False lens = torch.as_tensor(copy.deepcopy(lens)) N, T, _ = x.shape x = x.unsqueeze(1) x = self.features(x) x = x.transpose(1, 2).contiguous().flatten(-2) x, _ = self.gru(x) if input_dict["upsample"]: x = nn.functional.interpolate( x.transpose(1, 2), T, mode='linear', align_corners=False).transpose(1, 2) else: lens //= 4 attn_emb = x fc_emb = embedding_pooling(x, lens, self.pooling) return { "attn_emb": attn_emb, "fc_emb": fc_emb, "attn_emb_len": lens } def conv_conv_block(in_channel, out_channel): return nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), nn.Conv2d(out_channel, out_channel, kernel_size=3, bias=False, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True) ) class Cdur8Encoder(BaseEncoder): def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"): super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) self.pooling = pooling self.features = nn.Sequential( conv_conv_block(1, 64), MMPool((2, 2)), nn.Dropout(0.2, True), conv_conv_block(64, 128), MMPool((2, 2)), nn.Dropout(0.2, True), conv_conv_block(128, 256), MMPool((1, 2)), nn.Dropout(0.2, True), conv_conv_block(256, 512), MMPool((1, 2)), nn.Dropout(0.2, True), nn.AdaptiveAvgPool2d((None, 1)), ) self.init_bn = nn.BatchNorm2d(spec_dim) self.embedding = nn.Linear(512, 512) self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True) self.apply(init) def forward(self, input_dict): x = input_dict["spec"] lens = input_dict["spec_len"] lens = torch.as_tensor(copy.deepcopy(lens)) x = x.unsqueeze(1) # B x 1 x T x D x = x.transpose(1, 3) x = self.init_bn(x) x = x.transpose(1, 3) x = self.features(x) x = x.transpose(1, 2).contiguous().flatten(-2) x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.embedding(x)) x, _ = self.gru(x) attn_emb = x lens //= 4 fc_emb = embedding_pooling(x, lens, self.pooling) return { "attn_emb": attn_emb, "fc_emb": fc_emb, "attn_emb_len": lens } class Cnn10Encoder(BaseEncoder): def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim): super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) self.features = nn.Sequential( conv_conv_block(1, 64), nn.AvgPool2d((2, 2)), nn.Dropout(0.2, True), conv_conv_block(64, 128), nn.AvgPool2d((2, 2)), nn.Dropout(0.2, True), conv_conv_block(128, 256), nn.AvgPool2d((2, 2)), nn.Dropout(0.2, True), conv_conv_block(256, 512), nn.AvgPool2d((2, 2)), nn.Dropout(0.2, True), nn.AdaptiveAvgPool2d((None, 1)), ) self.init_bn = nn.BatchNorm2d(spec_dim) self.embedding = nn.Linear(512, 512) self.apply(init) def forward(self, input_dict): x = input_dict["spec"] lens = input_dict["spec_len"] lens = torch.as_tensor(copy.deepcopy(lens)) x = x.unsqueeze(1) # [N, 1, T, D] x = x.transpose(1, 3) x = self.init_bn(x) x = x.transpose(1, 3) x = self.features(x) # [N, 512, T/16, 1] x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512] attn_emb = x lens //= 16 fc_emb = embedding_pooling(x, lens, "mean+max") fc_emb = F.dropout(fc_emb, p=0.5, training=self.training) fc_emb = self.embedding(fc_emb) fc_emb = F.relu_(fc_emb) return { "attn_emb": attn_emb, "fc_emb": fc_emb, "attn_emb_len": lens } class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.bn2 = nn.BatchNorm2d(out_channels) self.init_weight() def init_weight(self): init_layer(self.conv1) init_layer(self.conv2) init_bn(self.bn1) init_bn(self.bn2) def forward(self, input, pool_size=(2, 2), pool_type='avg'): x = input x = F.relu_(self.bn1(self.conv1(x))) x = F.relu_(self.bn2(self.conv2(x))) if pool_type == 'max': x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg': x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg+max': x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 else: raise Exception('Incorrect argument!') return x class Cnn14Encoder(nn.Module): def __init__(self, sample_rate=32000): super().__init__() sr_to_fmax = { 32000: 14000, 16000: 8000 } # Logmel spectrogram extractor self.melspec_extractor = transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=32 * sample_rate // 1000, win_length=32 * sample_rate // 1000, hop_length=10 * sample_rate // 1000, f_min=50, f_max=sr_to_fmax[sample_rate], n_mels=64, norm="slaney", mel_scale="slaney" ) self.hop_length = 10 * sample_rate // 1000 self.db_transform = transforms.AmplitudeToDB() # Spec augmenter self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2) self.bn0 = nn.BatchNorm2d(64) self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) self.downsample_ratio = 32 self.fc1 = nn.Linear(2048, 2048, bias=True) self.init_weight() def init_weight(self): init_bn(self.bn0) init_layer(self.fc1) def load_pretrained(self, pretrained): checkpoint = torch.load(pretrained, map_location="cpu") if "model" in checkpoint: state_keys = checkpoint["model"].keys() backbone = False for key in state_keys: if key.startswith("backbone."): backbone = True break if backbone: # COLA state_dict = {} for key, value in checkpoint["model"].items(): if key.startswith("backbone."): model_key = key.replace("backbone.", "") state_dict[model_key] = value else: # PANNs state_dict = checkpoint["model"] elif "state_dict" in checkpoint: # CLAP state_dict = checkpoint["state_dict"] state_dict_keys = list(filter( lambda x: "audio_encoder" in x, state_dict.keys())) state_dict = { key.replace('audio_encoder.', ''): state_dict[key] for key in state_dict_keys } else: raise Exception("Unkown checkpoint format") model_dict = self.state_dict() pretrained_dict = { k: v for k, v in state_dict.items() if (k in model_dict) and ( model_dict[k].shape == v.shape) } model_dict.update(pretrained_dict) self.load_state_dict(model_dict, strict=True) def forward(self, input_dict): """ Input: (batch_size, n_samples)""" waveform = input_dict["wav"] wave_length = input_dict["wav_len"] specaug = input_dict["specaug"] x = self.melspec_extractor(waveform) x = self.db_transform(x) # (batch_size, mel_bins, time_steps) x = x.transpose(1, 2) x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins) # SpecAugment if self.training and specaug: x = self.spec_augmenter(x) x = x.transpose(1, 3) x = self.bn0(x) x = x.transpose(1, 3) x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = torch.mean(x, dim=3) attn_emb = x.transpose(1, 2) wave_length = torch.as_tensor(wave_length) feat_length = torch.div(wave_length, self.hop_length, rounding_mode="floor") + 1 feat_length = torch.div(feat_length, self.downsample_ratio, rounding_mode="floor") x_max = max_with_lens(attn_emb, feat_length) x_mean = mean_with_lens(attn_emb, feat_length) x = x_max + x_mean x = F.dropout(x, p=0.5, training=self.training) x = F.relu_(self.fc1(x)) fc_emb = F.dropout(x, p=0.5, training=self.training) output_dict = { 'fc_emb': fc_emb, 'attn_emb': attn_emb, 'attn_emb_len': feat_length } return output_dict class RnnEncoder(BaseEncoder): def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean", **kwargs): super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) self.pooling = pooling self.hidden_size = kwargs.get('hidden_size', 512) self.bidirectional = kwargs.get('bidirectional', False) self.num_layers = kwargs.get('num_layers', 1) self.dropout = kwargs.get('dropout', 0.2) self.rnn_type = kwargs.get('rnn_type', "GRU") self.in_bn = kwargs.get('in_bn', False) self.embed_dim = self.hidden_size * (self.bidirectional + 1) self.network = getattr(nn, self.rnn_type)( attn_feat_dim, self.hidden_size, num_layers=self.num_layers, bidirectional=self.bidirectional, dropout=self.dropout, batch_first=True) if self.in_bn: self.bn = nn.BatchNorm1d(self.embed_dim) self.apply(init) def forward(self, input_dict): x = input_dict["attn"] lens = input_dict["attn_len"] lens = torch.as_tensor(lens) # x: [N, T, E] if self.in_bn: x = pack_wrapper(self.bn, x, lens) out = pack_wrapper(self.network, x, lens) # out: [N, T, hidden] attn_emb = out fc_emb = embedding_pooling(out, lens, self.pooling) return { "attn_emb": attn_emb, "fc_emb": fc_emb, "attn_emb_len": lens } class Cnn14RnnEncoder(nn.Module): def __init__(self, sample_rate=32000, pretrained=None, freeze_cnn=False, freeze_cnn_bn=False, pooling="mean", **kwargs): super().__init__() self.cnn = Cnn14Encoder(sample_rate) self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs) if pretrained is not None: self.cnn.load_pretrained(pretrained) if freeze_cnn: assert pretrained is not None, "cnn is not pretrained but frozen" for param in self.cnn.parameters(): param.requires_grad = False self.freeze_cnn_bn = freeze_cnn_bn def train(self, mode): super().train(mode=mode) if self.freeze_cnn_bn: def bn_eval(module): class_name = module.__class__.__name__ if class_name.find("BatchNorm") != -1: module.eval() self.cnn.apply(bn_eval) return self def forward(self, input_dict): output_dict = self.cnn(input_dict) output_dict["attn"] = output_dict["attn_emb"] output_dict["attn_len"] = output_dict["attn_emb_len"] del output_dict["attn_emb"], output_dict["attn_emb_len"] output_dict = self.rnn(output_dict) return output_dict class TransformerEncoder(BaseEncoder): def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs): super().__init__(spec_dim, fc_feat_dim, attn_feat_dim) self.d_model = d_model dropout = kwargs.get("dropout", 0.2) self.nhead = kwargs.get("nhead", self.d_model // 64) self.nlayers = kwargs.get("nlayers", 2) self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4) self.attn_proj = nn.Sequential( nn.Linear(attn_feat_dim, self.d_model), nn.ReLU(), nn.Dropout(dropout), nn.LayerNorm(self.d_model) ) layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.nhead, dim_feedforward=self.dim_feedforward, dropout=dropout) self.model = nn.TransformerEncoder(layer, self.nlayers) self.cls_token = nn.Parameter(torch.zeros(d_model)) self.init_params() def init_params(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, input_dict): attn_feat = input_dict["attn"] attn_feat_len = input_dict["attn_len"] attn_feat_len = torch.as_tensor(attn_feat_len) attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model] cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat( attn_feat.size(0), 1, 1) attn_feat = torch.cat((cls_emb, attn_feat), dim=1) attn_feat = attn_feat.transpose(0, 1) attn_feat_len += 1 src_key_padding_mask = ~generate_length_mask( attn_feat_len, attn_feat.size(0)).to(attn_feat.device) output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask) attn_emb = output.transpose(0, 1) fc_emb = attn_emb[:, 0] return { "attn_emb": attn_emb, "fc_emb": fc_emb, "attn_emb_len": attn_feat_len } class Cnn14TransformerEncoder(nn.Module): def __init__(self, sample_rate=32000, pretrained=None, freeze_cnn=False, freeze_cnn_bn=False, d_model="mean", **kwargs): super().__init__() self.cnn = Cnn14Encoder(sample_rate) self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs) if pretrained is not None: self.cnn.load_pretrained(pretrained) if freeze_cnn: assert pretrained is not None, "cnn is not pretrained but frozen" for param in self.cnn.parameters(): param.requires_grad = False self.freeze_cnn_bn = freeze_cnn_bn def train(self, mode): super().train(mode=mode) if self.freeze_cnn_bn: def bn_eval(module): class_name = module.__class__.__name__ if class_name.find("BatchNorm") != -1: module.eval() self.cnn.apply(bn_eval) return self def forward(self, input_dict): output_dict = self.cnn(input_dict) output_dict["attn"] = output_dict["attn_emb"] output_dict["attn_len"] = output_dict["attn_emb_len"] del output_dict["attn_emb"], output_dict["attn_emb_len"] output_dict = self.trm(output_dict) return output_dict