VerVelVel's picture
pochti final
6a8cbcb
raw
history blame
No virus
3.85 kB
import logging
import torch
from aiogram import Bot, Dispatcher
from aiogram.types import Message, ReplyKeyboardMarkup, KeyboardButton, ReplyKeyboardRemove
from aiogram.filters.command import Command
from functools import lru_cache
from preprocess_text import TextPreprocessorBERT
from model import BERTClassifier
from transformers import AutoTokenizer
device = 'cpu'
# Инициализация объектов
TOKEN = '6864353709:AAHM-J59cETYpxWzJFdHpm9QyV7rE2FL_KU'
bot = Bot(token=TOKEN)
dp = Dispatcher()
logging.basicConfig(filename="mylog.log", level=logging.INFO)
start_keyboard = ReplyKeyboardMarkup(
keyboard=[
[KeyboardButton(text="/start")]
],
resize_keyboard=True
)
@lru_cache(maxsize=1)
def load_model():
model = BERTClassifier()
weights_path = 'bot/model_weights_new.pth'
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
@lru_cache(maxsize=1)
def load_tokenizer():
return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')
model = load_model()
tokenizer = load_tokenizer()
# Обработка команды start
@dp.message(Command(commands=['start']))
async def proccess_command_start(message: Message):
user_name = message.from_user.full_name
user_id = message.from_user.id
text = f'Привет, {user_name}! Я помогу тебе оценить токсичность сообщений 😃'
logging.info(f'{user_name} {user_id} запустил бота')
await bot.send_message(chat_id=user_id, text=text, reply_markup=ReplyKeyboardRemove())
# Добавление кнопки "Start" при старте
@dp.message(Command(commands=['start']))
async def send_welcome(message: Message):
user_id = message.from_user.id
await bot.send_message(chat_id=user_id, text="Нажмите кнопку /start для начала работы", reply_markup=start_keyboard)
@dp.message()
async def predict_sentence(message: Message):
user_name = message.from_user.full_name
user_id = message.from_user.id
text = message.text
# Предобработка сообщения
preprocessor = TextPreprocessorBERT()
preprocessed_text = preprocessor.transform(text)
tokens = tokenizer.encode_plus(
preprocessed_text,
add_special_tokens=True,
truncation=True,
max_length=100,
padding='max_length',
return_tensors='pt'
)
# Получаем input_ids и attention_mask из токенов
input_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
# Предсказание
with torch.no_grad():
output = model(input_ids, attention_mask=attention_mask)
# Интерпретация результата
prediction = torch.sigmoid(output).item()
if prediction > 0.5:
predicted_class = "ТОКСИК!!!"
response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}'
sticker_id = 'CAACAgIAAxkBAAMrZll5jPH6HJ3j7kSLDEQU8NKDjR0AAhQAA5KfHhEGBsTRjH5zHDUE'
else:
predicted_class = 'Не токсик)'
# response_text = f'{predicted_class} c вероятностью {round(1 - prediction, 3)}'
response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}'
sticker_id = 'CAACAgIAAxkBAAMtZll5udV6ScWrGUMhkJIFmvazQicAAlgAA5KfHhFUuZt-mMSZyTUE'
# Отправка ответа пользователю
logging.info(f'{user_name} {user_id}: {text}')
await bot.send_message(chat_id=user_id, text=response_text)
await bot.send_sticker(chat_id=user_id, sticker=sticker_id)
if __name__ == '__main__':
dp.run_polling(bot)