Dzhamb's picture
Update utils.py
3464fc2
raw
history blame
No virus
1.08 kB
import torch
import streamlit as st
def get_text(title: str, abstract: str):
if abstract and title:
text = abstract + ' ' + title
elif title:
text = title
else:
text = None
return text
def get_labels(text, model, tokenizer, count_labels=8):
tokens = tokeinizer(text, return_tensors='pt')
outputs = model(**tokens)
probs = torch.nn.Softmax(dim=count_labels)(outputs.logits)
labels = ['Computer_science', 'Economics',
'Electrical_Engineering_and_Systems_Science', 'Mathematics',
'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
'Statistics']
sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
cumsum = 0
result_labels = []
for pair in sort_lst:
cumsum += pair[0]
if cumsum > 0.95 and len(result_labels) >= 1:
return result_labels
result_labels.append(pair[1])
@st.cache(allow_output_mutation=True)
def load_model(model, filename):
model.load_state_dict(torch.load(filename))
return model