""" Implementation of "Attention is All You Need" """ import torch import torch.nn as nn from .decoder import DecoderBase from .multi_headed_attn import MultiHeadedAttention from .average_attn import AverageAttention from .position_ffn import PositionwiseFeedForward from .misc import sequence_mask class TransformerDecoderLayer(nn.Module): """Transformer Decoder layer block in Pre-Norm style. Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, providing better converge speed and performance. This is also the actual implementation in tensor2tensor and also avalable in fairseq. See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. .. mermaid:: graph LR %% "*SubLayer" can be self-attn, src-attn or feed forward block A(input) --> B[Norm] B --> C["*SubLayer"] C --> D[Drop] D --> E((+)) A --> E E --> F(out) Args: d_model (int): the dimension of keys/values/queries in :class:`MultiHeadedAttention`, also the input size of the first-layer of the :class:`PositionwiseFeedForward`. heads (int): the number of heads for MultiHeadedAttention. d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`. dropout (float): dropout in residual, self-attn(dot) and feed-forward attention_dropout (float): dropout in context_attn (and self-attn(avg)) self_attn_type (string): type of self-attention scaled-dot, average max_relative_positions (int): Max distance between inputs in relative positions representations aan_useffn (bool): Turn on the FFN layer in the AAN decoder full_context_alignment (bool): whether enable an extra full context decoder forward for alignment alignment_heads (int): N. of cross attention heads to use for alignment guiding """ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, self_attn_type="scaled-dot", max_relative_positions=0, aan_useffn=False, full_context_alignment=False, alignment_heads=0): super(TransformerDecoderLayer, self).__init__() if self_attn_type == "scaled-dot": self.self_attn = MultiHeadedAttention( heads, d_model, dropout=attention_dropout, max_relative_positions=max_relative_positions) elif self_attn_type == "average": self.self_attn = AverageAttention(d_model, dropout=attention_dropout, aan_useffn=aan_useffn) self.context_attn = MultiHeadedAttention( heads, d_model, dropout=attention_dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) self.full_context_alignment = full_context_alignment self.alignment_heads = alignment_heads def forward(self, *args, **kwargs): """ Extend `_forward` for (possibly) multiple decoder pass: Always a default (future masked) decoder forward pass, Possibly a second future aware decoder pass for joint learn full context alignement, :cite:`garg2019jointly`. Args: * All arguments of _forward. with_align (bool): whether return alignment attention. Returns: (FloatTensor, FloatTensor, FloatTensor or None): * output ``(batch_size, T, model_dim)`` * top_attn ``(batch_size, T, src_len)`` * attn_align ``(batch_size, T, src_len)`` or None """ with_align = kwargs.pop('with_align', False) output, attns = self._forward(*args, **kwargs) top_attn = attns[:, 0, :, :].contiguous() attn_align = None if with_align: if self.full_context_alignment: # return _, (B, Q_len, K_len) _, attns = self._forward(*args, **kwargs, future=True) if self.alignment_heads > 0: attns = attns[:, :self.alignment_heads, :, :].contiguous() # layer average attention across heads, get ``(B, Q, K)`` # Case 1: no full_context, no align heads -> layer avg baseline # Case 2: no full_context, 1 align heads -> guided align # Case 3: full_context, 1 align heads -> full cte guided align attn_align = attns.mean(dim=1) return output, top_attn, attn_align def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, layer_cache=None, step=None, future=False): """ A naive forward pass for transformer decoder. # T: could be 1 in the case of stepwise decoding or tgt_len Args: inputs (FloatTensor): ``(batch_size, T, model_dim)`` memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` src_pad_mask (LongTensor): ``(batch_size, 1, src_len)`` tgt_pad_mask (LongTensor): ``(batch_size, 1, T)`` layer_cache (dict or None): cached layer info when stepwise decode step (int or None): stepwise decoding counter future (bool): If set True, do not apply future_mask. Returns: (FloatTensor, FloatTensor): * output ``(batch_size, T, model_dim)`` * attns ``(batch_size, head, T, src_len)`` """ dec_mask = None if step is None: tgt_len = tgt_pad_mask.size(-1) if not future: # apply future_mask, result mask in (B, T, T) future_mask = torch.ones( [tgt_len, tgt_len], device=tgt_pad_mask.device, dtype=torch.uint8) future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) # BoolTensor was introduced in pytorch 1.2 try: future_mask = future_mask.bool() except AttributeError: pass dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) else: # only mask padding, result mask in (B, 1, T) dec_mask = tgt_pad_mask input_norm = self.layer_norm_1(inputs) if isinstance(self.self_attn, MultiHeadedAttention): query, _ = self.self_attn(input_norm, input_norm, input_norm, mask=dec_mask, layer_cache=layer_cache, attn_type="self") elif isinstance(self.self_attn, AverageAttention): query, _ = self.self_attn(input_norm, mask=dec_mask, layer_cache=layer_cache, step=step) query = self.drop(query) + inputs query_norm = self.layer_norm_2(query) mid, attns = self.context_attn(memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, attn_type="context") output = self.feed_forward(self.drop(mid) + query) return output, attns def update_dropout(self, dropout, attention_dropout): self.self_attn.update_dropout(attention_dropout) self.context_attn.update_dropout(attention_dropout) self.feed_forward.update_dropout(dropout) self.drop.p = dropout class TransformerDecoder(DecoderBase): """The Transformer decoder from "Attention is All You Need". :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` .. mermaid:: graph BT A[input] B[multi-head self-attn] BB[multi-head src-attn] C[feed forward] O[output] A --> B B --> BB BB --> C C --> O Args: num_layers (int): number of encoder layers. d_model (int): size of the model heads (int): number of heads d_ff (int): size of the inner FF layer copy_attn (bool): if using a separate copy attention self_attn_type (str): type of self-attention scaled-dot, average dropout (float): dropout in residual, self-attn(dot) and feed-forward attention_dropout (float): dropout in context_attn (and self-attn(avg)) embeddings (onmt.modules.Embeddings): embeddings to use, should have positional encodings max_relative_positions (int): Max distance between inputs in relative positions representations aan_useffn (bool): Turn on the FFN layer in the AAN decoder full_context_alignment (bool): whether enable an extra full context decoder forward for alignment alignment_layer (int): N° Layer to supervise with for alignment guiding alignment_heads (int): N. of cross attention heads to use for alignment guiding """ def __init__(self, num_layers, d_model, heads, d_ff, copy_attn, self_attn_type, dropout, attention_dropout, embeddings, max_relative_positions, aan_useffn, full_context_alignment, alignment_layer, alignment_heads): super(TransformerDecoder, self).__init__() self.embeddings = embeddings # Decoder State self.state = {} self.transformer_layers = nn.ModuleList( [TransformerDecoderLayer(d_model, heads, d_ff, dropout, attention_dropout, self_attn_type=self_attn_type, max_relative_positions=max_relative_positions, aan_useffn=aan_useffn, full_context_alignment=full_context_alignment, alignment_heads=alignment_heads) for i in range(num_layers)]) # previously, there was a GlobalAttention module here for copy # attention. But it was never actually used -- the "copy" attention # just reuses the context attention. self._copy = copy_attn self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.alignment_layer = alignment_layer @classmethod def from_opt(cls, opt, embeddings): """Alternate constructor.""" return cls( opt.dec_layers, opt.dec_rnn_size, opt.heads, opt.transformer_ff, opt.copy_attn, opt.self_attn_type, opt.dropout[0] if type(opt.dropout) is list else opt.dropout, opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.dropout, embeddings, opt.max_relative_positions, opt.aan_useffn, opt.full_context_alignment, opt.alignment_layer, alignment_heads=opt.alignment_heads) def init_state(self, src, memory_bank, enc_hidden): """Initialize decoder state.""" self.state["src"] = src self.state["cache"] = None def map_state(self, fn): def _recursive_map(struct, batch_dim=0): for k, v in struct.items(): if v is not None: if isinstance(v, dict): _recursive_map(v) else: struct[k] = fn(v, batch_dim) self.state["src"] = fn(self.state["src"], 1) if self.state["cache"] is not None: _recursive_map(self.state["cache"]) def detach_state(self): self.state["src"] = self.state["src"].detach() def forward(self, tgt, memory_bank, step=None, **kwargs): """Decode, possibly stepwise.""" if step == 0: self._init_cache(memory_bank) tgt_words = tgt[:, :, 0].transpose(0, 1) emb = self.embeddings(tgt, step=step) assert emb.dim() == 3 # len x batch x embedding_dim output = emb.transpose(0, 1).contiguous() src_memory_bank = memory_bank.transpose(0, 1).contiguous() pad_idx = self.embeddings.word_padding_idx src_lens = kwargs["memory_lengths"] src_max_len = self.state["src"].shape[0] src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] with_align = kwargs.pop('with_align', False) attn_aligns = [] for i, layer in enumerate(self.transformer_layers): layer_cache = self.state["cache"]["layer_{}".format(i)] \ if step is not None else None output, attn, attn_align = layer( output, src_memory_bank, src_pad_mask, tgt_pad_mask, layer_cache=layer_cache, step=step, with_align=with_align) if attn_align is not None: attn_aligns.append(attn_align) output = self.layer_norm(output) dec_outs = output.transpose(0, 1).contiguous() attn = attn.transpose(0, 1).contiguous() attns = {"std": attn} if self._copy: attns["copy"] = attn if with_align: attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg # TODO change the way attns is returned dict => list or tuple (onnx) return dec_outs, attns def _init_cache(self, memory_bank): self.state["cache"] = {} batch_size = memory_bank.size(1) depth = memory_bank.size(-1) for i, layer in enumerate(self.transformer_layers): layer_cache = {"memory_keys": None, "memory_values": None} if isinstance(layer.self_attn, AverageAttention): layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth), device=memory_bank.device) else: layer_cache["self_keys"] = None layer_cache["self_values"] = None self.state["cache"]["layer_{}".format(i)] = layer_cache def update_dropout(self, dropout, attention_dropout): self.embeddings.update_dropout(dropout) for layer in self.transformer_layers: layer.update_dropout(dropout, attention_dropout)