File size: 1,120 Bytes
67b4c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch import nn, Tensor
from typing import Optional
from transformers import DebertaV2PreTrainedModel, DebertaV2Model
from .configuration_deberta_multi import MultiHeadDebertaV2Config

class MultiHeadDebertaForSequenceClassificationModel(DebertaV2PreTrainedModel):

    config_class = MultiHeadDebertaV2Config
    def __init__(self, config):  # type: ignore
        super().__init__(config)
        self.deberta = DebertaV2Model(config)
        self.heads = nn.ModuleList(
            [nn.Linear(config.hidden_size, 4) for _ in range(config.num_heads)]
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.post_init()

    def forward(
        self,
        input_ids: Optional["Tensor"] = None,
        attention_mask: Optional["Tensor"] = None,
    ) -> "Tensor":
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        logits_list = [
            head(self.dropout(sequence_output[:, 0, :])) for head in self.heads
        ]
        logits = torch.stack(logits_list, dim=1)
        return logits