Haoxiang-Wang commited on
Commit
9452cf2
1 Parent(s): 57271c4

Create modelling_custom.py

Browse files
Files changed (1) hide show
  1. modelling_custom.py +18 -0
modelling_custom.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from transformers.models.mistral.modeling_mistral import MistralForSequenceClassification
4
+
5
+ class NormalizedLinear(torch.nn.Linear):
6
+ def forward(self, x):
7
+ x = F.normalize(x, p=2, dim=-1)
8
+ return super().forward(x)
9
+
10
+
11
+ class MistralForAttributePrediction(MistralForSequenceClassification):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ del self.score
16
+ self.score = NormalizedLinear(config.hidden_size, config.num_labels, bias=True)
17
+ # Initialize weights and apply final processing
18
+ self.post_init()