# Copyright (c) OpenMMLab. All rights reserved. import torch.nn.functional as F from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.runner.base_module import BaseModule from torch.nn.modules.utils import _pair as to_2tuple # Modified from Pytorch-Image-Models class PatchEmbed(BaseModule): """Image to Patch Embedding V2. We use a conv layer to implement PatchEmbed. Args: in_channels (int): The num of input channels. Default: 3 embed_dims (int): The dimensions of embedding. Default: 768 conv_type (dict, optional): The config dict for conv layers type selection. Default: None. kernel_size (int): The kernel_size of embedding conv. Default: 16. stride (int): The slide stride of embedding conv. Default: None (Default to be equal with kernel_size). padding (int): The padding length of embedding conv. Default: 0. dilation (int): The dilation rate of embedding conv. Default: 1. pad_to_patch_size (bool, optional): Whether to pad feature map shape to multiple patch size. Default: True. norm_cfg (dict, optional): Config dict for normalization layer. init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. Default: None. """ def __init__(self, in_channels=3, embed_dims=768, conv_type=None, kernel_size=16, stride=16, padding=0, dilation=1, pad_to_patch_size=True, norm_cfg=None, init_cfg=None): super(PatchEmbed, self).__init__() self.embed_dims = embed_dims self.init_cfg = init_cfg if stride is None: stride = kernel_size self.pad_to_patch_size = pad_to_patch_size # The default setting of patch size is equal to kernel size. patch_size = kernel_size if isinstance(patch_size, int): patch_size = to_2tuple(patch_size) elif isinstance(patch_size, tuple): if len(patch_size) == 1: patch_size = to_2tuple(patch_size[0]) assert len(patch_size) == 2, \ f'The size of patch should have length 1 or 2, ' \ f'but got {len(patch_size)}' self.patch_size = patch_size # Use conv layer to embed conv_type = conv_type or 'Conv2d' self.projection = build_conv_layer( dict(type=conv_type), in_channels=in_channels, out_channels=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) if norm_cfg is not None: self.norm = build_norm_layer(norm_cfg, embed_dims)[1] else: self.norm = None def forward(self, x): H, W = x.shape[2], x.shape[3] # TODO: Process overlapping op if self.pad_to_patch_size: # Modify H, W to multiple of patch size. if H % self.patch_size[0] != 0: x = F.pad( x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) if W % self.patch_size[1] != 0: x = F.pad( x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) x = self.projection(x) self.DH, self.DW = x.shape[2], x.shape[3] x = x.flatten(2).transpose(1, 2) if self.norm is not None: x = self.norm(x) return x