test / modeling_tsp.py
RichardWang's picture
add model
6549220
raw
history blame
20.1 kB
# A BERT model that
# - has embedding projector when embedding_size != hiddne_size, like ELECTRA
# - the attention use one linear projection to generate query, key, value at once to get faster
# - is able to choose rotary position embedding
from copy import deepcopy
import math
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from configuration_tsp import TSPConfig
class TSPPreTrainedModel(PreTrainedModel):
config_class = TSPConfig
base_model_prefix = "backbone"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# ====================================
# Pretraining Model
# ====================================
class TSPModelForPreTraining(TSPPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.backbone = TSPModel(config)
if config.use_electra:
mlm_config = deepcopy(config)
mlm_config.hidden_size //= config.electra_generator_size_divisor
mlm_config.intermediate_size //= config.electra_generator_size_divisor
mlm_config.num_attention_heads //= config.electra_generator_size_divisor
self.mlm_backbone = TSPModel(mlm_config)
self.mlm_head = MaskedLMHead(
mlm_config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings
)
self.rtd_backbone = self.backbone
self.rtd_backbone.embeddings = self.mlm_backbone.embeddings
self.rtd_head = ReplacedTokenDiscriminationHead(config)
else:
self.mlm_backbone = self.backbone
self.mlm_head = MaskedLMHead(
config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings
)
self.tsp_head = TextStructurePredictionHead(config)
self.apply(self._init_weights)
def forward(self, *args, **kwargs):
raise NotImplementedError(
"Refer to the implementation of text structrue prediction task for how to use the model."
)
class MaskedLMHead(nn.Module):
def __init__(self, config, word_embeddings=None):
super().__init__()
self.linear = nn.Linear(config.hidden_size, config.embedding_size)
self.norm = nn.LayerNorm(config.embedding_size)
self.predictor = nn.Linear(config.embedding_size, config.vocab_size)
if word_embeddings is not None:
self.predictor.weight = word_embeddings.weight
def forward(
self,
x, # (B,L,D)
is_selected=None, # <bool>(B,L), True at positions choosed by mlm probability
):
if is_selected is not None:
# Only mlm positions are counted in loss, so we can apply output layer computation only to
# those positions to significantly reduce compuatational cost
x = x[is_selected] # ( #selected, D)
x = self.linear(x) # (B,L,E)/(#selected,E)
x = F.gelu(x) # (B,L,E)/(#selected,E)
x = self.norm(x) # (B,L,E)/(#selected,E)
return self.predictor(x) # (B,L,V)/(#selected,V)
class ReplacedTokenDiscriminationHead(nn.Module):
def __init__(self, config):
super().__init__()
self.linear = nn.Linear(config.hidden_size, config.hidden_size)
self.predictor = nn.Linear(config.hidden_size, 1)
def forward(self, x): # (B,L,D)
x = self.linear(x) # (B,L,D)
x = F.gelu(x)
x = self.predictor(x) # (B,L,1)
return x.squeeze(-1) # (B,L)
class TextStructurePredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.linear1 = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
self.norm = nn.LayerNorm(config.hidden_size * 2)
self.linear2 = nn.Linear(config.hidden_size * 2, 6)
def forward(
self, x, # (...,2D)
):
x = self.linear1(x) # (...,2D)
x = F.gelu(x) # (...,2D)
x = self.norm(x) # (...,2D)
return self.linear2(x) # (...,C)
# ====================================
# Finetuning Model
# ====================================
class TSPModelForTokenClassification(TSPPreTrainedModel):
def __init__(self, config, num_classes):
super().__init__(config)
self.backbone = TSPModel(config)
self.head = TokenClassificationHead(config, num_classes)
self.apply(self._init_weights)
def forward(
self,
input_ids, # <int>(B,L)
attention_mask, # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
token_type_ids, # <int>(B,L), 0 / 1 corresponds to a segment A / B token
):
hidden_states = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
) # (B,L,D)
return self.head(hidden_states) # (B,L,C)
class TokenClassificationHead(nn.Module):
def __init__(self, config, num_classes):
super().__init__()
self.dropout = nn.Dropout(config.dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_classes)
def forward(self, x): # (B,L,D)
x = self.dropout(x) # (B,L,D)
x = self.classifier(x) # (B,L,C)
return x # (B,L,C)
class TSPModelForSequenceClassification(TSPPreTrainedModel):
def __init__(self, config, num_classes):
super().__init__(config)
self.backbone = TSPModel(config)
self.head = SequenceClassififcationHead(config, num_classes)
self.apply(self._init_weights)
def forward(
self,
input_ids, # <int>(B,L)
attention_mask, # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
token_type_ids, # <int>(B,L), 0 / 1 corresponds to a segment A / B token
):
hidden_states = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
) # (B,L,D)
return self.head(hidden_states) # (B,L,C)
class SequenceClassififcationHead(nn.Module):
def __init__(self, config, num_classes):
super().__init__()
self.dropout = nn.Dropout(config.dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_classes)
def forward(
self, x, # (B,L,D)
):
x = x[:, 0, :] # (B,D), CLS token is taken
x = self.dropout(x) # (B,D)
return self.classifier(x) # (B,C)
class TSPModelForQuestionAnswering(TSPPreTrainedModel):
def __init__(self, config, num_classes):
super().__init__()
self.backbone = TSPModel(config)
self.head = SequenceClassififcationHead(config, num_classes)
self.apply(self._init_weights)
def forward(
self,
input_ids, # <int>(B,L)
attention_mask, # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
token_type_ids, # <int>(B,L), 0 / 1 corresponds to a segment A / B token
):
hidden_states = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
) # (B,L,D)
return self.head(hidden_states) # (B,L), (B,L), (B)/None
class SquadHead(nn.Module):
def __init__(
self, config, beam_size, predict_answerability,
):
super().__init__()
self.beam_size = beam_size
self.predict_answerability = predict_answerability
# answer start position predictor
self.start_predictor = nn.Linear(config.hidden_size, 1)
# answer end position predictor
self.end_predictor = nn.Sequential(
nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1),
)
# answerability_predictor
if predict_answerability:
self.answerability_predictor = nn.Sequential(
nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1),
)
else:
self.answerability_predictor = None
def forward(
self,
hidden_states, # (B,L,D)
token_type_ids, # <int>(B,L), 0/1 for first sentence (question) or pad, 1 for second sentence (context)
answer_start_position=None, # train/eval: <int>(B)/None
):
# Possible range for answer. Note CLS token is also possible to say it is unanswerable
answer_mask = token_type_ids # (B,L)
last_sep = answer_mask.cumsum(dim=1) == answer_mask.sum(
dim=1, keepdim=True
) # (B,L), True if it is the last SEP or token after it
answer_mask = answer_mask * ~last_sep
answer_mask[:, 0] = 1
answer_mask = answer_mask.bool()
# preidct start positions
start_logits, start_top_hidden_states = self._calculate_start(
hidden_states, answer_mask, answer_start_position
) # (B,L) , None/ (B,1,D)/ (B,k,D)
# predict end positions
end_logits = self._calculate_end_logits(
hidden_states, start_top_hidden_states, answer_mask,
) # (B,L) / (B,k,L)
# (optional) preidct answerability
answerability_logits = None
if self.answerability_predictor is not None:
answerability_logits = self._calculate_answerability_logits(
hidden_states, start_logits
) # (B)
return start_logits, end_logits, answerability_logits
def _calculate_start(self, hidden_states, answer_mask, start_positions):
start_logits = self.start_predictor(hidden_states).squeeze(-1) # (B, L)
start_logits = start_logits.masked_fill(~answer_mask, -float("inf")) # (B,L)
start_top_indices, start_top_hidden_states = None, None
if self.training:
start_top_indices = start_positions # (B,)
else:
k = self.beam_size
_, start_top_indices = start_logits.topk(k=k, dim=-1) # (B,k)
start_top_hidden_states = torch.stack(
[
hiddens.index_select(dim=0, index=index)
for hiddens, index in zip(hidden_states, start_top_indices)
]
) # train: (B,1,D)/ eval: (B,k,D)
return start_logits, start_top_hidden_states
def _calculate_end_logits(
self, hidden_states, start_top_hidden_states, answer_mask
):
B, L, D = hidden_states.shape
start_tophiddens = start_top_hidden_states.view(B, -1, 1, D).expand(
-1, -1, L, -1
) # train: (B,1,L,D) / eval: (B,k,L,D)
end_hidden_states = torch.cat(
[
start_tophiddens,
hidden_states.view(B, 1, L, D).expand_as(start_tophiddens),
],
dim=-1,
) # train: (B,1,L,2D) / eval: (B,k,L,2D)
end_logits = self.end_predictor(end_hidden_states).squeeze(-1) # (B,1/k,L)
end_logits = end_logits.masked_fill(
~answer_mask.view(B, 1, L), -float("inf")
) # train: (B,1,L) / eval: (B,k,L)
end_logits = end_logits.squeeze(1) # train: (B,L) / eval: (B,k,L)
return end_logits
def _calculate_answerability_logits(self, hidden_states, start_logits):
answerability_hidden_states = hidden_states[:, 0, :] # (B,D)
start_probs = start_logits.softmax(dim=-1).unsqueeze(-1) # (B,L,1)
start_featrues = (start_probs * hidden_states).sum(dim=1) # (B,D)
answerability_hidden_states = torch.cat(
[answerability_hidden_states, start_featrues], dim=-1
) # (B,2D)
answerability_logits = self.answerability_predictor(
answerability_hidden_states
) # (B,1)
return answerability_logits.squeeze(-1) # (B,)
# ====================================
# Backbone (Transformer Encoder)
# ====================================
class TSPModel(TSPPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = Embeddings(config)
if config.embedding_size != config.hidden_size:
self.embeddings_project = nn.Linear(
config.embedding_size, config.hidden_size
)
self.layers = nn.ModuleList(
EncoderLayer(config) for _ in range(config.num_hidden_layers)
)
self.apply(self._init_weights)
def forward(
self,
input_ids, # <int>(B,L)
attention_mask, # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
token_type_ids, # <int>(B,L), 0 / 1 corresponds to a segment A / B token
):
x = self.embeddings(
input_ids=input_ids, token_type_ids=token_type_ids
) # (B,L,E)
if hasattr(self, "embeddings_project"):
x = self.embeddings_project(x) # (B,L,D)
extended_attention_mask = self.get_extended_attention_mask(
attention_mask=attention_mask,
input_shape=input_ids.shape,
device=input_ids.device,
) # (B,1,1,L)
for layer_idx, layer in enumerate(self.layers):
x = layer(x, attention_mask=extended_attention_mask) # (B,L,D)
return x # (B,L,D)
class Embeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
)
if config.position_embedding_type == "absolute":
self.position_embeddings = nn.Embedding(
config.max_sequence_length, config.embedding_size
)
self.token_type_embeddings = nn.Embedding(2, config.embedding_size)
self.norm = nn.LayerNorm(config.embedding_size)
self.dropout = nn.Dropout(config.dropout_prob)
def forward(
self,
input_ids, # <int>(B,L)
token_type_ids, # <int>(B,L), 0 / 1 corresponds to a segment A / B token
):
B, L = input_ids.shape
embeddings = self.word_embeddings(input_ids) # (B,L,E)
embeddings += self.token_type_embeddings(token_type_ids)
if hasattr(self, "position_embeddings"):
embeddings += self.position_embeddings.weight[None, :L, :]
embeddings = self.norm(embeddings) # (B,L,E)
embeddings = self.dropout(embeddings) # (B,L,E)
return embeddings # (B,L,E)
class EncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn_block = BlockWrapper(config, MultiHeadSelfAttention)
self.transition_block = BlockWrapper(config, FeedForwardNetwork)
def forward(
self,
x, # (B,L,D)
attention_mask, # <int>(B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended
):
x = self.self_attn_block(x, attention_mask=attention_mask)
x = self.transition_block(x)
return x # (B,L,D)
class BlockWrapper(nn.Module):
def __init__(self, config, sublayer_cls):
super().__init__()
self.sublayer = sublayer_cls(config)
self.dropout = nn.Dropout(config.dropout_prob)
self.norm = nn.LayerNorm(config.hidden_size)
def forward(self, x, **kwargs):
original_x = x
x = self.sublayer(x, **kwargs)
x = self.dropout(x)
x = original_x + x
x = self.norm(x)
return x
class MultiHeadSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.mix_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.attention = Attention(config)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.H = config.num_attention_heads
self.d = config.hidden_size // self.H
if config.position_embedding_type == "rotary":
self.rotray_position_embeds = RotaryEmbedding(self.d)
def forward(
self,
x, # (B,L,D)
attention_mask, # <int>(B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended
):
B, L, D, H, d = *x.shape, self.H, self.d
query, key, value = (
self.mix_proj(x).view(B, L, H, 3 * d).transpose(1, 2).split(d, dim=-1)
) # (B,H,L,d),(B,H,L,d),(B,H,L,d)
if hasattr(self, "rotray_position_embeds"):
query, key = self.rotray_position_embeds(query, key)
output = self.attention(query, key, value, attention_mask) # (B,H,L,d)
output = self.o_proj(output.transpose(1, 2).reshape(B, L, D)) # (B,L,D)
return output # (B,L,D)
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = nn.Dropout(config.dropout_prob)
def forward(
self,
query, # (B,H,L,d)
key, # (B,H,L,d)
value, # (B,H,L,d)
attention_mask, # <int>(B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended
):
B, H, L, d = key.shape
attention_score = query.matmul(key.transpose(-2, -1)) # (B,H,L,L)
attention_score = attention_score / math.sqrt(d) # (B,H,L,L)
attention_score += attention_mask # (B,H,L,L)
attention_probs = attention_score.softmax(dim=-1) # (B,H,L,L)
attention_probs = self.dropout(attention_probs) # (B,H,L,L)
output = attention_probs.matmul(value) # (B,H,L,d)
return output # (B,H,L,d)
class FeedForwardNetwork(nn.Module):
def __init__(self, config):
super().__init__()
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, x): # (B,L,D)
x = self.linear1(x) # (B L,intermediate_size)
x = F.gelu(x) # (B,L,intermediate_size)
x = self.linear2(x) # (B,L,D)
return x # (B,L,D)
class RotaryEmbedding(nn.Module):
seq_len_cached = 0
cos_cached = None
sin_cached = None
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def _forward(self, x): # (B,H,L,d)
# Get rotary embeddings on the fly
## create
seq_len = x.shape[2]
if seq_len > RotaryEmbedding.seq_len_cached:
RotaryEmbedding.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = t.view(-1, 1) @ self.inv_freq.view(1, -1)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # (L,d)
RotaryEmbedding.cos_cached = emb.cos()[None, None, :, :]
RotaryEmbedding.sin_cached = emb.sin()[None, None, :, :]
## take
if seq_len == RotaryEmbedding.seq_len_cached:
cos, sin = RotaryEmbedding.cos_cached, RotaryEmbedding.sin_cached
else:
cos, sin = (
RotaryEmbedding.cos_cached[:, :, :seq_len, :], # (1,1,L,d)
RotaryEmbedding.sin_cached[:, :, :seq_len, :], # (1,1,L,d)
)
# Apply rotary embeddings
sections = [x.shape[-1] // 2, x.shape[-1] - x.shape[-1] // 2]
x1, x2 = x.split(sections, dim=-1)
half_rotated_x = torch.cat((-x2, x1), dim=-1)
return (x * cos) + (half_rotated_x * sin)
def forward(
self, query, key, # (B,H,L,d) # (B,H,L,d)
):
return self._forward(query), self._forward(key)