riccorl commited on
Commit
52a9eb9
1 Parent(s): 92bc11b

Automatic push from sapienzanlp

Browse files
added_tokens.json ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "--NME--": 128001,
3
+ "[E-0]": 128002,
4
+ "[E-10]": 128012,
5
+ "[E-11]": 128013,
6
+ "[E-12]": 128014,
7
+ "[E-13]": 128015,
8
+ "[E-14]": 128016,
9
+ "[E-15]": 128017,
10
+ "[E-16]": 128018,
11
+ "[E-17]": 128019,
12
+ "[E-18]": 128020,
13
+ "[E-19]": 128021,
14
+ "[E-1]": 128003,
15
+ "[E-20]": 128022,
16
+ "[E-21]": 128023,
17
+ "[E-22]": 128024,
18
+ "[E-23]": 128025,
19
+ "[E-24]": 128026,
20
+ "[E-25]": 128027,
21
+ "[E-26]": 128028,
22
+ "[E-27]": 128029,
23
+ "[E-28]": 128030,
24
+ "[E-29]": 128031,
25
+ "[E-2]": 128004,
26
+ "[E-30]": 128032,
27
+ "[E-31]": 128033,
28
+ "[E-32]": 128034,
29
+ "[E-33]": 128035,
30
+ "[E-34]": 128036,
31
+ "[E-35]": 128037,
32
+ "[E-36]": 128038,
33
+ "[E-37]": 128039,
34
+ "[E-38]": 128040,
35
+ "[E-39]": 128041,
36
+ "[E-3]": 128005,
37
+ "[E-40]": 128042,
38
+ "[E-41]": 128043,
39
+ "[E-42]": 128044,
40
+ "[E-43]": 128045,
41
+ "[E-44]": 128046,
42
+ "[E-45]": 128047,
43
+ "[E-46]": 128048,
44
+ "[E-47]": 128049,
45
+ "[E-48]": 128050,
46
+ "[E-49]": 128051,
47
+ "[E-4]": 128006,
48
+ "[E-50]": 128052,
49
+ "[E-51]": 128053,
50
+ "[E-52]": 128054,
51
+ "[E-53]": 128055,
52
+ "[E-54]": 128056,
53
+ "[E-55]": 128057,
54
+ "[E-56]": 128058,
55
+ "[E-57]": 128059,
56
+ "[E-58]": 128060,
57
+ "[E-59]": 128061,
58
+ "[E-5]": 128007,
59
+ "[E-60]": 128062,
60
+ "[E-61]": 128063,
61
+ "[E-62]": 128064,
62
+ "[E-63]": 128065,
63
+ "[E-64]": 128066,
64
+ "[E-65]": 128067,
65
+ "[E-66]": 128068,
66
+ "[E-67]": 128069,
67
+ "[E-68]": 128070,
68
+ "[E-69]": 128071,
69
+ "[E-6]": 128008,
70
+ "[E-70]": 128072,
71
+ "[E-71]": 128073,
72
+ "[E-72]": 128074,
73
+ "[E-73]": 128075,
74
+ "[E-74]": 128076,
75
+ "[E-75]": 128077,
76
+ "[E-76]": 128078,
77
+ "[E-77]": 128079,
78
+ "[E-78]": 128080,
79
+ "[E-79]": 128081,
80
+ "[E-7]": 128009,
81
+ "[E-80]": 128082,
82
+ "[E-81]": 128083,
83
+ "[E-82]": 128084,
84
+ "[E-83]": 128085,
85
+ "[E-84]": 128086,
86
+ "[E-85]": 128087,
87
+ "[E-86]": 128088,
88
+ "[E-87]": 128089,
89
+ "[E-88]": 128090,
90
+ "[E-89]": 128091,
91
+ "[E-8]": 128010,
92
+ "[E-90]": 128092,
93
+ "[E-91]": 128093,
94
+ "[E-92]": 128094,
95
+ "[E-93]": 128095,
96
+ "[E-94]": 128096,
97
+ "[E-95]": 128097,
98
+ "[E-96]": 128098,
99
+ "[E-97]": 128099,
100
+ "[E-98]": 128100,
101
+ "[E-99]": 128101,
102
+ "[E-9]": 128011,
103
+ "[MASK]": 128000
104
+ }
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "add_entity_embedding": null,
4
+ "additional_special_symbols": 101,
5
+ "additional_special_symbols_types": 0,
6
+ "architectures": [
7
+ "RelikReaderSpanModel"
8
+ ],
9
+ "auto_map": {
10
+ "AutoModel": "modeling_relik.RelikReaderSpanModel"
11
+ },
12
+ "default_reader_class": null,
13
+ "entity_type_loss": false,
14
+ "linears_hidden_size": 512,
15
+ "model_type": "relik-reader",
16
+ "num_layers": null,
17
+ "torch_dtype": "float32",
18
+ "training": true,
19
+ "transformer_model": "microsoft/deberta-v3-large",
20
+ "transformers_version": "4.33.3",
21
+ "use_last_k_layers": 1
22
+ }
configuration_relik.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class RelikReaderConfig(PretrainedConfig):
8
+ model_type = "relik-reader"
9
+
10
+ def __init__(
11
+ self,
12
+ transformer_model: str = "microsoft/deberta-v3-base",
13
+ additional_special_symbols: int = 101,
14
+ additional_special_symbols_types: Optional[int] = 0,
15
+ num_layers: Optional[int] = None,
16
+ activation: str = "gelu",
17
+ linears_hidden_size: Optional[int] = 512,
18
+ use_last_k_layers: int = 1,
19
+ entity_type_loss: bool = False,
20
+ add_entity_embedding: bool = None,
21
+ training: bool = False,
22
+ default_reader_class: Optional[str] = None,
23
+ **kwargs
24
+ ) -> None:
25
+ # TODO: add name_or_path to kwargs
26
+ self.transformer_model = transformer_model
27
+ self.additional_special_symbols = additional_special_symbols
28
+ self.additional_special_symbols_types = additional_special_symbols_types
29
+ self.num_layers = num_layers
30
+ self.activation = activation
31
+ self.linears_hidden_size = linears_hidden_size
32
+ self.use_last_k_layers = use_last_k_layers
33
+ self.entity_type_loss = entity_type_loss
34
+ self.add_entity_embedding = (
35
+ True
36
+ if add_entity_embedding is None and entity_type_loss
37
+ else add_entity_embedding
38
+ )
39
+ self.training = training
40
+ self.default_reader_class = default_reader_class
41
+ super().__init__(**kwargs)
42
+
43
+
44
+ AutoConfig.register("relik-reader", RelikReaderConfig)
modeling_relik.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from transformers import AutoModel, PreTrainedModel
5
+ from transformers.activations import ClippedGELUActivation, GELUActivation
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_utils import PoolerEndLogits
8
+
9
+ from .configuration_relik import RelikReaderConfig
10
+
11
+ torch.set_float32_matmul_precision("medium")
12
+
13
+ class RelikReaderSample:
14
+ def __init__(self, **kwargs):
15
+ super().__setattr__("_d", {})
16
+ self._d = kwargs
17
+
18
+ def __getattribute__(self, item):
19
+ return super(RelikReaderSample, self).__getattribute__(item)
20
+
21
+ def __getattr__(self, item):
22
+ if item.startswith("__") and item.endswith("__"):
23
+ # this is likely some python library-specific variable (such as __deepcopy__ for copy)
24
+ # better follow standard behavior here
25
+ raise AttributeError(item)
26
+ elif item in self._d:
27
+ return self._d[item]
28
+ else:
29
+ return None
30
+
31
+ def __setattr__(self, key, value):
32
+ if key in self._d:
33
+ self._d[key] = value
34
+ else:
35
+ super().__setattr__(key, value)
36
+
37
+
38
+ activation2functions = {
39
+ "relu": torch.nn.ReLU(),
40
+ "gelu": GELUActivation(),
41
+ "gelu_10": ClippedGELUActivation(-10, 10),
42
+ }
43
+
44
+
45
+ class PoolerEndLogitsBi(PoolerEndLogits):
46
+ def __init__(self, config: PretrainedConfig):
47
+ super().__init__(config)
48
+ self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
49
+
50
+ def forward(
51
+ self,
52
+ hidden_states: torch.FloatTensor,
53
+ start_states: Optional[torch.FloatTensor] = None,
54
+ start_positions: Optional[torch.LongTensor] = None,
55
+ p_mask: Optional[torch.FloatTensor] = None,
56
+ ) -> torch.FloatTensor:
57
+ if p_mask is not None:
58
+ p_mask = p_mask.unsqueeze(-1)
59
+ logits = super().forward(
60
+ hidden_states,
61
+ start_states,
62
+ start_positions,
63
+ p_mask,
64
+ )
65
+ return logits
66
+
67
+
68
+ class RelikReaderSpanModel(PreTrainedModel):
69
+ config_class = RelikReaderConfig
70
+
71
+ def __init__(self, config: RelikReaderConfig, *args, **kwargs):
72
+ super().__init__(config)
73
+ # Transformer model declaration
74
+ self.config = config
75
+ self.transformer_model = (
76
+ AutoModel.from_pretrained(self.config.transformer_model)
77
+ if self.config.num_layers is None
78
+ else AutoModel.from_pretrained(
79
+ self.config.transformer_model, num_hidden_layers=self.config.num_layers
80
+ )
81
+ )
82
+ self.transformer_model.resize_token_embeddings(
83
+ self.transformer_model.config.vocab_size
84
+ + self.config.additional_special_symbols,
85
+ pad_to_multiple_of=8,
86
+ )
87
+
88
+ self.activation = self.config.activation
89
+ self.linears_hidden_size = self.config.linears_hidden_size
90
+ self.use_last_k_layers = self.config.use_last_k_layers
91
+
92
+ # named entity detection layers
93
+ self.ned_start_classifier = self._get_projection_layer(
94
+ self.activation, last_hidden=2, layer_norm=False
95
+ )
96
+ self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
97
+
98
+ # END entity disambiguation layer
99
+ self.ed_projector = self._get_projection_layer(self.activation, last_hidden = 2*self.linears_hidden_size, hidden=2*self.linears_hidden_size)
100
+
101
+ self.training = self.config.training
102
+
103
+ # criterion
104
+ self.criterion = torch.nn.CrossEntropyLoss()
105
+
106
+ def _get_projection_layer(
107
+ self,
108
+ activation: str,
109
+ last_hidden: Optional[int] = None,
110
+ hidden: Optional[int] = None,
111
+ input_hidden=None,
112
+ layer_norm: bool = True,
113
+ ) -> torch.nn.Sequential:
114
+ head_components = [
115
+ torch.nn.Dropout(0.1),
116
+ torch.nn.Linear(
117
+ (
118
+ self.transformer_model.config.hidden_size * self.use_last_k_layers
119
+ if input_hidden is None
120
+ else input_hidden
121
+ ),
122
+ self.linears_hidden_size if hidden is None else hidden,
123
+ ),
124
+ activation2functions[activation],
125
+ torch.nn.Dropout(0.1),
126
+ torch.nn.Linear(
127
+ self.linears_hidden_size if hidden is None else hidden,
128
+ self.linears_hidden_size if last_hidden is None else last_hidden,
129
+ ),
130
+ ]
131
+
132
+ if layer_norm:
133
+ head_components.append(
134
+ torch.nn.LayerNorm(
135
+ self.linears_hidden_size if last_hidden is None else last_hidden,
136
+ self.transformer_model.config.layer_norm_eps,
137
+ )
138
+ )
139
+
140
+ return torch.nn.Sequential(*head_components)
141
+
142
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
143
+ mask = mask.unsqueeze(-1)
144
+ if next(self.parameters()).dtype == torch.float16:
145
+ logits = logits * (1 - mask) - 65500 * mask
146
+ else:
147
+ logits = logits * (1 - mask) - 1e30 * mask
148
+ return logits
149
+
150
+ def _get_model_features(
151
+ self,
152
+ input_ids: torch.Tensor,
153
+ attention_mask: torch.Tensor,
154
+ token_type_ids: Optional[torch.Tensor],
155
+ ):
156
+ model_input = {
157
+ "input_ids": input_ids,
158
+ "attention_mask": attention_mask,
159
+ "output_hidden_states": self.use_last_k_layers > 1,
160
+ }
161
+
162
+ if token_type_ids is not None:
163
+ model_input["token_type_ids"] = token_type_ids
164
+
165
+ model_output = self.transformer_model(**model_input)
166
+
167
+ if self.use_last_k_layers > 1:
168
+ model_features = torch.cat(
169
+ model_output[1][-self.use_last_k_layers :], dim=-1
170
+ )
171
+ else:
172
+ model_features = model_output[0]
173
+
174
+ return model_features
175
+
176
+ def compute_ned_end_logits(
177
+ self,
178
+ start_predictions,
179
+ start_labels,
180
+ model_features,
181
+ prediction_mask,
182
+ batch_size,
183
+ ) -> Optional[torch.Tensor]:
184
+ # todo: maybe when constraining on the spans,
185
+ # we should not use a prediction_mask for the end tokens.
186
+ # at least we should not during training imo
187
+ start_positions = start_labels if self.training else start_predictions
188
+ start_positions_indices = (
189
+ torch.arange(start_positions.size(1), device=start_positions.device)
190
+ .unsqueeze(0)
191
+ .expand(batch_size, -1)[start_positions > 0]
192
+ ).to(start_positions.device)
193
+
194
+ if len(start_positions_indices) > 0:
195
+ expanded_features = model_features.repeat_interleave(
196
+ torch.sum(start_positions > 0, dim=-1), dim=0
197
+ )
198
+ expanded_prediction_mask = prediction_mask.repeat_interleave(
199
+ torch.sum(start_positions > 0, dim=-1), dim=0
200
+ )
201
+ end_logits = self.ned_end_classifier(
202
+ hidden_states=expanded_features,
203
+ start_positions=start_positions_indices,
204
+ p_mask=expanded_prediction_mask,
205
+ )
206
+
207
+ return end_logits
208
+
209
+ return None
210
+
211
+ def compute_classification_logits(
212
+ self,
213
+ model_features,
214
+ special_symbols_mask,
215
+ prediction_mask,
216
+ batch_size,
217
+ start_positions=None,
218
+ end_positions=None,
219
+ ) -> torch.Tensor:
220
+ if start_positions is None or end_positions is None:
221
+ start_positions = torch.zeros_like(prediction_mask)
222
+ end_positions = torch.zeros_like(prediction_mask)
223
+
224
+
225
+ model_ed_features = self.ed_projector(model_features)
226
+
227
+ model_ed_features[start_positions > 0][:, model_ed_features.shape[-1] // 2:] = model_ed_features[end_positions > 0][
228
+ :, :model_ed_features.shape[-1] // 2
229
+ ]
230
+
231
+ # computing ed features
232
+ classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
233
+ special_symbols_representation = model_ed_features[special_symbols_mask].view(
234
+ batch_size, classes_representations, -1
235
+ )
236
+
237
+ logits = torch.bmm(
238
+ model_ed_features,
239
+ torch.permute(special_symbols_representation, (0, 2, 1)),
240
+ )
241
+
242
+ logits = self._mask_logits(logits, prediction_mask)
243
+
244
+ return logits
245
+
246
+ def forward(
247
+ self,
248
+ input_ids: torch.Tensor,
249
+ attention_mask: torch.Tensor,
250
+ token_type_ids: Optional[torch.Tensor] = None,
251
+ prediction_mask: Optional[torch.Tensor] = None,
252
+ special_symbols_mask: Optional[torch.Tensor] = None,
253
+ start_labels: Optional[torch.Tensor] = None,
254
+ end_labels: Optional[torch.Tensor] = None,
255
+ use_predefined_spans: bool = False,
256
+ *args,
257
+ **kwargs,
258
+ ) -> Dict[str, Any]:
259
+ batch_size, seq_len = input_ids.shape
260
+
261
+ model_features = self._get_model_features(
262
+ input_ids, attention_mask, token_type_ids
263
+ )
264
+
265
+ ned_start_labels = None
266
+
267
+ # named entity detection if required
268
+ if use_predefined_spans: # no need to compute spans
269
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
270
+ None,
271
+ None,
272
+ (
273
+ torch.clone(start_labels)
274
+ if start_labels is not None
275
+ else torch.zeros_like(input_ids)
276
+ ),
277
+ )
278
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
279
+ None,
280
+ None,
281
+ (
282
+ torch.clone(end_labels)
283
+ if end_labels is not None
284
+ else torch.zeros_like(input_ids)
285
+ ),
286
+ )
287
+
288
+ ned_start_predictions[ned_start_predictions > 0] = 1
289
+ ned_end_predictions[ned_end_predictions > 0] = 1
290
+
291
+ else: # compute spans
292
+ # start boundary prediction
293
+ ned_start_logits = self.ned_start_classifier(model_features)
294
+ ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
295
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
296
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
297
+
298
+ # end boundary prediction
299
+ ned_start_labels = (
300
+ torch.zeros_like(start_labels) if start_labels is not None else None
301
+ )
302
+
303
+ if ned_start_labels is not None:
304
+ ned_start_labels[start_labels == -100] = -100
305
+ ned_start_labels[start_labels > 0] = 1
306
+
307
+ ned_end_logits = self.compute_ned_end_logits(
308
+ ned_start_predictions,
309
+ ned_start_labels,
310
+ model_features,
311
+ prediction_mask,
312
+ batch_size,
313
+ )
314
+
315
+ if ned_end_logits is not None:
316
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
317
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
318
+ else:
319
+ ned_end_logits, ned_end_probabilities = None, None
320
+ ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
321
+
322
+ # flattening end predictions
323
+ # (flattening can happen only if the
324
+ # end boundaries were not predicted using the gold labels)
325
+ if not self.training and ned_end_logits is not None:
326
+ flattened_end_predictions = torch.zeros_like(ned_start_predictions)
327
+
328
+ row_indices, start_positions = torch.where(ned_start_predictions > 0)
329
+ ned_end_predictions[ned_end_predictions < start_positions] = (
330
+ start_positions[ned_end_predictions < start_positions]
331
+ )
332
+
333
+ end_spans_repeated = (row_indices + 1) * seq_len + ned_end_predictions
334
+ cummax_values, _ = end_spans_repeated.cummax(dim=0)
335
+
336
+ end_spans_repeated = end_spans_repeated > torch.cat(
337
+ (end_spans_repeated[:1], cummax_values[:-1])
338
+ )
339
+ end_spans_repeated[0] = True
340
+
341
+ ned_start_predictions[
342
+ row_indices[~end_spans_repeated],
343
+ start_positions[~end_spans_repeated],
344
+ ] = 0
345
+
346
+ row_indices, start_positions, ned_end_predictions = (
347
+ row_indices[end_spans_repeated],
348
+ start_positions[end_spans_repeated],
349
+ ned_end_predictions[end_spans_repeated],
350
+ )
351
+
352
+ flattened_end_predictions[row_indices, ned_end_predictions] = 1
353
+
354
+ total_start_predictions, total_end_predictions = (
355
+ ned_start_predictions.sum(),
356
+ flattened_end_predictions.sum(),
357
+ )
358
+
359
+ assert (
360
+ total_start_predictions == 0
361
+ or total_start_predictions == total_end_predictions
362
+ ), (
363
+ f"Total number of start predictions = {total_start_predictions}. "
364
+ f"Total number of end predictions = {total_end_predictions}"
365
+ )
366
+ ned_end_predictions = flattened_end_predictions
367
+ else:
368
+ ned_end_predictions = torch.zeros_like(ned_start_predictions)
369
+
370
+ start_position, end_position = (
371
+ (start_labels, end_labels)
372
+ if self.training
373
+ else (ned_start_predictions, ned_end_predictions)
374
+ )
375
+
376
+ # Entity disambiguation
377
+ ed_logits = self.compute_classification_logits(
378
+ model_features,
379
+ special_symbols_mask,
380
+ prediction_mask,
381
+ batch_size,
382
+ start_position,
383
+ end_position,
384
+ )
385
+ ed_probabilities = torch.softmax(ed_logits, dim=-1)
386
+ ed_predictions = torch.argmax(ed_probabilities, dim=-1)
387
+
388
+ # output build
389
+ output_dict = dict(
390
+ batch_size=batch_size,
391
+ ned_start_logits=ned_start_logits,
392
+ ned_start_probabilities=ned_start_probabilities,
393
+ ned_start_predictions=ned_start_predictions,
394
+ ned_end_logits=ned_end_logits,
395
+ ned_end_probabilities=ned_end_probabilities,
396
+ ned_end_predictions=ned_end_predictions,
397
+ ed_logits=ed_logits,
398
+ ed_probabilities=ed_probabilities,
399
+ ed_predictions=ed_predictions,
400
+ )
401
+
402
+ # compute loss if labels
403
+ if start_labels is not None and end_labels is not None and self.training:
404
+ # named entity detection loss
405
+
406
+ # start
407
+ if ned_start_logits is not None:
408
+ ned_start_loss = self.criterion(
409
+ ned_start_logits.view(-1, ned_start_logits.shape[-1]),
410
+ ned_start_labels.view(-1),
411
+ )
412
+ else:
413
+ ned_start_loss = 0
414
+
415
+ # end
416
+ if ned_end_logits is not None:
417
+ ned_end_labels = torch.zeros_like(end_labels)
418
+ ned_end_labels[end_labels == -100] = -100
419
+ ned_end_labels[end_labels > 0] = 1
420
+
421
+ ned_end_loss = self.criterion(
422
+ ned_end_logits,
423
+ (
424
+ torch.arange(
425
+ ned_end_labels.size(1), device=ned_end_labels.device
426
+ )
427
+ .unsqueeze(0)
428
+ .expand(batch_size, -1)[ned_end_labels > 0]
429
+ ).to(ned_end_labels.device),
430
+ )
431
+
432
+ else:
433
+ ned_end_loss = 0
434
+
435
+ # entity disambiguation loss
436
+ start_labels[ned_start_labels != 1] = -100
437
+ ed_labels = torch.clone(start_labels)
438
+ ed_labels[end_labels > 0] = end_labels[end_labels > 0]
439
+ ed_loss = self.criterion(
440
+ ed_logits.view(-1, ed_logits.shape[-1]),
441
+ ed_labels.view(-1),
442
+ )
443
+
444
+ output_dict["ned_start_loss"] = ned_start_loss
445
+ output_dict["ned_end_loss"] = ned_end_loss
446
+ output_dict["ed_loss"] = ed_loss
447
+
448
+ output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
449
+
450
+ return output_dict
451
+
452
+
453
+ class RelikReaderREModel(PreTrainedModel):
454
+ config_class = RelikReaderConfig
455
+
456
+ def __init__(self, config, *args, **kwargs):
457
+ super().__init__(config)
458
+ # Transformer model declaration
459
+ # self.transformer_model_name = transformer_model
460
+ self.config = config
461
+ self.transformer_model = (
462
+ AutoModel.from_pretrained(config.transformer_model)
463
+ if config.num_layers is None
464
+ else AutoModel.from_pretrained(
465
+ config.transformer_model, num_hidden_layers=config.num_layers
466
+ )
467
+ )
468
+ self.transformer_model.resize_token_embeddings(
469
+ self.transformer_model.config.vocab_size
470
+ + config.additional_special_symbols
471
+ + config.additional_special_symbols_types,
472
+ pad_to_multiple_of=8,
473
+ )
474
+
475
+ # named entity detection layers
476
+ self.ned_start_classifier = self._get_projection_layer(
477
+ config.activation, last_hidden=2, layer_norm=False
478
+ )
479
+
480
+ self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
481
+
482
+ self.relation_disambiguation_loss = (
483
+ config.relation_disambiguation_loss
484
+ if hasattr(config, "relation_disambiguation_loss")
485
+ else False
486
+ )
487
+
488
+ if self.config.entity_type_loss and self.config.add_entity_embedding:
489
+ input_hidden_ents = 3 * self.transformer_model.config.hidden_size
490
+ else:
491
+ input_hidden_ents = 2 * self.transformer_model.config.hidden_size
492
+
493
+ self.re_subject_projector = self._get_projection_layer(
494
+ config.activation, input_hidden=input_hidden_ents
495
+ )
496
+ self.re_object_projector = self._get_projection_layer(
497
+ config.activation, input_hidden=input_hidden_ents
498
+ )
499
+ self.re_relation_projector = self._get_projection_layer(config.activation)
500
+
501
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
502
+ self.re_entities_projector = self._get_projection_layer(
503
+ config.activation,
504
+ input_hidden=2 * self.transformer_model.config.hidden_size,
505
+ )
506
+ self.re_definition_projector = self._get_projection_layer(
507
+ config.activation,
508
+ )
509
+
510
+ self.re_classifier = self._get_projection_layer(
511
+ config.activation,
512
+ input_hidden=config.linears_hidden_size,
513
+ last_hidden=2,
514
+ layer_norm=False,
515
+ )
516
+
517
+ self.training = config.training
518
+
519
+ # criterion
520
+ self.criterion = torch.nn.CrossEntropyLoss()
521
+ self.criterion_type = torch.nn.BCEWithLogitsLoss()
522
+
523
+ def _get_projection_layer(
524
+ self,
525
+ activation: str,
526
+ last_hidden: Optional[int] = None,
527
+ input_hidden=None,
528
+ layer_norm: bool = True,
529
+ ) -> torch.nn.Sequential:
530
+ head_components = [
531
+ torch.nn.Dropout(0.1),
532
+ torch.nn.Linear(
533
+ (
534
+ self.transformer_model.config.hidden_size
535
+ * self.config.use_last_k_layers
536
+ if input_hidden is None
537
+ else input_hidden
538
+ ),
539
+ self.config.linears_hidden_size,
540
+ ),
541
+ activation2functions[activation],
542
+ torch.nn.Dropout(0.1),
543
+ torch.nn.Linear(
544
+ self.config.linears_hidden_size,
545
+ self.config.linears_hidden_size if last_hidden is None else last_hidden,
546
+ ),
547
+ ]
548
+
549
+ if layer_norm:
550
+ head_components.append(
551
+ torch.nn.LayerNorm(
552
+ (
553
+ self.config.linears_hidden_size
554
+ if last_hidden is None
555
+ else last_hidden
556
+ ),
557
+ self.transformer_model.config.layer_norm_eps,
558
+ )
559
+ )
560
+
561
+ return torch.nn.Sequential(*head_components)
562
+
563
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
564
+ mask = mask.unsqueeze(-1)
565
+ if next(self.parameters()).dtype == torch.float16:
566
+ logits = logits * (1 - mask) - 65500 * mask
567
+ else:
568
+ logits = logits * (1 - mask) - 1e30 * mask
569
+ return logits
570
+
571
+ def _get_model_features(
572
+ self,
573
+ input_ids: torch.Tensor,
574
+ attention_mask: torch.Tensor,
575
+ token_type_ids: Optional[torch.Tensor],
576
+ ):
577
+ model_input = {
578
+ "input_ids": input_ids,
579
+ "attention_mask": attention_mask,
580
+ "output_hidden_states": self.config.use_last_k_layers > 1,
581
+ }
582
+
583
+ if token_type_ids is not None:
584
+ model_input["token_type_ids"] = token_type_ids
585
+
586
+ model_output = self.transformer_model(**model_input)
587
+
588
+ if self.config.use_last_k_layers > 1:
589
+ model_features = torch.cat(
590
+ model_output[1][-self.config.use_last_k_layers :], dim=-1
591
+ )
592
+ else:
593
+ model_features = model_output[0]
594
+
595
+ return model_features
596
+
597
+ def compute_ned_end_logits(
598
+ self,
599
+ start_predictions,
600
+ start_labels,
601
+ model_features,
602
+ prediction_mask,
603
+ batch_size,
604
+ mask_preceding: bool = False,
605
+ ) -> Optional[torch.Tensor]:
606
+ # todo: maybe when constraining on the spans,
607
+ # we should not use a prediction_mask for the end tokens.
608
+ # at least we should not during training imo
609
+ start_positions = start_labels if self.training else start_predictions
610
+ start_positions_indices = (
611
+ torch.arange(start_positions.size(1), device=start_positions.device)
612
+ .unsqueeze(0)
613
+ .expand(batch_size, -1)[start_positions > 0]
614
+ ).to(start_positions.device)
615
+
616
+ if len(start_positions_indices) > 0:
617
+ expanded_features = model_features.repeat_interleave(
618
+ torch.sum(start_positions > 0, dim=-1), dim=0
619
+ )
620
+ expanded_prediction_mask = prediction_mask.repeat_interleave(
621
+ torch.sum(start_positions > 0, dim=-1), dim=0
622
+ )
623
+ if mask_preceding:
624
+ expanded_prediction_mask[
625
+ torch.arange(
626
+ expanded_prediction_mask.shape[1],
627
+ device=expanded_prediction_mask.device,
628
+ )
629
+ < start_positions_indices.unsqueeze(1)
630
+ ] = 1
631
+ end_logits = self.ned_end_classifier(
632
+ hidden_states=expanded_features,
633
+ start_positions=start_positions_indices,
634
+ p_mask=expanded_prediction_mask,
635
+ )
636
+
637
+ return end_logits
638
+
639
+ return None
640
+
641
+ def compute_relation_logits(
642
+ self,
643
+ model_entity_features,
644
+ special_symbols_features,
645
+ ) -> torch.Tensor:
646
+ model_subject_features = self.re_subject_projector(model_entity_features)
647
+ model_object_features = self.re_object_projector(model_entity_features)
648
+ special_symbols_start_representation = self.re_relation_projector(
649
+ special_symbols_features
650
+ )
651
+ re_logits = torch.einsum(
652
+ "bse,bde,bfe->bsdfe",
653
+ model_subject_features,
654
+ model_object_features,
655
+ special_symbols_start_representation,
656
+ )
657
+ re_logits = self.re_classifier(re_logits)
658
+
659
+ return re_logits
660
+
661
+ def compute_entity_logits(
662
+ self,
663
+ model_entity_features,
664
+ special_symbols_features,
665
+ ) -> torch.Tensor:
666
+ model_ed_features = self.re_entities_projector(model_entity_features)
667
+ special_symbols_ed_representation = self.re_definition_projector(
668
+ special_symbols_features
669
+ )
670
+
671
+ logits = torch.bmm(
672
+ model_ed_features,
673
+ torch.permute(special_symbols_ed_representation, (0, 2, 1)),
674
+ )
675
+ logits = self._mask_logits(
676
+ logits, (model_entity_features == -100).all(2).long()
677
+ )
678
+ return logits
679
+
680
+ def compute_loss(self, logits, labels, mask=None):
681
+ logits = logits.reshape(-1, logits.shape[-1])
682
+ labels = labels.reshape(-1).long()
683
+ if mask is not None:
684
+ return self.criterion(logits[mask], labels[mask])
685
+ return self.criterion(logits, labels)
686
+
687
+ def compute_ned_type_loss(
688
+ self,
689
+ disambiguation_labels,
690
+ re_ned_entities_logits,
691
+ ned_type_logits,
692
+ re_entities_logits,
693
+ entity_types,
694
+ mask,
695
+ ):
696
+ if self.config.entity_type_loss and self.relation_disambiguation_loss:
697
+ return self.criterion_type(
698
+ re_ned_entities_logits[disambiguation_labels != -100],
699
+ disambiguation_labels[disambiguation_labels != -100],
700
+ )
701
+ if self.config.entity_type_loss:
702
+ return self.criterion_type(
703
+ ned_type_logits[mask],
704
+ disambiguation_labels[:, :, :entity_types][mask],
705
+ )
706
+
707
+ if self.relation_disambiguation_loss:
708
+ return self.criterion_type(
709
+ re_entities_logits[disambiguation_labels != -100],
710
+ disambiguation_labels[disambiguation_labels != -100],
711
+ )
712
+ return 0
713
+
714
+ def compute_relation_loss(self, relation_labels, re_logits):
715
+ return self.compute_loss(
716
+ re_logits, relation_labels, relation_labels.view(-1) != -100
717
+ )
718
+
719
+ def forward(
720
+ self,
721
+ input_ids: torch.Tensor,
722
+ attention_mask: torch.Tensor,
723
+ token_type_ids: torch.Tensor,
724
+ prediction_mask: Optional[torch.Tensor] = None,
725
+ special_symbols_mask: Optional[torch.Tensor] = None,
726
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
727
+ start_labels: Optional[torch.Tensor] = None,
728
+ end_labels: Optional[torch.Tensor] = None,
729
+ disambiguation_labels: Optional[torch.Tensor] = None,
730
+ relation_labels: Optional[torch.Tensor] = None,
731
+ relation_threshold: float = 0.5,
732
+ is_validation: bool = False,
733
+ is_prediction: bool = False,
734
+ use_predefined_spans: bool = False,
735
+ *args,
736
+ **kwargs,
737
+ ) -> Dict[str, Any]:
738
+ batch_size = input_ids.shape[0]
739
+
740
+ model_features = self._get_model_features(
741
+ input_ids, attention_mask, token_type_ids
742
+ )
743
+
744
+ # named entity detection
745
+ if use_predefined_spans:
746
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
747
+ None,
748
+ None,
749
+ torch.zeros_like(start_labels),
750
+ )
751
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
752
+ None,
753
+ None,
754
+ torch.zeros_like(end_labels),
755
+ )
756
+
757
+ ned_start_predictions[start_labels > 0] = 1
758
+ ned_end_predictions[end_labels > 0] = 1
759
+ ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
760
+ ned_start_labels = start_labels
761
+ ned_start_labels[start_labels > 0] = 1
762
+ else:
763
+ # start boundary prediction
764
+ ned_start_logits = self.ned_start_classifier(model_features)
765
+ if is_validation or is_prediction:
766
+ ned_start_logits = self._mask_logits(
767
+ ned_start_logits, prediction_mask
768
+ ) # why?
769
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
770
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
771
+
772
+ # end boundary prediction
773
+ ned_start_labels = (
774
+ torch.zeros_like(start_labels) if start_labels is not None else None
775
+ )
776
+
777
+ # start_labels contain entity id at their position, we just need 1 for start of entity
778
+ if ned_start_labels is not None:
779
+ ned_start_labels[start_labels == -100] = -100
780
+ ned_start_labels[start_labels > 0] = 1
781
+
782
+ # compute end logits only if there are any start predictions.
783
+ # For each start prediction, n end predictions are made
784
+ ned_end_logits = self.compute_ned_end_logits(
785
+ ned_start_predictions,
786
+ ned_start_labels,
787
+ model_features,
788
+ prediction_mask,
789
+ batch_size,
790
+ True,
791
+ )
792
+
793
+ if ned_end_logits is not None:
794
+ # For each start prediction, n end predictions are made based on
795
+ # binary classification ie. argmax at each position.
796
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
797
+ ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
798
+ else:
799
+ ned_end_logits, ned_end_probabilities = None, None
800
+ ned_end_predictions = torch.zeros_like(ned_start_predictions)
801
+
802
+ if is_prediction or is_validation:
803
+ end_preds_count = ned_end_predictions.sum(1)
804
+ # If there are no end predictions for a start prediction, remove the start prediction
805
+ if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
806
+ ned_start_predictions[ned_start_predictions == 1] = (
807
+ end_preds_count != 0
808
+ ).long()
809
+ ned_end_predictions = ned_end_predictions[end_preds_count != 0]
810
+
811
+ if end_labels is not None:
812
+ end_labels = end_labels[~(end_labels == -100).all(2)]
813
+
814
+ start_position, end_position = (
815
+ (start_labels, end_labels)
816
+ if (not is_prediction and not is_validation)
817
+ else (ned_start_predictions, ned_end_predictions)
818
+ )
819
+
820
+ start_counts = (start_position > 0).sum(1)
821
+ if (start_counts > 0).any():
822
+ ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
823
+ # limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
824
+ # if is_validation or is_prediction:
825
+ # ned_start_predictions[ned_start_predictions == 1] = start_counts
826
+ # We can only predict relations if we have start and end predictions
827
+ if (end_position > 0).sum() > 0:
828
+ ends_count = (end_position > 0).sum(1)
829
+ model_subject_features = torch.cat(
830
+ [
831
+ torch.repeat_interleave(
832
+ model_features[start_position > 0], ends_count, dim=0
833
+ ), # start position features
834
+ torch.repeat_interleave(model_features, start_counts, dim=0)[
835
+ end_position > 0
836
+ ], # end position features
837
+ ],
838
+ dim=-1,
839
+ )
840
+ ents_count = torch.nn.utils.rnn.pad_sequence(
841
+ torch.split(ends_count, start_counts.tolist()),
842
+ batch_first=True,
843
+ padding_value=0,
844
+ ).sum(1)
845
+ model_subject_features = torch.nn.utils.rnn.pad_sequence(
846
+ torch.split(model_subject_features, ents_count.tolist()),
847
+ batch_first=True,
848
+ padding_value=-100,
849
+ )
850
+
851
+ # if is_validation or is_prediction:
852
+ # model_subject_features = model_subject_features[:, :30, :]
853
+
854
+ # entity disambiguation. Here relation_disambiguation_loss would only be useful to
855
+ # reduce the number of candidate relations for the next step, but currently unused.
856
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
857
+ (re_ned_entities_logits) = self.compute_entity_logits(
858
+ model_subject_features,
859
+ model_features[
860
+ special_symbols_mask | special_symbols_mask_entities
861
+ ].view(batch_size, -1, model_features.shape[-1]),
862
+ )
863
+ entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
864
+ ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
865
+ re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
866
+
867
+ if self.config.entity_type_loss:
868
+ ned_type_probabilities = torch.sigmoid(ned_type_logits)
869
+ ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
870
+
871
+ if self.config.add_entity_embedding:
872
+ special_symbols_representation = model_features[
873
+ special_symbols_mask_entities
874
+ ].view(batch_size, entity_types, -1)
875
+
876
+ entities_representation = torch.einsum(
877
+ "bsp,bpe->bse",
878
+ ned_type_probabilities,
879
+ special_symbols_representation,
880
+ )
881
+ model_subject_features = torch.cat(
882
+ [model_subject_features, entities_representation], dim=-1
883
+ )
884
+ re_entities_probabilities = torch.sigmoid(re_entities_logits)
885
+ re_entities_predictions = re_entities_probabilities.round()
886
+ else:
887
+ (
888
+ ned_type_logits,
889
+ ned_type_probabilities,
890
+ re_entities_logits,
891
+ re_entities_probabilities,
892
+ ) = (None, None, None, None)
893
+ ned_type_predictions, re_entities_predictions = (
894
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
895
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
896
+ )
897
+
898
+ # Compute relation logits
899
+ re_logits = self.compute_relation_logits(
900
+ model_subject_features,
901
+ model_features[special_symbols_mask].view(
902
+ batch_size, -1, model_features.shape[-1]
903
+ ),
904
+ )
905
+
906
+ re_probabilities = torch.softmax(re_logits, dim=-1)
907
+ # we set a thresshold instead of argmax in cause it needs to be tweaked
908
+ re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
909
+ # re_predictions = re_probabilities.argmax(dim=-1)
910
+ re_probabilities = re_probabilities[:, :, :, :, 1]
911
+ # re_logits, re_probabilities, re_predictions = (
912
+ # torch.zeros(
913
+ # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
914
+ # ).to(input_ids.device),
915
+ # torch.zeros(
916
+ # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
917
+ # ).to(input_ids.device),
918
+ # torch.zeros(
919
+ # [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
920
+ # ).to(input_ids.device),
921
+ # )
922
+
923
+ else:
924
+ (
925
+ ned_type_logits,
926
+ ned_type_probabilities,
927
+ re_entities_logits,
928
+ re_entities_probabilities,
929
+ ) = (None, None, None, None)
930
+ ned_type_predictions, re_entities_predictions = (
931
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
932
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
933
+ )
934
+ re_logits, re_probabilities, re_predictions = (
935
+ torch.zeros(
936
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
937
+ ).to(input_ids.device),
938
+ torch.zeros(
939
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
940
+ ).to(input_ids.device),
941
+ torch.zeros(
942
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
943
+ ).to(input_ids.device),
944
+ )
945
+
946
+ # output build
947
+ output_dict = dict(
948
+ batch_size=batch_size,
949
+ ned_start_logits=ned_start_logits,
950
+ ned_start_probabilities=ned_start_probabilities,
951
+ ned_start_predictions=ned_start_predictions,
952
+ ned_end_logits=ned_end_logits,
953
+ ned_end_probabilities=ned_end_probabilities,
954
+ ned_end_predictions=ned_end_predictions,
955
+ ned_type_logits=ned_type_logits,
956
+ ned_type_probabilities=ned_type_probabilities,
957
+ ned_type_predictions=ned_type_predictions,
958
+ re_entities_logits=re_entities_logits,
959
+ re_entities_probabilities=re_entities_probabilities,
960
+ re_entities_predictions=re_entities_predictions,
961
+ re_logits=re_logits,
962
+ re_probabilities=re_probabilities,
963
+ re_predictions=re_predictions,
964
+ )
965
+
966
+ if (
967
+ start_labels is not None
968
+ and end_labels is not None
969
+ and relation_labels is not None
970
+ and is_prediction is False
971
+ ):
972
+ ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
973
+ end_labels[end_labels > 0] = 1
974
+ ned_end_loss = self.compute_loss(ned_end_logits, end_labels)
975
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
976
+ ned_type_loss = self.compute_ned_type_loss(
977
+ disambiguation_labels,
978
+ re_ned_entities_logits,
979
+ ned_type_logits,
980
+ re_entities_logits,
981
+ entity_types,
982
+ (model_subject_features != -100).all(2),
983
+ )
984
+ relation_loss = self.compute_relation_loss(relation_labels, re_logits)
985
+ # compute loss. We can skip the relation loss if we are in the first epochs (optional)
986
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
987
+ output_dict["loss"] = (
988
+ ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
989
+ ) / 4
990
+ output_dict["ned_type_loss"] = ned_type_loss
991
+ else:
992
+ # output_dict["loss"] = ((1 / 4) * (ned_start_loss + ned_end_loss)) + (
993
+ # (1 / 2) * relation_loss
994
+ # )
995
+ output_dict["loss"] = ((1 / 16) * (ned_start_loss + ned_end_loss)) + (
996
+ (7 / 8) * relation_loss
997
+ )
998
+
999
+ output_dict["ned_start_loss"] = ned_start_loss
1000
+ output_dict["ned_end_loss"] = ned_end_loss
1001
+ output_dict["re_loss"] = relation_loss
1002
+
1003
+ return output_dict
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a1fc57f4bab61364d5b4e68ab6f9d435f3157c28b20b54c97647cf855c0cd24
3
+ size 1755549154
special_tokens_map.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "--NME--",
4
+ "[E-0]",
5
+ "[E-1]",
6
+ "[E-2]",
7
+ "[E-3]",
8
+ "[E-4]",
9
+ "[E-5]",
10
+ "[E-6]",
11
+ "[E-7]",
12
+ "[E-8]",
13
+ "[E-9]",
14
+ "[E-10]",
15
+ "[E-11]",
16
+ "[E-12]",
17
+ "[E-13]",
18
+ "[E-14]",
19
+ "[E-15]",
20
+ "[E-16]",
21
+ "[E-17]",
22
+ "[E-18]",
23
+ "[E-19]",
24
+ "[E-20]",
25
+ "[E-21]",
26
+ "[E-22]",
27
+ "[E-23]",
28
+ "[E-24]",
29
+ "[E-25]",
30
+ "[E-26]",
31
+ "[E-27]",
32
+ "[E-28]",
33
+ "[E-29]",
34
+ "[E-30]",
35
+ "[E-31]",
36
+ "[E-32]",
37
+ "[E-33]",
38
+ "[E-34]",
39
+ "[E-35]",
40
+ "[E-36]",
41
+ "[E-37]",
42
+ "[E-38]",
43
+ "[E-39]",
44
+ "[E-40]",
45
+ "[E-41]",
46
+ "[E-42]",
47
+ "[E-43]",
48
+ "[E-44]",
49
+ "[E-45]",
50
+ "[E-46]",
51
+ "[E-47]",
52
+ "[E-48]",
53
+ "[E-49]",
54
+ "[E-50]",
55
+ "[E-51]",
56
+ "[E-52]",
57
+ "[E-53]",
58
+ "[E-54]",
59
+ "[E-55]",
60
+ "[E-56]",
61
+ "[E-57]",
62
+ "[E-58]",
63
+ "[E-59]",
64
+ "[E-60]",
65
+ "[E-61]",
66
+ "[E-62]",
67
+ "[E-63]",
68
+ "[E-64]",
69
+ "[E-65]",
70
+ "[E-66]",
71
+ "[E-67]",
72
+ "[E-68]",
73
+ "[E-69]",
74
+ "[E-70]",
75
+ "[E-71]",
76
+ "[E-72]",
77
+ "[E-73]",
78
+ "[E-74]",
79
+ "[E-75]",
80
+ "[E-76]",
81
+ "[E-77]",
82
+ "[E-78]",
83
+ "[E-79]",
84
+ "[E-80]",
85
+ "[E-81]",
86
+ "[E-82]",
87
+ "[E-83]",
88
+ "[E-84]",
89
+ "[E-85]",
90
+ "[E-86]",
91
+ "[E-87]",
92
+ "[E-88]",
93
+ "[E-89]",
94
+ "[E-90]",
95
+ "[E-91]",
96
+ "[E-92]",
97
+ "[E-93]",
98
+ "[E-94]",
99
+ "[E-95]",
100
+ "[E-96]",
101
+ "[E-97]",
102
+ "[E-98]",
103
+ "[E-99]"
104
+ ],
105
+ "bos_token": "[CLS]",
106
+ "cls_token": "[CLS]",
107
+ "eos_token": "[SEP]",
108
+ "mask_token": "[MASK]",
109
+ "pad_token": "[PAD]",
110
+ "sep_token": "[SEP]",
111
+ "unk_token": "[UNK]"
112
+ }
spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
+ size 2464616
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "additional_special_tokens": [
4
+ "--NME--",
5
+ "[E-0]",
6
+ "[E-1]",
7
+ "[E-2]",
8
+ "[E-3]",
9
+ "[E-4]",
10
+ "[E-5]",
11
+ "[E-6]",
12
+ "[E-7]",
13
+ "[E-8]",
14
+ "[E-9]",
15
+ "[E-10]",
16
+ "[E-11]",
17
+ "[E-12]",
18
+ "[E-13]",
19
+ "[E-14]",
20
+ "[E-15]",
21
+ "[E-16]",
22
+ "[E-17]",
23
+ "[E-18]",
24
+ "[E-19]",
25
+ "[E-20]",
26
+ "[E-21]",
27
+ "[E-22]",
28
+ "[E-23]",
29
+ "[E-24]",
30
+ "[E-25]",
31
+ "[E-26]",
32
+ "[E-27]",
33
+ "[E-28]",
34
+ "[E-29]",
35
+ "[E-30]",
36
+ "[E-31]",
37
+ "[E-32]",
38
+ "[E-33]",
39
+ "[E-34]",
40
+ "[E-35]",
41
+ "[E-36]",
42
+ "[E-37]",
43
+ "[E-38]",
44
+ "[E-39]",
45
+ "[E-40]",
46
+ "[E-41]",
47
+ "[E-42]",
48
+ "[E-43]",
49
+ "[E-44]",
50
+ "[E-45]",
51
+ "[E-46]",
52
+ "[E-47]",
53
+ "[E-48]",
54
+ "[E-49]",
55
+ "[E-50]",
56
+ "[E-51]",
57
+ "[E-52]",
58
+ "[E-53]",
59
+ "[E-54]",
60
+ "[E-55]",
61
+ "[E-56]",
62
+ "[E-57]",
63
+ "[E-58]",
64
+ "[E-59]",
65
+ "[E-60]",
66
+ "[E-61]",
67
+ "[E-62]",
68
+ "[E-63]",
69
+ "[E-64]",
70
+ "[E-65]",
71
+ "[E-66]",
72
+ "[E-67]",
73
+ "[E-68]",
74
+ "[E-69]",
75
+ "[E-70]",
76
+ "[E-71]",
77
+ "[E-72]",
78
+ "[E-73]",
79
+ "[E-74]",
80
+ "[E-75]",
81
+ "[E-76]",
82
+ "[E-77]",
83
+ "[E-78]",
84
+ "[E-79]",
85
+ "[E-80]",
86
+ "[E-81]",
87
+ "[E-82]",
88
+ "[E-83]",
89
+ "[E-84]",
90
+ "[E-85]",
91
+ "[E-86]",
92
+ "[E-87]",
93
+ "[E-88]",
94
+ "[E-89]",
95
+ "[E-90]",
96
+ "[E-91]",
97
+ "[E-92]",
98
+ "[E-93]",
99
+ "[E-94]",
100
+ "[E-95]",
101
+ "[E-96]",
102
+ "[E-97]",
103
+ "[E-98]",
104
+ "[E-99]"
105
+ ],
106
+ "bos_token": "[CLS]",
107
+ "clean_up_tokenization_spaces": true,
108
+ "cls_token": "[CLS]",
109
+ "do_lower_case": false,
110
+ "eos_token": "[SEP]",
111
+ "mask_token": "[MASK]",
112
+ "model_max_length": 1000000000000000019884624838656,
113
+ "pad_token": "[PAD]",
114
+ "sep_token": "[SEP]",
115
+ "sp_model_kwargs": {},
116
+ "split_by_punct": false,
117
+ "tokenizer_class": "DebertaV2Tokenizer",
118
+ "unk_token": "[UNK]",
119
+ "vocab_type": "spm"
120
+ }