File size: 1,399 Bytes
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Tuple

import torch.nn as nn
from torch import Tensor

from modules.multi_head_attention import MultiHeadAttention
from modules.positionwise_feed_forward import PositionwiseFeedForwardNetwork

class EncoderLayer(nn.Module):
    """
    An Encoder layer.
    
    Args:
    """
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout_p: int,
    ) -> None:
        super(EncoderLayer, self).__init__()
        self.self_attn_prenorm = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p)
        self.self_attn_dropout = nn.Dropout(p=dropout_p)

        self.feed_forward_prenorm = nn.LayerNorm(d_model)
        self.feed_forward = PositionwiseFeedForwardNetwork(d_model=d_model, d_ff=d_ff, dropout_p=dropout_p)

    def forward(self, inputs: Tensor, src_mask: Tensor = None) -> Tuple[Tensor, Tensor]:
        # Normalize -> sublayer -> dropout -> add residual
        residual = inputs
        inputs = self.self_attn_prenorm(inputs)
        outputs, attn = self.self_attn(inputs, inputs, inputs, src_mask)
        outputs = self.self_attn_dropout(outputs) + residual

        residual = outputs
        outputs = self.feed_forward_prenorm(outputs)
        outputs = self.feed_forward(outputs)
        outputs += residual

        return outputs, attn