File size: 1,902 Bytes
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import Tuple

import torch.nn as nn
from torch import Tensor

from modules.multi_head_attention import MultiHeadAttention
from modules.positionwise_feed_forward import PositionwiseFeedForwardNetwork


class DecoderLayer(nn.Module):
    """
    A Decoder layer.
    
    Args:
    """
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout_p: int,
    ) -> None:
        super(DecoderLayer, self).__init__()
        self.self_attn_prenorm = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p)
        self.self_attn_dropout = nn.Dropout(p=dropout_p)

        self.cross_attn_prenorm = nn.LayerNorm(d_model)
        self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p)
        self.cross_attn_dropout = nn.Dropout(p=dropout_p)

        self.feed_forward_prenorm = nn.LayerNorm(d_model)
        self.feed_forward = PositionwiseFeedForwardNetwork(d_model=d_model, d_ff=d_ff, dropout_p=dropout_p)

    def forward(
        self, 
        decoder_inputs: Tensor,
        encoder_outputs: Tensor,
        src_mask: Tensor,
        tgt_mask: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        residual = decoder_inputs
        outputs = self.self_attn_prenorm(decoder_inputs)
        outputs, attn = self.self_attn(outputs, outputs, outputs, tgt_mask)
        outputs = self.self_attn_dropout(outputs) + residual

        residual = outputs
        outputs = self.self_attn_prenorm(outputs)
        outputs, attn = self.self_attn(outputs, encoder_outputs, encoder_outputs, src_mask)
        outputs = self.self_attn_dropout(outputs) + residual

        residual = outputs
        outputs = self.feed_forward_prenorm(outputs)
        outputs = self.feed_forward(outputs)
        outputs += residual

        return outputs, attn