PereLluis13 commited on
Commit
fbb76b5
1 Parent(s): f324b1d

Update modeling_relik.py

Browse files
Files changed (1) hide show
  1. modeling_relik.py +116 -98
modeling_relik.py CHANGED
@@ -92,7 +92,10 @@ class RelikReaderSpanModel(PreTrainedModel):
92
  self.ned_start_classifier = self._get_projection_layer(
93
  self.activation, last_hidden=2, layer_norm=False
94
  )
95
- self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
 
 
 
96
 
97
  # END entity disambiguation layer
98
  self.ed_start_projector = self._get_projection_layer(self.activation)
@@ -209,29 +212,20 @@ class RelikReaderSpanModel(PreTrainedModel):
209
 
210
  def compute_classification_logits(
211
  self,
212
- model_features,
213
- special_symbols_mask,
214
- prediction_mask,
215
- batch_size,
216
- start_positions=None,
217
- end_positions=None,
218
  ) -> torch.Tensor:
219
- if start_positions is None or end_positions is None:
220
- start_positions = torch.zeros_like(prediction_mask)
221
- end_positions = torch.zeros_like(prediction_mask)
222
-
223
- model_start_features = self.ed_start_projector(model_features)
224
- model_end_features = self.ed_end_projector(model_features)
225
- model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
226
 
227
  model_ed_features = torch.cat(
228
  [model_start_features, model_end_features], dim=-1
229
  )
230
-
231
- # computing ed features
232
- classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
233
- special_symbols_representation = model_ed_features[special_symbols_mask].view(
234
- batch_size, classes_representations, -1
235
  )
236
 
237
  logits = torch.bmm(
@@ -239,8 +233,9 @@ class RelikReaderSpanModel(PreTrainedModel):
239
  torch.permute(special_symbols_representation, (0, 2, 1)),
240
  )
241
 
242
- logits = self._mask_logits(logits, prediction_mask)
243
-
 
244
  return logits
245
 
246
  def forward(
@@ -284,9 +279,9 @@ class RelikReaderSpanModel(PreTrainedModel):
284
  else torch.zeros_like(input_ids)
285
  ),
286
  )
287
-
288
  ned_start_predictions[ned_start_predictions > 0] = 1
289
- ned_end_predictions[ned_end_predictions > 0] = 1
 
290
 
291
  else: # compute spans
292
  # start boundary prediction
@@ -314,63 +309,80 @@ class RelikReaderSpanModel(PreTrainedModel):
314
 
315
  if ned_end_logits is not None:
316
  ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
317
- ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
 
 
 
 
318
  else:
319
  ned_end_logits, ned_end_probabilities = None, None
320
- ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
321
-
322
- # flattening end predictions
323
- # (flattening can happen only if the
324
- # end boundaries were not predicted using the gold labels)
325
- if not self.training and ned_end_logits is not None:
326
- flattened_end_predictions = torch.zeros_like(ned_start_predictions)
327
-
328
- row_indices, start_positions = torch.where(ned_start_predictions > 0)
329
- ned_end_predictions[ned_end_predictions<start_positions] = start_positions[ned_end_predictions<start_positions]
330
-
331
- end_spans_repeated = (row_indices + 1)* seq_len + ned_end_predictions
332
- cummax_values, _ = end_spans_repeated.cummax(dim=0)
333
-
334
- end_spans_repeated = (end_spans_repeated > torch.cat((end_spans_repeated[:1], cummax_values[:-1])))
335
- end_spans_repeated[0] = True
336
-
337
- ned_start_predictions[row_indices[~end_spans_repeated], start_positions[~end_spans_repeated]] = 0
338
-
339
- row_indices, start_positions, ned_end_predictions = row_indices[end_spans_repeated], start_positions[end_spans_repeated], ned_end_predictions[end_spans_repeated]
340
-
341
- flattened_end_predictions[row_indices, ned_end_predictions] = 1
342
-
343
- total_start_predictions, total_end_predictions = ned_start_predictions.sum(), flattened_end_predictions.sum()
344
 
