LingConv / ling_disc.py
mohdelgaar's picture
add missing file
d9dca49
raw
history blame
799 Bytes
from torch import nn
from transformers import DebertaV2ForSequenceClassification, AutoModel
class DebertaReplacedTokenizer(DebertaV2ForSequenceClassification):
def __init__(self, config, **kwargs):
tok_model_name = kwargs.pop('tok_model_name')
if 'num_labels' in kwargs:
config.num_labels = kwargs.pop('num_labels')
super().__init__(config, **kwargs)
tok_model = AutoModel.from_pretrained(tok_model_name)
new_emb = nn.Sequential(
tok_model.get_input_embeddings(),
nn.Linear(tok_model.config.hidden_size\
if 'opt' not in tok_model_name else tok_model.config.word_embed_proj_dim,
self.config.hidden_size)
)
self.set_input_embeddings(new_emb)