Spaces:
Runtime error
Runtime error
File size: 1,076 Bytes
1f9f1ad c2c3782 1f9f1ad a6c4164 7832338 c2c3782 3464fc2 c2c3782 ab282f2 7832338 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
import streamlit as st
def get_text(title: str, abstract: str):
if abstract and title:
text = abstract + ' ' + title
elif title:
text = title
else:
text = None
return text
def get_labels(text, model, tokenizer, count_labels=8):
tokens = tokeinizer(text, return_tensors='pt')
outputs = model(**tokens)
probs = torch.nn.Softmax(dim=count_labels)(outputs.logits)
labels = ['Computer_science', 'Economics',
'Electrical_Engineering_and_Systems_Science', 'Mathematics',
'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
'Statistics']
sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
cumsum = 0
result_labels = []
for pair in sort_lst:
cumsum += pair[0]
if cumsum > 0.95 and len(result_labels) >= 1:
return result_labels
result_labels.append(pair[1])
@st.cache(allow_output_mutation=True)
def load_model(model, filename):
model.load_state_dict(torch.load(filename))
return model
|