Update models.py
Browse filesAdded few fixed parameters
models.py
CHANGED
@@ -19,11 +19,12 @@ class BertPooler(nn.Module):
|
|
19 |
return pooled_output
|
20 |
|
21 |
|
|
|
22 |
class Model_Rational_Label(BertPreTrainedModel):
|
23 |
-
def __init__(self,config
|
24 |
super().__init__(config)
|
25 |
-
self.num_labels=
|
26 |
-
self.impact_factor=
|
27 |
self.bert = BertModel(config,add_pooling_layer=False)
|
28 |
self.bert_pooler=BertPooler(config)
|
29 |
self.token_dropout = nn.Dropout(0.1)
|
@@ -33,8 +34,8 @@ class Model_Rational_Label(BertPreTrainedModel):
|
|
33 |
self.init_weights()
|
34 |
# self.embeddings = AutoModelForTokenClassification.from_pretrained(params['model_path'], cache_dir=params['cache_path'])
|
35 |
|
36 |
-
def forward(self, input_ids=None,
|
37 |
-
outputs = self.bert(input_ids,
|
38 |
# out = outputs.last_hidden_state
|
39 |
out=outputs[0]
|
40 |
logits = self.token_classifier(self.token_dropout(out))
|
|
|
19 |
return pooled_output
|
20 |
|
21 |
|
22 |
+
|
23 |
class Model_Rational_Label(BertPreTrainedModel):
|
24 |
+
def __init__(self,config):
|
25 |
super().__init__(config)
|
26 |
+
self.num_labels=2
|
27 |
+
self.impact_factor=0.8
|
28 |
self.bert = BertModel(config,add_pooling_layer=False)
|
29 |
self.bert_pooler=BertPooler(config)
|
30 |
self.token_dropout = nn.Dropout(0.1)
|
|
|
34 |
self.init_weights()
|
35 |
# self.embeddings = AutoModelForTokenClassification.from_pretrained(params['model_path'], cache_dir=params['cache_path'])
|
36 |
|
37 |
+
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, attn=None, labels=None):
|
38 |
+
outputs = self.bert(input_ids, attention_mask)
|
39 |
# out = outputs.last_hidden_state
|
40 |
out=outputs[0]
|
41 |
logits = self.token_classifier(self.token_dropout(out))
|