Spaces:
No application file
No application file
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
from sklearn.preprocessing import LabelEncoder | |
#Загрузка сохраненной модели и токенизатора в Streamlit | |
loaded_model_path = "nlp_project/model" | |
loaded_tokenizer_path = "nlp_project/tokenizer" | |
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path) | |
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path) | |
labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта'] | |
label_encoder = LabelEncoder() | |
label_encoder.fit(labels) | |
def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128): | |
if not user_input: | |
return "Введите текст" | |
def tokenize_text(text): | |
encoded_text = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=max_length, | |
pad_to_max_length=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
return encoded_text | |
encoded_text = tokenize_text(user_input) | |
with torch.no_grad(): | |
model.eval() | |
input_ids = encoded_text['input_ids'] | |
attention_mask = encoded_text['attention_mask'] | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits | |
predicted_class_index = torch.argmax(logits, dim=1).item() | |
# Получение названия класса | |
predicted_class = label_encoder.classes_[predicted_class_index] | |
return predicted_class | |