sahilnishad's picture
Create bert_model.py
3558c2e
import torch
import torch.nn as nn
class BERTEmbedding(nn.Module):
def __init__(self, vocab_size, n_segments, max_len, embed_dim, dropout):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, embed_dim)
self.segment_embed = nn.Embedding(n_segments, embed_dim)
self.pos_embed = nn.Embedding(max_len, embed_dim)
self.drop = nn.Dropout(dropout)
self.pos_inp = torch.tensor([i for i in range(max_len)],)
def forward(self, seq, seg):
current_max_len = seq.size(1) # Get current sequence length
pos_inp = torch.arange(0, current_max_len, device=seq.device).unsqueeze(0) # Dynamically create position tensor based on input size
embed_val = self.token_embed(seq) + self.segment_embed(seg) + self.pos_embed(pos_inp)
embed_val = self.drop(embed_val)
return embed_val
class BERT(nn.Module):
def __init__(self, vocab_size, n_segments, max_len, embed_dim, n_layers, attn_heads, dropout):
super().__init__()
self.embedding = BERTEmbedding(vocab_size, n_segments, max_len, embed_dim, dropout)
self.encoder_layer = nn.TransformerEncoderLayer(embed_dim, attn_heads, embed_dim*4)
self.encoder_block = nn.TransformerEncoder(self.encoder_layer, n_layers)
def forward(self, seq, seg):
out = self.embedding(seq, seg)
out = self.encoder_block(out)
return out