# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import copy from functools import partial from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor, nn from torch.nn import functional as F from modules.norms import AdaptiveLayerNorm, LayerNorm, BalancedBasicNorm, IdentityNorm from modules.transformer import MultiheadAttention from modules.general.scaling import BalancedDoubleSwish class TransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, linear1_self_attention_cls: nn.Module = nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear, layer_norm_cls: nn.Module = LayerNorm, layer_norm_eps: float = 1e-5, adaptive_layer_norm=False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) # Implementation of Feedforward model self.linear1 = linear1_feedforward_cls( d_model, dim_feedforward, **factory_kwargs ) self.dropout = nn.Dropout(dropout) self.linear2 = linear2_feedforward_cls( dim_feedforward, d_model, **factory_kwargs ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) if isinstance(activation, str): activation = _get_activation_fn(activation) elif isinstance(activation, partial): activation = activation(d_model) elif activation == BalancedDoubleSwish: activation = BalancedDoubleSwish(d_model) self.activation = activation norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls == IdentityNorm: norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) else: norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if adaptive_layer_norm: self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm2 = AdaptiveLayerNorm(d_model, norm2) else: self.norm1 = norm1 self.norm2 = norm2 def __setstate__(self, state): super(TransformerEncoderLayer, self).__setstate__(state) if not hasattr(self, "activation"): self.activation = F.relu def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ x, stage_embedding = src, None is_src_tuple = False if isinstance(src, tuple): x, stage_embedding = src is_src_tuple = True if src_key_padding_mask is not None: _skpm_dtype = src_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( src_key_padding_mask ): raise AssertionError( "only bool and floating types of key_padding_mask are supported" ) if self.norm_first: x = x + self._sa_block( self.norm1(x, stage_embedding), src_mask, src_key_padding_mask, ) x = x + self._ff_block(self.norm2(x, stage_embedding)) else: x = self.norm1( x + self._sa_block(x, src_mask, src_key_padding_mask), stage_embedding, ) x = self.norm2(x + self._ff_block(x), stage_embedding) if is_src_tuple: return (x, stage_embedding) return x def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TransformerEncoder(nn.Module): """TransformerEncoder is a stack of N encoder layers.""" def __init__(self, encoder_layer, num_layers, norm=None): super(TransformerEncoder, self).__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_layer_states: bool = False, ) -> Tensor: # Pass the input through the encoder layers output = src layer_states = [] if return_layer_states else None for mod in self.layers: output = self._apply_module( mod, output, mask, src_key_padding_mask, layer_states ) if self.norm is not None: output = self.norm(output) return (layer_states, output) if return_layer_states else output def _apply_module(self, module, output, mask, key_padding_mask, layer_states): # Apply a single transformer module output = module(output, src_mask=mask, src_key_padding_mask=key_padding_mask) if layer_states is not None: layer_states.append(output) return output class TransformerDecoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, linear1_self_attention_cls: nn.Module = nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, layer_norm_cls: nn.Module = LayerNorm, layer_norm_eps: float = 1e-5, adaptive_layer_norm=False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerDecoderLayer, self).__init__() self.self_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) self.multihead_attn = MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) self.linear1 = linear1_feedforward_cls( d_model, dim_feedforward, **factory_kwargs ) self.dropout = nn.Dropout(dropout) self.linear2 = linear2_feedforward_cls( dim_feedforward, d_model, **factory_kwargs ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = self._get_activation_fn(activation) self.norm1, self.norm2, self.norm3 = self._init_norm_layers( d_model, layer_norm_cls, layer_norm_eps, adaptive_layer_norm, factory_kwargs ) def forward( self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). memory_mask: the mask for the memory sequence (optional). tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). Shape: see the docs in Transformer class. """ tgt_is_tuple = False if isinstance(tgt, tuple): x, stage_embedding = tgt tgt_is_tuple = True else: x, stage_embedding = tgt, None if self.norm_first: x = x + self._sa_block( self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask ) x = x + self._mha_block( self.norm2(x, stage_embedding), memory, memory_mask, memory_key_padding_mask, ) x = x + self._ff_block(self.norm3(x, stage_embedding)) else: x = self.norm1( x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), stage_embedding, ) x = self.norm2( x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask), stage_embedding, ) x = self.norm3(x + self._ff_block(x), stage_embedding) if tgt_is_tuple: return (x, stage_embedding) return x def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) def _mha_block( self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], ) -> Tensor: x = self.multihead_attn( x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout2(x) def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout3(x) def _get_activation_fn(self, activation): if isinstance(activation, str): return _get_activation_fn(activation) elif callable(activation): return activation else: raise ValueError("Unsupported activation type") def _init_norm_layers( self, d_model, layer_norm_cls, layer_norm_eps, adaptive_layer_norm, factory_kwargs, ): if adaptive_layer_norm: return ( AdaptiveLayerNorm( d_model, layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs), ), AdaptiveLayerNorm( d_model, layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs), ), AdaptiveLayerNorm( d_model, layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs), ), ) else: return ( layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs), layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs), ( layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls != IdentityNorm else BalancedBasicNorm( d_model, eps=layer_norm_eps, **factory_kwargs ) ), ) def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2)