import torch from torch import nn, Tensor from typing import Optional from transformers import DebertaV2PreTrainedModel, DebertaV2Model from .configuration_deberta_multi import MultiHeadDebertaV2Config class MultiHeadDebertaForSequenceClassificationModel(DebertaV2PreTrainedModel): config_class = MultiHeadDebertaV2Config def __init__(self, config): # type: ignore super().__init__(config) self.deberta = DebertaV2Model(config) self.heads = nn.ModuleList( [nn.Linear(config.hidden_size, 4) for _ in range(config.num_heads)] ) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.post_init() def forward( self, input_ids: Optional["Tensor"] = None, attention_mask: Optional["Tensor"] = None, ) -> "Tensor": outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) sequence_output = outputs[0] logits_list = [ head(self.dropout(sequence_output[:, 0, :])) for head in self.heads ] logits = torch.stack(logits_list, dim=1) return logits