File size: 1,309 Bytes
41143b7 a3ba203 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
from torch import nn
from transformers import BertModel, BertPreTrainedModel
class CustomBertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
# Freeze first 6 layers
for param in self.bert.encoder.layer[:6].parameters():
param.requires_grad = False
self.dropout = nn.Dropout(0.22)
self.fc1 = nn.Linear(768, 512)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(512, 512)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(512, 128)
self.relu3 = nn.ReLU()
self.fc4 = nn.Linear(128, 1)
self.sigmoid = nn.Sigmoid()
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
pooled_output = outputs.pooler_output
x = self.dropout(pooled_output)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
x = self.relu3(x)
x = self.fc4(x)
logits = self.sigmoid(x)
return logits
|