mohdelgaar commited on
Commit
d9dca49
1 Parent(s): 9f22f23

add missing file

Browse files
Files changed (1) hide show
  1. ling_disc.py +19 -0
ling_disc.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import DebertaV2ForSequenceClassification, AutoModel
3
+
4
+
5
+ class DebertaReplacedTokenizer(DebertaV2ForSequenceClassification):
6
+ def __init__(self, config, **kwargs):
7
+ tok_model_name = kwargs.pop('tok_model_name')
8
+ if 'num_labels' in kwargs:
9
+ config.num_labels = kwargs.pop('num_labels')
10
+ super().__init__(config, **kwargs)
11
+
12
+ tok_model = AutoModel.from_pretrained(tok_model_name)
13
+ new_emb = nn.Sequential(
14
+ tok_model.get_input_embeddings(),
15
+ nn.Linear(tok_model.config.hidden_size\
16
+ if 'opt' not in tok_model_name else tok_model.config.word_embed_proj_dim,
17
+ self.config.hidden_size)
18
+ )
19
+ self.set_input_embeddings(new_emb)