Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Average Attention module.""" | |
import torch | |
import torch.nn as nn | |
from .position_ffn import PositionwiseFeedForward | |
class AverageAttention(nn.Module): | |
""" | |
Average Attention module from | |
"Accelerating Neural Transformer via an Average Attention Network" | |
:cite:`DBLP:journals/corr/abs-1805-00631`. | |
Args: | |
model_dim (int): the dimension of keys/values/queries, | |
must be divisible by head_count | |
dropout (float): dropout parameter | |
""" | |
def __init__(self, model_dim, dropout=0.1, aan_useffn=False): | |
self.model_dim = model_dim | |
self.aan_useffn = aan_useffn | |
super(AverageAttention, self).__init__() | |
if aan_useffn: | |
self.average_layer = PositionwiseFeedForward(model_dim, model_dim, | |
dropout) | |
self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) | |
def cumulative_average_mask(self, batch_size, inputs_len, device): | |
""" | |
Builds the mask to compute the cumulative average as described in | |
:cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3 | |
Args: | |
batch_size (int): batch size | |
inputs_len (int): length of the inputs | |
Returns: | |
(FloatTensor): | |
* A Tensor of shape ``(batch_size, input_len, input_len)`` | |
""" | |
triangle = torch.tril(torch.ones(inputs_len, inputs_len, | |
dtype=torch.float, device=device)) | |
weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \ | |
/ torch.arange(1, inputs_len + 1, dtype=torch.float, device=device) | |
mask = triangle * weights.transpose(0, 1) | |
return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len) | |
def cumulative_average(self, inputs, mask_or_step, | |
layer_cache=None, step=None): | |
""" | |
Computes the cumulative average as described in | |
:cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6) | |
Args: | |
inputs (FloatTensor): sequence to average | |
``(batch_size, input_len, dimension)`` | |
mask_or_step: if cache is set, this is assumed | |
to be the current step of the | |
dynamic decoding. Otherwise, it is the mask matrix | |
used to compute the cumulative average. | |
layer_cache: a dictionary containing the cumulative average | |
of the previous step. | |
Returns: | |
a tensor of the same shape and type as ``inputs``. | |
""" | |
if layer_cache is not None: | |
step = mask_or_step | |
average_attention = (inputs + step * | |
layer_cache["prev_g"]) / (step + 1) | |
layer_cache["prev_g"] = average_attention | |
return average_attention | |
else: | |
mask = mask_or_step | |
return torch.matmul(mask.to(inputs.dtype), inputs) | |
def forward(self, inputs, mask=None, layer_cache=None, step=None): | |
""" | |
Args: | |
inputs (FloatTensor): ``(batch_size, input_len, model_dim)`` | |
Returns: | |
(FloatTensor, FloatTensor): | |
* gating_outputs ``(batch_size, input_len, model_dim)`` | |
* average_outputs average attention | |
``(batch_size, input_len, model_dim)`` | |
""" | |
batch_size = inputs.size(0) | |
inputs_len = inputs.size(1) | |
average_outputs = self.cumulative_average( | |
inputs, self.cumulative_average_mask(batch_size, | |
inputs_len, inputs.device) | |
if layer_cache is None else step, layer_cache=layer_cache) | |
if self.aan_useffn: | |
average_outputs = self.average_layer(average_outputs) | |
gating_outputs = self.gating_layer(torch.cat((inputs, | |
average_outputs), -1)) | |
input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) | |
gating_outputs = torch.sigmoid(input_gate) * inputs + \ | |
torch.sigmoid(forget_gate) * average_outputs | |
return gating_outputs, average_outputs | |