import streamlit as st import torch from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast from utils import get_text, get_labels, load_model count_labels = 8 st.markdown("## Классификатор статей") st.markdown("Сервис классифицирует статьи по названию и аннотации. Нужно ввести в каждое окошко свою сущность и вам выдадут к какому классу относится статья") title = st.text_area("Введите название статьи") abstract = st.text_area("Введите аннотацию к статье, abstract статьи") tokenizer = DistilBertTokenizerFast() model = DistilBertForSequenceClassification() load_model(model, 'weight_model') #model.load_state_dict(torch.load('weight_model')) text = get_text(title, abstract) if text: raw_predictions = get_labels(text, model, tokenizer) st.markdown("Список классов к которым может относится данная статья") for raw in raw_predictions: st.markdown(f"{raw}") else: st.markdown("Ваш запрос пуст. Введите хотя бы название")