Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
from aiogram import Bot, Dispatcher | |
from aiogram.types import Message, ReplyKeyboardMarkup, KeyboardButton, ReplyKeyboardRemove | |
from aiogram.filters.command import Command | |
from functools import lru_cache | |
from preprocess_text import TextPreprocessorBERT | |
from model import BERTClassifier | |
from transformers import AutoTokenizer | |
device = 'cpu' | |
# Инициализация объектов | |
TOKEN = '6864353709:AAHM-J59cETYpxWzJFdHpm9QyV7rE2FL_KU' | |
bot = Bot(token=TOKEN) | |
dp = Dispatcher() | |
logging.basicConfig(filename="mylog.log", level=logging.INFO) | |
start_keyboard = ReplyKeyboardMarkup( | |
keyboard=[ | |
[KeyboardButton(text="/start")] | |
], | |
resize_keyboard=True | |
) | |
def load_model(): | |
model = BERTClassifier() | |
weights_path = 'bot/model_weights_new.pth' | |
state_dict = torch.load(weights_path, map_location=device) | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
return model | |
def load_tokenizer(): | |
return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity') | |
model = load_model() | |
tokenizer = load_tokenizer() | |
# Обработка команды start | |
async def proccess_command_start(message: Message): | |
user_name = message.from_user.full_name | |
user_id = message.from_user.id | |
text = f'Привет, {user_name}! Я помогу тебе оценить токсичность сообщений 😃' | |
logging.info(f'{user_name} {user_id} запустил бота') | |
await bot.send_message(chat_id=user_id, text=text, reply_markup=ReplyKeyboardRemove()) | |
# Добавление кнопки "Start" при старте | |
async def send_welcome(message: Message): | |
user_id = message.from_user.id | |
await bot.send_message(chat_id=user_id, text="Нажмите кнопку /start для начала работы", reply_markup=start_keyboard) | |
async def predict_sentence(message: Message): | |
user_name = message.from_user.full_name | |
user_id = message.from_user.id | |
text = message.text | |
# Предобработка сообщения | |
preprocessor = TextPreprocessorBERT() | |
preprocessed_text = preprocessor.transform(text) | |
tokens = tokenizer.encode_plus( | |
preprocessed_text, | |
add_special_tokens=True, | |
truncation=True, | |
max_length=100, | |
padding='max_length', | |
return_tensors='pt' | |
) | |
# Получаем input_ids и attention_mask из токенов | |
input_ids = tokens['input_ids'].to(device) | |
attention_mask = tokens['attention_mask'].to(device) | |
# Предсказание | |
with torch.no_grad(): | |
output = model(input_ids, attention_mask=attention_mask) | |
# Интерпретация результата | |
prediction = torch.sigmoid(output).item() | |
if prediction > 0.5: | |
predicted_class = "ТОКСИК!!!" | |
response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}' | |
sticker_id = 'CAACAgIAAxkBAAMrZll5jPH6HJ3j7kSLDEQU8NKDjR0AAhQAA5KfHhEGBsTRjH5zHDUE' | |
else: | |
predicted_class = 'Не токсик)' | |
# response_text = f'{predicted_class} c вероятностью {round(1 - prediction, 3)}' | |
response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}' | |
sticker_id = 'CAACAgIAAxkBAAMtZll5udV6ScWrGUMhkJIFmvazQicAAlgAA5KfHhFUuZt-mMSZyTUE' | |
# Отправка ответа пользователю | |
logging.info(f'{user_name} {user_id}: {text}') | |
await bot.send_message(chat_id=user_id, text=response_text) | |
await bot.send_sticker(chat_id=user_id, sticker=sticker_id) | |
if __name__ == '__main__': | |
dp.run_polling(bot) | |