PereLluis13 commited on
Commit
8f5593f
1 Parent(s): e208b82

Upload model

Browse files
Files changed (4) hide show
  1. config.json +23 -0
  2. configuration_relik.py +44 -0
  3. modeling_relik.py +981 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/leonardo_scratch/fast/IscrC_MEL/wandb/offline-run-20240701_170315-vwnjjdl0/files/files",
3
+ "activation": "gelu",
4
+ "add_entity_embedding": null,
5
+ "additional_special_symbols": 101,
6
+ "additional_special_symbols_types": 0,
7
+ "architectures": [
8
+ "RelikReaderSpanModel"
9
+ ],
10
+ "auto_map": {
11
+ "AutoModel": "modeling_relik.RelikReaderSpanModel"
12
+ },
13
+ "default_reader_class": null,
14
+ "entity_type_loss": false,
15
+ "linears_hidden_size": 512,
16
+ "model_type": "relik-reader",
17
+ "num_layers": null,
18
+ "torch_dtype": "float32",
19
+ "training": true,
20
+ "transformer_model": "microsoft/deberta-v3-large",
21
+ "transformers_version": "4.33.3",
22
+ "use_last_k_layers": 1
23
+ }
configuration_relik.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class RelikReaderConfig(PretrainedConfig):
8
+ model_type = "relik-reader"
9
+
10
+ def __init__(
11
+ self,
12
+ transformer_model: str = "microsoft/deberta-v3-base",
13
+ additional_special_symbols: int = 101,
14
+ additional_special_symbols_types: Optional[int] = 0,
15
+ num_layers: Optional[int] = None,
16
+ activation: str = "gelu",
17
+ linears_hidden_size: Optional[int] = 512,
18
+ use_last_k_layers: int = 1,
19
+ entity_type_loss: bool = False,
20
+ add_entity_embedding: bool = None,
21
+ training: bool = False,
22
+ default_reader_class: Optional[str] = None,
23
+ **kwargs
24
+ ) -> None:
25
+ # TODO: add name_or_path to kwargs
26
+ self.transformer_model = transformer_model
27
+ self.additional_special_symbols = additional_special_symbols
28
+ self.additional_special_symbols_types = additional_special_symbols_types
29
+ self.num_layers = num_layers
30
+ self.activation = activation
31
+ self.linears_hidden_size = linears_hidden_size
32
+ self.use_last_k_layers = use_last_k_layers
33
+ self.entity_type_loss = entity_type_loss
34
+ self.add_entity_embedding = (
35
+ True
36
+ if add_entity_embedding is None and entity_type_loss
37
+ else add_entity_embedding
38
+ )
39
+ self.training = training
40
+ self.default_reader_class = default_reader_class
41
+ super().__init__(**kwargs)
42
+
43
+
44
+ AutoConfig.register("relik-reader", RelikReaderConfig)
modeling_relik.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from transformers import AutoModel, PreTrainedModel
5
+ from transformers.activations import ClippedGELUActivation, GELUActivation
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_utils import PoolerEndLogits
8
+
9
+ from .configuration_relik import RelikReaderConfig
10
+
11
+
12
+ class RelikReaderSample:
13
+ def __init__(self, **kwargs):
14
+ super().__setattr__("_d", {})
15
+ self._d = kwargs
16
+
17
+ def __getattribute__(self, item):
18
+ return super(RelikReaderSample, self).__getattribute__(item)
19
+
20
+ def __getattr__(self, item):
21
+ if item.startswith("__") and item.endswith("__"):
22
+ # this is likely some python library-specific variable (such as __deepcopy__ for copy)
23
+ # better follow standard behavior here
24
+ raise AttributeError(item)
25
+ elif item in self._d:
26
+ return self._d[item]
27
+ else:
28
+ return None
29
+
30
+ def __setattr__(self, key, value):
31
+ if key in self._d:
32
+ self._d[key] = value
33
+ else:
34
+ super().__setattr__(key, value)
35
+ self._d[key] = value
36
+
37
+
38
+ activation2functions = {
39
+ "relu": torch.nn.ReLU(),
40
+ "gelu": GELUActivation(),
41
+ "gelu_10": ClippedGELUActivation(-10, 10),
42
+ }
43
+
44
+
45
+ class PoolerEndLogitsBi(PoolerEndLogits):
46
+ def __init__(self, config: PretrainedConfig):
47
+ super().__init__(config)
48
+ self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
49
+
50
+ def forward(
51
+ self,
52
+ hidden_states: torch.FloatTensor,
53
+ start_states: Optional[torch.FloatTensor] = None,
54
+ start_positions: Optional[torch.LongTensor] = None,
55
+ p_mask: Optional[torch.FloatTensor] = None,
56
+ ) -> torch.FloatTensor:
57
+ if p_mask is not None:
58
+ p_mask = p_mask.unsqueeze(-1)
59
+ logits = super().forward(
60
+ hidden_states,
61
+ start_states,
62
+ start_positions,
63
+ p_mask,
64
+ )
65
+ return logits
66
+
67
+
68
+ class RelikReaderSpanModel(PreTrainedModel):
69
+ config_class = RelikReaderConfig
70
+
71
+ def __init__(self, config: RelikReaderConfig, *args, **kwargs):
72
+ super().__init__(config)
73
+ # Transformer model declaration
74
+ self.config = config
75
+ self.transformer_model = (
76
+ AutoModel.from_pretrained(self.config.transformer_model)
77
+ if self.config.num_layers is None
78
+ else AutoModel.from_pretrained(
79
+ self.config.transformer_model, num_hidden_layers=self.config.num_layers
80
+ )
81
+ )
82
+ self.transformer_model.resize_token_embeddings(
83
+ self.transformer_model.config.vocab_size
84
+ + self.config.additional_special_symbols
85
+ )
86
+
87
+ self.activation = self.config.activation
88
+ self.linears_hidden_size = self.config.linears_hidden_size
89
+ self.use_last_k_layers = self.config.use_last_k_layers
90
+
91
+ # named entity detection layers
92
+ self.ned_start_classifier = self._get_projection_layer(
93
+ self.activation, last_hidden=2, layer_norm=False
94
+ )
95
+ self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
96
+
97
+ # END entity disambiguation layer
98
+ self.ed_start_projector = self._get_projection_layer(self.activation)
99
+ self.ed_end_projector = self._get_projection_layer(self.activation)
100
+
101
+ self.training = self.config.training
102
+
103
+ # criterion
104
+ self.criterion = torch.nn.CrossEntropyLoss()
105
+
106
+ def _get_projection_layer(
107
+ self,
108
+ activation: str,
109
+ last_hidden: Optional[int] = None,
110
+ input_hidden=None,
111
+ layer_norm: bool = True,
112
+ ) -> torch.nn.Sequential:
113
+ head_components = [
114
+ torch.nn.Dropout(0.1),
115
+ torch.nn.Linear(
116
+ (
117
+ self.transformer_model.config.hidden_size * self.use_last_k_layers
118
+ if input_hidden is None
119
+ else input_hidden
120
+ ),
121
+ self.linears_hidden_size,
122
+ ),
123
+ activation2functions[activation],
124
+ torch.nn.Dropout(0.1),
125
+ torch.nn.Linear(
126
+ self.linears_hidden_size,
127
+ self.linears_hidden_size if last_hidden is None else last_hidden,
128
+ ),
129
+ ]
130
+
131
+ if layer_norm:
132
+ head_components.append(
133
+ torch.nn.LayerNorm(
134
+ self.linears_hidden_size if last_hidden is None else last_hidden,
135
+ self.transformer_model.config.layer_norm_eps,
136
+ )
137
+ )
138
+
139
+ return torch.nn.Sequential(*head_components)
140
+
141
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
142
+ mask = mask.unsqueeze(-1)
143
+ if next(self.parameters()).dtype == torch.float16:
144
+ logits = logits * (1 - mask) - 65500 * mask
145
+ else:
146
+ logits = logits * (1 - mask) - 1e30 * mask
147
+ return logits
148
+
149
+ def _get_model_features(
150
+ self,
151
+ input_ids: torch.Tensor,
152
+ attention_mask: torch.Tensor,
153
+ token_type_ids: Optional[torch.Tensor],
154
+ ):
155
+ model_input = {
156
+ "input_ids": input_ids,
157
+ "attention_mask": attention_mask,
158
+ "output_hidden_states": self.use_last_k_layers > 1,
159
+ }
160
+
161
+ if token_type_ids is not None:
162
+ model_input["token_type_ids"] = token_type_ids
163
+
164
+ model_output = self.transformer_model(**model_input)
165
+
166
+ if self.use_last_k_layers > 1:
167
+ model_features = torch.cat(
168
+ model_output[1][-self.use_last_k_layers :], dim=-1
169
+ )
170
+ else:
171
+ model_features = model_output[0]
172
+
173
+ return model_features
174
+
175
+ def compute_ned_end_logits(
176
+ self,
177
+ start_predictions,
178
+ start_labels,
179
+ model_features,
180
+ prediction_mask,
181
+ batch_size,
182
+ ) -> Optional[torch.Tensor]:
183
+ # todo: maybe when constraining on the spans,
184
+ # we should not use a prediction_mask for the end tokens.
185
+ # at least we should not during training imo
186
+ start_positions = start_labels if self.training else start_predictions
187
+ start_positions_indices = (
188
+ torch.arange(start_positions.size(1), device=start_positions.device)
189
+ .unsqueeze(0)
190
+ .expand(batch_size, -1)[start_positions > 0]
191
+ ).to(start_positions.device)
192
+
193
+ if len(start_positions_indices) > 0:
194
+ expanded_features = model_features.repeat_interleave(
195
+ torch.sum(start_positions > 0, dim=-1), dim=0
196
+ )
197
+ expanded_prediction_mask = prediction_mask.repeat_interleave(
198
+ torch.sum(start_positions > 0, dim=-1), dim=0
199
+ )
200
+ end_logits = self.ned_end_classifier(
201
+ hidden_states=expanded_features,
202
+ start_positions=start_positions_indices,
203
+ p_mask=expanded_prediction_mask,
204
+ )
205
+
206
+ return end_logits
207
+
208
+ return None
209
+
210
+ def compute_classification_logits(
211
+ self,
212
+ model_features,
213
+ special_symbols_mask,
214
+ prediction_mask,
215
+ batch_size,
216
+ start_positions=None,
217
+ end_positions=None,
218
+ ) -> torch.Tensor:
219
+ if start_positions is None or end_positions is None:
220
+ start_positions = torch.zeros_like(prediction_mask)
221
+ end_positions = torch.zeros_like(prediction_mask)
222
+
223
+ model_start_features = self.ed_start_projector(model_features)
224
+ model_end_features = self.ed_end_projector(model_features)
225
+ model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
226
+
227
+ model_ed_features = torch.cat(
228
+ [model_start_features, model_end_features], dim=-1
229
+ )
230
+
231
+ # computing ed features
232
+ classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
233
+ special_symbols_representation = model_ed_features[special_symbols_mask].view(
234
+ batch_size, classes_representations, -1
235
+ )
236
+
237
+ logits = torch.bmm(
238
+ model_ed_features,
239
+ torch.permute(special_symbols_representation, (0, 2, 1)),
240
+ )
241
+
242
+ logits = self._mask_logits(logits, prediction_mask)
243
+
244
+ return logits
245
+
246
+ def forward(
247
+ self,
248
+ input_ids: torch.Tensor,
249
+ attention_mask: torch.Tensor,
250
+ token_type_ids: Optional[torch.Tensor] = None,
251
+ prediction_mask: Optional[torch.Tensor] = None,
252
+ special_symbols_mask: Optional[torch.Tensor] = None,
253
+ start_labels: Optional[torch.Tensor] = None,
254
+ end_labels: Optional[torch.Tensor] = None,
255
+ use_predefined_spans: bool = False,
256
+ *args,
257
+ **kwargs,
258
+ ) -> Dict[str, Any]:
259
+ batch_size, seq_len = input_ids.shape
260
+
261
+ model_features = self._get_model_features(
262
+ input_ids, attention_mask, token_type_ids
263
+ )
264
+
265
+ ned_start_labels = None
266
+
267
+ # named entity detection if required
268
+ if use_predefined_spans: # no need to compute spans
269
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
270
+ None,
271
+ None,
272
+ (
273
+ torch.clone(start_labels)
274
+ if start_labels is not None
275
+ else torch.zeros_like(input_ids)
276
+ ),
277
+ )
278
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
279
+ None,
280
+ None,
281
+ (
282
+ torch.clone(end_labels)
283
+ if end_labels is not None
284
+ else torch.zeros_like(input_ids)
285
+ ),
286
+ )
287
+
288
+ ned_start_predictions[ned_start_predictions > 0] = 1
289
+ ned_end_predictions[ned_end_predictions > 0] = 1
290
+
291
+ else: # compute spans
292
+ # start boundary prediction
293
+ ned_start_logits = self.ned_start_classifier(model_features)
294
+ ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
295
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
296
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
297
+
298
+ # end boundary prediction
299
+ ned_start_labels = (
300
+ torch.zeros_like(start_labels) if start_labels is not None else None
301
+ )
302
+
303
+ if ned_start_labels is not None:
304
+ ned_start_labels[start_labels == -100] = -100
305
+ ned_start_labels[start_labels > 0] = 1
306
+
307
+ ned_end_logits = self.compute_ned_end_logits(
308
+ ned_start_predictions,
309
+ ned_start_labels,
310
+ model_features,
311
+ prediction_mask,
312
+ batch_size,
313
+ )
314
+
315
+ if ned_end_logits is not None:
316
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
317
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
318
+ else:
319
+ ned_end_logits, ned_end_probabilities = None, None
320
+ ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
321
+
322
+ # flattening end predictions
323
+ # (flattening can happen only if the
324
+ # end boundaries were not predicted using the gold labels)
325
+ if not self.training and ned_end_logits is not None:
326
+ flattened_end_predictions = torch.zeros_like(ned_start_predictions)
327
+
328
+ row_indices, start_positions = torch.where(ned_start_predictions > 0)
329
+ ned_end_predictions[ned_end_predictions<start_positions] = start_positions[ned_end_predictions<start_positions]
330
+
331
+ end_spans_repeated = (row_indices + 1)* seq_len + ned_end_predictions
332
+ cummax_values, _ = end_spans_repeated.cummax(dim=0)
333
+
334
+ end_spans_repeated = (end_spans_repeated > torch.cat((end_spans_repeated[:1], cummax_values[:-1])))
335
+ end_spans_repeated[0] = True
336
+
337
+ ned_start_predictions[row_indices[~end_spans_repeated], start_positions[~end_spans_repeated]] = 0
338
+
339
+ row_indices, start_positions, ned_end_predictions = row_indices[end_spans_repeated], start_positions[end_spans_repeated], ned_end_predictions[end_spans_repeated]
340
+
341
+ flattened_end_predictions[row_indices, ned_end_predictions] = 1
342
+
343
+ total_start_predictions, total_end_predictions = ned_start_predictions.sum(), flattened_end_predictions.sum()
344
+
345
+ assert (
346
+ total_start_predictions == 0
347
+ or total_start_predictions == total_end_predictions
348
+ ), (
349
+ f"Total number of start predictions = {total_start_predictions}. "
350
+ f"Total number of end predictions = {total_end_predictions}"
351
+ )
352
+ ned_end_predictions = flattened_end_predictions
353
+ else:
354
+ ned_end_predictions = torch.zeros_like(ned_start_predictions)
355
+
356
+ start_position, end_position = (
357
+ (start_labels, end_labels)
358
+ if self.training
359
+ else (ned_start_predictions, ned_end_predictions)
360
+ )
361
+
362
+ # Entity disambiguation
363
+ ed_logits = self.compute_classification_logits(
364
+ model_features,
365
+ special_symbols_mask,
366
+ prediction_mask,
367
+ batch_size,
368
+ start_position,
369
+ end_position,
370
+ )
371
+ ed_probabilities = torch.softmax(ed_logits, dim=-1)
372
+ ed_predictions = torch.argmax(ed_probabilities, dim=-1)
373
+
374
+ # output build
375
+ output_dict = dict(
376
+ batch_size=batch_size,
377
+ ned_start_logits=ned_start_logits,
378
+ ned_start_probabilities=ned_start_probabilities,
379
+ ned_start_predictions=ned_start_predictions,
380
+ ned_end_logits=ned_end_logits,
381
+ ned_end_probabilities=ned_end_probabilities,
382
+ ned_end_predictions=ned_end_predictions,
383
+ ed_logits=ed_logits,
384
+ ed_probabilities=ed_probabilities,
385
+ ed_predictions=ed_predictions,
386
+ )
387
+
388
+ # compute loss if labels
389
+ if start_labels is not None and end_labels is not None and self.training:
390
+ # named entity detection loss
391
+
392
+ # start
393
+ if ned_start_logits is not None:
394
+ ned_start_loss = self.criterion(
395
+ ned_start_logits.view(-1, ned_start_logits.shape[-1]),
396
+ ned_start_labels.view(-1),
397
+ )
398
+ else:
399
+ ned_start_loss = 0
400
+
401
+ # end
402
+ if ned_end_logits is not None:
403
+ ned_end_labels = torch.zeros_like(end_labels)
404
+ ned_end_labels[end_labels == -100] = -100
405
+ ned_end_labels[end_labels > 0] = 1
406
+
407
+ ned_end_loss = self.criterion(
408
+ ned_end_logits,
409
+ (
410
+ torch.arange(
411
+ ned_end_labels.size(1), device=ned_end_labels.device
412
+ )
413
+ .unsqueeze(0)
414
+ .expand(batch_size, -1)[ned_end_labels > 0]
415
+ ).to(ned_end_labels.device),
416
+ )
417
+
418
+ else:
419
+ ned_end_loss = 0
420
+
421
+ # entity disambiguation loss
422
+ start_labels[ned_start_labels != 1] = -100
423
+ ed_labels = torch.clone(start_labels)
424
+ ed_labels[end_labels > 0] = end_labels[end_labels > 0]
425
+ ed_loss = self.criterion(
426
+ ed_logits.view(-1, ed_logits.shape[-1]),
427
+ ed_labels.view(-1),
428
+ )
429
+
430
+ output_dict["ned_start_loss"] = ned_start_loss
431
+ output_dict["ned_end_loss"] = ned_end_loss
432
+ output_dict["ed_loss"] = ed_loss
433
+
434
+ output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
435
+
436
+ return output_dict
437
+
438
+
439
+ class RelikReaderREModel(PreTrainedModel):
440
+ config_class = RelikReaderConfig
441
+
442
+ def __init__(self, config, *args, **kwargs):
443
+ super().__init__(config)
444
+ # Transformer model declaration
445
+ # self.transformer_model_name = transformer_model
446
+ self.config = config
447
+ self.transformer_model = (
448
+ AutoModel.from_pretrained(config.transformer_model)
449
+ if config.num_layers is None
450
+ else AutoModel.from_pretrained(
451
+ config.transformer_model, num_hidden_layers=config.num_layers
452
+ )
453
+ )
454
+ self.transformer_model.resize_token_embeddings(
455
+ self.transformer_model.config.vocab_size
456
+ + config.additional_special_symbols
457
+ + config.additional_special_symbols_types,
458
+ )
459
+
460
+ # named entity detection layers
461
+ self.ned_start_classifier = self._get_projection_layer(
462
+ config.activation, last_hidden=2, layer_norm=False
463
+ )
464
+
465
+ self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
466
+
467
+ self.relation_disambiguation_loss = (
468
+ config.relation_disambiguation_loss
469
+ if hasattr(config, "relation_disambiguation_loss")
470
+ else False
471
+ )
472
+
473
+ if self.config.entity_type_loss and self.config.add_entity_embedding:
474
+ input_hidden_ents = 3 * self.config.linears_hidden_size
475
+ else:
476
+ input_hidden_ents = 2 * self.config.linears_hidden_size
477
+
478
+ self.re_projector = self._get_projection_layer(
479
+ config.activation, input_hidden=2*self.transformer_model.config.hidden_size, hidden=input_hidden_ents, last_hidden=2*self.config.linears_hidden_size
480
+ )
481
+
482
+ self.re_relation_projector = self._get_projection_layer(
483
+ config.activation, input_hidden=self.transformer_model.config.hidden_size,
484
+ )
485
+
486
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
487
+ self.re_entities_projector = self._get_projection_layer(
488
+ config.activation,
489
+ input_hidden=2 * self.transformer_model.config.hidden_size,
490
+ )
491
+ self.re_definition_projector = self._get_projection_layer(
492
+ config.activation,
493
+ )
494
+
495
+ self.re_classifier = self._get_projection_layer(
496
+ config.activation,
497
+ input_hidden=config.linears_hidden_size,
498
+ last_hidden=2,
499
+ layer_norm=False,
500
+ )
501
+
502
+ self.training = config.training
503
+
504
+ # criterion
505
+ self.criterion = torch.nn.CrossEntropyLoss()
506
+ self.criterion_type = torch.nn.BCEWithLogitsLoss()
507
+
508
+ def _get_projection_layer(
509
+ self,
510
+ activation: str,
511
+ last_hidden: Optional[int] = None,
512
+ hidden: Optional[int] = None,
513
+ input_hidden=None,
514
+ layer_norm: bool = True,
515
+ ) -> torch.nn.Sequential:
516
+ head_components = [
517
+ torch.nn.Dropout(0.1),
518
+ torch.nn.Linear(
519
+ (
520
+ self.transformer_model.config.hidden_size
521
+ * self.config.use_last_k_layers
522
+ if input_hidden is None
523
+ else input_hidden
524
+ ),
525
+ self.config.linears_hidden_size if hidden is None else hidden,
526
+ ),
527
+ activation2functions[activation],
528
+ torch.nn.Dropout(0.1),
529
+ torch.nn.Linear(
530
+ self.config.linears_hidden_size if hidden is None else hidden,
531
+ self.config.linears_hidden_size if last_hidden is None else last_hidden,
532
+ ),
533
+ ]
534
+
535
+ if layer_norm:
536
+ head_components.append(
537
+ torch.nn.LayerNorm(
538
+ (
539
+ self.config.linears_hidden_size
540
+ if last_hidden is None
541
+ else last_hidden
542
+ ),
543
+ self.transformer_model.config.layer_norm_eps,
544
+ )
545
+ )
546
+
547
+ return torch.nn.Sequential(*head_components)
548
+
549
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
550
+ mask = mask.unsqueeze(-1)
551
+ if next(self.parameters()).dtype == torch.float16:
552
+ logits = logits * (1 - mask) - 65500 * mask
553
+ else:
554
+ logits = logits * (1 - mask) - 1e30 * mask
555
+ return logits
556
+
557
+ def _get_model_features(
558
+ self,
559
+ input_ids: torch.Tensor,
560
+ attention_mask: torch.Tensor,
561
+ token_type_ids: Optional[torch.Tensor],
562
+ ):
563
+ model_input = {
564
+ "input_ids": input_ids,
565
+ "attention_mask": attention_mask,
566
+ "output_hidden_states": self.config.use_last_k_layers > 1,
567
+ }
568
+
569
+ if token_type_ids is not None:
570
+ model_input["token_type_ids"] = token_type_ids
571
+
572
+ model_output = self.transformer_model(**model_input)
573
+
574
+ if self.config.use_last_k_layers > 1:
575
+ model_features = torch.cat(
576
+ model_output[1][-self.config.use_last_k_layers :], dim=-1
577
+ )
578
+ else:
579
+ model_features = model_output[0]
580
+
581
+ return model_features
582
+
583
+ def compute_ned_end_logits(
584
+ self,
585
+ start_predictions,
586
+ start_labels,
587
+ model_features,
588
+ prediction_mask,
589
+ batch_size,
590
+ mask_preceding: bool = False,
591
+ ) -> Optional[torch.Tensor]:
592
+ # todo: maybe when constraining on the spans,
593
+ # we should not use a prediction_mask for the end tokens.
594
+ # at least we should not during training imo
595
+ start_positions = start_labels if self.training else start_predictions
596
+ start_positions_indices = (
597
+ torch.arange(start_positions.size(1), device=start_positions.device)
598
+ .unsqueeze(0)
599
+ .expand(batch_size, -1)[start_positions > 0]
600
+ ).to(start_positions.device)
601
+
602
+ if len(start_positions_indices) > 0:
603
+ expanded_features = model_features.repeat_interleave(
604
+ torch.sum(start_positions > 0, dim=-1), dim=0
605
+ )
606
+ expanded_prediction_mask = prediction_mask.repeat_interleave(
607
+ torch.sum(start_positions > 0, dim=-1), dim=0
608
+ )
609
+ if mask_preceding:
610
+ expanded_prediction_mask[
611
+ torch.arange(
612
+ expanded_prediction_mask.shape[1],
613
+ device=expanded_prediction_mask.device,
614
+ )
615
+ < start_positions_indices.unsqueeze(1)
616
+ ] = 1
617
+ end_logits = self.ned_end_classifier(
618
+ hidden_states=expanded_features,
619
+ start_positions=start_positions_indices,
620
+ p_mask=expanded_prediction_mask,
621
+ )
622
+
623
+ return end_logits
624
+
625
+ return None
626
+
627
+ def compute_relation_logits(
628
+ self,
629
+ model_entity_features,
630
+ special_symbols_features,
631
+ ) -> torch.Tensor:
632
+ model_subject_object_features = self.re_projector(model_entity_features)
633
+ model_subject_features = model_subject_object_features[
634
+ :, :, : model_subject_object_features.shape[-1] // 2
635
+ ]
636
+ model_object_features = model_subject_object_features[
637
+ :, :, model_subject_object_features.shape[-1] // 2 :
638
+ ]
639
+ special_symbols_start_representation = self.re_relation_projector(
640
+ special_symbols_features
641
+ )
642
+ re_logits = torch.einsum(
643
+ "bse,bde,bfe->bsdfe",
644
+ model_subject_features,
645
+ model_object_features,
646
+ special_symbols_start_representation,
647
+ )
648
+ re_logits = self.re_classifier(re_logits)
649
+
650
+ return re_logits
651
+
652
+ def compute_entity_logits(
653
+ self,
654
+ model_entity_features,
655
+ special_symbols_features,
656
+ ) -> torch.Tensor:
657
+ model_ed_features = self.re_entities_projector(model_entity_features)
658
+ special_symbols_ed_representation = self.re_definition_projector(
659
+ special_symbols_features
660
+ )
661
+
662
+ logits = torch.bmm(
663
+ model_ed_features,
664
+ torch.permute(special_symbols_ed_representation, (0, 2, 1)),
665
+ )
666
+ logits = self._mask_logits(
667
+ logits, (model_entity_features == -100).all(2).long()
668
+ )
669
+ return logits
670
+
671
+ def compute_loss(self, logits, labels, mask=None):
672
+ logits = logits.reshape(-1, logits.shape[-1])
673
+ labels = labels.reshape(-1).long()
674
+ if mask is not None:
675
+ return self.criterion(logits[mask], labels[mask])
676
+ return self.criterion(logits, labels)
677
+
678
+ def compute_ned_type_loss(
679
+ self,
680
+ disambiguation_labels,
681
+ re_ned_entities_logits,
682
+ ned_type_logits,
683
+ re_entities_logits,
684
+ entity_types,
685
+ mask,
686
+ ):
687
+ if self.config.entity_type_loss and self.relation_disambiguation_loss:
688
+ return self.criterion_type(
689
+ re_ned_entities_logits[disambiguation_labels != -100],
690
+ disambiguation_labels[disambiguation_labels != -100],
691
+ )
692
+ if self.config.entity_type_loss:
693
+ return self.criterion_type(
694
+ ned_type_logits[mask],
695
+ disambiguation_labels[:, :, :entity_types][mask],
696
+ )
697
+
698
+ if self.relation_disambiguation_loss:
699
+ return self.criterion_type(
700
+ re_entities_logits[disambiguation_labels != -100],
701
+ disambiguation_labels[disambiguation_labels != -100],
702
+ )
703
+ return 0
704
+
705
+ def compute_relation_loss(self, relation_labels, re_logits):
706
+ return self.compute_loss(
707
+ re_logits, relation_labels, relation_labels.view(-1) != -100
708
+ )
709
+
710
+ def forward(
711
+ self,
712
+ input_ids: torch.Tensor,
713
+ attention_mask: torch.Tensor,
714
+ token_type_ids: torch.Tensor,
715
+ prediction_mask: Optional[torch.Tensor] = None,
716
+ special_symbols_mask: Optional[torch.Tensor] = None,
717
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
718
+ start_labels: Optional[torch.Tensor] = None,
719
+ end_labels: Optional[torch.Tensor] = None,
720
+ disambiguation_labels: Optional[torch.Tensor] = None,
721
+ relation_labels: Optional[torch.Tensor] = None,
722
+ relation_threshold: float = None,
723
+ is_validation: bool = False,
724
+ is_prediction: bool = False,
725
+ use_predefined_spans: bool = False,
726
+ *args,
727
+ **kwargs,
728
+ ) -> Dict[str, Any]:
729
+
730
+ relation_threshold = self.config.threshold if relation_threshold is None else relation_threshold
731
+
732
+ batch_size = input_ids.shape[0]
733
+
734
+ model_features = self._get_model_features(
735
+ input_ids, attention_mask, token_type_ids
736
+ )
737
+
738
+ # named entity detection
739
+ if use_predefined_spans:
740
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
741
+ None,
742
+ None,
743
+ torch.zeros_like(start_labels),
744
+ )
745
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
746
+ None,
747
+ None,
748
+ torch.zeros_like(end_labels),
749
+ )
750
+
751
+ ned_start_predictions[start_labels > 0] = 1
752
+ ned_end_predictions[end_labels > 0] = 1
753
+ ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
754
+ ned_start_labels = start_labels
755
+ ned_start_labels[start_labels > 0] = 1
756
+ else:
757
+ # start boundary prediction
758
+ ned_start_logits = self.ned_start_classifier(model_features)
759
+ if is_validation or is_prediction:
760
+ ned_start_logits = self._mask_logits(
761
+ ned_start_logits, prediction_mask
762
+ ) # why?
763
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
764
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
765
+
766
+ # end boundary prediction
767
+ ned_start_labels = (
768
+ torch.zeros_like(start_labels) if start_labels is not None else None
769
+ )
770
+
771
+ # start_labels contain entity id at their position, we just need 1 for start of entity
772
+ if ned_start_labels is not None:
773
+ ned_start_labels[start_labels == -100] = -100
774
+ ned_start_labels[start_labels > 0] = 1
775
+
776
+ # compute end logits only if there are any start predictions.
777
+ # For each start prediction, n end predictions are made
778
+ ned_end_logits = self.compute_ned_end_logits(
779
+ ned_start_predictions,
780
+ ned_start_labels,
781
+ model_features,
782
+ prediction_mask,
783
+ batch_size,
784
+ True,
785
+ )
786
+
787
+ if ned_end_logits is not None:
788
+ # For each start prediction, n end predictions are made based on
789
+ # binary classification ie. argmax at each position.
790
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
791
+ ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
792
+ else:
793
+ ned_end_logits, ned_end_probabilities = None, None
794
+ ned_end_predictions = torch.zeros_like(ned_start_predictions)
795
+
796
+ if is_prediction or is_validation:
797
+ end_preds_count = ned_end_predictions.sum(1)
798
+ # If there are no end predictions for a start prediction, remove the start prediction
799
+ if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
800
+ ned_start_predictions[ned_start_predictions == 1] = (
801
+ end_preds_count != 0
802
+ ).long()
803
+ ned_end_predictions = ned_end_predictions[end_preds_count != 0]
804
+
805
+ if end_labels is not None:
806
+ end_labels = end_labels[~(end_labels == -100).all(2)]
807
+
808
+ start_position, end_position = (
809
+ (start_labels, end_labels)
810
+ if (not is_prediction and not is_validation)
811
+ else (ned_start_predictions, ned_end_predictions)
812
+ )
813
+
814
+ start_counts = (start_position > 0).sum(1)
815
+ if (start_counts > 0).any():
816
+ ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
817
+ # limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
818
+ # if is_validation or is_prediction:
819
+ # ned_start_predictions[ned_start_predictions == 1] = start_counts
820
+ # We can only predict relations if we have start and end predictions
821
+ if (end_position > 0).sum() > 0:
822
+ ends_count = (end_position > 0).sum(1)
823
+ model_subject_features = torch.cat(
824
+ [
825
+ torch.repeat_interleave(
826
+ model_features[start_position > 0], ends_count, dim=0
827
+ ), # start position features
828
+ torch.repeat_interleave(model_features, start_counts, dim=0)[
829
+ end_position > 0
830
+ ], # end position features
831
+ ],
832
+ dim=-1,
833
+ )
834
+ ents_count = torch.nn.utils.rnn.pad_sequence(
835
+ torch.split(ends_count, start_counts.tolist()),
836
+ batch_first=True,
837
+ padding_value=0,
838
+ ).sum(1)
839
+ model_subject_features = torch.nn.utils.rnn.pad_sequence(
840
+ torch.split(model_subject_features, ents_count.tolist()),
841
+ batch_first=True,
842
+ padding_value=-100,
843
+ )
844
+
845
+ # if is_validation or is_prediction:
846
+ # model_subject_features = model_subject_features[:, :30, :]
847
+
848
+ # entity disambiguation. Here relation_disambiguation_loss would only be useful to
849
+ # reduce the number of candidate relations for the next step, but currently unused.
850
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
851
+ (re_ned_entities_logits) = self.compute_entity_logits(
852
+ model_subject_features,
853
+ model_features[
854
+ special_symbols_mask | special_symbols_mask_entities
855
+ ].view(batch_size, -1, model_features.shape[-1]),
856
+ )
857
+ entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
858
+ ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
859
+ re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
860
+
861
+ if self.config.entity_type_loss:
862
+ ned_type_probabilities = torch.sigmoid(ned_type_logits)
863
+ ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
864
+
865
+ if self.config.add_entity_embedding:
866
+ special_symbols_representation = model_features[
867
+ special_symbols_mask_entities
868
+ ].view(batch_size, entity_types, -1)
869
+
870
+ entities_representation = torch.einsum(
871
+ "bsp,bpe->bse",
872
+ ned_type_probabilities,
873
+ special_symbols_representation,
874
+ )
875
+ model_subject_features = torch.cat(
876
+ [model_subject_features, entities_representation], dim=-1
877
+ )
878
+ re_entities_probabilities = torch.sigmoid(re_entities_logits)
879
+ re_entities_predictions = re_entities_probabilities.round()
880
+ else:
881
+ (
882
+ ned_type_logits,
883
+ ned_type_probabilities,
884
+ re_entities_logits,
885
+ re_entities_probabilities,
886
+ ) = (None, None, None, None)
887
+ ned_type_predictions, re_entities_predictions = (
888
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
889
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
890
+ )
891
+
892
+ # Compute relation logits
893
+ re_logits = self.compute_relation_logits(
894
+ model_subject_features,
895
+ model_features[special_symbols_mask].view(
896
+ batch_size, -1, model_features.shape[-1]
897
+ ),
898
+ )
899
+
900
+ re_probabilities = torch.softmax(re_logits, dim=-1)
901
+ # we set a thresshold instead of argmax in cause it needs to be tweaked
902
+ re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
903
+ re_probabilities = re_probabilities[:, :, :, :, 1]
904
+
905
+ else:
906
+ (
907
+ ned_type_logits,
908
+ ned_type_probabilities,
909
+ re_entities_logits,
910
+ re_entities_probabilities,
911
+ ) = (None, None, None, None)
912
+ ned_type_predictions, re_entities_predictions = (
913
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
914
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
915
+ )
916
+ re_logits, re_probabilities, re_predictions = (
917
+ torch.zeros(
918
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
919
+ ).to(input_ids.device),
920
+ torch.zeros(
921
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
922
+ ).to(input_ids.device),
923
+ torch.zeros(
924
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
925
+ ).to(input_ids.device),
926
+ )
927
+
928
+ # output build
929
+ output_dict = dict(
930
+ batch_size=batch_size,
931
+ ned_start_logits=ned_start_logits,
932
+ ned_start_probabilities=ned_start_probabilities,
933
+ ned_start_predictions=ned_start_predictions,
934
+ ned_end_logits=ned_end_logits,
935
+ ned_end_probabilities=ned_end_probabilities,
936
+ ned_end_predictions=ned_end_predictions,
937
+ ned_type_logits=ned_type_logits,
938
+ ned_type_probabilities=ned_type_probabilities,
939
+ ned_type_predictions=ned_type_predictions,
940
+ re_entities_logits=re_entities_logits,
941
+ re_entities_probabilities=re_entities_probabilities,
942
+ re_entities_predictions=re_entities_predictions,
943
+ re_logits=re_logits,
944
+ re_probabilities=re_probabilities,
945
+ re_predictions=re_predictions,
946
+ )
947
+
948
+ if (
949
+ start_labels is not None
950
+ and end_labels is not None
951
+ and relation_labels is not None
952
+ and is_prediction is False
953
+ ):
954
+ ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
955
+ end_labels[end_labels > 0] = 1
956
+ ned_end_loss = self.compute_loss(ned_end_logits, end_labels)
957
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
958
+ ned_type_loss = self.compute_ned_type_loss(
959
+ disambiguation_labels,
960
+ re_ned_entities_logits,
961
+ ned_type_logits,
962
+ re_entities_logits,
963
+ entity_types,
964
+ (model_subject_features != -100).all(2),
965
+ )
966
+ relation_loss = self.compute_relation_loss(relation_labels, re_logits)
967
+ # compute loss. We can skip the relation loss if we are in the first epochs (optional)
968
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
969
+ output_dict["loss"] = (
970
+ ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
971
+ ) / 4
972
+ output_dict["ned_type_loss"] = ned_type_loss
973
+ else:
974
+ output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + (
975
+ (9 / 10) * relation_loss
976
+ )
977
+ output_dict["ned_start_loss"] = ned_start_loss
978
+ output_dict["ned_end_loss"] = ned_end_loss
979
+ output_dict["re_loss"] = relation_loss
980
+
981
+ return output_dict
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d470972c47d651a0a36f33ad04526b1fd442fe3a59f54a5f47db59541e7ddf0
3
+ size 1753419514