relik-ie commited on
Commit
9ed9a98
1 Parent(s): ec59143

Update modeling_relik.py

Browse files
Files changed (1) hide show
  1. modeling_relik.py +24 -16
modeling_relik.py CHANGED
@@ -233,9 +233,7 @@ class RelikReaderSpanModel(PreTrainedModel):
233
  torch.permute(special_symbols_representation, (0, 2, 1)),
234
  )
235
 
236
- logits = self._mask_logits(
237
- logits, (model_features_start == -100).all(2).long()
238
- )
239
  return logits
240
 
241
  def forward(
@@ -280,7 +278,7 @@ class RelikReaderSpanModel(PreTrainedModel):
280
  ),
281
  )
282
  ned_start_predictions[ned_start_predictions > 0] = 1
283
- ned_end_predictions[end_labels > 0] = 1
284
  ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
285
 
286
  else: # compute spans
@@ -310,14 +308,20 @@ class RelikReaderSpanModel(PreTrainedModel):
310
  if ned_end_logits is not None:
311
  ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
312
  if not self.config.binary_end_logits:
313
- ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1, keepdim=True)
314
- ned_end_predictions = torch.zeros_like(ned_end_probabilities).scatter_(1, ned_end_predictions, 1)
 
 
 
 
315
  else:
316
  ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
317
  else:
318
  ned_end_logits, ned_end_probabilities = None, None
319
- ned_end_predictions = ned_start_predictions.new_zeros(batch_size, seq_len)
320
-
 
 
321
  if not self.training:
322
  # if len(ned_end_predictions.shape) < 2:
323
  # print(ned_end_predictions)
@@ -344,12 +348,11 @@ class RelikReaderSpanModel(PreTrainedModel):
344
  if (end_position > 0).sum() > 0:
345
  ends_count = (end_position > 0).sum(1)
346
  model_entity_start = torch.repeat_interleave(
347
- model_features[start_position > 0], ends_count, dim=0
348
- )
349
  model_entity_end = torch.repeat_interleave(
350
- model_features, start_counts, dim=0)[
351
- end_position > 0
352
- ]
353
  ents_count = torch.nn.utils.rnn.pad_sequence(
354
  torch.split(ends_count, start_counts.tolist()),
355
  batch_first=True,
@@ -379,7 +382,7 @@ class RelikReaderSpanModel(PreTrainedModel):
379
  ed_predictions = torch.argmax(ed_probabilities, dim=-1)
380
  else:
381
  ed_logits, ed_probabilities, ed_predictions = (
382
- None,
383
  ned_start_predictions.new_zeros(batch_size, seq_len),
384
  ned_start_predictions.new_zeros(batch_size),
385
  )
@@ -429,8 +432,11 @@ class RelikReaderSpanModel(PreTrainedModel):
429
  end_labels.view(-1),
430
  )
431
  else:
432
- ned_end_loss = self.criterion(ned_end_logits.reshape(-1, ned_end_logits.shape[-1]), end_labels.reshape(-1).long())
433
-
 
 
 
434
  # entity disambiguation loss
435
  ed_loss = self.criterion(
436
  ed_logits.view(-1, ed_logits.shape[-1]),
@@ -833,6 +839,8 @@ class RelikReaderREModel(PreTrainedModel):
833
  start_counts = (start_position > 0).sum(1)
834
  if (start_counts > 0).any():
835
  ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
 
 
836
  # limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
837
  # if is_validation or is_prediction:
838
  # ned_start_predictions[ned_start_predictions == 1] = start_counts
 
233
  torch.permute(special_symbols_representation, (0, 2, 1)),
234
  )
235
 
236
+ logits = self._mask_logits(logits, (model_features_start == -100).all(2).long())
 
 
237
  return logits
238
 
239
  def forward(
 
278
  ),
279
  )
280
  ned_start_predictions[ned_start_predictions > 0] = 1
281
+ ned_end_predictions[end_labels > 0] = 1
282
  ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
283
 
284
  else: # compute spans
 
308
  if ned_end_logits is not None:
309
  ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
310
  if not self.config.binary_end_logits:
311
+ ned_end_predictions = torch.argmax(
312
+ ned_end_probabilities, dim=-1, keepdim=True
313
+ )
314
+ ned_end_predictions = torch.zeros_like(
315
+ ned_end_probabilities
316
+ ).scatter_(1, ned_end_predictions, 1)
317
  else:
318
  ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
319
  else:
320
  ned_end_logits, ned_end_probabilities = None, None
321
+ ned_end_predictions = ned_start_predictions.new_zeros(
322
+ batch_size, seq_len
323
+ )
324
+
325
  if not self.training:
326
  # if len(ned_end_predictions.shape) < 2:
327
  # print(ned_end_predictions)
 
348
  if (end_position > 0).sum() > 0:
349
  ends_count = (end_position > 0).sum(1)
350
  model_entity_start = torch.repeat_interleave(
351
+ model_features[start_position > 0], ends_count, dim=0
352
+ )
353
  model_entity_end = torch.repeat_interleave(
354
+ model_features, start_counts, dim=0
355
+ )[end_position > 0]
 
356
  ents_count = torch.nn.utils.rnn.pad_sequence(
357
  torch.split(ends_count, start_counts.tolist()),
358
  batch_first=True,
 
382
  ed_predictions = torch.argmax(ed_probabilities, dim=-1)
383
  else:
384
  ed_logits, ed_probabilities, ed_predictions = (
385
+ None,
386
  ned_start_predictions.new_zeros(batch_size, seq_len),
387
  ned_start_predictions.new_zeros(batch_size),
388
  )
 
432
  end_labels.view(-1),
433
  )
434
  else:
435
+ ned_end_loss = self.criterion(
436
+ ned_end_logits.reshape(-1, ned_end_logits.shape[-1]),
437
+ end_labels.reshape(-1).long(),
438
+ )
439
+
440
  # entity disambiguation loss
441
  ed_loss = self.criterion(
442
  ed_logits.view(-1, ed_logits.shape[-1]),
 
839
  start_counts = (start_position > 0).sum(1)
840
  if (start_counts > 0).any():
841
  ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
842
+ else:
843
+ ned_end_predictions = [torch.empty(0, input_ids.shape[1], dtype=torch.int64) for _ in range(batch_size)]
844
  # limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
845
  # if is_validation or is_prediction:
846
  # ned_start_predictions[ned_start_predictions == 1] = start_counts