File size: 454 Bytes
54e216c |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from torch import nn
from transformers import XLMRobertaModel
class XLMRobertaForSequenceClassification(XLMRobertaModel):
def __init__(self, config):
super().__init__(config)
self.classifier = nn.Linear(768, 2)
def forward(self, input_ids, attention_mask):
outputs = super(XLMRobertaForSequenceClassification, self).forward(input_ids=input_ids, attention_mask=attention_mask)
return self.classifier(outputs[1]) |