|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
import math |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel |
|
from configuration_tsp import TSPConfig |
|
|
|
|
|
class TSPPreTrainedModel(PreTrainedModel): |
|
config_class = TSPConfig |
|
base_model_prefix = "backbone" |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TSPModelForPreTraining(TSPPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.backbone = TSPModel(config) |
|
if config.use_electra: |
|
mlm_config = deepcopy(config) |
|
mlm_config.hidden_size //= config.electra_generator_size_divisor |
|
mlm_config.intermediate_size //= config.electra_generator_size_divisor |
|
mlm_config.num_attention_heads //= config.electra_generator_size_divisor |
|
self.mlm_backbone = TSPModel(mlm_config) |
|
self.mlm_head = MaskedLMHead( |
|
mlm_config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings |
|
) |
|
self.rtd_backbone = self.backbone |
|
self.rtd_backbone.embeddings = self.mlm_backbone.embeddings |
|
self.rtd_head = ReplacedTokenDiscriminationHead(config) |
|
else: |
|
self.mlm_backbone = self.backbone |
|
self.mlm_head = MaskedLMHead( |
|
config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings |
|
) |
|
self.tsp_head = TextStructurePredictionHead(config) |
|
self.apply(self._init_weights) |
|
|
|
def forward(self, *args, **kwargs): |
|
raise NotImplementedError( |
|
"Refer to the implementation of text structrue prediction task for how to use the model." |
|
) |
|
|
|
|
|
class MaskedLMHead(nn.Module): |
|
def __init__(self, config, word_embeddings=None): |
|
super().__init__() |
|
self.linear = nn.Linear(config.hidden_size, config.embedding_size) |
|
self.norm = nn.LayerNorm(config.embedding_size) |
|
self.predictor = nn.Linear(config.embedding_size, config.vocab_size) |
|
if word_embeddings is not None: |
|
self.predictor.weight = word_embeddings.weight |
|
|
|
def forward( |
|
self, |
|
x, |
|
is_selected=None, |
|
): |
|
if is_selected is not None: |
|
|
|
|
|
x = x[is_selected] |
|
x = self.linear(x) |
|
x = F.gelu(x) |
|
x = self.norm(x) |
|
return self.predictor(x) |
|
|
|
|
|
class ReplacedTokenDiscriminationHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.linear = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.predictor = nn.Linear(config.hidden_size, 1) |
|
|
|
def forward(self, x): |
|
x = self.linear(x) |
|
x = F.gelu(x) |
|
x = self.predictor(x) |
|
return x.squeeze(-1) |
|
|
|
|
|
class TextStructurePredictionHead(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.linear1 = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) |
|
self.norm = nn.LayerNorm(config.hidden_size * 2) |
|
self.linear2 = nn.Linear(config.hidden_size * 2, 6) |
|
|
|
def forward( |
|
self, x, |
|
): |
|
x = self.linear1(x) |
|
x = F.gelu(x) |
|
x = self.norm(x) |
|
return self.linear2(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TSPModelForTokenClassification(TSPPreTrainedModel): |
|
def __init__(self, config, num_classes): |
|
super().__init__(config) |
|
self.backbone = TSPModel(config) |
|
self.head = TokenClassificationHead(config, num_classes) |
|
self.apply(self._init_weights) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
): |
|
hidden_states = self.backbone( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
) |
|
return self.head(hidden_states) |
|
|
|
|
|
class TokenClassificationHead(nn.Module): |
|
def __init__(self, config, num_classes): |
|
super().__init__() |
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, num_classes) |
|
|
|
def forward(self, x): |
|
x = self.dropout(x) |
|
x = self.classifier(x) |
|
return x |
|
|
|
|
|
class TSPModelForSequenceClassification(TSPPreTrainedModel): |
|
def __init__(self, config, num_classes): |
|
super().__init__(config) |
|
self.backbone = TSPModel(config) |
|
self.head = SequenceClassififcationHead(config, num_classes) |
|
self.apply(self._init_weights) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
): |
|
hidden_states = self.backbone( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
) |
|
return self.head(hidden_states) |
|
|
|
|
|
class SequenceClassififcationHead(nn.Module): |
|
def __init__(self, config, num_classes): |
|
super().__init__() |
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, num_classes) |
|
|
|
def forward( |
|
self, x, |
|
): |
|
x = x[:, 0, :] |
|
x = self.dropout(x) |
|
return self.classifier(x) |
|
|
|
|
|
class TSPModelForQuestionAnswering(TSPPreTrainedModel): |
|
def __init__(self, config, num_classes): |
|
super().__init__() |
|
self.backbone = TSPModel(config) |
|
self.head = SequenceClassififcationHead(config, num_classes) |
|
self.apply(self._init_weights) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
): |
|
hidden_states = self.backbone( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
) |
|
return self.head(hidden_states) |
|
|
|
|
|
class SquadHead(nn.Module): |
|
def __init__( |
|
self, config, beam_size, predict_answerability, |
|
): |
|
super().__init__() |
|
self.beam_size = beam_size |
|
self.predict_answerability = predict_answerability |
|
|
|
|
|
self.start_predictor = nn.Linear(config.hidden_size, 1) |
|
|
|
|
|
self.end_predictor = nn.Sequential( |
|
nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1), |
|
) |
|
|
|
|
|
if predict_answerability: |
|
self.answerability_predictor = nn.Sequential( |
|
nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1), |
|
) |
|
else: |
|
self.answerability_predictor = None |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
token_type_ids, |
|
answer_start_position=None, |
|
): |
|
|
|
|
|
answer_mask = token_type_ids |
|
last_sep = answer_mask.cumsum(dim=1) == answer_mask.sum( |
|
dim=1, keepdim=True |
|
) |
|
answer_mask = answer_mask * ~last_sep |
|
answer_mask[:, 0] = 1 |
|
answer_mask = answer_mask.bool() |
|
|
|
|
|
start_logits, start_top_hidden_states = self._calculate_start( |
|
hidden_states, answer_mask, answer_start_position |
|
) |
|
|
|
|
|
end_logits = self._calculate_end_logits( |
|
hidden_states, start_top_hidden_states, answer_mask, |
|
) |
|
|
|
|
|
answerability_logits = None |
|
if self.answerability_predictor is not None: |
|
answerability_logits = self._calculate_answerability_logits( |
|
hidden_states, start_logits |
|
) |
|
|
|
return start_logits, end_logits, answerability_logits |
|
|
|
def _calculate_start(self, hidden_states, answer_mask, start_positions): |
|
start_logits = self.start_predictor(hidden_states).squeeze(-1) |
|
start_logits = start_logits.masked_fill(~answer_mask, -float("inf")) |
|
start_top_indices, start_top_hidden_states = None, None |
|
if self.training: |
|
start_top_indices = start_positions |
|
else: |
|
k = self.beam_size |
|
_, start_top_indices = start_logits.topk(k=k, dim=-1) |
|
start_top_hidden_states = torch.stack( |
|
[ |
|
hiddens.index_select(dim=0, index=index) |
|
for hiddens, index in zip(hidden_states, start_top_indices) |
|
] |
|
) |
|
return start_logits, start_top_hidden_states |
|
|
|
def _calculate_end_logits( |
|
self, hidden_states, start_top_hidden_states, answer_mask |
|
): |
|
B, L, D = hidden_states.shape |
|
start_tophiddens = start_top_hidden_states.view(B, -1, 1, D).expand( |
|
-1, -1, L, -1 |
|
) |
|
end_hidden_states = torch.cat( |
|
[ |
|
start_tophiddens, |
|
hidden_states.view(B, 1, L, D).expand_as(start_tophiddens), |
|
], |
|
dim=-1, |
|
) |
|
end_logits = self.end_predictor(end_hidden_states).squeeze(-1) |
|
end_logits = end_logits.masked_fill( |
|
~answer_mask.view(B, 1, L), -float("inf") |
|
) |
|
end_logits = end_logits.squeeze(1) |
|
|
|
return end_logits |
|
|
|
def _calculate_answerability_logits(self, hidden_states, start_logits): |
|
answerability_hidden_states = hidden_states[:, 0, :] |
|
start_probs = start_logits.softmax(dim=-1).unsqueeze(-1) |
|
start_featrues = (start_probs * hidden_states).sum(dim=1) |
|
answerability_hidden_states = torch.cat( |
|
[answerability_hidden_states, start_featrues], dim=-1 |
|
) |
|
answerability_logits = self.answerability_predictor( |
|
answerability_hidden_states |
|
) |
|
return answerability_logits.squeeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TSPModel(TSPPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.embeddings = Embeddings(config) |
|
if config.embedding_size != config.hidden_size: |
|
self.embeddings_project = nn.Linear( |
|
config.embedding_size, config.hidden_size |
|
) |
|
self.layers = nn.ModuleList( |
|
EncoderLayer(config) for _ in range(config.num_hidden_layers) |
|
) |
|
self.apply(self._init_weights) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
): |
|
x = self.embeddings( |
|
input_ids=input_ids, token_type_ids=token_type_ids |
|
) |
|
if hasattr(self, "embeddings_project"): |
|
x = self.embeddings_project(x) |
|
|
|
extended_attention_mask = self.get_extended_attention_mask( |
|
attention_mask=attention_mask, |
|
input_shape=input_ids.shape, |
|
device=input_ids.device, |
|
) |
|
|
|
for layer_idx, layer in enumerate(self.layers): |
|
x = layer(x, attention_mask=extended_attention_mask) |
|
|
|
return x |
|
|
|
|
|
class Embeddings(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.word_embeddings = nn.Embedding( |
|
config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id |
|
) |
|
if config.position_embedding_type == "absolute": |
|
self.position_embeddings = nn.Embedding( |
|
config.max_sequence_length, config.embedding_size |
|
) |
|
self.token_type_embeddings = nn.Embedding(2, config.embedding_size) |
|
self.norm = nn.LayerNorm(config.embedding_size) |
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
token_type_ids, |
|
): |
|
B, L = input_ids.shape |
|
embeddings = self.word_embeddings(input_ids) |
|
embeddings += self.token_type_embeddings(token_type_ids) |
|
if hasattr(self, "position_embeddings"): |
|
embeddings += self.position_embeddings.weight[None, :L, :] |
|
embeddings = self.norm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.self_attn_block = BlockWrapper(config, MultiHeadSelfAttention) |
|
self.transition_block = BlockWrapper(config, FeedForwardNetwork) |
|
|
|
def forward( |
|
self, |
|
x, |
|
attention_mask, |
|
): |
|
x = self.self_attn_block(x, attention_mask=attention_mask) |
|
x = self.transition_block(x) |
|
return x |
|
|
|
|
|
class BlockWrapper(nn.Module): |
|
def __init__(self, config, sublayer_cls): |
|
super().__init__() |
|
self.sublayer = sublayer_cls(config) |
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
self.norm = nn.LayerNorm(config.hidden_size) |
|
|
|
def forward(self, x, **kwargs): |
|
original_x = x |
|
x = self.sublayer(x, **kwargs) |
|
x = self.dropout(x) |
|
x = original_x + x |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.mix_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size) |
|
self.attention = Attention(config) |
|
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.H = config.num_attention_heads |
|
self.d = config.hidden_size // self.H |
|
if config.position_embedding_type == "rotary": |
|
self.rotray_position_embeds = RotaryEmbedding(self.d) |
|
|
|
def forward( |
|
self, |
|
x, |
|
attention_mask, |
|
): |
|
B, L, D, H, d = *x.shape, self.H, self.d |
|
query, key, value = ( |
|
self.mix_proj(x).view(B, L, H, 3 * d).transpose(1, 2).split(d, dim=-1) |
|
) |
|
if hasattr(self, "rotray_position_embeds"): |
|
query, key = self.rotray_position_embeds(query, key) |
|
output = self.attention(query, key, value, attention_mask) |
|
output = self.o_proj(output.transpose(1, 2).reshape(B, L, D)) |
|
return output |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dropout = nn.Dropout(config.dropout_prob) |
|
|
|
def forward( |
|
self, |
|
query, |
|
key, |
|
value, |
|
attention_mask, |
|
): |
|
B, H, L, d = key.shape |
|
attention_score = query.matmul(key.transpose(-2, -1)) |
|
attention_score = attention_score / math.sqrt(d) |
|
attention_score += attention_mask |
|
attention_probs = attention_score.softmax(dim=-1) |
|
attention_probs = self.dropout(attention_probs) |
|
output = attention_probs.matmul(value) |
|
return output |
|
|
|
|
|
class FeedForwardNetwork(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size) |
|
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
|
def forward(self, x): |
|
x = self.linear1(x) |
|
x = F.gelu(x) |
|
x = self.linear2(x) |
|
return x |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
seq_len_cached = 0 |
|
cos_cached = None |
|
sin_cached = None |
|
|
|
def __init__(self, dim): |
|
super().__init__() |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
def _forward(self, x): |
|
|
|
|
|
seq_len = x.shape[2] |
|
if seq_len > RotaryEmbedding.seq_len_cached: |
|
RotaryEmbedding.seq_len_cached = seq_len |
|
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
|
freqs = t.view(-1, 1) @ self.inv_freq.view(1, -1) |
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
|
RotaryEmbedding.cos_cached = emb.cos()[None, None, :, :] |
|
RotaryEmbedding.sin_cached = emb.sin()[None, None, :, :] |
|
|
|
if seq_len == RotaryEmbedding.seq_len_cached: |
|
cos, sin = RotaryEmbedding.cos_cached, RotaryEmbedding.sin_cached |
|
else: |
|
cos, sin = ( |
|
RotaryEmbedding.cos_cached[:, :, :seq_len, :], |
|
RotaryEmbedding.sin_cached[:, :, :seq_len, :], |
|
) |
|
|
|
|
|
sections = [x.shape[-1] // 2, x.shape[-1] - x.shape[-1] // 2] |
|
x1, x2 = x.split(sections, dim=-1) |
|
half_rotated_x = torch.cat((-x2, x1), dim=-1) |
|
return (x * cos) + (half_rotated_x * sin) |
|
|
|
def forward( |
|
self, query, key, |
|
): |
|
return self._forward(query), self._forward(key) |
|
|