Dzhamb commited on
Commit
7832338
1 Parent(s): 498004d

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +20 -2
utils.py CHANGED
@@ -8,5 +8,23 @@ def get_text(title: str, abstract: str):
8
 
9
  return text
10
 
11
- def get_label():
12
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  return text
10
 
11
+ def get_labels(text, model, tokenizer, count_labels=8):
12
+ tokens = tokeinizer(text, return_tensors='pt')
13
+ outputs = model(**tokens)
14
+ probs = torch.nn.Softmax(dim=count_labels)(outputs.logits)
15
+
16
+ labels = ['Computer_science', 'Economics',
17
+ 'Electrical_Engineering_and_Systems_Science', 'Mathematics',
18
+ 'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
19
+ 'Statistics']
20
+
21
+ sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
22
+ cumsum = 0
23
+ result_labels = []
24
+ for pair in sort_lst:
25
+ cumsum += pair[0]
26
+ if cumsum > 0.95 and len(result_labels) >= 1:
27
+ return result_labels
28
+ result_labels.append(pair[1])
29
+
30
+