""" Multi-Head Attention module """ import math import torch import torch.nn as nn from .misc import generate_relative_positions_matrix,\ relative_matmul # from onmt.utils.misc import aeq class MultiHeadedAttention(nn.Module): """Multi-Head Attention module from "Attention is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. Similar to standard `dot` attention but uses multiple attention distributions simulataneously to select relevant items. .. mermaid:: graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> O Also includes several additional tricks. Args: head_count (int): number of parallel heads model_dim (int): the dimension of keys/values/queries, must be divisible by head_count dropout (float): dropout parameter """ def __init__(self, head_count, model_dim, dropout=0.1, max_relative_positions=0): assert model_dim % head_count == 0 self.dim_per_head = model_dim // head_count self.model_dim = model_dim super(MultiHeadedAttention, self).__init__() self.head_count = head_count self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head) self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.final_linear = nn.Linear(model_dim, model_dim) self.max_relative_positions = max_relative_positions if max_relative_positions > 0: vocab_size = max_relative_positions * 2 + 1 self.relative_positions_embeddings = nn.Embedding( vocab_size, self.dim_per_head) def forward(self, key, value, query, mask=None, layer_cache=None, attn_type=None): """ Compute the context vector and the attention vectors. Args: key (FloatTensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (FloatTensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (FloatTensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask 1/0 indicating which keys have zero / non-zero attention ``(batch, query_len, key_len)`` Returns: (FloatTensor, FloatTensor): * output context vectors ``(batch, query_len, dim)`` * Attention vector in heads ``(batch, head, query_len, key_len)``. """ # CHECKS # batch, k_len, d = key.size() # batch_, k_len_, d_ = value.size() # aeq(batch, batch_) # aeq(k_len, k_len_) # aeq(d, d_) # batch_, q_len, d_ = query.size() # aeq(batch, batch_) # aeq(d, d_) # aeq(self.model_dim % 8, 0) # if mask is not None: # batch_, q_len_, k_len_ = mask.size() # aeq(batch_, batch) # aeq(k_len_, k_len) # aeq(q_len_ == q_len) # END CHECKS batch_size = key.size(0) dim_per_head = self.dim_per_head head_count = self.head_count key_len = key.size(1) query_len = query.size(1) def shape(x): """Projection.""" return x.view(batch_size, -1, head_count, dim_per_head) \ .transpose(1, 2) def unshape(x): """Compute context.""" return x.transpose(1, 2).contiguous() \ .view(batch_size, -1, head_count * dim_per_head) # 1) Project key, value, and query. if layer_cache is not None: if attn_type == "self": query, key, value = self.linear_query(query),\ self.linear_keys(query),\ self.linear_values(query) key = shape(key) value = shape(value) if layer_cache["self_keys"] is not None: key = torch.cat( (layer_cache["self_keys"], key), dim=2) if layer_cache["self_values"] is not None: value = torch.cat( (layer_cache["self_values"], value), dim=2) layer_cache["self_keys"] = key layer_cache["self_values"] = value elif attn_type == "context": query = self.linear_query(query) if layer_cache["memory_keys"] is None: key, value = self.linear_keys(key),\ self.linear_values(value) key = shape(key) value = shape(value) else: key, value = layer_cache["memory_keys"],\ layer_cache["memory_values"] layer_cache["memory_keys"] = key layer_cache["memory_values"] = value else: key = self.linear_keys(key) value = self.linear_values(value) query = self.linear_query(query) key = shape(key) value = shape(value) if self.max_relative_positions > 0 and attn_type == "self": key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = generate_relative_positions_matrix( key_len, self.max_relative_positions, cache=True if layer_cache is not None else False) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) # 1 or key_len x key_len x dim_per_head relations_values = self.relative_positions_embeddings( relative_positions_matrix.to(key.device)) query = shape(query) key_len = key.size(2) query_len = query.size(2) # 2) Calculate and scale scores. query = query / math.sqrt(dim_per_head) # batch x num_heads x query_len x key_len query_key = torch.matmul(query, key.transpose(2, 3)) if self.max_relative_positions > 0 and attn_type == "self": scores = query_key + relative_matmul(query, relations_keys, True) else: scores = query_key scores = scores.float() if mask is not None: mask = mask.unsqueeze(1) # [B, 1, 1, T_values] scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) context_original = torch.matmul(drop_attn, value) if self.max_relative_positions > 0 and attn_type == "self": context = unshape(context_original + relative_matmul(drop_attn, relations_values, False)) else: context = unshape(context_original) output = self.final_linear(context) # CHECK # batch_, q_len_, d_ = output.size() # aeq(q_len, q_len_) # aeq(batch, batch_) # aeq(d, d_) # Return multi-head attn attns = attn \ .view(batch_size, head_count, query_len, key_len) return output, attns def update_dropout(self, dropout): self.dropout.p = dropout