345
- assert (
346
- total_start_predictions == 0
347
- or total_start_predictions == total_end_predictions
348
- ), (
349
- f"Total number of start predictions = {total_start_predictions}. "
350
- f"Total number of end predictions = {total_end_predictions}"
351
- )
352
- ned_end_predictions = flattened_end_predictions
353
- else:
354
- ned_end_predictions = torch.zeros_like(ned_start_predictions)
355
 
356
  start_position, end_position = (
357
  (start_labels, end_labels)
358
  if self.training
359
  else (ned_start_predictions, ned_end_predictions)
360
  )
361
-
 
 
362
  # Entity disambiguation
363
- ed_logits = self.compute_classification_logits(
364
- model_features,
365
- special_symbols_mask,
366
- prediction_mask,
367
- batch_size,
368
- start_position,
369
- end_position,
370
- )
371
- ed_probabilities = torch.softmax(ed_logits, dim=-1)
372
- ed_predictions = torch.argmax(ed_probabilities, dim=-1)
 
 
 
 
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  # output build
375
  output_dict = dict(
376
  batch_size=batch_size,
@@ -399,33 +411,35 @@ class RelikReaderSpanModel(PreTrainedModel):
399
  ned_start_loss = 0
400
 
401
  # end
402
- if ned_end_logits is not None:
403
- ned_end_labels = torch.zeros_like(end_labels)
404
- ned_end_labels[end_labels == -100] = -100
405
- ned_end_labels[end_labels > 0] = 1
406
 
407
- ned_end_loss = self.criterion(
408
- ned_end_logits,
409
- (
410
- torch.arange(
411
- ned_end_labels.size(1), device=ned_end_labels.device
412
- )
413
- .unsqueeze(0)
414
- .expand(batch_size, -1)[ned_end_labels > 0]
415
- ).to(ned_end_labels.device),
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  )
417
 
418
  else:
419
  ned_end_loss = 0
420
-
421
- # entity disambiguation loss
422
- start_labels[ned_start_labels != 1] = -100
423
- ed_labels = torch.clone(start_labels)
424
- ed_labels[end_labels > 0] = end_labels[end_labels > 0]
425
- ed_loss = self.criterion(
426
- ed_logits.view(-1, ed_logits.shape[-1]),
427
- ed_labels.view(-1),
428
- )
429
 
430
  output_dict["ned_start_loss"] = ned_start_loss
431
  output_dict["ned_end_loss"] = ned_end_loss
@@ -471,16 +485,20 @@ class RelikReaderREModel(PreTrainedModel):
471
  )
472
 
473
  if self.config.entity_type_loss and self.config.add_entity_embedding:
474
- input_hidden_ents = 3 * self.config.linears_hidden_size
475
  else:
476
- input_hidden_ents = 2 * self.config.linears_hidden_size
477
 
478
  self.re_projector = self._get_projection_layer(
479
- config.activation, input_hidden=2*self.transformer_model.config.hidden_size, hidden=input_hidden_ents, last_hidden=2*self.config.linears_hidden_size
 
 
 
480
  )
481
 
482
  self.re_relation_projector = self._get_projection_layer(
483
- config.activation, input_hidden=self.transformer_model.config.hidden_size,
 
484
  )
485
 
486
  if self.config.entity_type_loss or self.relation_disambiguation_loss:
@@ -726,8 +744,9 @@ class RelikReaderREModel(PreTrainedModel):
726
  *args,
727
  **kwargs,
728
  ) -> Dict[str, Any]:
729
-
730
- relation_threshold = self.config.threshold if relation_threshold is None else relation_threshold
 
731
 
732
  batch_size = input_ids.shape[0]
733
 
@@ -901,7 +920,6 @@ class RelikReaderREModel(PreTrainedModel):
901
  # we set a thresshold instead of argmax in cause it needs to be tweaked
902
  re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
903
  re_probabilities = re_probabilities[:, :, :, :, 1]
904
-
905
  else:
