PereLluis13
commited on
Commit
•
fbb76b5
1
Parent(s):
f324b1d
Update modeling_relik.py
Browse files- modeling_relik.py +116 -98
modeling_relik.py
CHANGED
@@ -92,7 +92,10 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
92 |
self.ned_start_classifier = self._get_projection_layer(
|
93 |
self.activation, last_hidden=2, layer_norm=False
|
94 |
)
|
95 |
-
|
|
|
|
|
|
|
96 |
|
97 |
# END entity disambiguation layer
|
98 |
self.ed_start_projector = self._get_projection_layer(self.activation)
|
@@ -209,29 +212,20 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
209 |
|
210 |
def compute_classification_logits(
|
211 |
self,
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
batch_size,
|
216 |
-
start_positions=None,
|
217 |
-
end_positions=None,
|
218 |
) -> torch.Tensor:
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
model_start_features = self.ed_start_projector(model_features)
|
224 |
-
model_end_features = self.ed_end_projector(model_features)
|
225 |
-
model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
|
226 |
|
227 |
model_ed_features = torch.cat(
|
228 |
[model_start_features, model_end_features], dim=-1
|
229 |
)
|
230 |
-
|
231 |
-
|
232 |
-
classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
|
233 |
-
special_symbols_representation = model_ed_features[special_symbols_mask].view(
|
234 |
-
batch_size, classes_representations, -1
|
235 |
)
|
236 |
|
237 |
logits = torch.bmm(
|
@@ -239,8 +233,9 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
239 |
torch.permute(special_symbols_representation, (0, 2, 1)),
|
240 |
)
|
241 |
|
242 |
-
logits = self._mask_logits(
|
243 |
-
|
|
|
244 |
return logits
|
245 |
|
246 |
def forward(
|
@@ -284,9 +279,9 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
284 |
else torch.zeros_like(input_ids)
|
285 |
),
|
286 |
)
|
287 |
-
|
288 |
ned_start_predictions[ned_start_predictions > 0] = 1
|
289 |
-
ned_end_predictions[
|
|
|
290 |
|
291 |
else: # compute spans
|
292 |
# start boundary prediction
|
@@ -314,63 +309,80 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
314 |
|
315 |
if ned_end_logits is not None:
|
316 |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
317 |
-
|
|
|
|
|
|
|
|
|
318 |
else:
|
319 |
ned_end_logits, ned_end_probabilities = None, None
|
320 |
-
ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
cummax_values, _ = end_spans_repeated.cummax(dim=0)
|
333 |
-
|
334 |
-
end_spans_repeated = (end_spans_repeated > torch.cat((end_spans_repeated[:1], cummax_values[:-1])))
|
335 |
-
end_spans_repeated[0] = True
|
336 |
-
|
337 |
-
ned_start_predictions[row_indices[~end_spans_repeated], start_positions[~end_spans_repeated]] = 0
|
338 |
-
|
339 |
-
row_indices, start_positions, ned_end_predictions = row_indices[end_spans_repeated], start_positions[end_spans_repeated], ned_end_predictions[end_spans_repeated]
|
340 |
-
|
341 |
-
flattened_end_predictions[row_indices, ned_end_predictions] = 1
|
342 |
-
|
343 |
-
total_start_predictions, total_end_predictions = ned_start_predictions.sum(), flattened_end_predictions.sum()
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
or total_start_predictions == total_end_predictions
|
348 |
-
), (
|
349 |
-
f"Total number of start predictions = {total_start_predictions}. "
|
350 |
-
f"Total number of end predictions = {total_end_predictions}"
|
351 |
-
)
|
352 |
-
ned_end_predictions = flattened_end_predictions
|
353 |
-
else:
|
354 |
-
ned_end_predictions = torch.zeros_like(ned_start_predictions)
|
355 |
|
356 |
start_position, end_position = (
|
357 |
(start_labels, end_labels)
|
358 |
if self.training
|
359 |
else (ned_start_predictions, ned_end_predictions)
|
360 |
)
|
361 |
-
|
|
|
|
|
362 |
# Entity disambiguation
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
# output build
|
375 |
output_dict = dict(
|
376 |
batch_size=batch_size,
|
@@ -399,33 +411,35 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
399 |
ned_start_loss = 0
|
400 |
|
401 |
# end
|
402 |
-
|
403 |
-
ned_end_labels = torch.zeros_like(end_labels)
|
404 |
-
ned_end_labels[end_labels == -100] = -100
|
405 |
-
ned_end_labels[end_labels > 0] = 1
|
406 |
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
)
|
417 |
|
418 |
else:
|
419 |
ned_end_loss = 0
|
420 |
-
|
421 |
-
# entity disambiguation loss
|
422 |
-
start_labels[ned_start_labels != 1] = -100
|
423 |
-
ed_labels = torch.clone(start_labels)
|
424 |
-
ed_labels[end_labels > 0] = end_labels[end_labels > 0]
|
425 |
-
ed_loss = self.criterion(
|
426 |
-
ed_logits.view(-1, ed_logits.shape[-1]),
|
427 |
-
ed_labels.view(-1),
|
428 |
-
)
|
429 |
|
430 |
output_dict["ned_start_loss"] = ned_start_loss
|
431 |
output_dict["ned_end_loss"] = ned_end_loss
|
@@ -471,16 +485,20 @@ class RelikReaderREModel(PreTrainedModel):
|
|
471 |
)
|
472 |
|
473 |
if self.config.entity_type_loss and self.config.add_entity_embedding:
|
474 |
-
input_hidden_ents = 3
|
475 |
else:
|
476 |
-
input_hidden_ents = 2
|
477 |
|
478 |
self.re_projector = self._get_projection_layer(
|
479 |
-
config.activation,
|
|
|
|
|
|
|
480 |
)
|
481 |
|
482 |
self.re_relation_projector = self._get_projection_layer(
|
483 |
-
config.activation,
|
|
|
484 |
)
|
485 |
|
486 |
if self.config.entity_type_loss or self.relation_disambiguation_loss:
|
@@ -726,8 +744,9 @@ class RelikReaderREModel(PreTrainedModel):
|
|
726 |
*args,
|
727 |
**kwargs,
|
728 |
) -> Dict[str, Any]:
|
729 |
-
|
730 |
-
|
|
|
731 |
|
732 |
batch_size = input_ids.shape[0]
|
733 |
|
@@ -901,7 +920,6 @@ class RelikReaderREModel(PreTrainedModel):
|
|
901 |
# we set a thresshold instead of argmax in cause it needs to be tweaked
|
902 |
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
|
903 |
re_probabilities = re_probabilities[:, :, :, :, 1]
|
904 |
-
|
905 |
else:
|
906 |
(
|
907 |
ned_type_logits,
|
|
|
92 |
self.ned_start_classifier = self._get_projection_layer(
|
93 |
self.activation, last_hidden=2, layer_norm=False
|
94 |
)
|
95 |
+
if self.config.binary_end_logits:
|
96 |
+
self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
|
97 |
+
else:
|
98 |
+
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
|
99 |
|
100 |
# END entity disambiguation layer
|
101 |
self.ed_start_projector = self._get_projection_layer(self.activation)
|
|
|
212 |
|
213 |
def compute_classification_logits(
|
214 |
self,
|
215 |
+
model_features_start,
|
216 |
+
model_features_end,
|
217 |
+
special_symbols_features,
|
|
|
|
|
|
|
218 |
) -> torch.Tensor:
|
219 |
+
model_start_features = self.ed_start_projector(model_features_start)
|
220 |
+
model_end_features = self.ed_end_projector(model_features_end)
|
221 |
+
model_start_features_symbols = self.ed_start_projector(special_symbols_features)
|
222 |
+
model_end_features_symbols = self.ed_end_projector(special_symbols_features)
|
|
|
|
|
|
|
223 |
|
224 |
model_ed_features = torch.cat(
|
225 |
[model_start_features, model_end_features], dim=-1
|
226 |
)
|
227 |
+
special_symbols_representation = torch.cat(
|
228 |
+
[model_start_features_symbols, model_end_features_symbols], dim=-1
|
|
|
|
|
|
|
229 |
)
|
230 |
|
231 |
logits = torch.bmm(
|
|
|
233 |
torch.permute(special_symbols_representation, (0, 2, 1)),
|
234 |
)
|
235 |
|
236 |
+
logits = self._mask_logits(
|
237 |
+
logits, (model_features_start == -100).all(2).long()
|
238 |
+
)
|
239 |
return logits
|
240 |
|
241 |
def forward(
|
|
|
279 |
else torch.zeros_like(input_ids)
|
280 |
),
|
281 |
)
|
|
|
282 |
ned_start_predictions[ned_start_predictions > 0] = 1
|
283 |
+
ned_end_predictions[end_labels > 0] = 1
|
284 |
+
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
|
285 |
|
286 |
else: # compute spans
|
287 |
# start boundary prediction
|
|
|
309 |
|
310 |
if ned_end_logits is not None:
|
311 |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
312 |
+
if not self.config.binary_end_logits:
|
313 |
+
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1, keepdim=True)
|
314 |
+
ned_end_predictions = torch.zeros_like(ned_end_probabilities).scatter_(1, ned_end_predictions, 1)
|
315 |
+
else:
|
316 |
+
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
|
317 |
else:
|
318 |
ned_end_logits, ned_end_probabilities = None, None
|
319 |
+
ned_end_predictions = ned_start_predictions.new_zeros(batch_size, seq_len)
|
320 |
+
|
321 |
+
if not self.training:
|
322 |
+
# if len(ned_end_predictions.shape) < 2:
|
323 |
+
# print(ned_end_predictions)
|
324 |
+
end_preds_count = ned_end_predictions.sum(1)
|
325 |
+
# If there are no end predictions for a start prediction, remove the start prediction
|
326 |
+
if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
|
327 |
+
ned_start_predictions[ned_start_predictions == 1] = (
|
328 |
+
end_preds_count != 0
|
329 |
+
).long()
|
330 |
+
ned_end_predictions = ned_end_predictions[end_preds_count != 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
+
if end_labels is not None:
|
333 |
+
end_labels = end_labels[~(end_labels == -100).all(2)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
start_position, end_position = (
|
336 |
(start_labels, end_labels)
|
337 |
if self.training
|
338 |
else (ned_start_predictions, ned_end_predictions)
|
339 |
)
|
340 |
+
start_counts = (start_position > 0).sum(1)
|
341 |
+
if (start_counts > 0).any():
|
342 |
+
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
|
343 |
# Entity disambiguation
|
344 |
+
if (end_position > 0).sum() > 0:
|
345 |
+
ends_count = (end_position > 0).sum(1)
|
346 |
+
model_entity_start = torch.repeat_interleave(
|
347 |
+
model_features[start_position > 0], ends_count, dim=0
|
348 |
+
)
|
349 |
+
model_entity_end = torch.repeat_interleave(
|
350 |
+
model_features, start_counts, dim=0)[
|
351 |
+
end_position > 0
|
352 |
+
]
|
353 |
+
ents_count = torch.nn.utils.rnn.pad_sequence(
|
354 |
+
torch.split(ends_count, start_counts.tolist()),
|
355 |
+
batch_first=True,
|
356 |
+
padding_value=0,
|
357 |
+
).sum(1)
|
358 |
|
359 |
+
model_entity_start = torch.nn.utils.rnn.pad_sequence(
|
360 |
+
torch.split(model_entity_start, ents_count.tolist()),
|
361 |
+
batch_first=True,
|
362 |
+
padding_value=-100,
|
363 |
+
)
|
364 |
+
|
365 |
+
model_entity_end = torch.nn.utils.rnn.pad_sequence(
|
366 |
+
torch.split(model_entity_end, ents_count.tolist()),
|
367 |
+
batch_first=True,
|
368 |
+
padding_value=-100,
|
369 |
+
)
|
370 |
+
|
371 |
+
ed_logits = self.compute_classification_logits(
|
372 |
+
model_entity_start,
|
373 |
+
model_entity_end,
|
374 |
+
model_features[special_symbols_mask].view(
|
375 |
+
batch_size, -1, model_features.shape[-1]
|
376 |
+
),
|
377 |
+
)
|
378 |
+
ed_probabilities = torch.softmax(ed_logits, dim=-1)
|
379 |
+
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
|
380 |
+
else:
|
381 |
+
ed_logits, ed_probabilities, ed_predictions = (
|
382 |
+
None,
|
383 |
+
ned_start_predictions.new_zeros(batch_size, seq_len),
|
384 |
+
ned_start_predictions.new_zeros(batch_size),
|
385 |
+
)
|
386 |
# output build
|
387 |
output_dict = dict(
|
388 |
batch_size=batch_size,
|
|
|
411 |
ned_start_loss = 0
|
412 |
|
413 |
# end
|
414 |
+
# use ents_count to assign the labels to the correct positions i.e. using end_labels -> [[0,0,4,0], [0,0,0,2]] -> [4,2] (this is just an element, for batch we need to mask it with ents_count), ie -> [[4,2,-100,-100], [3,1,2,-100], [1,3,2,5]]
|
|
|
|
|
|
|
415 |
|
416 |
+
if ned_end_logits is not None:
|
417 |
+
ed_labels = end_labels.clone()
|
418 |
+
ed_labels = torch.nn.utils.rnn.pad_sequence(
|
419 |
+
torch.split(ed_labels[ed_labels > 0], ents_count.tolist()),
|
420 |
+
batch_first=True,
|
421 |
+
padding_value=-100,
|
422 |
+
)
|
423 |
+
end_labels[end_labels > 0] = 1
|
424 |
+
if not self.config.binary_end_logits:
|
425 |
+
# transform label to position in the sequence
|
426 |
+
end_labels = end_labels.argmax(dim=-1)
|
427 |
+
ned_end_loss = self.criterion(
|
428 |
+
ned_end_logits.view(-1, ned_end_logits.shape[-1]),
|
429 |
+
end_labels.view(-1),
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
ned_end_loss = self.criterion(ned_end_logits.reshape(-1, ned_end_logits.shape[-1]), end_labels.reshape(-1).long())
|
433 |
+
|
434 |
+
# entity disambiguation loss
|
435 |
+
ed_loss = self.criterion(
|
436 |
+
ed_logits.view(-1, ed_logits.shape[-1]),
|
437 |
+
ed_labels.view(-1).long(),
|
438 |
)
|
439 |
|
440 |
else:
|
441 |
ned_end_loss = 0
|
442 |
+
ed_loss = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
|
444 |
output_dict["ned_start_loss"] = ned_start_loss
|
445 |
output_dict["ned_end_loss"] = ned_end_loss
|
|
|
485 |
)
|
486 |
|
487 |
if self.config.entity_type_loss and self.config.add_entity_embedding:
|
488 |
+
input_hidden_ents = 3
|
489 |
else:
|
490 |
+
input_hidden_ents = 2
|
491 |
|
492 |
self.re_projector = self._get_projection_layer(
|
493 |
+
config.activation,
|
494 |
+
input_hidden=input_hidden_ents * self.transformer_model.config.hidden_size,
|
495 |
+
hidden=input_hidden_ents * self.config.linears_hidden_size,
|
496 |
+
last_hidden=2 * self.config.linears_hidden_size,
|
497 |
)
|
498 |
|
499 |
self.re_relation_projector = self._get_projection_layer(
|
500 |
+
config.activation,
|
501 |
+
input_hidden=self.transformer_model.config.hidden_size,
|
502 |
)
|
503 |
|
504 |
if self.config.entity_type_loss or self.relation_disambiguation_loss:
|
|
|
744 |
*args,
|
745 |
**kwargs,
|
746 |
) -> Dict[str, Any]:
|
747 |
+
relation_threshold = (
|
748 |
+
self.config.threshold if relation_threshold is None else relation_threshold
|
749 |
+
)
|
750 |
|
751 |
batch_size = input_ids.shape[0]
|
752 |
|
|
|
920 |
# we set a thresshold instead of argmax in cause it needs to be tweaked
|
921 |
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
|
922 |
re_probabilities = re_probabilities[:, :, :, :, 1]
|
|
|
923 |
else:
|
924 |
(
|
925 |
ned_type_logits,
|