Spaces:
Runtime error
Runtime error
import torch | |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast | |
import streamlit as st | |
def get_text(title: str, abstract: str): | |
if abstract and title: | |
text = abstract + ' ' + title | |
elif title: | |
text = title | |
else: | |
text = None | |
return text | |
def get_labels(text, model, tokenizer, count_labels=8): | |
tokens = tokeinizer(text, return_tensors='pt') | |
outputs = model(**tokens) | |
probs = torch.nn.Softmax(dim=count_labels)(outputs.logits) | |
labels = ['Computer_science', 'Economics', | |
'Electrical_Engineering_and_Systems_Science', 'Mathematics', | |
'Physics', 'Quantitative_Biology', 'Quantitative_Finance', | |
'Statistics'] | |
sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0]) | |
cumsum = 0 | |
result_labels = [] | |
for pair in sort_lst: | |
cumsum += pair[0] | |
if cumsum > 0.95 and len(result_labels) >= 1: | |
return result_labels | |
result_labels.append(pair[1]) | |
def load_model(): | |
tokenizer = DistilBertTokenizerFast() | |
model = DistilBertForSequenceClassification() | |
model.load_state_dict(torch.load('weight_model')) | |
return model, tokenizer | |