906
  (
907
  ned_type_logits,
 
92
  self.ned_start_classifier = self._get_projection_layer(
93
  self.activation, last_hidden=2, layer_norm=False
94
  )
95
+ if self.config.binary_end_logits:
96
+ self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
97
+ else:
98
+ self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
99
 
100
  # END entity disambiguation layer
101
  self.ed_start_projector = self._get_projection_layer(self.activation)
 
212
 
213
  def compute_classification_logits(
214
  self,
215
+ model_features_start,
216
+ model_features_end,
217
+ special_symbols_features,
 
 
 
218
  ) -> torch.Tensor:
219
+ model_start_features = self.ed_start_projector(model_features_start)
220
+ model_end_features = self.ed_end_projector(model_features_end)
221
+ model_start_features_symbols = self.ed_start_projector(special_symbols_features)
222
+ model_end_features_symbols = self.ed_end_projector(special_symbols_features)
 
 
 
223
 
224
  model_ed_features = torch.cat(
225
  [model_start_features, model_end_features], dim=-1
226
  )
227
+ special_symbols_representation = torch.cat(
228
+ [model_start_features_symbols, model_end_features_symbols], dim=-1
 
 
 
229
  )
230
 
231
  logits = torch.bmm(
 
233
  torch.permute(special_symbols_representation, (0, 2, 1)),
234
  )
235
 
236
+ logits = self._mask_logits(
237
+ logits, (model_features_start == -100).all(2).long()
238
+ )
239
  return logits
240
 
241
  def forward(
 
279
  else torch.zeros_like(input_ids)
280
  ),
281
  )
 
282
  ned_start_predictions[ned_start_predictions > 0] = 1
283
+ ned_end_predictions[end_labels > 0] = 1
284
+ ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
285
 
286
  else: # compute spans
287
  # start boundary prediction
 
309
 
310
  if ned_end_logits is not None:
311
  ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
312
+ if not self.config.binary_end_logits:
313
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1, keepdim=True)
314
+ ned_end_predictions = torch.zeros_like(ned_end_probabilities).scatter_(1, ned_end_predictions, 1)
315
+ else:
316
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
317
  else:
318
  ned_end_logits, ned_end_probabilities = None, None
319
+ ned_end_predictions = ned_start_predictions.new_zeros(batch_size, seq_len)
320
+
321
+ if not self.training:
322
+ # if len(ned_end_predictions.shape) < 2:
323
+ # print(ned_end_predictions)
324
+ end_preds_count = ned_end_predictions.sum(1)
325
+ # If there are no end predictions for a start prediction, remove the start prediction
326
+ if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
327
+ ned_start_predictions[ned_start_predictions == 1] = (
328
+ end_preds_count != 0
329
+ ).long()
330
+ ned_end_predictions = ned_end_predictions[end_preds_count != 0]
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ if end_labels is not None:
333
+ end_labels = end_labels[~(end_labels == -100).all(2)]
 
 
 
 
 
 
 
 
334
 
335
  start_position, end_position = (
336
  (start_labels, end_labels)
337
  if self.training
338
  else (ned_start_predictions, ned_end_predictions)
339
  )
340
+ start_counts = (start_position > 0).sum(1)
341
+ if (start_counts > 0).any():
342
+ ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
343
  # Entity disambiguation
344
+ if (end_position > 0).sum() > 0:
345
+ ends_count = (end_position > 0).sum(1)
346
+ model_entity_start = torch.repeat_interleave(
347
+ model_features[start_position > 0], ends_count, dim=0
348
+ )
349
+ model_entity_end = torch.repeat_interleave(
350
+ model_features, start_counts, dim=0)[
351
+ end_position > 0
352
+ ]
353
+ ents_count = torch.nn.utils.rnn.pad_sequence(
354
+ torch.split(ends_count, start_counts.tolist()),
355
+ batch_first=True,
356
+ padding_value=0,
357
+ ).sum(1)
358
 
