oweller2 commited on
Commit
4b203f9
1 Parent(s): e0229bb
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_flexbert.py +13 -3
config.json CHANGED
@@ -2,12 +2,12 @@
2
  "allow_embedding_resizing": true,
3
  "architectures": [
4
  "FlexBertModel",
5
- "FlexBertForCasualLM"
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "orionweller/test-flex-gpt--configuration_bert.FlexBertConfig",
9
  "AutoModel": "orionweller/test-flex-gpt--modeling_flexbert.FlexBertModel",
10
- "AutoModelForCausalLM": "orionweller/test-flex-gpt--modeling_flexbert.FlexBertForCasualLM"
11
  },
12
  "attention_layer": "rope",
13
  "attention_probs_dropout_prob": 0.0,
 
2
  "allow_embedding_resizing": true,
3
  "architectures": [
4
  "FlexBertModel",
5
+ "FlexBertForCausalLM"
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "orionweller/test-flex-gpt--configuration_bert.FlexBertConfig",
9
  "AutoModel": "orionweller/test-flex-gpt--modeling_flexbert.FlexBertModel",
10
+ "AutoModelForCausalLM": "orionweller/test-flex-gpt--modeling_flexbert.FlexBertForCausalLM"
11
  },
12
  "attention_layer": "rope",
13
  "attention_probs_dropout_prob": 0.0,
modeling_flexbert.py CHANGED
@@ -1534,14 +1534,23 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1534
  self._init_weights(reset_params=False)
1535
 
1536
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
 
1537
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1538
- if module:
1539
- self._init_module_weights(module)
 
 
 
 
 
 
 
 
1540
  else:
1541
  assert isinstance(reset_params, bool)
1542
  self.bert._init_weights(reset_params=reset_params)
1543
  self.lm_head._init_weights(reset_params=reset_params)
1544
-
1545
  if not self.config.tie_word_embeddings:
1546
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1547
 
@@ -1742,6 +1751,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1742
  params += _count_parameters(self.lm_head, trainable)
1743
  return params
1744
 
 
1745
 
1746
  def init_model_from_pretrained(
1747
  pretrained_model: FlexBertModel,
 
1534
  self._init_weights(reset_params=False)
1535
 
1536
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1537
+ # Handle the XOR condition
1538
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1539
+
1540
+ if module is not None:
1541
+ # Add basic initialization for common module types
1542
+ if isinstance(module, (nn.Linear, nn.Embedding)):
1543
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1544
+ if isinstance(module, nn.Linear) and module.bias is not None:
1545
+ module.bias.data.zero_()
1546
+ elif isinstance(module, nn.LayerNorm):
1547
+ module.bias.data.zero_()
1548
+ module.weight.data.fill_(1.0)
1549
  else:
1550
  assert isinstance(reset_params, bool)
1551
  self.bert._init_weights(reset_params=reset_params)
1552
  self.lm_head._init_weights(reset_params=reset_params)
1553
+
1554
  if not self.config.tie_word_embeddings:
1555
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1556
 
 
1751
  params += _count_parameters(self.lm_head, trainable)
1752
  return params
1753
 
1754
+ FlexBertForCausalLM.register_for_auto_class("AutoModelForCausalLM")
1755
 
1756
  def init_model_from_pretrained(
1757
  pretrained_model: FlexBertModel,