File size: 4,215 Bytes
7ce5feb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
"""Average Attention module."""

import torch
import torch.nn as nn

from .position_ffn import PositionwiseFeedForward


class AverageAttention(nn.Module):
    """
    Average Attention module from
    "Accelerating Neural Transformer via an Average Attention Network"
    :cite:`DBLP:journals/corr/abs-1805-00631`.

    Args:
       model_dim (int): the dimension of keys/values/queries,
           must be divisible by head_count
       dropout (float): dropout parameter
    """

    def __init__(self, model_dim, dropout=0.1, aan_useffn=False):
        self.model_dim = model_dim
        self.aan_useffn = aan_useffn
        super(AverageAttention, self).__init__()
        if aan_useffn:
            self.average_layer = PositionwiseFeedForward(model_dim, model_dim,
                                                         dropout)
        self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2)

    def cumulative_average_mask(self, batch_size, inputs_len, device):
        """
        Builds the mask to compute the cumulative average as described in
        :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3

        Args:
            batch_size (int): batch size
            inputs_len (int): length of the inputs

        Returns:
            (FloatTensor):

            * A Tensor of shape ``(batch_size, input_len, input_len)``
        """

        triangle = torch.tril(torch.ones(inputs_len, inputs_len,
                              dtype=torch.float, device=device))
        weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \
            / torch.arange(1, inputs_len + 1, dtype=torch.float, device=device)
        mask = triangle * weights.transpose(0, 1)

        return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len)

    def cumulative_average(self, inputs, mask_or_step,
                           layer_cache=None, step=None):
        """
        Computes the cumulative average as described in
        :cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6)

        Args:
            inputs (FloatTensor): sequence to average
                ``(batch_size, input_len, dimension)``
            mask_or_step: if cache is set, this is assumed
                to be the current step of the
                dynamic decoding. Otherwise, it is the mask matrix
                used to compute the cumulative average.
            layer_cache: a dictionary containing the cumulative average
                of the previous step.

        Returns:
            a tensor of the same shape and type as ``inputs``.
        """

        if layer_cache is not None:
            step = mask_or_step
            average_attention = (inputs + step *
                                 layer_cache["prev_g"]) / (step + 1)
            layer_cache["prev_g"] = average_attention
            return average_attention
        else:
            mask = mask_or_step
            return torch.matmul(mask.to(inputs.dtype), inputs)

    def forward(self, inputs, mask=None, layer_cache=None, step=None):
        """
        Args:
            inputs (FloatTensor): ``(batch_size, input_len, model_dim)``

        Returns:
            (FloatTensor, FloatTensor):

            * gating_outputs ``(batch_size, input_len, model_dim)``
            * average_outputs average attention
                ``(batch_size, input_len, model_dim)``
        """

        batch_size = inputs.size(0)
        inputs_len = inputs.size(1)
        average_outputs = self.cumulative_average(
          inputs, self.cumulative_average_mask(batch_size,
                                               inputs_len, inputs.device)
          if layer_cache is None else step, layer_cache=layer_cache)
        if self.aan_useffn:
            average_outputs = self.average_layer(average_outputs)
        gating_outputs = self.gating_layer(torch.cat((inputs,
                                                      average_outputs), -1))
        input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2)
        gating_outputs = torch.sigmoid(input_gate) * inputs + \
            torch.sigmoid(forget_gate) * average_outputs

        return gating_outputs, average_outputs