Crystalcareai commited on
Commit
2449553
1 Parent(s): 88ca699

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +13 -0
modeling_quiet.py CHANGED
@@ -1283,6 +1283,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
1283
  self.talk_head = nn.ModuleList([nn.Sequential(
1284
  nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1285
  )])
 
 
 
 
 
1286
 
1287
  # Initialize weights and apply final processing
1288
  self.post_init()
@@ -1304,6 +1309,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
1304
 
1305
  def get_decoder(self):
1306
  return self.model
 
 
 
 
 
 
 
 
1307
 
1308
  @torch.no_grad()
1309
  def infer(
 
1283
  self.talk_head = nn.ModuleList([nn.Sequential(
1284
  nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1285
  )])
1286
+
1287
+ self.apply(self._init_weights)
1288
+
1289
+ # Add dropout regularization
1290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1291
 
1292
  # Initialize weights and apply final processing
1293
  self.post_init()
 
1309
 
1310
  def get_decoder(self):
1311
  return self.model
1312
+
1313
+ def _init_weights(self, module):
1314
+ if isinstance(module, nn.Linear):
1315
+ nn.init.xavier_uniform_(module.weight)
1316
+ if module.bias is not None:
1317
+ nn.init.constant_(module.bias, 0)
1318
+ elif isinstance(module, nn.Embedding):
1319
+ nn.init.xavier_uniform_(module.weight)
1320
 
1321
  @torch.no_grad()
1322
  def infer(