btarakcioglu
Fix: Added missing import for BertPreTrainedModel
41143b7
raw
history blame
1.31 kB
import torch
from torch import nn
from transformers import BertModel, BertPreTrainedModel
class CustomBertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
# Freeze first 6 layers
for param in self.bert.encoder.layer[:6].parameters():
param.requires_grad = False
self.dropout = nn.Dropout(0.22)
self.fc1 = nn.Linear(768, 512)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(512, 512)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(512, 128)
self.relu3 = nn.ReLU()
self.fc4 = nn.Linear(128, 1)
self.sigmoid = nn.Sigmoid()
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
pooled_output = outputs.pooler_output
x = self.dropout(pooled_output)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
x = self.relu3(x)
x = self.fc4(x)
logits = self.sigmoid(x)
return logits