import torch from torch import nn from transformers import BertModel, BertPreTrainedModel class CustomBertModel(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.bert = BertModel(config) # Freeze first 6 layers for param in self.bert.encoder.layer[:6].parameters(): param.requires_grad = False self.dropout = nn.Dropout(0.22) self.fc1 = nn.Linear(768, 512) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(512, 512) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(512, 128) self.relu3 = nn.ReLU() self.fc4 = nn.Linear(128, 1) self.sigmoid = nn.Sigmoid() self.init_weights() def forward(self, input_ids, attention_mask=None, token_type_ids=None): outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) pooled_output = outputs.pooler_output x = self.dropout(pooled_output) x = self.fc1(x) x = self.relu1(x) x = self.fc2(x) x = self.relu2(x) x = self.fc3(x) x = self.relu3(x) x = self.fc4(x) logits = self.sigmoid(x) return logits