dfki-nlp commited on
Commit
386fb69
1 Parent(s): 7b69b0e

Upload transformer_re_text_classification2.py

Browse files
transformer_re_text_classification2.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ workflow:
3
+ Document
4
+ -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding
5
+ -> ModelBatchEncoding -> ModelBatchOutput
6
+ -> TaskOutput
7
+ -> Document
8
+ """
9
+
10
+ import logging
11
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TypedDict, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ from pytorch_ie.annotations import BinaryRelation, LabeledSpan, MultiLabeledBinaryRelation, Span
16
+ from pytorch_ie.core import TaskEncoding, TaskModule
17
+ from pytorch_ie.documents import TextDocument
18
+ from pytorch_ie.models import (
19
+ TransformerTextClassificationModelBatchOutput,
20
+ TransformerTextClassificationModelStepBatchEncoding,
21
+ )
22
+ from pytorch_ie.utils.span import get_token_slice, is_contained_in
23
+ from pytorch_ie.utils.window import get_window_around_slice
24
+ from transformers import AutoTokenizer
25
+ from transformers.file_utils import PaddingStrategy
26
+ from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
27
+ from typing_extensions import TypeAlias
28
+
29
+ TransformerReTextClassificationInputEncoding2: TypeAlias = Dict[str, Any]
30
+ TransformerReTextClassificationTargetEncoding2: TypeAlias = Sequence[int]
31
+
32
+ TransformerReTextClassificationTaskEncoding2: TypeAlias = TaskEncoding[
33
+ TextDocument,
34
+ TransformerReTextClassificationInputEncoding2,
35
+ TransformerReTextClassificationTargetEncoding2,
36
+ ]
37
+
38
+
39
+ class TransformerReTextClassificationTaskOutput2(TypedDict, total=False):
40
+ labels: Sequence[str]
41
+ probabilities: Sequence[float]
42
+
43
+
44
+ _TransformerReTextClassificationTaskModule2: TypeAlias = TaskModule[
45
+ # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
46
+ TextDocument,
47
+ TransformerReTextClassificationInputEncoding2,
48
+ TransformerReTextClassificationTargetEncoding2,
49
+ TransformerTextClassificationModelStepBatchEncoding,
50
+ TransformerTextClassificationModelBatchOutput,
51
+ TransformerReTextClassificationTaskOutput2,
52
+ ]
53
+
54
+
55
+ HEAD = "head"
56
+ TAIL = "tail"
57
+ START = "start"
58
+ END = "end"
59
+
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ class RelationArgument:
65
+ def __init__(
66
+ self,
67
+ entity: LabeledSpan,
68
+ role: str,
69
+ offsets: Tuple[int, int],
70
+ add_type_to_marker: bool,
71
+ ) -> None:
72
+ self.entity = entity
73
+ self.role = role
74
+ assert self.role in (HEAD, TAIL)
75
+ self.offsets = offsets
76
+ self.add_type_to_marker = add_type_to_marker
77
+
78
+ @property
79
+ def is_head(self) -> bool:
80
+ return self.role == HEAD
81
+
82
+ @property
83
+ def is_tail(self) -> bool:
84
+ return self.role == TAIL
85
+
86
+ @property
87
+ def as_start_marker(self) -> str:
88
+ return self._get_marker(is_start=True)
89
+
90
+ @property
91
+ def as_end_marker(self) -> str:
92
+ return self._get_marker(is_start=False)
93
+
94
+ def _get_marker(self, is_start: bool = True) -> str:
95
+ return f"[{'' if is_start else '/'}{'H' if self.is_head else 'T'}" + (
96
+ f":{self.entity.label}]" if self.add_type_to_marker else "]"
97
+ )
98
+
99
+ @property
100
+ def as_append_marker(self) -> str:
101
+ return f"[{'H' if self.is_head else 'T'}={self.entity.label}]"
102
+
103
+
104
+ def _enumerate_entity_pairs(
105
+ entities: Sequence[Span],
106
+ partition: Optional[Span] = None,
107
+ relations: Optional[Sequence[BinaryRelation]] = None,
108
+ ):
109
+ """Given a list of `entities` iterate all valid pairs of entities, including inverted pairs.
110
+
111
+ If a `partition` is provided, restrict pairs to be contained in that. If `relations` are given,
112
+ return only pairs for which a predefined relation exists (e.g. in the case of relation
113
+ classification for train,val,test splits in supervised datasets).
114
+ """
115
+ existing_head_tail = {(relation.head, relation.tail) for relation in relations or []}
116
+ for head in entities:
117
+ if partition is not None and not is_contained_in(
118
+ (head.start, head.end), (partition.start, partition.end)
119
+ ):
120
+ continue
121
+
122
+ for tail in entities:
123
+ if partition is not None and not is_contained_in(
124
+ (tail.start, tail.end), (partition.start, partition.end)
125
+ ):
126
+ continue
127
+
128
+ if head == tail:
129
+ continue
130
+
131
+ if relations is not None and (head, tail) not in existing_head_tail:
132
+ continue
133
+
134
+ yield head, tail
135
+
136
+
137
+ @TaskModule.register()
138
+ class TransformerRETextClassificationTaskModule2(_TransformerReTextClassificationTaskModule2):
139
+ """Marker based relation extraction. This taskmodule prepares the input token ids in such a way
140
+ that before and after the candidate head and tail entities special marker tokens are inserted.
141
+ Then, the modified token ids can be simply passed into a transformer based text classifier
142
+ model.
143
+
144
+ parameters:
145
+
146
+ partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are
147
+ expected to define partitions of the document that will be processed individually, e.g. sentences
148
+ or sections of the document text.
149
+ none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations.
150
+ Predicted relations with that label will not be added to the document(s).
151
+ max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens
152
+ around the center of head and tail entities and pass only that into the transformer.
153
+ """
154
+
155
+ PREPARED_ATTRIBUTES = ["label_to_id", "entity_labels"]
156
+
157
+ def __init__(
158
+ self,
159
+ tokenizer_name_or_path: str,
160
+ entity_annotation: str = "entities",
161
+ relation_annotation: str = "relations",
162
+ partition_annotation: Optional[str] = None,
163
+ none_label: str = "no_relation",
164
+ padding: Union[bool, str, PaddingStrategy] = True,
165
+ truncation: Union[bool, str, TruncationStrategy] = True,
166
+ max_length: Optional[int] = None,
167
+ pad_to_multiple_of: Optional[int] = None,
168
+ multi_label: bool = False,
169
+ label_to_id: Optional[Dict[str, int]] = None,
170
+ add_type_to_marker: bool = False,
171
+ single_argument_pair: bool = True,
172
+ append_markers: bool = False,
173
+ entity_labels: Optional[List[str]] = None,
174
+ max_window: Optional[int] = None,
175
+ log_first_n_examples: Optional[int] = None,
176
+ **kwargs,
177
+ ) -> None:
178
+ super().__init__(**kwargs)
179
+ self.save_hyperparameters()
180
+
181
+ self.entity_annotation = entity_annotation
182
+ self.relation_annotation = relation_annotation
183
+ self.padding = padding
184
+ self.truncation = truncation
185
+ self.label_to_id = label_to_id or {}
186
+ self.id_to_label = {v: k for k, v in self.label_to_id.items()}
187
+ self.max_length = max_length
188
+ self.pad_to_multiple_of = pad_to_multiple_of
189
+ self.multi_label = multi_label
190
+ self.add_type_to_marker = add_type_to_marker
191
+ self.single_argument_pair = single_argument_pair
192
+ self.append_markers = append_markers
193
+ self.entity_labels = entity_labels
194
+ self.partition_annotation = partition_annotation
195
+ self.none_label = none_label
196
+ self.max_window = max_window
197
+ self.log_first_n_examples = log_first_n_examples
198
+
199
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
200
+
201
+ self.argument_markers = None
202
+
203
+ self._logged_examples_counter = 0
204
+
205
+ def _prepare(self, documents: Sequence[TextDocument]) -> None:
206
+ entity_labels: Set[str] = set()
207
+ relation_labels: Set[str] = set()
208
+ for document in documents:
209
+ entities: Sequence[LabeledSpan] = document[self.entity_annotation]
210
+ relations: Sequence[BinaryRelation] = document[self.relation_annotation]
211
+
212
+ for entity in entities:
213
+ entity_labels.add(entity.label)
214
+
215
+ for relation in relations:
216
+ relation_labels.add(relation.label)
217
+
218
+ if self.none_label in relation_labels:
219
+ relation_labels.remove(self.none_label)
220
+
221
+ self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(relation_labels))}
222
+ self.label_to_id[self.none_label] = 0
223
+
224
+ self.entity_labels = sorted(entity_labels)
225
+
226
+ def _post_prepare(self):
227
+ self.argument_markers = self._initialize_argument_markers()
228
+ self.tokenizer.add_tokens(self.argument_markers, special_tokens=True)
229
+
230
+ self.argument_markers_to_id = {
231
+ marker: self.tokenizer.vocab[marker] for marker in self.argument_markers
232
+ }
233
+ self.sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token]
234
+
235
+ self.id_to_label = {v: k for k, v in self.label_to_id.items()}
236
+
237
+ def _initialize_argument_markers(self) -> List[str]:
238
+ argument_markers: Set[str] = set()
239
+ for arg_type in [HEAD, TAIL]:
240
+ for arg_pos in [START, END]:
241
+ is_head = arg_type == HEAD
242
+ is_start = arg_pos == START
243
+ argument_markers.add(f"[{'' if is_start else '/'}{'H' if is_head else 'T'}]")
244
+ if self.add_type_to_marker:
245
+ for entity_type in self.entity_labels: # type: ignore
246
+ argument_markers.add(
247
+ f"[{'' if is_start else '/'}{'H' if is_head else 'T'}"
248
+ f"{':' + entity_type if self.add_type_to_marker else ''}]"
249
+ )
250
+ if self.append_markers:
251
+ for entity_type in self.entity_labels: # type: ignore
252
+ argument_markers.add(f"[{'H' if is_head else 'T'}={entity_type}]")
253
+
254
+ return sorted(list(argument_markers))
255
+
256
+ def _encode_text(
257
+ self,
258
+ document: TextDocument,
259
+ partition: Optional[Span] = None,
260
+ add_special_tokens: bool = True,
261
+ ) -> BatchEncoding:
262
+ text = (
263
+ document.text[partition.start : partition.end]
264
+ if partition is not None
265
+ else document.text
266
+ )
267
+ encoding = self.tokenizer(
268
+ text,
269
+ padding=False,
270
+ truncation=self.truncation,
271
+ max_length=self.max_length,
272
+ is_split_into_words=False,
273
+ return_offsets_mapping=False,
274
+ add_special_tokens=add_special_tokens,
275
+ )
276
+ return encoding
277
+
278
+ def encode_input(
279
+ self,
280
+ document: TextDocument,
281
+ is_training: bool = False,
282
+ ) -> Optional[
283
+ Union[
284
+ TransformerReTextClassificationTaskEncoding2,
285
+ Sequence[TransformerReTextClassificationTaskEncoding2],
286
+ ]
287
+ ]:
288
+
289
+ assert (
290
+ self.argument_markers is not None
291
+ ), "No argument markers available, was `prepare` already called?"
292
+
293
+ entities: Sequence[Span] = document[self.entity_annotation]
294
+
295
+ relations: Sequence[BinaryRelation] = document[self.relation_annotation]
296
+
297
+ partitions: Sequence[Optional[Span]]
298
+ if self.partition_annotation is not None:
299
+ partitions = document[self.partition_annotation]
300
+ else:
301
+ # use single dummy partition
302
+ partitions = [None]
303
+
304
+ task_encodings: List[TransformerReTextClassificationTaskEncoding2] = []
305
+ for partition_idx, partition in enumerate(partitions):
306
+ partition_offset = 0 if partition is None else partition.start
307
+ add_special_tokens = self.max_window is None
308
+ encoding = self._encode_text(
309
+ document=document, partition=partition, add_special_tokens=add_special_tokens
310
+ )
311
+
312
+ for (head, tail,) in _enumerate_entity_pairs(
313
+ entities=entities,
314
+ partition=partition,
315
+ relations=relations,
316
+ ):
317
+ head_token_slice = get_token_slice(
318
+ character_slice=(head.start, head.end),
319
+ char_to_token_mapper=encoding.char_to_token,
320
+ character_offset=partition_offset,
321
+ )
322
+ tail_token_slice = get_token_slice(
323
+ character_slice=(tail.start, tail.end),
324
+ char_to_token_mapper=encoding.char_to_token,
325
+ character_offset=partition_offset,
326
+ )
327
+ # this happens if the head/tail start/end does not match a token start/end
328
+ if head_token_slice is None or tail_token_slice is None:
329
+ # if statistics is not None:
330
+ # statistics["entity_token_alignment_error"][
331
+ # relation_mapping.get((head, tail), "TO_PREDICT")
332
+ # ] += 1
333
+ logger.warning(
334
+ f"Skipping invalid example {document.id}, cannot get token slice(s)"
335
+ )
336
+ continue
337
+
338
+ input_ids = encoding["input_ids"]
339
+ # not sure if this is the correct way to get the tokens corresponding to the input_ids
340
+ tokens = encoding.encodings[0].tokens
341
+
342
+ # windowing
343
+ if self.max_window is not None:
344
+ head_start, head_end = head_token_slice
345
+ tail_start, tail_end = tail_token_slice
346
+ # The actual number of tokens will be lower than max_window because we add the
347
+ # 4 marker tokens (before / after the head /tail) and the default special tokens
348
+ # (e.g. CLS and SEP).
349
+ num_added_special_tokens = len(
350
+ self.tokenizer.build_inputs_with_special_tokens([])
351
+ )
352
+ max_tokens = self.max_window - 4 - num_added_special_tokens
353
+ # the slice from the beginning of the first entity to the end of the second is required
354
+ slice_required = (min(head_start, tail_start), max(head_end, tail_end))
355
+ window_slice = get_window_around_slice(
356
+ slice=slice_required,
357
+ max_window_size=max_tokens,
358
+ available_input_length=len(input_ids),
359
+ )
360
+ # this happens if slice_required does not fit into max_tokens
361
+ if window_slice is None:
362
+ # if statistics is not None:
363
+ # statistics["out_of_token_window"][
364
+ # relation_mapping.get((head, tail), "TO_PREDICT")
365
+ # ] += 1
366
+ continue
367
+
368
+ window_start, window_end = window_slice
369
+ input_ids = input_ids[window_start:window_end]
370
+
371
+ head_token_slice = head_start - window_start, head_end - window_start
372
+ tail_token_slice = tail_start - window_start, tail_end - window_start
373
+
374
+ # maybe expand to n-ary relations?
375
+ head_arg = RelationArgument(head, HEAD, head_token_slice, self.add_type_to_marker)
376
+ tail_arg = RelationArgument(tail, TAIL, tail_token_slice, self.add_type_to_marker)
377
+ arg_list = [head_arg, tail_arg]
378
+
379
+ if head_token_slice[0] < tail_token_slice[0]:
380
+ assert (
381
+ head_token_slice[1] <= tail_token_slice[0]
382
+ ), f"the head and tail entities are not allowed to overlap in {document.id}"
383
+
384
+ else:
385
+ assert (
386
+ tail_token_slice[1] <= head_token_slice[0]
387
+ ), f"the head and tail entities are not allowed to overlap in {document.id}"
388
+ # expand to n-ary relations?
389
+ arg_list.reverse()
390
+
391
+ first_arg_start_id = self.argument_markers_to_id[arg_list[0].as_start_marker]
392
+ first_arg_end_id = self.argument_markers_to_id[arg_list[0].as_end_marker]
393
+ second_arg_start_id = self.argument_markers_to_id[arg_list[1].as_start_marker]
394
+ second_arg_end_id = self.argument_markers_to_id[arg_list[1].as_end_marker]
395
+
396
+ new_input_ids = (
397
+ input_ids[: arg_list[0].offsets[0]]
398
+ + [first_arg_start_id]
399
+ + input_ids[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
400
+ + [first_arg_end_id]
401
+ + input_ids[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
402
+ + [second_arg_start_id]
403
+ + input_ids[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
404
+ + [second_arg_end_id]
405
+ + input_ids[arg_list[1].offsets[1] :]
406
+ )
407
+
408
+ if self.append_markers:
409
+
410
+ new_input_ids.extend(
411
+ [
412
+ self.argument_markers_to_id[head_arg.as_append_marker],
413
+ self.sep_token_id,
414
+ self.argument_markers_to_id[tail_arg.as_append_marker],
415
+ self.sep_token_id,
416
+ ]
417
+ )
418
+
419
+ # when windowing is used, we have to add the special tokens manually
420
+ if not add_special_tokens:
421
+ new_input_ids = self.tokenizer.build_inputs_with_special_tokens(
422
+ token_ids_0=new_input_ids
423
+ )
424
+
425
+ # lots of logging from here on
426
+ log_this_example = (
427
+ self.log_first_n_examples is not None
428
+ and self._logged_examples_counter <= self.log_first_n_examples
429
+ )
430
+ if log_this_example:
431
+ self._log_example(document, arg_list, new_input_ids, relations, tokens)
432
+
433
+ task_encodings.append(
434
+ TaskEncoding(
435
+ document=document,
436
+ inputs={"input_ids": new_input_ids},
437
+ metadata={
438
+ HEAD: head,
439
+ TAIL: tail,
440
+ },
441
+ )
442
+ )
443
+
444
+ return task_encodings
445
+
446
+ def _log_example(
447
+ self,
448
+ document: TextDocument,
449
+ arg_list: List[RelationArgument],
450
+ input_ids: List[int],
451
+ relations: Sequence[BinaryRelation],
452
+ tokens: List[str],
453
+ ):
454
+
455
+ first_arg_start = arg_list[0].as_start_marker
456
+ first_arg_end = arg_list[0].as_end_marker
457
+ second_arg_start = arg_list[1].as_start_marker
458
+ second_arg_end = arg_list[1].as_end_marker
459
+ new_tokens = (
460
+ tokens[: arg_list[0].offsets[0]]
461
+ + [first_arg_start]
462
+ + tokens[arg_list[0].offsets[0] : arg_list[0].offsets[1]]
463
+ + [first_arg_end]
464
+ + tokens[arg_list[0].offsets[1] : arg_list[1].offsets[0]]
465
+ + [second_arg_start]
466
+ + tokens[arg_list[1].offsets[0] : arg_list[1].offsets[1]]
467
+ + [second_arg_end]
468
+ + tokens[arg_list[1].offsets[1] :]
469
+ )
470
+
471
+ head_idx = 0 if arg_list[0].role == HEAD else 1
472
+ tail_idx = 0 if arg_list[0].role == TAIL else 1
473
+
474
+ if self.append_markers:
475
+ head_marker = arg_list[head_idx].as_append_marker
476
+ tail_marker = arg_list[tail_idx].as_append_marker
477
+ new_tokens.extend(
478
+ [head_marker, self.tokenizer.sep_token, tail_marker, self.tokenizer.sep_token]
479
+ )
480
+ logger.info("*** Example ***")
481
+ logger.info("doc id: %s", document.id)
482
+ logger.info("tokens: %s", " ".join([str(x) for x in new_tokens]))
483
+ logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
484
+ rel_labels = [relation.label for relation in relations]
485
+ rel_label_ids = [self.label_to_id[label] for label in rel_labels]
486
+ logger.info("Expected labels: %s (ids = %s)", rel_labels, rel_label_ids)
487
+
488
+ self._logged_examples_counter += 1
489
+
490
+ def encode_target(
491
+ self,
492
+ task_encoding: TransformerReTextClassificationTaskEncoding2,
493
+ ) -> TransformerReTextClassificationTargetEncoding2:
494
+ metadata = task_encoding.metadata
495
+ document = task_encoding.document
496
+
497
+ relations: Sequence[BinaryRelation] = document[self.relation_annotation]
498
+
499
+ head_tail_to_labels = {
500
+ (relation.head, relation.tail): [relation.label] for relation in relations
501
+ }
502
+
503
+ labels = head_tail_to_labels.get((metadata[HEAD], metadata[TAIL]), [self.none_label])
504
+ target = [self.label_to_id[label] for label in labels]
505
+
506
+ return target
507
+
508
+ def unbatch_output(
509
+ self, model_output: TransformerTextClassificationModelBatchOutput
510
+ ) -> Sequence[TransformerReTextClassificationTaskOutput2]:
511
+ logits = model_output["logits"]
512
+
513
+ output_label_probs = logits.sigmoid() if self.multi_label else logits.softmax(dim=-1)
514
+ output_label_probs = output_label_probs.detach().cpu().numpy()
515
+
516
+ unbatched_output = []
517
+ if self.multi_label:
518
+ raise NotImplementedError
519
+ else:
520
+ label_ids = np.argmax(output_label_probs, axis=-1)
521
+ for batch_idx, label_id in enumerate(label_ids):
522
+ label = self.id_to_label[label_id]
523
+ prob = float(output_label_probs[batch_idx, label_id])
524
+ result: TransformerReTextClassificationTaskOutput2 = {
525
+ "labels": [label],
526
+ "probabilities": [prob],
527
+ }
528
+ unbatched_output.append(result)
529
+
530
+ return unbatched_output
531
+
532
+ def create_annotations_from_output(
533
+ self,
534
+ task_encoding: TransformerReTextClassificationTaskEncoding2,
535
+ task_output: TransformerReTextClassificationTaskOutput2,
536
+ ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation]]]:
537
+ labels = task_output["labels"]
538
+ probabilities = task_output["probabilities"]
539
+ if labels != [self.none_label]:
540
+ yield (
541
+ self.relation_annotation,
542
+ BinaryRelation(
543
+ head=task_encoding.metadata[HEAD],
544
+ tail=task_encoding.metadata[TAIL],
545
+ label=labels[0],
546
+ score=probabilities[0],
547
+ ),
548
+ )
549
+
550
+ def collate(
551
+ self, task_encodings: Sequence[TransformerReTextClassificationTaskEncoding2]
552
+ ) -> TransformerTextClassificationModelStepBatchEncoding:
553
+ input_features = [task_encoding.inputs for task_encoding in task_encodings]
554
+
555
+ inputs: Dict[str, torch.Tensor] = self.tokenizer.pad(
556
+ input_features,
557
+ padding=self.padding,
558
+ max_length=self.max_length,
559
+ pad_to_multiple_of=self.pad_to_multiple_of,
560
+ return_tensors="pt",
561
+ )
562
+
563
+ if not task_encodings[0].has_targets:
564
+ return inputs, None
565
+
566
+ target_list: List[TransformerReTextClassificationTargetEncoding2] = [
567
+ task_encoding.targets for task_encoding in task_encodings
568
+ ]
569
+ targets = torch.tensor(target_list, dtype=torch.int64)
570
+
571
+ if not self.multi_label:
572
+ targets = targets.flatten()
573
+
574
+ return inputs, targets