File size: 1,902 Bytes
bc1ada8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from typing import Tuple
import torch.nn as nn
from torch import Tensor
from modules.multi_head_attention import MultiHeadAttention
from modules.positionwise_feed_forward import PositionwiseFeedForwardNetwork
class DecoderLayer(nn.Module):
"""
A Decoder layer.
Args:
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout_p: int,
) -> None:
super(DecoderLayer, self).__init__()
self.self_attn_prenorm = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p)
self.self_attn_dropout = nn.Dropout(p=dropout_p)
self.cross_attn_prenorm = nn.LayerNorm(d_model)
self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p)
self.cross_attn_dropout = nn.Dropout(p=dropout_p)
self.feed_forward_prenorm = nn.LayerNorm(d_model)
self.feed_forward = PositionwiseFeedForwardNetwork(d_model=d_model, d_ff=d_ff, dropout_p=dropout_p)
def forward(
self,
decoder_inputs: Tensor,
encoder_outputs: Tensor,
src_mask: Tensor,
tgt_mask: Tensor,
) -> Tuple[Tensor, Tensor]:
residual = decoder_inputs
outputs = self.self_attn_prenorm(decoder_inputs)
outputs, attn = self.self_attn(outputs, outputs, outputs, tgt_mask)
outputs = self.self_attn_dropout(outputs) + residual
residual = outputs
outputs = self.self_attn_prenorm(outputs)
outputs, attn = self.self_attn(outputs, encoder_outputs, encoder_outputs, src_mask)
outputs = self.self_attn_dropout(outputs) + residual
residual = outputs
outputs = self.feed_forward_prenorm(outputs)
outputs = self.feed_forward(outputs)
outputs += residual
return outputs, attn |