Update modeling_relik.py
Browse files- modeling_relik.py +25 -17
modeling_relik.py
CHANGED
@@ -233,9 +233,7 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
233 |
torch.permute(special_symbols_representation, (0, 2, 1)),
|
234 |
)
|
235 |
|
236 |
-
logits = self._mask_logits(
|
237 |
-
logits, (model_features_start == -100).all(2).long()
|
238 |
-
)
|
239 |
return logits
|
240 |
|
241 |
def forward(
|
@@ -280,7 +278,7 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
280 |
),
|
281 |
)
|
282 |
ned_start_predictions[ned_start_predictions > 0] = 1
|
283 |
-
ned_end_predictions[end_labels > 0] = 1
|
284 |
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
|
285 |
|
286 |
else: # compute spans
|
@@ -310,14 +308,20 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
310 |
if ned_end_logits is not None:
|
311 |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
312 |
if not self.config.binary_end_logits:
|
313 |
-
ned_end_predictions = torch.argmax(
|
314 |
-
|
|
|
|
|
|
|
|
|
315 |
else:
|
316 |
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
|
317 |
else:
|
318 |
ned_end_logits, ned_end_probabilities = None, None
|
319 |
-
ned_end_predictions = ned_start_predictions.new_zeros(
|
320 |
-
|
|
|
|
|
321 |
if not self.training:
|
322 |
# if len(ned_end_predictions.shape) < 2:
|
323 |
# print(ned_end_predictions)
|
@@ -344,12 +348,11 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
344 |
if (end_position > 0).sum() > 0:
|
345 |
ends_count = (end_position > 0).sum(1)
|
346 |
model_entity_start = torch.repeat_interleave(
|
347 |
-
|
348 |
-
|
349 |
model_entity_end = torch.repeat_interleave(
|
350 |
-
|
351 |
-
|
352 |
-
]
|
353 |
ents_count = torch.nn.utils.rnn.pad_sequence(
|
354 |
torch.split(ends_count, start_counts.tolist()),
|
355 |
batch_first=True,
|
@@ -379,7 +382,7 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
379 |
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
|
380 |
else:
|
381 |
ed_logits, ed_probabilities, ed_predictions = (
|
382 |
-
None,
|
383 |
ned_start_predictions.new_zeros(batch_size, seq_len),
|
384 |
ned_start_predictions.new_zeros(batch_size),
|
385 |
)
|
@@ -429,8 +432,11 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
429 |
end_labels.view(-1),
|
430 |
)
|
431 |
else:
|
432 |
-
ned_end_loss = self.criterion(
|
433 |
-
|
|
|
|
|
|
|
434 |
# entity disambiguation loss
|
435 |
ed_loss = self.criterion(
|
436 |
ed_logits.view(-1, ed_logits.shape[-1]),
|
@@ -833,6 +839,8 @@ class RelikReaderREModel(PreTrainedModel):
|
|
833 |
start_counts = (start_position > 0).sum(1)
|
834 |
if (start_counts > 0).any():
|
835 |
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
|
|
|
|
|
836 |
# limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
|
837 |
# if is_validation or is_prediction:
|
838 |
# ned_start_predictions[ned_start_predictions == 1] = start_counts
|
@@ -996,4 +1004,4 @@ class RelikReaderREModel(PreTrainedModel):
|
|
996 |
output_dict["ned_end_loss"] = ned_end_loss
|
997 |
output_dict["re_loss"] = relation_loss
|
998 |
|
999 |
-
return output_dict
|
|
|
233 |
torch.permute(special_symbols_representation, (0, 2, 1)),
|
234 |
)
|
235 |
|
236 |
+
logits = self._mask_logits(logits, (model_features_start == -100).all(2).long())
|
|
|
|
|
237 |
return logits
|
238 |
|
239 |
def forward(
|
|
|
278 |
),
|
279 |
)
|
280 |
ned_start_predictions[ned_start_predictions > 0] = 1
|
281 |
+
ned_end_predictions[end_labels > 0] = 1
|
282 |
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
|
283 |
|
284 |
else: # compute spans
|
|
|
308 |
if ned_end_logits is not None:
|
309 |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
310 |
if not self.config.binary_end_logits:
|
311 |
+
ned_end_predictions = torch.argmax(
|
312 |
+
ned_end_probabilities, dim=-1, keepdim=True
|
313 |
+
)
|
314 |
+
ned_end_predictions = torch.zeros_like(
|
315 |
+
ned_end_probabilities
|
316 |
+
).scatter_(1, ned_end_predictions, 1)
|
317 |
else:
|
318 |
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
|
319 |
else:
|
320 |
ned_end_logits, ned_end_probabilities = None, None
|
321 |
+
ned_end_predictions = ned_start_predictions.new_zeros(
|
322 |
+
batch_size, seq_len
|
323 |
+
)
|
324 |
+
|
325 |
if not self.training:
|
326 |
# if len(ned_end_predictions.shape) < 2:
|
327 |
# print(ned_end_predictions)
|
|
|
348 |
if (end_position > 0).sum() > 0:
|
349 |
ends_count = (end_position > 0).sum(1)
|
350 |
model_entity_start = torch.repeat_interleave(
|
351 |
+
model_features[start_position > 0], ends_count, dim=0
|
352 |
+
)
|
353 |
model_entity_end = torch.repeat_interleave(
|
354 |
+
model_features, start_counts, dim=0
|
355 |
+
)[end_position > 0]
|
|
|
356 |
ents_count = torch.nn.utils.rnn.pad_sequence(
|
357 |
torch.split(ends_count, start_counts.tolist()),
|
358 |
batch_first=True,
|
|
|
382 |
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
|
383 |
else:
|
384 |
ed_logits, ed_probabilities, ed_predictions = (
|
385 |
+
None,
|
386 |
ned_start_predictions.new_zeros(batch_size, seq_len),
|
387 |
ned_start_predictions.new_zeros(batch_size),
|
388 |
)
|
|
|
432 |
end_labels.view(-1),
|
433 |
)
|
434 |
else:
|
435 |
+
ned_end_loss = self.criterion(
|
436 |
+
ned_end_logits.reshape(-1, ned_end_logits.shape[-1]),
|
437 |
+
end_labels.reshape(-1).long(),
|
438 |
+
)
|
439 |
+
|
440 |
# entity disambiguation loss
|
441 |
ed_loss = self.criterion(
|
442 |
ed_logits.view(-1, ed_logits.shape[-1]),
|
|
|
839 |
start_counts = (start_position > 0).sum(1)
|
840 |
if (start_counts > 0).any():
|
841 |
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
|
842 |
+
else:
|
843 |
+
ned_end_predictions = [torch.empty(0, input_ids.shape[1], dtype=torch.int64) for _ in range(batch_size)]
|
844 |
# limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
|
845 |
# if is_validation or is_prediction:
|
846 |
# ned_start_predictions[ned_start_predictions == 1] = start_counts
|
|
|
1004 |
output_dict["ned_end_loss"] = ned_end_loss
|
1005 |
output_dict["re_loss"] = relation_loss
|
1006 |
|
1007 |
+
return output_dict
|