relik-ie's picture
Update modeling_relik.py
9ed9a98 verified
from typing import Any, Dict, Optional
import torch
from transformers import AutoModel, PreTrainedModel
from transformers.activations import ClippedGELUActivation, GELUActivation
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PoolerEndLogits
from .configuration_relik import RelikReaderConfig
class RelikReaderSample:
def __init__(self, **kwargs):
super().__setattr__("_d", {})
self._d = kwargs
def __getattribute__(self, item):
return super(RelikReaderSample, self).__getattribute__(item)
def __getattr__(self, item):
if item.startswith("__") and item.endswith("__"):
# this is likely some python library-specific variable (such as __deepcopy__ for copy)
# better follow standard behavior here
raise AttributeError(item)
elif item in self._d:
return self._d[item]
else:
return None
def __setattr__(self, key, value):
if key in self._d:
self._d[key] = value
else:
super().__setattr__(key, value)
self._d[key] = value
activation2functions = {
"relu": torch.nn.ReLU(),
"gelu": GELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10),
}
class PoolerEndLogitsBi(PoolerEndLogits):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if p_mask is not None:
p_mask = p_mask.unsqueeze(-1)
logits = super().forward(
hidden_states,
start_states,
start_positions,
p_mask,
)
return logits
class RelikReaderSpanModel(PreTrainedModel):
config_class = RelikReaderConfig
def __init__(self, config: RelikReaderConfig, *args, **kwargs):
super().__init__(config)
# Transformer model declaration
self.config = config
self.transformer_model = (
AutoModel.from_pretrained(self.config.transformer_model)
if self.config.num_layers is None
else AutoModel.from_pretrained(
self.config.transformer_model, num_hidden_layers=self.config.num_layers
)
)
self.transformer_model.resize_token_embeddings(
self.transformer_model.config.vocab_size
+ self.config.additional_special_symbols
)
self.activation = self.config.activation
self.linears_hidden_size = self.config.linears_hidden_size
self.use_last_k_layers = self.config.use_last_k_layers
# named entity detection layers
self.ned_start_classifier = self._get_projection_layer(
self.activation, last_hidden=2, layer_norm=False
)
if self.config.binary_end_logits:
self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
else:
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
# END entity disambiguation layer
self.ed_start_projector = self._get_projection_layer(self.activation)
self.ed_end_projector = self._get_projection_layer(self.activation)
self.training = self.config.training
# criterion
self.criterion = torch.nn.CrossEntropyLoss()
def _get_projection_layer(
self,
activation: str,
last_hidden: Optional[int] = None,
input_hidden=None,
layer_norm: bool = True,
) -> torch.nn.Sequential:
head_components = [
torch.nn.Dropout(0.1),
torch.nn.Linear(
(
self.transformer_model.config.hidden_size * self.use_last_k_layers
if input_hidden is None
else input_hidden
),
self.linears_hidden_size,
),
activation2functions[activation],
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.linears_hidden_size,
self.linears_hidden_size if last_hidden is None else last_hidden,
),
]
if layer_norm:
head_components.append(
torch.nn.LayerNorm(
self.linears_hidden_size if last_hidden is None else last_hidden,
self.transformer_model.config.layer_norm_eps,
)
)
return torch.nn.Sequential(*head_components)
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1)
if next(self.parameters()).dtype == torch.float16:
logits = logits * (1 - mask) - 65500 * mask
else:
logits = logits * (1 - mask) - 1e30 * mask
return logits
def _get_model_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
):
model_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": self.use_last_k_layers > 1,
}
if token_type_ids is not None:
model_input["token_type_ids"] = token_type_ids
model_output = self.transformer_model(**model_input)
if self.use_last_k_layers > 1:
model_features = torch.cat(
model_output[1][-self.use_last_k_layers :], dim=-1
)
else:
model_features = model_output[0]
return model_features
def compute_ned_end_logits(
self,
start_predictions,
start_labels,
model_features,
prediction_mask,
batch_size,
) -> Optional[torch.Tensor]:
# todo: maybe when constraining on the spans,
# we should not use a prediction_mask for the end tokens.
# at least we should not during training imo
start_positions = start_labels if self.training else start_predictions
start_positions_indices = (
torch.arange(start_positions.size(1), device=start_positions.device)
.unsqueeze(0)
.expand(batch_size, -1)[start_positions > 0]
).to(start_positions.device)
if len(start_positions_indices) > 0:
expanded_features = model_features.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
expanded_prediction_mask = prediction_mask.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
end_logits = self.ned_end_classifier(
hidden_states=expanded_features,
start_positions=start_positions_indices,
p_mask=expanded_prediction_mask,
)
return end_logits
return None
def compute_classification_logits(
self,
model_features_start,
model_features_end,
special_symbols_features,
) -> torch.Tensor:
model_start_features = self.ed_start_projector(model_features_start)
model_end_features = self.ed_end_projector(model_features_end)
model_start_features_symbols = self.ed_start_projector(special_symbols_features)
model_end_features_symbols = self.ed_end_projector(special_symbols_features)
model_ed_features = torch.cat(
[model_start_features, model_end_features], dim=-1
)
special_symbols_representation = torch.cat(
[model_start_features_symbols, model_end_features_symbols], dim=-1
)
logits = torch.bmm(
model_ed_features,
torch.permute(special_symbols_representation, (0, 2, 1)),
)
logits = self._mask_logits(logits, (model_features_start == -100).all(2).long())
return logits
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
start_labels: Optional[torch.Tensor] = None,
end_labels: Optional[torch.Tensor] = None,
use_predefined_spans: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
batch_size, seq_len = input_ids.shape
model_features = self._get_model_features(
input_ids, attention_mask, token_type_ids
)
ned_start_labels = None
# named entity detection if required
if use_predefined_spans: # no need to compute spans
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
None,
None,
(
torch.clone(start_labels)
if start_labels is not None
else torch.zeros_like(input_ids)
),
)
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
None,
None,
(
torch.clone(end_labels)
if end_labels is not None
else torch.zeros_like(input_ids)
),
)
ned_start_predictions[ned_start_predictions > 0] = 1
ned_end_predictions[end_labels > 0] = 1
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
else: # compute spans
# start boundary prediction
ned_start_logits = self.ned_start_classifier(model_features)
ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
# end boundary prediction
ned_start_labels = (
torch.zeros_like(start_labels) if start_labels is not None else None
)
if ned_start_labels is not None:
ned_start_labels[start_labels == -100] = -100
ned_start_labels[start_labels > 0] = 1
ned_end_logits = self.compute_ned_end_logits(
ned_start_predictions,
ned_start_labels,
model_features,
prediction_mask,
batch_size,
)
if ned_end_logits is not None:
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
if not self.config.binary_end_logits:
ned_end_predictions = torch.argmax(
ned_end_probabilities, dim=-1, keepdim=True
)
ned_end_predictions = torch.zeros_like(
ned_end_probabilities
).scatter_(1, ned_end_predictions, 1)
else:
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
else:
ned_end_logits, ned_end_probabilities = None, None
ned_end_predictions = ned_start_predictions.new_zeros(
batch_size, seq_len
)
if not self.training:
# if len(ned_end_predictions.shape) < 2:
# print(ned_end_predictions)
end_preds_count = ned_end_predictions.sum(1)
# If there are no end predictions for a start prediction, remove the start prediction
if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
ned_start_predictions[ned_start_predictions == 1] = (
end_preds_count != 0
).long()
ned_end_predictions = ned_end_predictions[end_preds_count != 0]
if end_labels is not None:
end_labels = end_labels[~(end_labels == -100).all(2)]
start_position, end_position = (
(start_labels, end_labels)
if self.training
else (ned_start_predictions, ned_end_predictions)
)
start_counts = (start_position > 0).sum(1)
if (start_counts > 0).any():
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
# Entity disambiguation
if (end_position > 0).sum() > 0:
ends_count = (end_position > 0).sum(1)
model_entity_start = torch.repeat_interleave(
model_features[start_position > 0], ends_count, dim=0
)
model_entity_end = torch.repeat_interleave(
model_features, start_counts, dim=0
)[end_position > 0]
ents_count = torch.nn.utils.rnn.pad_sequence(
torch.split(ends_count, start_counts.tolist()),
batch_first=True,
padding_value=0,
).sum(1)
model_entity_start = torch.nn.utils.rnn.pad_sequence(
torch.split(model_entity_start, ents_count.tolist()),
batch_first=True,
padding_value=-100,
)
model_entity_end = torch.nn.utils.rnn.pad_sequence(
torch.split(model_entity_end, ents_count.tolist()),
batch_first=True,
padding_value=-100,
)
ed_logits = self.compute_classification_logits(
model_entity_start,
model_entity_end,
model_features[special_symbols_mask].view(
batch_size, -1, model_features.shape[-1]
),
)
ed_probabilities = torch.softmax(ed_logits, dim=-1)
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
else:
ed_logits, ed_probabilities, ed_predictions = (
None,
ned_start_predictions.new_zeros(batch_size, seq_len),
ned_start_predictions.new_zeros(batch_size),
)
# output build
output_dict = dict(
batch_size=batch_size,
ned_start_logits=ned_start_logits,
ned_start_probabilities=ned_start_probabilities,
ned_start_predictions=ned_start_predictions,
ned_end_logits=ned_end_logits,
ned_end_probabilities=ned_end_probabilities,
ned_end_predictions=ned_end_predictions,
ed_logits=ed_logits,
ed_probabilities=ed_probabilities,
ed_predictions=ed_predictions,
)
# compute loss if labels
if start_labels is not None and end_labels is not None and self.training:
# named entity detection loss
# start
if ned_start_logits is not None:
ned_start_loss = self.criterion(
ned_start_logits.view(-1, ned_start_logits.shape[-1]),
ned_start_labels.view(-1),
)
else:
ned_start_loss = 0
# end
# use ents_count to assign the labels to the correct positions i.e. using end_labels -> [[0,0,4,0], [0,0,0,2]] -> [4,2] (this is just an element, for batch we need to mask it with ents_count), ie -> [[4,2,-100,-100], [3,1,2,-100], [1,3,2,5]]
if ned_end_logits is not None:
ed_labels = end_labels.clone()
ed_labels = torch.nn.utils.rnn.pad_sequence(
torch.split(ed_labels[ed_labels > 0], ents_count.tolist()),
batch_first=True,
padding_value=-100,
)
end_labels[end_labels > 0] = 1
if not self.config.binary_end_logits:
# transform label to position in the sequence
end_labels = end_labels.argmax(dim=-1)
ned_end_loss = self.criterion(
ned_end_logits.view(-1, ned_end_logits.shape[-1]),
end_labels.view(-1),
)
else:
ned_end_loss = self.criterion(
ned_end_logits.reshape(-1, ned_end_logits.shape[-1]),
end_labels.reshape(-1).long(),
)
# entity disambiguation loss
ed_loss = self.criterion(
ed_logits.view(-1, ed_logits.shape[-1]),
ed_labels.view(-1).long(),
)
else:
ned_end_loss = 0
ed_loss = 0
output_dict["ned_start_loss"] = ned_start_loss
output_dict["ned_end_loss"] = ned_end_loss
output_dict["ed_loss"] = ed_loss
output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
return output_dict
class RelikReaderREModel(PreTrainedModel):
config_class = RelikReaderConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config)
# Transformer model declaration
# self.transformer_model_name = transformer_model
self.config = config
self.transformer_model = (
AutoModel.from_pretrained(config.transformer_model)
if config.num_layers is None
else AutoModel.from_pretrained(
config.transformer_model, num_hidden_layers=config.num_layers
)
)
self.transformer_model.resize_token_embeddings(
self.transformer_model.config.vocab_size
+ config.additional_special_symbols
+ config.additional_special_symbols_types,
)
# named entity detection layers
self.ned_start_classifier = self._get_projection_layer(
config.activation, last_hidden=2, layer_norm=False
)
self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
self.relation_disambiguation_loss = (
config.relation_disambiguation_loss
if hasattr(config, "relation_disambiguation_loss")
else False
)
if self.config.entity_type_loss and self.config.add_entity_embedding:
input_hidden_ents = 3
else:
input_hidden_ents = 2
self.re_projector = self._get_projection_layer(
config.activation,
input_hidden=input_hidden_ents * self.transformer_model.config.hidden_size,
hidden=input_hidden_ents * self.config.linears_hidden_size,
last_hidden=2 * self.config.linears_hidden_size,
)
self.re_relation_projector = self._get_projection_layer(
config.activation,
input_hidden=self.transformer_model.config.hidden_size,
)
if self.config.entity_type_loss or self.relation_disambiguation_loss:
self.re_entities_projector = self._get_projection_layer(
config.activation,
input_hidden=2 * self.transformer_model.config.hidden_size,
)
self.re_definition_projector = self._get_projection_layer(
config.activation,
)
self.re_classifier = self._get_projection_layer(
config.activation,
input_hidden=config.linears_hidden_size,
last_hidden=2,
layer_norm=False,
)
self.training = config.training
# criterion
self.criterion = torch.nn.CrossEntropyLoss()
self.criterion_type = torch.nn.BCEWithLogitsLoss()
def _get_projection_layer(
self,
activation: str,
last_hidden: Optional[int] = None,
hidden: Optional[int] = None,
input_hidden=None,
layer_norm: bool = True,
) -> torch.nn.Sequential:
head_components = [
torch.nn.Dropout(0.1),
torch.nn.Linear(
(
self.transformer_model.config.hidden_size
* self.config.use_last_k_layers
if input_hidden is None
else input_hidden
),
self.config.linears_hidden_size if hidden is None else hidden,
),
activation2functions[activation],
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.config.linears_hidden_size if hidden is None else hidden,
self.config.linears_hidden_size if last_hidden is None else last_hidden,
),
]
if layer_norm:
head_components.append(
torch.nn.LayerNorm(
(
self.config.linears_hidden_size
if last_hidden is None
else last_hidden
),
self.transformer_model.config.layer_norm_eps,
)
)
return torch.nn.Sequential(*head_components)
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1)
if next(self.parameters()).dtype == torch.float16:
logits = logits * (1 - mask) - 65500 * mask
else:
logits = logits * (1 - mask) - 1e30 * mask
return logits
def _get_model_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
):
model_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": self.config.use_last_k_layers > 1,
}
if token_type_ids is not None:
model_input["token_type_ids"] = token_type_ids
model_output = self.transformer_model(**model_input)
if self.config.use_last_k_layers > 1:
model_features = torch.cat(
model_output[1][-self.config.use_last_k_layers :], dim=-1
)
else:
model_features = model_output[0]
return model_features
def compute_ned_end_logits(
self,
start_predictions,
start_labels,
model_features,
prediction_mask,
batch_size,
mask_preceding: bool = False,
) -> Optional[torch.Tensor]:
# todo: maybe when constraining on the spans,
# we should not use a prediction_mask for the end tokens.
# at least we should not during training imo
start_positions = start_labels if self.training else start_predictions
start_positions_indices = (
torch.arange(start_positions.size(1), device=start_positions.device)
.unsqueeze(0)
.expand(batch_size, -1)[start_positions > 0]
).to(start_positions.device)
if len(start_positions_indices) > 0:
expanded_features = model_features.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
expanded_prediction_mask = prediction_mask.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
if mask_preceding:
expanded_prediction_mask[
torch.arange(
expanded_prediction_mask.shape[1],
device=expanded_prediction_mask.device,
)
< start_positions_indices.unsqueeze(1)
] = 1
end_logits = self.ned_end_classifier(
hidden_states=expanded_features,
start_positions=start_positions_indices,
p_mask=expanded_prediction_mask,
)
return end_logits
return None
def compute_relation_logits(
self,
model_entity_features,
special_symbols_features,
) -> torch.Tensor:
model_subject_object_features = self.re_projector(model_entity_features)
model_subject_features = model_subject_object_features[
:, :, : model_subject_object_features.shape[-1] // 2
]
model_object_features = model_subject_object_features[
:, :, model_subject_object_features.shape[-1] // 2 :
]
special_symbols_start_representation = self.re_relation_projector(
special_symbols_features
)
re_logits = torch.einsum(
"bse,bde,bfe->bsdfe",
model_subject_features,
model_object_features,
special_symbols_start_representation,
)
re_logits = self.re_classifier(re_logits)
return re_logits
def compute_entity_logits(
self,
model_entity_features,
special_symbols_features,
) -> torch.Tensor:
model_ed_features = self.re_entities_projector(model_entity_features)
special_symbols_ed_representation = self.re_definition_projector(
special_symbols_features
)
logits = torch.bmm(
model_ed_features,
torch.permute(special_symbols_ed_representation, (0, 2, 1)),
)
logits = self._mask_logits(
logits, (model_entity_features == -100).all(2).long()
)
return logits
def compute_loss(self, logits, labels, mask=None):
logits = logits.reshape(-1, logits.shape[-1])
labels = labels.reshape(-1).long()
if mask is not None:
return self.criterion(logits[mask], labels[mask])
return self.criterion(logits, labels)
def compute_ned_type_loss(
self,
disambiguation_labels,
re_ned_entities_logits,
ned_type_logits,
re_entities_logits,
entity_types,
mask,
):
if self.config.entity_type_loss and self.relation_disambiguation_loss:
return self.criterion_type(
re_ned_entities_logits[disambiguation_labels != -100],
disambiguation_labels[disambiguation_labels != -100],
)
if self.config.entity_type_loss:
return self.criterion_type(
ned_type_logits[mask],
disambiguation_labels[:, :, :entity_types][mask],
)
if self.relation_disambiguation_loss:
return self.criterion_type(
re_entities_logits[disambiguation_labels != -100],
disambiguation_labels[disambiguation_labels != -100],
)
return 0
def compute_relation_loss(self, relation_labels, re_logits):
return self.compute_loss(
re_logits, relation_labels, relation_labels.view(-1) != -100
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
special_symbols_mask_entities: Optional[torch.Tensor] = None,
start_labels: Optional[torch.Tensor] = None,
end_labels: Optional[torch.Tensor] = None,
disambiguation_labels: Optional[torch.Tensor] = None,
relation_labels: Optional[torch.Tensor] = None,
relation_threshold: float = None,
is_validation: bool = False,
is_prediction: bool = False,
use_predefined_spans: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
relation_threshold = (
self.config.threshold if relation_threshold is None else relation_threshold
)
batch_size = input_ids.shape[0]
model_features = self._get_model_features(
input_ids, attention_mask, token_type_ids
)
# named entity detection
if use_predefined_spans:
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
None,
None,
torch.zeros_like(start_labels),
)
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
None,
None,
torch.zeros_like(end_labels),
)
ned_start_predictions[start_labels > 0] = 1
ned_end_predictions[end_labels > 0] = 1
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
ned_start_labels = start_labels
ned_start_labels[start_labels > 0] = 1
else:
# start boundary prediction
ned_start_logits = self.ned_start_classifier(model_features)
if is_validation or is_prediction:
ned_start_logits = self._mask_logits(
ned_start_logits, prediction_mask
) # why?
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
# end boundary prediction
ned_start_labels = (
torch.zeros_like(start_labels) if start_labels is not None else None
)
# start_labels contain entity id at their position, we just need 1 for start of entity
if ned_start_labels is not None:
ned_start_labels[start_labels == -100] = -100
ned_start_labels[start_labels > 0] = 1
# compute end logits only if there are any start predictions.
# For each start prediction, n end predictions are made
ned_end_logits = self.compute_ned_end_logits(
ned_start_predictions,
ned_start_labels,
model_features,
prediction_mask,
batch_size,
True,
)
if ned_end_logits is not None:
# For each start prediction, n end predictions are made based on
# binary classification ie. argmax at each position.
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
else:
ned_end_logits, ned_end_probabilities = None, None
ned_end_predictions = torch.zeros_like(ned_start_predictions)
if is_prediction or is_validation:
end_preds_count = ned_end_predictions.sum(1)
# If there are no end predictions for a start prediction, remove the start prediction
if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
ned_start_predictions[ned_start_predictions == 1] = (
end_preds_count != 0
).long()
ned_end_predictions = ned_end_predictions[end_preds_count != 0]
if end_labels is not None:
end_labels = end_labels[~(end_labels == -100).all(2)]
start_position, end_position = (
(start_labels, end_labels)
if (not is_prediction and not is_validation)
else (ned_start_predictions, ned_end_predictions)
)
start_counts = (start_position > 0).sum(1)
if (start_counts > 0).any():
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
else:
ned_end_predictions = [torch.empty(0, input_ids.shape[1], dtype=torch.int64) for _ in range(batch_size)]
# limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
# if is_validation or is_prediction:
# ned_start_predictions[ned_start_predictions == 1] = start_counts
# We can only predict relations if we have start and end predictions
if (end_position > 0).sum() > 0:
ends_count = (end_position > 0).sum(1)
model_subject_features = torch.cat(
[
torch.repeat_interleave(
model_features[start_position > 0], ends_count, dim=0
), # start position features
torch.repeat_interleave(model_features, start_counts, dim=0)[
end_position > 0
], # end position features
],
dim=-1,
)
ents_count = torch.nn.utils.rnn.pad_sequence(
torch.split(ends_count, start_counts.tolist()),
batch_first=True,
padding_value=0,
).sum(1)
model_subject_features = torch.nn.utils.rnn.pad_sequence(
torch.split(model_subject_features, ents_count.tolist()),
batch_first=True,
padding_value=-100,
)
# if is_validation or is_prediction:
# model_subject_features = model_subject_features[:, :30, :]
# entity disambiguation. Here relation_disambiguation_loss would only be useful to
# reduce the number of candidate relations for the next step, but currently unused.
if self.config.entity_type_loss or self.relation_disambiguation_loss:
(re_ned_entities_logits) = self.compute_entity_logits(
model_subject_features,
model_features[
special_symbols_mask | special_symbols_mask_entities
].view(batch_size, -1, model_features.shape[-1]),
)
entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
if self.config.entity_type_loss:
ned_type_probabilities = torch.sigmoid(ned_type_logits)
ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
if self.config.add_entity_embedding:
special_symbols_representation = model_features[
special_symbols_mask_entities
].view(batch_size, entity_types, -1)
entities_representation = torch.einsum(
"bsp,bpe->bse",
ned_type_probabilities,
special_symbols_representation,
)
model_subject_features = torch.cat(
[model_subject_features, entities_representation], dim=-1
)
re_entities_probabilities = torch.sigmoid(re_entities_logits)
re_entities_predictions = re_entities_probabilities.round()
else:
(
ned_type_logits,
ned_type_probabilities,
re_entities_logits,
re_entities_probabilities,
) = (None, None, None, None)
ned_type_predictions, re_entities_predictions = (
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
)
# Compute relation logits
re_logits = self.compute_relation_logits(
model_subject_features,
model_features[special_symbols_mask].view(
batch_size, -1, model_features.shape[-1]
),
)
re_probabilities = torch.softmax(re_logits, dim=-1)
# we set a thresshold instead of argmax in cause it needs to be tweaked
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
re_probabilities = re_probabilities[:, :, :, :, 1]
else:
(
ned_type_logits,
ned_type_probabilities,
re_entities_logits,
re_entities_probabilities,
) = (None, None, None, None)
ned_type_predictions, re_entities_predictions = (
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
)
re_logits, re_probabilities, re_predictions = (
torch.zeros(
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
).to(input_ids.device),
torch.zeros(
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
).to(input_ids.device),
torch.zeros(
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
).to(input_ids.device),
)
# output build
output_dict = dict(
batch_size=batch_size,
ned_start_logits=ned_start_logits,
ned_start_probabilities=ned_start_probabilities,
ned_start_predictions=ned_start_predictions,
ned_end_logits=ned_end_logits,
ned_end_probabilities=ned_end_probabilities,
ned_end_predictions=ned_end_predictions,
ned_type_logits=ned_type_logits,
ned_type_probabilities=ned_type_probabilities,
ned_type_predictions=ned_type_predictions,
re_entities_logits=re_entities_logits,
re_entities_probabilities=re_entities_probabilities,
re_entities_predictions=re_entities_predictions,
re_logits=re_logits,
re_probabilities=re_probabilities,
re_predictions=re_predictions,
)
if (
start_labels is not None
and end_labels is not None
and relation_labels is not None
and is_prediction is False
):
ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
end_labels[end_labels > 0] = 1
ned_end_loss = self.compute_loss(ned_end_logits, end_labels)
if self.config.entity_type_loss or self.relation_disambiguation_loss:
ned_type_loss = self.compute_ned_type_loss(
disambiguation_labels,
re_ned_entities_logits,
ned_type_logits,
re_entities_logits,
entity_types,
(model_subject_features != -100).all(2),
)
relation_loss = self.compute_relation_loss(relation_labels, re_logits)
# compute loss. We can skip the relation loss if we are in the first epochs (optional)
if self.config.entity_type_loss or self.relation_disambiguation_loss:
output_dict["loss"] = (
ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
) / 4
output_dict["ned_type_loss"] = ned_type_loss
else:
output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + (
(9 / 10) * relation_loss
)
output_dict["ned_start_loss"] = ned_start_loss
output_dict["ned_end_loss"] = ned_end_loss
output_dict["re_loss"] = relation_loss
return output_dict