|
from typing import Tuple |
|
|
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from modules.multi_head_attention import MultiHeadAttention |
|
from modules.positionwise_feed_forward import PositionwiseFeedForwardNetwork |
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
""" |
|
A Decoder layer. |
|
|
|
Args: |
|
""" |
|
def __init__( |
|
self, |
|
d_model: int, |
|
num_heads: int, |
|
d_ff: int, |
|
dropout_p: int, |
|
) -> None: |
|
super(DecoderLayer, self).__init__() |
|
self.self_attn_prenorm = nn.LayerNorm(d_model) |
|
self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p) |
|
self.self_attn_dropout = nn.Dropout(p=dropout_p) |
|
|
|
self.cross_attn_prenorm = nn.LayerNorm(d_model) |
|
self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p) |
|
self.cross_attn_dropout = nn.Dropout(p=dropout_p) |
|
|
|
self.feed_forward_prenorm = nn.LayerNorm(d_model) |
|
self.feed_forward = PositionwiseFeedForwardNetwork(d_model=d_model, d_ff=d_ff, dropout_p=dropout_p) |
|
|
|
def forward( |
|
self, |
|
decoder_inputs: Tensor, |
|
encoder_outputs: Tensor, |
|
src_mask: Tensor, |
|
tgt_mask: Tensor, |
|
) -> Tuple[Tensor, Tensor]: |
|
residual = decoder_inputs |
|
outputs = self.self_attn_prenorm(decoder_inputs) |
|
outputs, attn = self.self_attn(outputs, outputs, outputs, tgt_mask) |
|
outputs = self.self_attn_dropout(outputs) + residual |
|
|
|
residual = outputs |
|
outputs = self.self_attn_prenorm(outputs) |
|
outputs, attn = self.self_attn(outputs, encoder_outputs, encoder_outputs, src_mask) |
|
outputs = self.self_attn_dropout(outputs) + residual |
|
|
|
residual = outputs |
|
outputs = self.feed_forward_prenorm(outputs) |
|
outputs = self.feed_forward(outputs) |
|
outputs += residual |
|
|
|
return outputs, attn |