gyr66 commited on
Commit
7483a21
1 Parent(s): 477ebb5

Upload BertForRelationExtraction

Browse files
Files changed (2) hide show
  1. config.json +4 -1
  2. model.py +85 -0
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_name_or_path": "bert-base-uncased",
3
  "architectures": [
4
  "BertForRelationExtraction"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
 
 
 
7
  "classifier_dropout": null,
8
  "e1_end_token_id": 30524,
9
  "e1_start_token_id": 30523,
 
1
  {
2
+ "_name_or_path": "gyr66/relation_extraction_bert_base_uncased",
3
  "architectures": [
4
  "BertForRelationExtraction"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModelForSequenceClassification": "model.BertForRelationExtraction"
9
+ },
10
  "classifier_dropout": null,
11
  "e1_end_token_id": 30524,
12
  "e1_start_token_id": 30523,
model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import (
4
+ BertPreTrainedModel,
5
+ BertModel,
6
+ AutoModelForSequenceClassification,
7
+ BertConfig,
8
+ )
9
+ from transformers.modeling_outputs import SequenceClassifierOutput
10
+
11
+
12
+ class BertForRelationExtraction(BertPreTrainedModel):
13
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ self.num_labels = len(config.label2id)
18
+ self.config = config
19
+ self.bert = BertModel(config, add_pooling_layer=False)
20
+ self.dropout = nn.Dropout(
21
+ config.classifier_dropout
22
+ if config.classifier_dropout is not None
23
+ else config.hidden_dropout_prob
24
+ )
25
+ self.layer_norm = nn.LayerNorm(2 * config.hidden_size)
26
+ self.classifier = nn.Linear(2 * config.hidden_size, self.num_labels)
27
+ self.post_init()
28
+
29
+ def forward(
30
+ self,
31
+ input_ids=None,
32
+ attention_mask=None,
33
+ token_type_ids=None,
34
+ position_ids=None,
35
+ head_mask=None,
36
+ inputs_embeds=None,
37
+ labels=None,
38
+ output_attentions=None,
39
+ output_hidden_states=None,
40
+ return_dict=None,
41
+ ):
42
+ return_dict = (
43
+ return_dict if return_dict is not None else self.config.use_return_dict
44
+ )
45
+
46
+ outputs = self.bert(
47
+ input_ids,
48
+ attention_mask=attention_mask,
49
+ token_type_ids=token_type_ids,
50
+ position_ids=position_ids,
51
+ head_mask=head_mask,
52
+ inputs_embeds=inputs_embeds,
53
+ output_attentions=output_attentions,
54
+ output_hidden_states=output_hidden_states,
55
+ return_dict=return_dict,
56
+ )
57
+
58
+ sequence_output = outputs[0]
59
+
60
+ sequence_output = self.dropout(sequence_output)
61
+
62
+ e1_start = torch.where(input_ids == self.config.e1_start_token_id)
63
+ e2_start = torch.where(input_ids == self.config.e2_start_token_id)
64
+
65
+ e1_hidden_states = sequence_output[e1_start[0], e1_start[1]]
66
+ e2_hidden_states = sequence_output[e2_start[0], e2_start[1]]
67
+
68
+ h = torch.cat((e1_hidden_states, e2_hidden_states), dim=-1)
69
+ logits = self.classifier(self.layer_norm(h))
70
+
71
+ loss = None
72
+ if labels is not None:
73
+ loss_fct = nn.CrossEntropyLoss()
74
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
75
+
76
+ if not return_dict:
77
+ output = (logits,) + outputs[2:] # Need to check outputs shape
78
+ return ((loss,) + output) if loss is not None else output
79
+
80
+ return SequenceClassifierOutput(
81
+ loss=loss,
82
+ logits=logits,
83
+ hidden_states=outputs.hidden_states,
84
+ attentions=outputs.attentions,
85
+ )