Spaces:
Sleeping
Sleeping
from torch import nn | |
from transformers import DebertaV2ForSequenceClassification, AutoModel | |
class DebertaReplacedTokenizer(DebertaV2ForSequenceClassification): | |
def __init__(self, config, **kwargs): | |
tok_model_name = kwargs.pop('tok_model_name') | |
if 'num_labels' in kwargs: | |
config.num_labels = kwargs.pop('num_labels') | |
super().__init__(config, **kwargs) | |
tok_model = AutoModel.from_pretrained(tok_model_name) | |
new_emb = nn.Sequential( | |
tok_model.get_input_embeddings(), | |
nn.Linear(tok_model.config.hidden_size\ | |
if 'opt' not in tok_model_name else tok_model.config.word_embed_proj_dim, | |
self.config.hidden_size) | |
) | |
self.set_input_embeddings(new_emb) | |