|
import copy |
|
from typing import Optional, Any |
|
|
|
import torch |
|
|
|
from torch import Tensor |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
def conv3x3(in_channels, out_channels, num_groups=0): |
|
return nn.Sequential( |
|
|
|
nn.Conv2d(in_channels, out_channels, (3, 3), 1, 1, bias=False), |
|
nn.BatchNorm2d(out_channels) if num_groups < 1 else nn.GroupNorm(num_groups, out_channels), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
class XTransformerEncoder(nn.Module): |
|
__constants__ = ['norm'] |
|
def __init__(self, encoder_layer, num_layers, num_conv=2, norm=None): |
|
super().__init__() |
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
|
|
d_model = encoder_layer.linear1.in_features |
|
self.conv = nn.ModuleList([ |
|
nn.Sequential(*[ |
|
conv3x3(d_model, d_model) for _ in range(num_conv) |
|
]) for _ in range(num_layers) |
|
]) |
|
|
|
def flatten(self, x): |
|
N, D, H, W = x.size() |
|
x = x.to(memory_format=torch.channels_last) |
|
x = x.permute(0, 2, 3, 1).view(N, H*W, D) |
|
return x |
|
|
|
def unflatten(self, x, size): |
|
N, R, D = x.size() |
|
H, W = size |
|
assert R == H*W, 'wrong tensor size' |
|
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format) |
|
x = x.view(N, D, H, W) |
|
return x |
|
|
|
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, size=None) -> Tensor: |
|
output = src |
|
|
|
for i, mod in enumerate(self.layers): |
|
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) |
|
|
|
vis = self.unflatten(output[:, :size[0]*size[1]], size) |
|
vis = self.flatten(self.conv[i](vis)) |
|
|
|
output = torch.cat([vis, output[:, size[0]*size[1]:]], dim=1) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
|
|
return output |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
r"""TransformerEncoder is a stack of N encoder layers |
|
|
|
Args: |
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required). |
|
num_layers: the number of sub-encoder-layers in the encoder (required). |
|
norm: the layer normalization component (optional). |
|
|
|
Examples:: |
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) |
|
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) |
|
>>> src = torch.rand(10, 32, 512) |
|
>>> out = transformer_encoder(src) |
|
""" |
|
__constants__ = ['norm'] |
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None): |
|
super(TransformerEncoder, self).__init__() |
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
|
|
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None) -> Tensor: |
|
r"""Pass the input through the encoder layers in turn. |
|
|
|
Args: |
|
src: the sequence to the encoder (required). |
|
mask: the mask for the src sequence (optional). |
|
src_key_padding_mask: the mask for the src keys per batch (optional). |
|
|
|
Shape: |
|
see the docs in Transformer class. |
|
""" |
|
output = src |
|
|
|
for mod in self.layers: |
|
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
|
|
return output |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
r"""TransformerEncoderLayer is made up of self-attn and feedforward network. |
|
This standard encoder layer is based on the paper "Attention Is All You Need". |
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, |
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in |
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement |
|
in a different way during application. |
|
|
|
Args: |
|
d_model: the number of expected features in the input (required). |
|
nhead: the number of heads in the multiheadattention models (required). |
|
dim_feedforward: the dimension of the feedforward network model (default=2048). |
|
dropout: the dropout value (default=0.1). |
|
activation: the activation function of intermediate layer, relu or gelu (default=relu). |
|
layer_norm_eps: the eps value in layer normalization components (default=1e-5). |
|
batch_first: If ``True``, then the input and output tensors are provided |
|
as (batch, seq, feature). Default: ``False``. |
|
|
|
Examples:: |
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) |
|
>>> src = torch.rand(10, 32, 512) |
|
>>> out = encoder_layer(src) |
|
|
|
Alternatively, when ``batch_first`` is ``True``: |
|
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) |
|
>>> src = torch.rand(32, 10, 512) |
|
>>> out = encoder_layer(src) |
|
""" |
|
__constants__ = ['batch_first'] |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", |
|
layer_norm_eps=1e-5, batch_first=False, |
|
device=None, dtype=None) -> None: |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
super(TransformerEncoderLayer, self).__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, |
|
**factory_kwargs) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) |
|
|
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
|
def __setstate__(self, state): |
|
if 'activation' not in state: |
|
state['activation'] = F.relu |
|
super(TransformerEncoderLayer, self).__setstate__(state) |
|
|
|
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None) -> Tensor: |
|
r"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
src: the sequence to the encoder layer (required). |
|
src_mask: the mask for the src sequence (optional). |
|
src_key_padding_mask: the mask for the src keys per batch (optional). |
|
|
|
Shape: |
|
see the docs in Transformer class. |
|
""" |
|
|
|
q = k = src if pos is None else src + pos |
|
|
|
src2 = self.self_attn(q, k, src, attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src = self.norm1(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
|
src = src + self.dropout2(src2) |
|
src = self.norm2(src) |
|
return src |
|
|
|
|
|
def _get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
if activation == "relu": |
|
return F.relu |
|
elif activation == "gelu": |
|
return F.gelu |
|
|
|
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) |
|
|