import streamlit as st import numpy as np import pandas as pd import torch import transformers import tokenizers @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None}) def load_model(): from transformers import AutoTokenizer, AutoModelForSequenceClassification model_name = 'distilbert-base-cased' tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8) model.load_state_dict(torch.load('model_weights2.pt', map_location=torch.device('cpu'))) model.eval() return tokenizer, model @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None}) def predict(title, summary, tokenizer, model): text = title + "\n" + summary tokens = tokenizer.encode(text) with torch.no_grad(): logits = model(torch.as_tensor([tokens]))[0] probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy() classes = np.flip(np.argsort(probs)) sum_probs = 0 ind = 0 prediction = [] prediction_probs = [] while sum_probs < 0.95: prediction.append(label_to_theme[classes[ind]]) prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%") sum_probs += probs[classes[ind]] ind += 1 return prediction, prediction_probs @st.cache(suppress_st_warning=True) def get_results(prediction, prediction_probs): frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs}) frame.index = np.arange(1, len(frame) + 1) return frame label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Math', 4: 'Quantitative biology', 5: 'Quantitative Finance', 6: 'Statistics', 7: 'Physics'} st.title("Arxiv articles classification") st.markdown("

", unsafe_allow_html=True) st.markdown("This is an interface that can determine the article's category based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.") tokenizer, model = load_model() title = st.text_area(label='Title', height=100) summary = st.text_area(label='Summary (optional)', height=250) button = st.button('Run') if button: prediction, prediction_probs = predict(title, summary, tokenizer, model) ans = get_results(prediction, prediction_probs) if len(title + "\n" + summary) < 20: st.error("Your input is too short. It is probably not a real article, please try again.") else: st.subheader('Results:') st.write(ans)