homemade_lo_vi / transformer.py
moiduy04's picture
Upload 18 files
bc1ada8
from typing import Tuple
import torch.nn as nn
from torch import Tensor
from modules.transformer_embedding import TransformerEmbedding
from modules.positional_encoding import PositionalEncoding
from model.encoder import Encoder
from model.decoder import Decoder
from layers.projection_layer import ProjectionLayer
class Transformer(nn.Module):
"""
Transformer.
Args:
- src_vocab_size (int): source vocabulary size
- tgt_vocab_size (int): target vocabulary size
- src_max_seq_len (int): source max sequence length
- tgt_max_seq_len (int): target max sequence length
- d_model (int): dimension of model
- num_heads (int): number of heads
- d_ff (int): dimension of hidden feed forward layer
- dropout_p (float): probability of dropout
- num_encoder_layers (int): number of encoder layers
- num_decoder_layers (int): number of decoder layers
"""
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
src_max_seq_len: int,
tgt_max_seq_len: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
dropout_p: float = 0.1,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
) -> None:
super(Transformer, self).__init__()
# Embedding layers
self.src_embedding = TransformerEmbedding(
d_model=d_model,
num_embeddings=src_vocab_size
)
self.tgt_embedding = TransformerEmbedding(
d_model=d_model,
num_embeddings=tgt_vocab_size
)
# Positional Encoding layers
self.src_positional_encoding = PositionalEncoding(
d_model=d_model,
dropout_p=dropout_p,
max_length=src_max_seq_len
)
self.tgt_positional_encoding = PositionalEncoding(
d_model=d_model,
dropout_p=dropout_p,
max_length=tgt_max_seq_len
)
# Encoder
self.encoder = Encoder(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
num_layers=num_encoder_layers
)
# Decoder
self.decoder = Decoder(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
num_layers=num_decoder_layers
)
# projecting decoder's output to the target language.
self.projection_layer = ProjectionLayer(
d_model=d_model,
vocab_size=tgt_vocab_size
)
def encode(
self,
src: Tensor,
src_mask: Tensor
) -> Tensor:
"""
Get encoder outputs.
"""
src = self.src_embedding(src)
src = self.src_positional_encoding(src)
return self.encoder(src, src_mask)
def decode(
self,
encoder_output: Tensor,
src_mask: Tensor,
tgt: Tensor,
tgt_mask: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Get decoder outputs for a set of target inputs.
"""
tgt = self.tgt_embedding(tgt)
tgt = self.tgt_positional_encoding(tgt)
return self.decoder(
x=tgt,
encoder_output=encoder_output,
src_mask=src_mask,
tgt_mask=tgt_mask
)
def project(self, decoder_output: Tensor) -> Tensor:
"""
Project decoder outputs to target vocabulary.
"""
return self.projection_layer(decoder_output)
def forward(
self,
src: Tensor,
src_mask: Tensor,
tgt: Tensor,
tgt_mask: Tensor
) -> Tuple[Tensor, Tensor]:
# src_mask = self.make_src_mask(src)
# tgt_mask = self.make_tgt_mask(tgt)
encoder_output = self.encode(src, src_mask)
decoder_output, attn = self.decode(
encoder_output, src_mask, tgt, tgt_mask
)
output = self.project(decoder_output)
return output, attn
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def get_model(config, src_vocab_size: int, tgt_vocab_size: int) -> Transformer:
"""
returns a `Transformer` model for a given config.
"""
return Transformer(
src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
src_max_seq_len=config['dataset']['src_max_seq_len'],
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'],
d_model=config['model']['d_model'],
num_heads=config['model']['num_heads'],
d_ff=config['model']['d_ff'],
dropout_p=config['model']['dropout_p'],
num_encoder_layers=config['model']['num_encoder_layers'],
num_decoder_layers=config['model']['num_decoder_layers'],
)