Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from torch.nn import TransformerEncoder, TransformerDecoder, \ | |
TransformerEncoderLayer, TransformerDecoderLayer | |
torch.manual_seed(0) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, emb_size: int, dropout, maxlen: int = 5000): | |
super(PositionalEncoding, self).__init__() | |
den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) | |
pos = torch.arange(0, maxlen).reshape(maxlen, 1) | |
pos_embedding = torch.zeros((maxlen, emb_size)) | |
pos_embedding[:, 0::2] = torch.sin(pos * den) | |
pos_embedding[:, 1::2] = torch.cos(pos * den) | |
pos_embedding = pos_embedding.unsqueeze(-2) | |
self.dropout = nn.Dropout(dropout) | |
self.register_buffer('pos_embedding', pos_embedding) | |
def forward(self, token_embedding: Tensor): | |
return self.dropout(token_embedding + | |
self.pos_embedding[:token_embedding.size(0),:]) | |
class TokenEmbedding(nn.Module): | |
def __init__(self, vocab_size: int, emb_size): | |
super(TokenEmbedding, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, emb_size) | |
self.emb_size = emb_size | |
def forward(self, tokens: Tensor): | |
return self.embedding(tokens.long()) * math.sqrt(self.emb_size) | |
class BanglaTransformer(nn.Module): | |
def __init__(self, num_encoder_layers: int, num_decoder_layers: int, | |
emb_size: int, src_vocab_size: int, tgt_vocab_size: int, | |
dim_feedforward:int = 512, dropout:float = 0.1, nhead:int=8): | |
super(BanglaTransformer, self).__init__() | |
encoder_layer = TransformerEncoderLayer( | |
d_model=emb_size, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward | |
) | |
self.transformer_encoder = TransformerEncoder( | |
encoder_layer, | |
num_layers=num_encoder_layers | |
) | |
decoder_layer = TransformerDecoderLayer( | |
d_model=emb_size, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward | |
) | |
self.transformer_decoder = TransformerDecoder( | |
decoder_layer, | |
num_layers=num_decoder_layers | |
) | |
self.generator = nn.Linear(emb_size, tgt_vocab_size) | |
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) | |
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) | |
self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout) | |
def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, | |
tgt_mask: Tensor, src_padding_mask: Tensor, | |
tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor): | |
src_emb = self.positional_encoding(self.src_tok_emb(src)) | |
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) | |
memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask) | |
outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None, | |
tgt_padding_mask, memory_key_padding_mask) | |
return self.generator(outs) | |
def encode(self, src: Tensor, src_mask: Tensor): | |
return self.transformer_encoder(self.positional_encoding( | |
self.src_tok_emb(src)), src_mask) | |
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): | |
return self.transformer_decoder(self.positional_encoding( | |
self.tgt_tok_emb(tgt)), memory, | |
tgt_mask) | |