nlp_project / task2.py
Tatiana
files added
aa3e28c
raw
history blame
1.63 kB
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.preprocessing import LabelEncoder
#Загрузка сохраненной модели и токенизатора в Streamlit
loaded_model_path = "nlp_project/model"
loaded_tokenizer_path = "nlp_project/tokenizer"
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)
labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта']
label_encoder = LabelEncoder()
label_encoder.fit(labels)
def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
if not user_input:
return "Введите текст"
def tokenize_text(text):
encoded_text = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt'
)
return encoded_text
encoded_text = tokenize_text(user_input)
with torch.no_grad():
model.eval()
input_ids = encoded_text['input_ids']
attention_mask = encoded_text['attention_mask']
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_class_index = torch.argmax(logits, dim=1).item()
# Получение названия класса
predicted_class = label_encoder.classes_[predicted_class_index]
return predicted_class