import torch from torch import nn from transformers import AutoTokenizer, AutoModelForSequenceClassification class BERTClassifier(nn.Module): def __init__(self): super().__init__() self.bert = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-tiny-toxicity') self.bert.classifier = nn.Linear(312, 312) for param in self.bert.parameters(): param.requires_grad = False self.linear = nn.Sequential( nn.Linear(312, 128), nn.Sigmoid(), nn.Dropout(), nn.Linear(128, 1) ) def forward(self, x, attention_mask=None): bert_out = self.bert(x, attention_mask=attention_mask).logits out = self.linear(bert_out).squeeze(1) return out