File size: 799 Bytes
d9dca49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch import nn
from transformers import DebertaV2ForSequenceClassification, AutoModel


class DebertaReplacedTokenizer(DebertaV2ForSequenceClassification):
    def __init__(self, config, **kwargs):
        tok_model_name = kwargs.pop('tok_model_name')
        if 'num_labels' in kwargs:
            config.num_labels = kwargs.pop('num_labels')
        super().__init__(config, **kwargs)

        tok_model = AutoModel.from_pretrained(tok_model_name)
        new_emb = nn.Sequential(
                tok_model.get_input_embeddings(),
                nn.Linear(tok_model.config.hidden_size\
                        if 'opt' not in tok_model_name else tok_model.config.word_embed_proj_dim,
                    self.config.hidden_size)
                )
        self.set_input_embeddings(new_emb)