Markus28 commited on
Commit
8e3d0b8
1 Parent(s): 139b4a5

feat: removed task type embeddings

Browse files
Files changed (1) hide show
  1. modeling_bert.py +1 -12
modeling_bert.py CHANGED
@@ -152,7 +152,7 @@ def _init_weights(module, initializer_range=0.02):
152
  nn.init.normal_(module.weight, std=initializer_range)
153
  if module.bias is not None:
154
  nn.init.zeros_(module.bias)
155
- elif isinstance(module, nn.Embedding) and not getattr(module, "skip_init", False):
156
  nn.init.normal_(module.weight, std=initializer_range)
157
  if module.padding_idx is not None:
158
  nn.init.zeros_(module.weight[module.padding_idx])
@@ -351,7 +351,6 @@ class BertModel(BertPreTrainedModel):
351
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
352
  self.encoder = BertEncoder(config)
353
  self.pooler = BertPooler(config) if add_pooling_layer else None
354
- self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
355
 
356
  self.emb_pooler = config.emb_pooler
357
  self._name_or_path = config._name_or_path
@@ -362,13 +361,6 @@ class BertModel(BertPreTrainedModel):
362
  else:
363
  self.tokenizer = None
364
 
365
- # We now initialize the task embeddings to 0; We do not use task types during
366
- # pretraining. When we start using task types during embedding training,
367
- # we want the model to behave exactly as in pretraining (i.e. task types
368
- # have no effect).
369
- nn.init.zeros_(self.task_type_embeddings.weight)
370
- self.task_type_embeddings.skip_init = True
371
- # The following code should skip the embeddings layer
372
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
373
 
374
  def forward(
@@ -376,7 +368,6 @@ class BertModel(BertPreTrainedModel):
376
  input_ids,
377
  position_ids=None,
378
  token_type_ids=None,
379
- task_type_ids=None,
380
  attention_mask=None,
381
  masked_tokens_mask=None,
382
  return_dict=True,
@@ -389,8 +380,6 @@ class BertModel(BertPreTrainedModel):
389
  hidden_states = self.embeddings(
390
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
391
  )
392
- if task_type_ids is not None:
393
- hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
394
 
395
  # TD [2022-12:18]: Don't need to force residual in fp32
396
  # BERT puts embedding LayerNorm before embedding dropout.
 
152
  nn.init.normal_(module.weight, std=initializer_range)
153
  if module.bias is not None:
154
  nn.init.zeros_(module.bias)
155
+ elif isinstance(module, nn.Embedding):
156
  nn.init.normal_(module.weight, std=initializer_range)
157
  if module.padding_idx is not None:
158
  nn.init.zeros_(module.weight[module.padding_idx])
 
351
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
352
  self.encoder = BertEncoder(config)
353
  self.pooler = BertPooler(config) if add_pooling_layer else None
 
354
 
355
  self.emb_pooler = config.emb_pooler
356
  self._name_or_path = config._name_or_path
 
361
  else:
362
  self.tokenizer = None
363
 
 
 
 
 
 
 
 
364
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
365
 
366
  def forward(
 
368
  input_ids,
369
  position_ids=None,
370
  token_type_ids=None,
 
371
  attention_mask=None,
372
  masked_tokens_mask=None,
373
  return_dict=True,
 
380
  hidden_states = self.embeddings(
381
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
382
  )
 
 
383
 
384
  # TD [2022-12:18]: Don't need to force residual in fp32
385
  # BERT puts embedding LayerNorm before embedding dropout.