File size: 1,076 Bytes
1f9f1ad
c2c3782
1f9f1ad
 
a6c4164
 
 
 
 
 
 
 
 
 
7832338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2c3782
3464fc2
c2c3782
ab282f2
 
7832338
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import streamlit as st


def get_text(title: str, abstract: str):
  if abstract and title:
    text = abstract + ' ' + title
  elif title:
    text = title
  else:
    text = None
    
  return text

def get_labels(text, model, tokenizer, count_labels=8):
  tokens = tokeinizer(text, return_tensors='pt')
  outputs = model(**tokens)
  probs = torch.nn.Softmax(dim=count_labels)(outputs.logits)
  
  labels = ['Computer_science', 'Economics',
       'Electrical_Engineering_and_Systems_Science', 'Mathematics',
       'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
       'Statistics']
       
  sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
  cumsum = 0
  result_labels = []
  for pair in sort_lst:
      cumsum += pair[0]
      if cumsum > 0.95 and len(result_labels) >= 1:
          return result_labels
      result_labels.append(pair[1])
 
 @st.cache(allow_output_mutation=True)
 def load_model(model, filename):
   model.load_state_dict(torch.load(filename)) 
   return model