tts-rvc-autopst / onmt_modules /average_attn.py
jonathanjordan21's picture
Upload folder using huggingface_hub (#1)
7ce5feb verified
# -*- coding: utf-8 -*-
"""Average Attention module."""
import torch
import torch.nn as nn
from .position_ffn import PositionwiseFeedForward
class AverageAttention(nn.Module):
"""
Average Attention module from
"Accelerating Neural Transformer via an Average Attention Network"
:cite:`DBLP:journals/corr/abs-1805-00631`.
Args:
model_dim (int): the dimension of keys/values/queries,
must be divisible by head_count
dropout (float): dropout parameter
"""
def __init__(self, model_dim, dropout=0.1, aan_useffn=False):
self.model_dim = model_dim
self.aan_useffn = aan_useffn
super(AverageAttention, self).__init__()
if aan_useffn:
self.average_layer = PositionwiseFeedForward(model_dim, model_dim,
dropout)
self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2)
def cumulative_average_mask(self, batch_size, inputs_len, device):
"""
Builds the mask to compute the cumulative average as described in
:cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3
Args:
batch_size (int): batch size
inputs_len (int): length of the inputs
Returns:
(FloatTensor):
* A Tensor of shape ``(batch_size, input_len, input_len)``
"""
triangle = torch.tril(torch.ones(inputs_len, inputs_len,
dtype=torch.float, device=device))
weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \
/ torch.arange(1, inputs_len + 1, dtype=torch.float, device=device)
mask = triangle * weights.transpose(0, 1)
return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len)
def cumulative_average(self, inputs, mask_or_step,
layer_cache=None, step=None):
"""
Computes the cumulative average as described in
:cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6)
Args:
inputs (FloatTensor): sequence to average
``(batch_size, input_len, dimension)``
mask_or_step: if cache is set, this is assumed
to be the current step of the
dynamic decoding. Otherwise, it is the mask matrix
used to compute the cumulative average.
layer_cache: a dictionary containing the cumulative average
of the previous step.
Returns:
a tensor of the same shape and type as ``inputs``.
"""
if layer_cache is not None:
step = mask_or_step
average_attention = (inputs + step *
layer_cache["prev_g"]) / (step + 1)
layer_cache["prev_g"] = average_attention
return average_attention
else:
mask = mask_or_step
return torch.matmul(mask.to(inputs.dtype), inputs)
def forward(self, inputs, mask=None, layer_cache=None, step=None):
"""
Args:
inputs (FloatTensor): ``(batch_size, input_len, model_dim)``
Returns:
(FloatTensor, FloatTensor):
* gating_outputs ``(batch_size, input_len, model_dim)``
* average_outputs average attention
``(batch_size, input_len, model_dim)``
"""
batch_size = inputs.size(0)
inputs_len = inputs.size(1)
average_outputs = self.cumulative_average(
inputs, self.cumulative_average_mask(batch_size,
inputs_len, inputs.device)
if layer_cache is None else step, layer_cache=layer_cache)
if self.aan_useffn:
average_outputs = self.average_layer(average_outputs)
gating_outputs = self.gating_layer(torch.cat((inputs,
average_outputs), -1))
input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2)
gating_outputs = torch.sigmoid(input_gate) * inputs + \
torch.sigmoid(forget_gate) * average_outputs
return gating_outputs, average_outputs