# A BERT model that # - has embedding projector when embedding_size != hiddne_size, like ELECTRA # - the attention use one linear projection to generate query, key, value at once to get faster # - is able to choose rotary position embedding from copy import deepcopy import math import torch from torch import nn import torch.nn.functional as F from transformers import PreTrainedModel from configuration_tsp import TSPConfig class TSPPreTrainedModel(PreTrainedModel): config_class = TSPConfig base_model_prefix = "backbone" def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # ==================================== # Pretraining Model # ==================================== class TSPModelForPreTraining(TSPPreTrainedModel): def __init__(self, config): super().__init__(config) self.backbone = TSPModel(config) if config.use_electra: mlm_config = deepcopy(config) mlm_config.hidden_size //= config.electra_generator_size_divisor mlm_config.intermediate_size //= config.electra_generator_size_divisor mlm_config.num_attention_heads //= config.electra_generator_size_divisor self.mlm_backbone = TSPModel(mlm_config) self.mlm_head = MaskedLMHead( mlm_config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings ) self.rtd_backbone = self.backbone self.rtd_backbone.embeddings = self.mlm_backbone.embeddings self.rtd_head = ReplacedTokenDiscriminationHead(config) else: self.mlm_backbone = self.backbone self.mlm_head = MaskedLMHead( config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings ) self.tsp_head = TextStructurePredictionHead(config) self.apply(self._init_weights) def forward(self, *args, **kwargs): raise NotImplementedError( "Refer to the implementation of text structrue prediction task for how to use the model." ) class MaskedLMHead(nn.Module): def __init__(self, config, word_embeddings=None): super().__init__() self.linear = nn.Linear(config.hidden_size, config.embedding_size) self.norm = nn.LayerNorm(config.embedding_size) self.predictor = nn.Linear(config.embedding_size, config.vocab_size) if word_embeddings is not None: self.predictor.weight = word_embeddings.weight def forward( self, x, # (B,L,D) is_selected=None, # (B,L), True at positions choosed by mlm probability ): if is_selected is not None: # Only mlm positions are counted in loss, so we can apply output layer computation only to # those positions to significantly reduce compuatational cost x = x[is_selected] # ( #selected, D) x = self.linear(x) # (B,L,E)/(#selected,E) x = F.gelu(x) # (B,L,E)/(#selected,E) x = self.norm(x) # (B,L,E)/(#selected,E) return self.predictor(x) # (B,L,V)/(#selected,V) class ReplacedTokenDiscriminationHead(nn.Module): def __init__(self, config): super().__init__() self.linear = nn.Linear(config.hidden_size, config.hidden_size) self.predictor = nn.Linear(config.hidden_size, 1) def forward(self, x): # (B,L,D) x = self.linear(x) # (B,L,D) x = F.gelu(x) x = self.predictor(x) # (B,L,1) return x.squeeze(-1) # (B,L) class TextStructurePredictionHead(nn.Module): def __init__(self, config): super().__init__() self.linear1 = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) self.norm = nn.LayerNorm(config.hidden_size * 2) self.linear2 = nn.Linear(config.hidden_size * 2, 6) def forward( self, x, # (...,2D) ): x = self.linear1(x) # (...,2D) x = F.gelu(x) # (...,2D) x = self.norm(x) # (...,2D) return self.linear2(x) # (...,C) # ==================================== # Finetuning Model # ==================================== class TSPModelForTokenClassification(TSPPreTrainedModel): def __init__(self, config, num_classes): super().__init__(config) self.backbone = TSPModel(config) self.head = TokenClassificationHead(config, num_classes) self.apply(self._init_weights) def forward( self, input_ids, # (B,L) attention_mask, # (B,L), 1 / 0 for tokens that are not attended/ attended token_type_ids, # (B,L), 0 / 1 corresponds to a segment A / B token ): hidden_states = self.backbone( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) # (B,L,D) return self.head(hidden_states) # (B,L,C) class TokenClassificationHead(nn.Module): def __init__(self, config, num_classes): super().__init__() self.dropout = nn.Dropout(config.dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_classes) def forward(self, x): # (B,L,D) x = self.dropout(x) # (B,L,D) x = self.classifier(x) # (B,L,C) return x # (B,L,C) class TSPModelForSequenceClassification(TSPPreTrainedModel): def __init__(self, config, num_classes): super().__init__(config) self.backbone = TSPModel(config) self.head = SequenceClassififcationHead(config, num_classes) self.apply(self._init_weights) def forward( self, input_ids, # (B,L) attention_mask, # (B,L), 1 / 0 for tokens that are not attended/ attended token_type_ids, # (B,L), 0 / 1 corresponds to a segment A / B token ): hidden_states = self.backbone( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) # (B,L,D) return self.head(hidden_states) # (B,L,C) class SequenceClassififcationHead(nn.Module): def __init__(self, config, num_classes): super().__init__() self.dropout = nn.Dropout(config.dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_classes) def forward( self, x, # (B,L,D) ): x = x[:, 0, :] # (B,D), CLS token is taken x = self.dropout(x) # (B,D) return self.classifier(x) # (B,C) class TSPModelForQuestionAnswering(TSPPreTrainedModel): def __init__(self, config, num_classes): super().__init__() self.backbone = TSPModel(config) self.head = SequenceClassififcationHead(config, num_classes) self.apply(self._init_weights) def forward( self, input_ids, # (B,L) attention_mask, # (B,L), 1 / 0 for tokens that are not attended/ attended token_type_ids, # (B,L), 0 / 1 corresponds to a segment A / B token ): hidden_states = self.backbone( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) # (B,L,D) return self.head(hidden_states) # (B,L), (B,L), (B)/None class SquadHead(nn.Module): def __init__( self, config, beam_size, predict_answerability, ): super().__init__() self.beam_size = beam_size self.predict_answerability = predict_answerability # answer start position predictor self.start_predictor = nn.Linear(config.hidden_size, 1) # answer end position predictor self.end_predictor = nn.Sequential( nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1), ) # answerability_predictor if predict_answerability: self.answerability_predictor = nn.Sequential( nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1), ) else: self.answerability_predictor = None def forward( self, hidden_states, # (B,L,D) token_type_ids, # (B,L), 0/1 for first sentence (question) or pad, 1 for second sentence (context) answer_start_position=None, # train/eval: (B)/None ): # Possible range for answer. Note CLS token is also possible to say it is unanswerable answer_mask = token_type_ids # (B,L) last_sep = answer_mask.cumsum(dim=1) == answer_mask.sum( dim=1, keepdim=True ) # (B,L), True if it is the last SEP or token after it answer_mask = answer_mask * ~last_sep answer_mask[:, 0] = 1 answer_mask = answer_mask.bool() # preidct start positions start_logits, start_top_hidden_states = self._calculate_start( hidden_states, answer_mask, answer_start_position ) # (B,L) , None/ (B,1,D)/ (B,k,D) # predict end positions end_logits = self._calculate_end_logits( hidden_states, start_top_hidden_states, answer_mask, ) # (B,L) / (B,k,L) # (optional) preidct answerability answerability_logits = None if self.answerability_predictor is not None: answerability_logits = self._calculate_answerability_logits( hidden_states, start_logits ) # (B) return start_logits, end_logits, answerability_logits def _calculate_start(self, hidden_states, answer_mask, start_positions): start_logits = self.start_predictor(hidden_states).squeeze(-1) # (B, L) start_logits = start_logits.masked_fill(~answer_mask, -float("inf")) # (B,L) start_top_indices, start_top_hidden_states = None, None if self.training: start_top_indices = start_positions # (B,) else: k = self.beam_size _, start_top_indices = start_logits.topk(k=k, dim=-1) # (B,k) start_top_hidden_states = torch.stack( [ hiddens.index_select(dim=0, index=index) for hiddens, index in zip(hidden_states, start_top_indices) ] ) # train: (B,1,D)/ eval: (B,k,D) return start_logits, start_top_hidden_states def _calculate_end_logits( self, hidden_states, start_top_hidden_states, answer_mask ): B, L, D = hidden_states.shape start_tophiddens = start_top_hidden_states.view(B, -1, 1, D).expand( -1, -1, L, -1 ) # train: (B,1,L,D) / eval: (B,k,L,D) end_hidden_states = torch.cat( [ start_tophiddens, hidden_states.view(B, 1, L, D).expand_as(start_tophiddens), ], dim=-1, ) # train: (B,1,L,2D) / eval: (B,k,L,2D) end_logits = self.end_predictor(end_hidden_states).squeeze(-1) # (B,1/k,L) end_logits = end_logits.masked_fill( ~answer_mask.view(B, 1, L), -float("inf") ) # train: (B,1,L) / eval: (B,k,L) end_logits = end_logits.squeeze(1) # train: (B,L) / eval: (B,k,L) return end_logits def _calculate_answerability_logits(self, hidden_states, start_logits): answerability_hidden_states = hidden_states[:, 0, :] # (B,D) start_probs = start_logits.softmax(dim=-1).unsqueeze(-1) # (B,L,1) start_featrues = (start_probs * hidden_states).sum(dim=1) # (B,D) answerability_hidden_states = torch.cat( [answerability_hidden_states, start_featrues], dim=-1 ) # (B,2D) answerability_logits = self.answerability_predictor( answerability_hidden_states ) # (B,1) return answerability_logits.squeeze(-1) # (B,) # ==================================== # Backbone (Transformer Encoder) # ==================================== class TSPModel(TSPPreTrainedModel): def __init__(self, config): super().__init__(config) self.embeddings = Embeddings(config) if config.embedding_size != config.hidden_size: self.embeddings_project = nn.Linear( config.embedding_size, config.hidden_size ) self.layers = nn.ModuleList( EncoderLayer(config) for _ in range(config.num_hidden_layers) ) self.apply(self._init_weights) def forward( self, input_ids, # (B,L) attention_mask, # (B,L), 1 / 0 for tokens that are not attended/ attended token_type_ids, # (B,L), 0 / 1 corresponds to a segment A / B token ): x = self.embeddings( input_ids=input_ids, token_type_ids=token_type_ids ) # (B,L,E) if hasattr(self, "embeddings_project"): x = self.embeddings_project(x) # (B,L,D) extended_attention_mask = self.get_extended_attention_mask( attention_mask=attention_mask, input_shape=input_ids.shape, device=input_ids.device, ) # (B,1,1,L) for layer_idx, layer in enumerate(self.layers): x = layer(x, attention_mask=extended_attention_mask) # (B,L,D) return x # (B,L,D) class Embeddings(nn.Module): def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id ) if config.position_embedding_type == "absolute": self.position_embeddings = nn.Embedding( config.max_sequence_length, config.embedding_size ) self.token_type_embeddings = nn.Embedding(2, config.embedding_size) self.norm = nn.LayerNorm(config.embedding_size) self.dropout = nn.Dropout(config.dropout_prob) def forward( self, input_ids, # (B,L) token_type_ids, # (B,L), 0 / 1 corresponds to a segment A / B token ): B, L = input_ids.shape embeddings = self.word_embeddings(input_ids) # (B,L,E) embeddings += self.token_type_embeddings(token_type_ids) if hasattr(self, "position_embeddings"): embeddings += self.position_embeddings.weight[None, :L, :] embeddings = self.norm(embeddings) # (B,L,E) embeddings = self.dropout(embeddings) # (B,L,E) return embeddings # (B,L,E) class EncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn_block = BlockWrapper(config, MultiHeadSelfAttention) self.transition_block = BlockWrapper(config, FeedForwardNetwork) def forward( self, x, # (B,L,D) attention_mask, # (B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended ): x = self.self_attn_block(x, attention_mask=attention_mask) x = self.transition_block(x) return x # (B,L,D) class BlockWrapper(nn.Module): def __init__(self, config, sublayer_cls): super().__init__() self.sublayer = sublayer_cls(config) self.dropout = nn.Dropout(config.dropout_prob) self.norm = nn.LayerNorm(config.hidden_size) def forward(self, x, **kwargs): original_x = x x = self.sublayer(x, **kwargs) x = self.dropout(x) x = original_x + x x = self.norm(x) return x class MultiHeadSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.mix_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.attention = Attention(config) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size) self.H = config.num_attention_heads self.d = config.hidden_size // self.H if config.position_embedding_type == "rotary": self.rotray_position_embeds = RotaryEmbedding(self.d) def forward( self, x, # (B,L,D) attention_mask, # (B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended ): B, L, D, H, d = *x.shape, self.H, self.d query, key, value = ( self.mix_proj(x).view(B, L, H, 3 * d).transpose(1, 2).split(d, dim=-1) ) # (B,H,L,d),(B,H,L,d),(B,H,L,d) if hasattr(self, "rotray_position_embeds"): query, key = self.rotray_position_embeds(query, key) output = self.attention(query, key, value, attention_mask) # (B,H,L,d) output = self.o_proj(output.transpose(1, 2).reshape(B, L, D)) # (B,L,D) return output # (B,L,D) class Attention(nn.Module): def __init__(self, config): super().__init__() self.dropout = nn.Dropout(config.dropout_prob) def forward( self, query, # (B,H,L,d) key, # (B,H,L,d) value, # (B,H,L,d) attention_mask, # (B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended ): B, H, L, d = key.shape attention_score = query.matmul(key.transpose(-2, -1)) # (B,H,L,L) attention_score = attention_score / math.sqrt(d) # (B,H,L,L) attention_score += attention_mask # (B,H,L,L) attention_probs = attention_score.softmax(dim=-1) # (B,H,L,L) attention_probs = self.dropout(attention_probs) # (B,H,L,L) output = attention_probs.matmul(value) # (B,H,L,d) return output # (B,H,L,d) class FeedForwardNetwork(nn.Module): def __init__(self, config): super().__init__() self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size) self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, x): # (B,L,D) x = self.linear1(x) # (B L,intermediate_size) x = F.gelu(x) # (B,L,intermediate_size) x = self.linear2(x) # (B,L,D) return x # (B,L,D) class RotaryEmbedding(nn.Module): seq_len_cached = 0 cos_cached = None sin_cached = None def __init__(self, dim): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def _forward(self, x): # (B,H,L,d) # Get rotary embeddings on the fly ## create seq_len = x.shape[2] if seq_len > RotaryEmbedding.seq_len_cached: RotaryEmbedding.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) freqs = t.view(-1, 1) @ self.inv_freq.view(1, -1) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # (L,d) RotaryEmbedding.cos_cached = emb.cos()[None, None, :, :] RotaryEmbedding.sin_cached = emb.sin()[None, None, :, :] ## take if seq_len == RotaryEmbedding.seq_len_cached: cos, sin = RotaryEmbedding.cos_cached, RotaryEmbedding.sin_cached else: cos, sin = ( RotaryEmbedding.cos_cached[:, :, :seq_len, :], # (1,1,L,d) RotaryEmbedding.sin_cached[:, :, :seq_len, :], # (1,1,L,d) ) # Apply rotary embeddings sections = [x.shape[-1] // 2, x.shape[-1] - x.shape[-1] // 2] x1, x2 = x.split(sections, dim=-1) half_rotated_x = torch.cat((-x2, x1), dim=-1) return (x * cos) + (half_rotated_x * sin) def forward( self, query, key, # (B,H,L,d) # (B,H,L,d) ): return self._forward(query), self._forward(key)