359
+ model_entity_start = torch.nn.utils.rnn.pad_sequence(
360
+ torch.split(model_entity_start, ents_count.tolist()),
361
+ batch_first=True,
362
+ padding_value=-100,
363
+ )
364
+
365
+ model_entity_end = torch.nn.utils.rnn.pad_sequence(
366
+ torch.split(model_entity_end, ents_count.tolist()),
367
+ batch_first=True,
368
+ padding_value=-100,
369
+ )
370
+
371
+ ed_logits = self.compute_classification_logits(
372
+ model_entity_start,
373
+ model_entity_end,
374
+ model_features[special_symbols_mask].view(
375
+ batch_size, -1, model_features.shape[-1]
376
+ ),
377
+ )
378
+ ed_probabilities = torch.softmax(ed_logits, dim=-1)
379
+ ed_predictions = torch.argmax(ed_probabilities, dim=-1)
380
+ else:
381
+ ed_logits, ed_probabilities, ed_predictions = (
382
+ None,
383
+ ned_start_predictions.new_zeros(batch_size, seq_len),
384
+ ned_start_predictions.new_zeros(batch_size),
385
+ )
386
  # output build
387
  output_dict = dict(
388
  batch_size=batch_size,
 
411
  ned_start_loss = 0
412
 
413
  # end
414
+ # use ents_count to assign the labels to the correct positions i.e. using end_labels -> [[0,0,4,0], [0,0,0,2]] -> [4,2] (this is just an element, for batch we need to mask it with ents_count), ie -> [[4,2,-100,-100], [3,1,2,-100], [1,3,2,5]]
 
 
 
415
 
416
+ if ned_end_logits is not None:
417
+ ed_labels = end_labels.clone()
418
+ ed_labels = torch.nn.utils.rnn.pad_sequence(
419
+ torch.split(ed_labels[ed_labels > 0], ents_count.tolist()),
420
+ batch_first=True,
421
+ padding_value=-100,
422
+ )
423
+ end_labels[end_labels > 0] = 1
424
+ if not self.config.binary_end_logits:
425
+ # transform label to position in the sequence
426
+ end_labels = end_labels.argmax(dim=-1)
427
+ ned_end_loss = self.criterion(
428
+ ned_end_logits.view(-1, ned_end_logits.shape[-1]),
429
+ end_labels.view(-1),
430
+ )
431
+ else:
432
+ ned_end_loss = self.criterion(ned_end_logits.reshape(-1, ned_end_logits.shape[-1]), end_labels.reshape(-1).long())
433
+
434
+ # entity disambiguation loss
435
+ ed_loss = self.criterion(
436
+ ed_logits.view(-1, ed_logits.shape[-1]),
437
+ ed_labels.view(-1).long(),
438
  )
439
 
440
  else:
441
  ned_end_loss = 0
442
+ ed_loss = 0
 
 
 
 
 
 
 
 
443
 
444
  output_dict["ned_start_loss"] = ned_start_loss
445
  output_dict["ned_end_loss"] = ned_end_loss
 
485
  )
486
 
487
  if self.config.entity_type_loss and self.config.add_entity_embedding:
488
+ input_hidden_ents = 3
489
  else:
490
+ input_hidden_ents = 2
491
 
492
  self.re_projector = self._get_projection_layer(
493
+ config.activation,
494
+ input_hidden=input_hidden_ents * self.transformer_model.config.hidden_size,
495
+ hidden=input_hidden_ents * self.config.linears_hidden_size,
496
+ last_hidden=2 * self.config.linears_hidden_size,
497
  )
498
 
499
  self.re_relation_projector = self._get_projection_layer(
500
+ config.activation,
501
+ input_hidden=self.transformer_model.config.hidden_size,
502
  )
503
 
504
  if self.config.entity_type_loss or self.relation_disambiguation_loss:
 
744
  *args,
745
  **kwargs,
746
  ) -> Dict[str, Any]:
747
+ relation_threshold = (
748
+ self.config.threshold if relation_threshold is None else relation_threshold
749
+ )
750
 
751
  batch_size = input_ids.shape[0]
752
 
 
920
  # we set a thresshold instead of argmax in cause it needs to be tweaked
921
  re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
922
  re_probabilities = re_probabilities[:, :, :, :, 1]
 
923
  else:
924
  (
925
  ned_type_logits,