File size: 3,845 Bytes
cdb0abe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a8cbcb
cdb0abe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import logging
import torch
from aiogram import Bot, Dispatcher
from aiogram.types import Message, ReplyKeyboardMarkup, KeyboardButton, ReplyKeyboardRemove
from aiogram.filters.command import Command
from functools import lru_cache

from preprocess_text import TextPreprocessorBERT
from model import BERTClassifier
from transformers import AutoTokenizer


device = 'cpu'

# Инициализация объектов
TOKEN = '6864353709:AAHM-J59cETYpxWzJFdHpm9QyV7rE2FL_KU'
bot = Bot(token=TOKEN)
dp = Dispatcher()
logging.basicConfig(filename="mylog.log", level=logging.INFO)

start_keyboard = ReplyKeyboardMarkup(
    keyboard=[
        [KeyboardButton(text="/start")]
    ],
    resize_keyboard=True
)


@lru_cache(maxsize=1)
def load_model():
    model = BERTClassifier()
    weights_path = 'bot/model_weights_new.pth'
    state_dict = torch.load(weights_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model

@lru_cache(maxsize=1)
def load_tokenizer():
    return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')

model = load_model()
tokenizer = load_tokenizer()

# Обработка команды start
@dp.message(Command(commands=['start']))
async def proccess_command_start(message: Message):
    user_name = message.from_user.full_name
    user_id = message.from_user.id
    text = f'Привет, {user_name}! Я помогу тебе оценить токсичность сообщений 😃'
    logging.info(f'{user_name} {user_id} запустил бота')
    await bot.send_message(chat_id=user_id, text=text, reply_markup=ReplyKeyboardRemove())

# Добавление кнопки "Start" при старте
@dp.message(Command(commands=['start']))
async def send_welcome(message: Message):
    user_id = message.from_user.id
    await bot.send_message(chat_id=user_id, text="Нажмите кнопку /start для начала работы", reply_markup=start_keyboard)


@dp.message()
async def predict_sentence(message: Message):
    user_name = message.from_user.full_name
    user_id = message.from_user.id
    text = message.text 
    
    # Предобработка сообщения
    preprocessor = TextPreprocessorBERT()
    preprocessed_text = preprocessor.transform(text)

    tokens = tokenizer.encode_plus(
        preprocessed_text,
        add_special_tokens=True,
        truncation=True,
        max_length=100,
        padding='max_length',
        return_tensors='pt'
    )
    
    # Получаем input_ids и attention_mask из токенов
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)
    
    # Предсказание
    with torch.no_grad():
        output = model(input_ids, attention_mask=attention_mask)
    
    # Интерпретация результата
    prediction = torch.sigmoid(output).item()
    if prediction > 0.5:
            predicted_class = "ТОКСИК!!!"
            response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}'
            sticker_id = 'CAACAgIAAxkBAAMrZll5jPH6HJ3j7kSLDEQU8NKDjR0AAhQAA5KfHhEGBsTRjH5zHDUE'
    else:
          predicted_class = 'Не токсик)'
        #   response_text = f'{predicted_class} c вероятностью  {round(1 - prediction, 3)}'
          response_text = f'{predicted_class} c вероятностью  {round(prediction, 3)}'
          sticker_id = 'CAACAgIAAxkBAAMtZll5udV6ScWrGUMhkJIFmvazQicAAlgAA5KfHhFUuZt-mMSZyTUE'
    # Отправка ответа пользователю
    logging.info(f'{user_name} {user_id}: {text}')
    await bot.send_message(chat_id=user_id, text=response_text)
    await bot.send_sticker(chat_id=user_id, sticker=sticker_id)

if __name__ == '__main__':
     dp.run_polling(bot)