arosyihuddin's picture
update
ecfd12f
raw
history blame
473 Bytes
from transformers import BertForTokenClassification
import torch
class BertModel(torch.nn.Module):
def __init__(self, pretrained_model, num_labels):
super(BertModel, self).__init__()
self.bert = BertForTokenClassification.from_pretrained(pretrained_model, num_labels=num_labels)
def forward(self, input_id, mask, label):
output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
return output