import logging import torch from aiogram import Bot, Dispatcher from aiogram.types import Message, ReplyKeyboardMarkup, KeyboardButton, ReplyKeyboardRemove from aiogram.filters.command import Command from functools import lru_cache from preprocess_text import TextPreprocessorBERT from model import BERTClassifier from transformers import AutoTokenizer device = 'cpu' # Инициализация объектов TOKEN = '6864353709:AAHM-J59cETYpxWzJFdHpm9QyV7rE2FL_KU' bot = Bot(token=TOKEN) dp = Dispatcher() logging.basicConfig(filename="mylog.log", level=logging.INFO) start_keyboard = ReplyKeyboardMarkup( keyboard=[ [KeyboardButton(text="/start")] ], resize_keyboard=True ) @lru_cache(maxsize=1) def load_model(): model = BERTClassifier() weights_path = 'bot/model_weights_new.pth' state_dict = torch.load(weights_path, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() return model @lru_cache(maxsize=1) def load_tokenizer(): return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity') model = load_model() tokenizer = load_tokenizer() # Обработка команды start @dp.message(Command(commands=['start'])) async def proccess_command_start(message: Message): user_name = message.from_user.full_name user_id = message.from_user.id text = f'Привет, {user_name}! Я помогу тебе оценить токсичность сообщений 😃' logging.info(f'{user_name} {user_id} запустил бота') await bot.send_message(chat_id=user_id, text=text, reply_markup=ReplyKeyboardRemove()) # Добавление кнопки "Start" при старте @dp.message(Command(commands=['start'])) async def send_welcome(message: Message): user_id = message.from_user.id await bot.send_message(chat_id=user_id, text="Нажмите кнопку /start для начала работы", reply_markup=start_keyboard) @dp.message() async def predict_sentence(message: Message): user_name = message.from_user.full_name user_id = message.from_user.id text = message.text # Предобработка сообщения preprocessor = TextPreprocessorBERT() preprocessed_text = preprocessor.transform(text) tokens = tokenizer.encode_plus( preprocessed_text, add_special_tokens=True, truncation=True, max_length=100, padding='max_length', return_tensors='pt' ) # Получаем input_ids и attention_mask из токенов input_ids = tokens['input_ids'].to(device) attention_mask = tokens['attention_mask'].to(device) # Предсказание with torch.no_grad(): output = model(input_ids, attention_mask=attention_mask) # Интерпретация результата prediction = torch.sigmoid(output).item() if prediction > 0.5: predicted_class = "ТОКСИК!!!" response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}' sticker_id = 'CAACAgIAAxkBAAMrZll5jPH6HJ3j7kSLDEQU8NKDjR0AAhQAA5KfHhEGBsTRjH5zHDUE' else: predicted_class = 'Не токсик)' # response_text = f'{predicted_class} c вероятностью {round(1 - prediction, 3)}' response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}' sticker_id = 'CAACAgIAAxkBAAMtZll5udV6ScWrGUMhkJIFmvazQicAAlgAA5KfHhFUuZt-mMSZyTUE' # Отправка ответа пользователю logging.info(f'{user_name} {user_id}: {text}') await bot.send_message(chat_id=user_id, text=response_text) await bot.send_sticker(chat_id=user_id, sticker=sticker_id) if __name__ == '__main__': dp.run_polling(bot)