from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple from torch import nn, Tensor import torch.nn.functional as F from torchvision.ops.misc import ConvNormActivation from torch.hub import load_state_dict_from_url import urllib.parse from efficientat.models.utils import cnn_out_size from efficientat.models.block_types import InvertedResidualConfig, InvertedResidual from efficientat.models.attention_pooling import MultiHeadAttentionPooling from efficientat.helpers.utils import NAME_TO_WIDTH # Adapted version of MobileNetV3 pytorch implementation # https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py # points to github releases model_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/" # folder to store downloaded models to model_dir = "resources" pretrained_models = { # pytorch ImageNet pre-trained model # own ImageNet pre-trained models will follow # NOTE: for easy loading we provide the adapted state dict ready for AudioSet training (1 input channel, # 527 output classes) # NOTE: the classifier is just a random initialization, feature extractor (conv layers) is pre-trained "mn10_im_pytorch": urllib.parse.urljoin(model_url, "mn10_im_pytorch.pt"), # Models trained on AudioSet "mn04_as": urllib.parse.urljoin(model_url, "mn04_as_mAP_432.pt"), "mn05_as": urllib.parse.urljoin(model_url, "mn05_as_mAP_443.pt"), "mn10_as": urllib.parse.urljoin(model_url, "mn10_as_mAP_471.pt"), "mn20_as": urllib.parse.urljoin(model_url, "mn20_as_mAP_478.pt"), "mn30_as": urllib.parse.urljoin(model_url, "mn30_as_mAP_482.pt"), "mn40_as": urllib.parse.urljoin(model_url, "mn40_as_mAP_484.pt"), "mn40_as(2)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483.pt"), "mn40_as(3)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483(2).pt"), "mn40_as_no_im_pre": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483.pt"), "mn40_as_no_im_pre(2)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483(2).pt"), "mn40_as_no_im_pre(3)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_482.pt"), "mn40_as_ext": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_487.pt"), "mn40_as_ext(2)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_486.pt"), "mn40_as_ext(3)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_485.pt"), # varying hop size (time resolution) "mn10_as_hop_15": urllib.parse.urljoin(model_url, "mn10_as_hop_15_mAP_463.pt"), "mn10_as_hop_20": urllib.parse.urljoin(model_url, "mn10_as_hop_20_mAP_456.pt"), "mn10_as_hop_25": urllib.parse.urljoin(model_url, "mn10_as_hop_25_mAP_447.pt"), # varying n_mels (frequency resolution) "mn10_as_mels_40": urllib.parse.urljoin(model_url, "mn10_as_mels_40_mAP_453.pt"), "mn10_as_mels_64": urllib.parse.urljoin(model_url, "mn10_as_mels_64_mAP_461.pt"), "mn10_as_mels_256": urllib.parse.urljoin(model_url, "mn10_as_mels_256_mAP_474.pt"), } class MobileNetV3(nn.Module): def __init__( self, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, num_classes: int = 1000, block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, dropout: float = 0.2, in_conv_kernel: int = 3, in_conv_stride: int = 2, in_channels: int = 1, **kwargs: Any, ) -> None: """ MobileNet V3 main class Args: inverted_residual_setting (List[InvertedResidualConfig]): Network structure last_channel (int): The number of channels on the penultimate layer num_classes (int): Number of classes block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for models norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use dropout (float): The droupout probability in_conv_kernel (int): Size of kernel for first convolution in_conv_stride (int): Size of stride for first convolution in_channels (int): Number of input channels """ super(MobileNetV3, self).__init__() if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") elif not ( isinstance(inverted_residual_setting, Sequence) and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting]) ): raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") if block is None: block = InvertedResidual depthwise_norm_layer = norm_layer = \ norm_layer if norm_layer is not None else partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) layers: List[nn.Module] = [] kernel_sizes = [in_conv_kernel] strides = [in_conv_stride] # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( ConvNormActivation( in_channels, firstconv_output_channels, kernel_size=in_conv_kernel, stride=in_conv_stride, norm_layer=norm_layer, activation_layer=nn.Hardswish, ) ) # get squeeze excitation config se_cnf = kwargs.get('se_conf', None) # building inverted residual blocks # - keep track of size of frequency and time dimensions for possible application of Squeeze-and-Excitation # on the frequency/time dimension # - applying Squeeze-and-Excitation on the time dimension is not recommended as this constrains the network to # a particular length of the audio clip, whereas Squeeze-and-Excitation on the frequency bands is fine, # as the number of frequency bands is usually not changing f_dim, t_dim = kwargs.get('input_dims', (128, 1000)) # take into account first conv layer f_dim = cnn_out_size(f_dim, 1, 1, 3, 2) t_dim = cnn_out_size(t_dim, 1, 1, 3, 2) for cnf in inverted_residual_setting: f_dim = cnf.out_size(f_dim) t_dim = cnf.out_size(t_dim) cnf.f_dim, cnf.t_dim = f_dim, t_dim # update dimensions in block config layers.append(block(cnf, se_cnf, norm_layer, depthwise_norm_layer)) kernel_sizes.append(cnf.kernel) strides.append(cnf.stride) # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels layers.append( ConvNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Hardswish, ) ) self.features = nn.Sequential(*layers) self.head_type = kwargs.get("head_type", False) if self.head_type == "multihead_attention_pooling": self.classifier = MultiHeadAttentionPooling(lastconv_output_channels, num_classes, num_heads=kwargs.get("multihead_attention_heads")) elif self.head_type == "fully_convolutional": self.classifier = nn.Sequential( nn.Conv2d( lastconv_output_channels, num_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False), nn.BatchNorm2d(num_classes), nn.AdaptiveAvgPool2d((1, 1)), ) elif self.head_type == "mlp": self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(start_dim=1), nn.Linear(lastconv_output_channels, last_channel), nn.Hardswish(inplace=True), nn.Dropout(p=dropout, inplace=True), nn.Linear(last_channel, num_classes), ) else: raise NotImplementedError(f"Head '{self.head_type}' unknown. Must be one of: 'mlp', " f"'fully_convolutional', 'multihead_attention_pooling'") for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.zeros_(m.bias) def _forward_impl(self, x: Tensor) -> (Tensor, Tensor): x = self.features(x) features = F.adaptive_avg_pool2d(x, (1, 1)).squeeze() x = self.classifier(x).squeeze() if features.dim() == 1 and x.dim() == 1: # squeezed batch dimension features = features.unsqueeze(0) x = x.unsqueeze(0) return x, features def forward(self, x: Tensor) -> (Tensor, Tensor): return self._forward_impl(x) def _mobilenet_v3_conf( width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, c4_stride: int = 2, **kwargs: Any ): reduce_divider = 2 if reduced_tail else 1 dilation = 2 if dilated else 1 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) # InvertedResidualConfig: # input_channels, kernel, expanded_channels, out_channels, use_se, activation, stride, dilation, width_mult inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", c4_stride, dilation), # C4 bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), ] last_channel = adjust_channels(1280 // reduce_divider) return inverted_residual_setting, last_channel def _mobilenet_v3( inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, pretrained_name: str, **kwargs: Any, ): model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) if pretrained_name in pretrained_models: model_url = pretrained_models.get(pretrained_name) state_dict = load_state_dict_from_url(model_url, model_dir=model_dir, map_location="cpu") if kwargs['num_classes'] != state_dict['classifier.5.bias'].size(0): # if the number of logits is not matching the state dict, # drop the corresponding pre-trained part print(f"Number of classes defined: {kwargs['num_classes']}, " f"but try to load pre-trained layer with logits: {state_dict['classifier.5.bias'].size(0)}\n" "Dropping last layer.") del state_dict['classifier.5.weight'] del state_dict['classifier.5.bias'] try: model.load_state_dict(state_dict) except RuntimeError as e: print(str(e)) print("Loading weights pre-trained weights in a non-strict manner.") model.load_state_dict(state_dict, strict=False) elif pretrained_name: raise NotImplementedError(f"Model name '{pretrained_name}' unknown.") return model def mobilenet_v3(pretrained_name: str = None, **kwargs: Any) \ -> MobileNetV3: """ Constructs a MobileNetV3 architecture from "Searching for MobileNetV3" ". """ inverted_residual_setting, last_channel = _mobilenet_v3_conf(**kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, pretrained_name, **kwargs) def get_model(num_classes: int = 527, pretrained_name: str = None, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, c4_stride: int = 2, head_type: str = "mlp", multihead_attention_heads: int = 4, input_dim_f: int = 128, input_dim_t: int = 1000, se_dims: str = 'c', se_agg: str = "max", se_r: int = 4): """ Arguments to modify the instantiation of a MobileNetv3 Args: num_classes (int): Specifies number of classes to predict pretrained_name (str): Specifies name of pre-trained model to load width_mult (float): Scales width of network reduced_tail (bool): Scales down network tail dilated (bool): Applies dilated convolution to network tail c4_stride (int): Set to '2' in original implementation; might be changed to modify the size of receptive field head_type (str): decides which classification head to use multihead_attention_heads (int): number of heads in case 'multihead_attention_heads' is used input_dim_f (int): number of frequency bands input_dim_t (int): number of time frames se_dims (Tuple): choose dimension to apply squeeze-excitation on, if multiple dimensions are chosen, then squeeze-excitation is applied concurrently and se layer outputs are fused by se_agg operation se_agg (str): operation to fuse output of concurrent se layers se_r (int): squeeze excitation bottleneck size se_dims (str): contains letters corresponding to dimensions 'c' - channel, 'f' - frequency, 't' - time """ dim_map = {'c': 1, 'f': 2, 't': 3} assert len(se_dims) <= 3 and all([s in dim_map.keys() for s in se_dims]) or se_dims == 'none' input_dims = (input_dim_f, input_dim_t) if se_dims == 'none': se_dims = None else: se_dims = [dim_map[s] for s in se_dims] se_conf = dict(se_dims=se_dims, se_agg=se_agg, se_r=se_r) m = mobilenet_v3(pretrained_name=pretrained_name, num_classes=num_classes, width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated, c4_stride=c4_stride, head_type=head_type, multihead_attention_heads=multihead_attention_heads, input_dims=input_dims, se_conf=se_conf ) print(m) return m class EnsemblerModel(nn.Module): def __init__(self, model_names): super(EnsemblerModel, self).__init__() self.models = nn.ModuleList([get_model(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name) for model_name in model_names]) def forward(self, x): all_out = None for m in self.models: out, _ = m(x) if all_out is None: all_out = out else: all_out = out + all_out all_out = all_out / len(self.models) return all_out, all_out def get_ensemble_model(model_names): return EnsemblerModel(model_names)