|
import torch |
|
from torch import nn, Tensor |
|
from typing import Optional |
|
from transformers import DebertaV2PreTrainedModel, DebertaV2Model |
|
from .configuration_deberta_multi import MultiHeadDebertaV2Config |
|
|
|
class MultiHeadDebertaForSequenceClassificationModel(DebertaV2PreTrainedModel): |
|
|
|
config_class = MultiHeadDebertaV2Config |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.deberta = DebertaV2Model(config) |
|
self.heads = nn.ModuleList( |
|
[nn.Linear(config.hidden_size, 4) for _ in range(config.num_heads)] |
|
) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional["Tensor"] = None, |
|
attention_mask: Optional["Tensor"] = None, |
|
) -> "Tensor": |
|
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) |
|
sequence_output = outputs[0] |
|
logits_list = [ |
|
head(self.dropout(sequence_output[:, 0, :])) for head in self.heads |
|
] |
|
logits = torch.stack(logits_list, dim=1) |
|
return logits |