import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin # Seq2SeqLMOutput for model forward, BaseModelOutputWithPastAndCrossAttentions for encoder forward from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions from transformers import AutoConfig, AutoModel, AutoModelForSeq2SeqLM from torch.nn.utils.rnn import pad_sequence _MAX_CONTEXT_SIZE = 10_000 # ========================================================== # Config # ========================================================== class OriginalTransformerConfig(PretrainedConfig): model_type = "original_transformer" def __init__( self, num_enc_layers = 6, num_dec_layers = 6, embed_dim = 512, num_heads = 8, enc_vocab_size = 37000, dec_vocab_size = 37000, d_ff = 2048, dropout=0, pad_token_id = 0, bos_token_id = 1, eos_token_id = 2, is_encoder_decoder=True, **kwargs ): super().__init__(**kwargs) self.num_enc_layers = num_enc_layers self.num_dec_layers = num_dec_layers self.embed_dim = embed_dim self.num_heads = num_heads self.enc_vocab_size = enc_vocab_size self.dec_vocab_size = dec_vocab_size self.d_ff = d_ff self.dropout = dropout self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.is_encoder_decoder = is_encoder_decoder # for using AutoModel.from_pretrained self.auto_map = { "AutoModel": "modeling_original_transformer.OrginalTransformer", "AutoModelForSeq2SeqLM": "modeling_original_transformer.OrginalTransformer" } # ========================================================== # Model # ========================================================== # combines both embedding and pos_encoding class Embed(nn.Module): def __init__(self, vocab_size, embed_dim, dropout=0): super().__init__() self.emb_factor = torch.sqrt(torch.tensor(embed_dim, dtype=torch.float32)) self.embed = nn.Embedding(vocab_size, embed_dim) # vocab x C self.dropout = nn.Dropout(dropout) pos_embed = torch.zeros(_MAX_CONTEXT_SIZE, embed_dim) # T x C position = torch.arange(0, _MAX_CONTEXT_SIZE).unsqueeze(1) # FROM 1 x T to T x 1 # P.E(pos,2i) = sin(pos/10000^(2i/dim)) # div_term = 10000 ^([0,1,2,...,C/2-1] * 2/C) <-- div_term = torch.pow(10_000.0, torch.arange(0, embed_dim//2) * 2/embed_dim) # 1 x C/2 (Embed_dim/2) pos_embed[:, 0::2] = torch.sin(position / div_term) # T x C/2 ((T x 1) / (1 x C/2) = T x C/2 broadcasted) pos_embed[:, 1::2] = torch.cos(position / div_term) # T x C/2 self.register_buffer('pos_embed', pos_embed, persistent=False) def forward(self,x): # x = B x T (NOT 1-hot) embed_x = self.embed(x) # B T C embed_x = embed_x * self.emb_factor # presumably to not be overpowered by the positional encoding # ================================ # For variable length # =============================== seq_len = x.shape[-1] # length of T truc_pos_embed = self.pos_embed[:seq_len,:] embed_x = self.dropout(embed_x + truc_pos_embed) return embed_x class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads, causal_mask = False, bias=True): super().__init__() self.dk = embed_dim // num_heads self.causal_mask = causal_mask self.combined_projection_q = nn.Linear(embed_dim,embed_dim, bias=bias) self.combined_projection_k = nn.Linear(embed_dim,embed_dim, bias=bias) self.combined_projection_v = nn.Linear(embed_dim,embed_dim, bias=bias) self.num_heads = num_heads self.multi_linear = nn.Linear(embed_dim,embed_dim, bias=bias) def attention(self,q,k,v, padding_mask = None): # input shape is B x h x T x dk output = (q @ k.transpose(-2,-1)) / torch.sqrt(torch.tensor(self.dk)) # QKt/(sqrt(dk)) #apply mask in decoder layer if self.causal_mask == True: seq_len = q.shape[-2] mask = torch.triu(torch.full((seq_len,seq_len), fill_value=-torch.inf,device=q.device), diagonal=1) #mask = torch.triu(torch.full((seq_len,seq_len), fill_value=-torch.inf), diagonal=1) #mask = mask.to(q.device) output = output + mask # apply padding mask in encoder self-attention and decoder cross-attention if padding_mask is not None: padding_mask = torch.tensor(padding_mask).unsqueeze(1).unsqueeze(1) # B x 1 x 1 x T (broadcasting) padding_mask = torch.where(padding_mask == 0, -torch.inf, padding_mask) # -inf turns to 0 output = output + padding_mask output = torch.softmax(output, -1) output = output @ v return output def forward(self,x_q,x_k,x_v, padding_mask = None): # combined projection, TxC @ CxC # Equivalent to doing Txhead @ CxC over all heads p_q = self.combined_projection_q(x_q) p_k = self.combined_projection_k(x_k) p_v = self.combined_projection_v(x_v) # For each of QKV. [B=Batch, T=Time, C=Channels, h=Heads, dk= head dim] # ========================|====================== # Split | Combine # ========================|====================== # | B T C /\ # | | | # | B T h dk | # | | | # \/ B h T dk | # | # # =============================================== B = p_q.shape[0] def split_heads(p): return p.view(B,-1,self.num_heads,self.dk).transpose(1,2) p_q = split_heads(p_q) p_k = split_heads(p_k) p_v = split_heads(p_v) output = self.attention(p_q,p_k,p_v, padding_mask=padding_mask) def combine_heads(p): return p.transpose(1,2).contiguous().view(B,-1,self.dk*self.num_heads) output = combine_heads(output) output = self.multi_linear(output) return output # This layer is slightly different from standard linear class PointwiseFeedForward(nn.Module): def __init__(self, embed_dim, d_ff): super(PointwiseFeedForward, self).__init__() self.linear1 = nn.Linear(embed_dim, d_ff, bias=True) self.linear2 = nn.Linear(d_ff, embed_dim, bias=True) def forward(self, x): return self.linear2(nn.functional.relu(self.linear1(x))) class EncoderLayer(nn.Module): def __init__(self, embed_dim, num_heads, d_ff,dropout=0): super().__init__() # self attention self.m_att = MultiHeadAttention(embed_dim, num_heads) self.att_norm = nn.LayerNorm(embed_dim) self.dropout1 = nn.Dropout(dropout) # pointwise feedforward module self.pwlinear = PointwiseFeedForward(embed_dim, d_ff) self.lin_norm = nn.LayerNorm(embed_dim) self.dropout2 = nn.Dropout(dropout) def forward(self, x, padding_mask = None): output = self.att_norm(x + self.dropout1(self.m_att(x,x,x, padding_mask=padding_mask))) output = self.lin_norm(output + self.dropout2(self.pwlinear(output))) return output class EncoderStack(nn.Module): def __init__(self, embed_dim, num_heads, num_layers, d_ff, dropout=0, bos_token_id=1, eos_token_id=2, pad_token_id=0): super().__init__() self.layers = nn.ModuleList([EncoderLayer(embed_dim, num_heads, d_ff, dropout) for i in range(num_layers)]) self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id def add_bos_eos(self, input_ids): modified_input_ids = [] for seq in input_ids: # iterate through each batch element # Prepend BOS token if needed if seq[0] != self.bos_token_id: seq = torch.cat([torch.tensor([self.bos_token_id], device=seq.device), seq]) # Append EOS token if needed if seq[-1] != self.eos_token_id: seq = torch.cat([seq, torch.tensor([self.eos_token_id], device=seq.device)]) modified_input_ids.append(seq) # Pad sequences to the same length padded_input_ids = pad_sequence(modified_input_ids, batch_first=True, padding_value=self.pad_token_id) return padded_input_ids # For huggingface compatibility, input_embeds are calculated inside encoder. # So encoder must handle both input_ids and input_embeds # Will use parent's embed layer. Can't transfer emb layer to encoder without breaking saved checkpoints. def forward(self, input_embeds=None, input_ids=None, padding_mask = None, **kwargs): input_ids = self.add_bos_eos(input_ids) # add bos and eos tokens if absent if input_embeds is None: input_embeds = self.emb(input_ids) i = 0 # for debugging for layer in self.layers: input_embeds = layer(input_embeds, padding_mask = padding_mask) return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=input_embeds, hidden_states=None, attentions=None) class DecoderLayer(nn.Module): def __init__(self, embed_dim, num_heads, d_ff,dropout=0): super().__init__() # self causal mask attention module self.m_att = MultiHeadAttention(embed_dim, num_heads, causal_mask=True) self.att_norm = nn.LayerNorm(embed_dim) self.dropout1 = nn.Dropout(dropout) # additional cross attention module self.cross_att = MultiHeadAttention(embed_dim, num_heads, causal_mask=False) self.cross_att_norm = nn.LayerNorm(embed_dim) self.dropout2 = nn.Dropout(dropout) # pointwise feedforward module with its layer norm self.pwlinear = PointwiseFeedForward(embed_dim, d_ff) self.lin_norm = nn.LayerNorm(embed_dim) self.dropout3 = nn.Dropout(dropout) def forward(self, x, enc_out, enc_padding_mask = None): output = self.att_norm(x + self.dropout1(self.m_att(x,x,x))) # self attention output = self.cross_att_norm(output + self.dropout2(self.cross_att(output, enc_out,enc_out, padding_mask=enc_padding_mask))) # cross attention output = self.lin_norm(output + self.dropout3(self.pwlinear(output))) # pointwise feedforward return output class DecoderStack(nn.Module): def __init__(self, embed_dim, num_heads, num_layers, d_ff,dropout=0): super().__init__() self.layers = nn.ModuleList([DecoderLayer(embed_dim, num_heads, d_ff,dropout) for i in range(num_layers)]) def forward(self, x, enc_out, enc_padding_mask = None): for layer in self.layers: x = layer(x, enc_out, enc_padding_mask) return x class OrginalTransformer(PreTrainedModel, GenerationMixin): config_class = OriginalTransformerConfig def __init__(self, config): super().__init__(config) self.emb = Embed(config.enc_vocab_size, config.embed_dim) # one embedding for both encoder and decoder self.enc = EncoderStack(config.embed_dim, config.num_heads, config.num_enc_layers, config.d_ff, config.dropout, config.bos_token_id, config.eos_token_id, config.pad_token_id) self.dec = DecoderStack(config.embed_dim, config.num_heads, config.num_dec_layers, config.d_ff, config.dropout) self.last_lin = nn.Linear(config.embed_dim, config.dec_vocab_size, bias=False) # bias false we're tying its weights with the embedding layer self.last_lin.weight = self.emb.embed.weight # tying weights # for accessing emb from inside encoder and decoder (for HF) self.enc.emb = self.emb self.dec.emb = self.emb # huggingface compabile forward def forward(self, input_ids= None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, use_cache=None, output_hidden_states=None, token_type_ids=None, inputs_embeds=None, labels=None, **kwargs): # Encoder # Dont actually need this. Encoder automatically called by .generate() method. if encoder_outputs is None: encoder_outputs = self.enc(self.emb(input_ids), None) # Encoder # Decoder # generate() calls the model with decoder_input_ids dec_out = self.dec(self.emb(decoder_input_ids), encoder_outputs.last_hidden_state, None) logits = self.last_lin(dec_out) output = Seq2SeqLMOutput(logits=logits, encoder_last_hidden_state=encoder_outputs) return output def get_encoder(self): return self.enc def get_decoder(self): return self.dec