from transformers import BertTokenizer, BertModel from .configuration_my_bert_classifier import MyBertClassifierConfig from torch import nn from transformers.modeling_utils import PreTrainedModel class MyBertClassifier(PreTrainedModel): config_class = MyBertClassifierConfig def __init__(self, config): super(MyBertClassifier, self).__init__(config) self.bert = BertModel.from_pretrained('bert-base-cased') self.dropout = nn.Dropout(0.5) self.linear = nn.Linear(768, 5) self.relu = nn.ReLU() def forward(self, input_id, mask): _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False) dropout_output = self.dropout(pooled_output) linear_output = self.linear(dropout_output) final_layer = self.relu(linear_output) return final_layer