File size: 784 Bytes
087390d
 
 
 
 
28a78d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification


class BERTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-tiny-toxicity')
        self.bert.classifier = nn.Linear(312, 312)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.linear = nn.Sequential(
            nn.Linear(312, 128),
            nn.Sigmoid(),
            nn.Dropout(),
            nn.Linear(128, 1)
        )

    def forward(self, x, attention_mask=None):
        bert_out = self.bert(x, attention_mask=attention_mask).logits
        out = self.linear(bert_out).squeeze(1)
        return out