thak123 commited on
Commit
6198294
1 Parent(s): 2eef0b9

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +25 -0
utils.py CHANGED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import config
3
+
4
+
5
+ def categorical_accuracy(preds, y):
6
+ """
7
+ Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
8
+ """
9
+ max_preds = preds.argmax(
10
+ dim=1, keepdim=True) # get the index of the max probability
11
+ correct = max_preds.squeeze(1).eq(y)
12
+ return correct.sum() / torch.FloatTensor([y.shape[0]])
13
+
14
+ def label_encoder(x):
15
+ label_vec = {"0": 0, "1": 1, "-1": 2}
16
+ return label_vec[x.replace("__label__", "")]
17
+
18
+ def label_decoder(x):
19
+ label_vec = { 0:"U", 1:"P", 2:"N"}
20
+ return label_vec[x]
21
+
22
+ def label_full_decoder(x):
23
+ label_vec = { 0:"Neutral", 1:"Positive", 2:"Negative"}
24
+ return label_vec[x]
25
+