homemade_lo_vi / model /decoder.py
moiduy04's picture
Upload decoder.py
befbc32
raw
history blame
1.18 kB
from typing import Tuple
import torch.nn as nn
from torch import Tensor
from layers.decoder_layer import DecoderLayer
class Decoder(nn.Module):
"""
A transformer Decoder (no embeddings or positional embeddings)
Args:
-
Outputs:
- (batch, seq_len, d_model): decoder output
- (batch, seq_len, seq_len): decoder attention
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout_p: int,
num_layers: int,
) -> None:
super(Decoder, self).__init__()
self.layers = nn.ModuleList(
[
DecoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
)
for _ in range(num_layers)
]
)
def forward(
self,
x: Tensor,
encoder_output: Tensor,
src_mask: Tensor,
tgt_mask: Tensor
) -> Tuple[Tensor, Tensor]:
for layer in self.layers:
x, attn = layer(x, encoder_output, src_mask, tgt_mask)
return x, attn