Crystalcareai
commited on
Commit
•
2449553
1
Parent(s):
88ca699
Update modeling_quiet.py
Browse files- modeling_quiet.py +13 -0
modeling_quiet.py
CHANGED
@@ -1283,6 +1283,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1283 |
self.talk_head = nn.ModuleList([nn.Sequential(
|
1284 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
1285 |
)])
|
|
|
|
|
|
|
|
|
|
|
1286 |
|
1287 |
# Initialize weights and apply final processing
|
1288 |
self.post_init()
|
@@ -1304,6 +1309,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1304 |
|
1305 |
def get_decoder(self):
|
1306 |
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1307 |
|
1308 |
@torch.no_grad()
|
1309 |
def infer(
|
|
|
1283 |
self.talk_head = nn.ModuleList([nn.Sequential(
|
1284 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
1285 |
)])
|
1286 |
+
|
1287 |
+
self.apply(self._init_weights)
|
1288 |
+
|
1289 |
+
# Add dropout regularization
|
1290 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1291 |
|
1292 |
# Initialize weights and apply final processing
|
1293 |
self.post_init()
|
|
|
1309 |
|
1310 |
def get_decoder(self):
|
1311 |
return self.model
|
1312 |
+
|
1313 |
+
def _init_weights(self, module):
|
1314 |
+
if isinstance(module, nn.Linear):
|
1315 |
+
nn.init.xavier_uniform_(module.weight)
|
1316 |
+
if module.bias is not None:
|
1317 |
+
nn.init.constant_(module.bias, 0)
|
1318 |
+
elif isinstance(module, nn.Embedding):
|
1319 |
+
nn.init.xavier_uniform_(module.weight)
|
1320 |
|
1321 |
@torch.no_grad()
|
1322 |
def infer(
|