Spaces:
Runtime error
Runtime error
""" Multi-Head Attention module """ | |
import math | |
import torch | |
import torch.nn as nn | |
from .misc import generate_relative_positions_matrix,\ | |
relative_matmul | |
# from onmt.utils.misc import aeq | |
class MultiHeadedAttention(nn.Module): | |
"""Multi-Head Attention module from "Attention is All You Need" | |
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. | |
Similar to standard `dot` attention but uses | |
multiple attention distributions simulataneously | |
to select relevant items. | |
.. mermaid:: | |
graph BT | |
A[key] | |
B[value] | |
C[query] | |
O[output] | |
subgraph Attn | |
D[Attn 1] | |
E[Attn 2] | |
F[Attn N] | |
end | |
A --> D | |
C --> D | |
A --> E | |
C --> E | |
A --> F | |
C --> F | |
D --> O | |
E --> O | |
F --> O | |
B --> O | |
Also includes several additional tricks. | |
Args: | |
head_count (int): number of parallel heads | |
model_dim (int): the dimension of keys/values/queries, | |
must be divisible by head_count | |
dropout (float): dropout parameter | |
""" | |
def __init__(self, head_count, model_dim, dropout=0.1, | |
max_relative_positions=0): | |
assert model_dim % head_count == 0 | |
self.dim_per_head = model_dim // head_count | |
self.model_dim = model_dim | |
super(MultiHeadedAttention, self).__init__() | |
self.head_count = head_count | |
self.linear_keys = nn.Linear(model_dim, | |
head_count * self.dim_per_head) | |
self.linear_values = nn.Linear(model_dim, | |
head_count * self.dim_per_head) | |
self.linear_query = nn.Linear(model_dim, | |
head_count * self.dim_per_head) | |
self.softmax = nn.Softmax(dim=-1) | |
self.dropout = nn.Dropout(dropout) | |
self.final_linear = nn.Linear(model_dim, model_dim) | |
self.max_relative_positions = max_relative_positions | |
if max_relative_positions > 0: | |
vocab_size = max_relative_positions * 2 + 1 | |
self.relative_positions_embeddings = nn.Embedding( | |
vocab_size, self.dim_per_head) | |
def forward(self, key, value, query, mask=None, | |
layer_cache=None, attn_type=None): | |
""" | |
Compute the context vector and the attention vectors. | |
Args: | |
key (FloatTensor): set of `key_len` | |
key vectors ``(batch, key_len, dim)`` | |
value (FloatTensor): set of `key_len` | |
value vectors ``(batch, key_len, dim)`` | |
query (FloatTensor): set of `query_len` | |
query vectors ``(batch, query_len, dim)`` | |
mask: binary mask 1/0 indicating which keys have | |
zero / non-zero attention ``(batch, query_len, key_len)`` | |
Returns: | |
(FloatTensor, FloatTensor): | |
* output context vectors ``(batch, query_len, dim)`` | |
* Attention vector in heads ``(batch, head, query_len, key_len)``. | |
""" | |
# CHECKS | |
# batch, k_len, d = key.size() | |
# batch_, k_len_, d_ = value.size() | |
# aeq(batch, batch_) | |
# aeq(k_len, k_len_) | |
# aeq(d, d_) | |
# batch_, q_len, d_ = query.size() | |
# aeq(batch, batch_) | |
# aeq(d, d_) | |
# aeq(self.model_dim % 8, 0) | |
# if mask is not None: | |
# batch_, q_len_, k_len_ = mask.size() | |
# aeq(batch_, batch) | |
# aeq(k_len_, k_len) | |
# aeq(q_len_ == q_len) | |
# END CHECKS | |
batch_size = key.size(0) | |
dim_per_head = self.dim_per_head | |
head_count = self.head_count | |
key_len = key.size(1) | |
query_len = query.size(1) | |
def shape(x): | |
"""Projection.""" | |
return x.view(batch_size, -1, head_count, dim_per_head) \ | |
.transpose(1, 2) | |
def unshape(x): | |
"""Compute context.""" | |
return x.transpose(1, 2).contiguous() \ | |
.view(batch_size, -1, head_count * dim_per_head) | |
# 1) Project key, value, and query. | |
if layer_cache is not None: | |
if attn_type == "self": | |
query, key, value = self.linear_query(query),\ | |
self.linear_keys(query),\ | |
self.linear_values(query) | |
key = shape(key) | |
value = shape(value) | |
if layer_cache["self_keys"] is not None: | |
key = torch.cat( | |
(layer_cache["self_keys"], key), | |
dim=2) | |
if layer_cache["self_values"] is not None: | |
value = torch.cat( | |
(layer_cache["self_values"], value), | |
dim=2) | |
layer_cache["self_keys"] = key | |
layer_cache["self_values"] = value | |
elif attn_type == "context": | |
query = self.linear_query(query) | |
if layer_cache["memory_keys"] is None: | |
key, value = self.linear_keys(key),\ | |
self.linear_values(value) | |
key = shape(key) | |
value = shape(value) | |
else: | |
key, value = layer_cache["memory_keys"],\ | |
layer_cache["memory_values"] | |
layer_cache["memory_keys"] = key | |
layer_cache["memory_values"] = value | |
else: | |
key = self.linear_keys(key) | |
value = self.linear_values(value) | |
query = self.linear_query(query) | |
key = shape(key) | |
value = shape(value) | |
if self.max_relative_positions > 0 and attn_type == "self": | |
key_len = key.size(2) | |
# 1 or key_len x key_len | |
relative_positions_matrix = generate_relative_positions_matrix( | |
key_len, self.max_relative_positions, | |
cache=True if layer_cache is not None else False) | |
# 1 or key_len x key_len x dim_per_head | |
relations_keys = self.relative_positions_embeddings( | |
relative_positions_matrix.to(key.device)) | |
# 1 or key_len x key_len x dim_per_head | |
relations_values = self.relative_positions_embeddings( | |
relative_positions_matrix.to(key.device)) | |
query = shape(query) | |
key_len = key.size(2) | |
query_len = query.size(2) | |
# 2) Calculate and scale scores. | |
query = query / math.sqrt(dim_per_head) | |
# batch x num_heads x query_len x key_len | |
query_key = torch.matmul(query, key.transpose(2, 3)) | |
if self.max_relative_positions > 0 and attn_type == "self": | |
scores = query_key + relative_matmul(query, relations_keys, True) | |
else: | |
scores = query_key | |
scores = scores.float() | |
if mask is not None: | |
mask = mask.unsqueeze(1) # [B, 1, 1, T_values] | |
scores = scores.masked_fill(mask, -1e18) | |
# 3) Apply attention dropout and compute context vectors. | |
attn = self.softmax(scores).to(query.dtype) | |
drop_attn = self.dropout(attn) | |
context_original = torch.matmul(drop_attn, value) | |
if self.max_relative_positions > 0 and attn_type == "self": | |
context = unshape(context_original | |
+ relative_matmul(drop_attn, | |
relations_values, | |
False)) | |
else: | |
context = unshape(context_original) | |
output = self.final_linear(context) | |
# CHECK | |
# batch_, q_len_, d_ = output.size() | |
# aeq(q_len, q_len_) | |
# aeq(batch, batch_) | |
# aeq(d, d_) | |
# Return multi-head attn | |
attns = attn \ | |
.view(batch_size, head_count, | |
query_len, key_len) | |
return output, attns | |
def update_dropout(self, dropout): | |
self.dropout.p = dropout | |