Crystalcareai
commited on
Commit
•
f2459a7
1
Parent(s):
6f470a7
Update modeling_quiet.py
Browse files- modeling_quiet.py +5 -3
modeling_quiet.py
CHANGED
@@ -910,7 +910,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
910 |
super().__init__(config)
|
911 |
self.padding_idx = config.pad_token_id
|
912 |
self.vocab_size = config.vocab_size
|
913 |
-
|
914 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
915 |
self.layers = nn.ModuleList(
|
916 |
[QuietDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
@@ -1102,6 +1102,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
1102 |
past_key_values=next_cache,
|
1103 |
hidden_states=all_hidden_states,
|
1104 |
attentions=all_self_attns,
|
|
|
1105 |
)
|
1106 |
|
1107 |
|
@@ -1215,7 +1216,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1215 |
)
|
1216 |
|
1217 |
hidden_states = outputs.last_hidden_state
|
1218 |
-
base_logits =
|
1219 |
|
1220 |
thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
|
1221 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
@@ -1224,7 +1225,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1224 |
mixing_input = torch.cat([hidden_states, thought_hidden_states], dim=-1)
|
1225 |
mixing_weights = self.mixing_head(mixing_input).squeeze(-1) # (batch_size, seq_length)
|
1226 |
mixed_logits = base_logits * (1 - mixing_weights.unsqueeze(-1)) + thought_logits * mixing_weights.unsqueeze(-1)
|
1227 |
-
|
1228 |
loss = None
|
1229 |
if labels is not None:
|
1230 |
# Shift so that tokens < n predict n
|
@@ -1240,6 +1240,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1240 |
rewards = torch.clamp(rewards, min=0)
|
1241 |
policy_loss = self.calculate_policy_loss(thought_ids, rewards)
|
1242 |
loss = loss + policy_loss
|
|
|
|
|
1243 |
|
1244 |
if not return_dict:
|
1245 |
output = (mixed_logits,) + outputs[1:]
|
|
|
910 |
super().__init__(config)
|
911 |
self.padding_idx = config.pad_token_id
|
912 |
self.vocab_size = config.vocab_size
|
913 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
914 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
915 |
self.layers = nn.ModuleList(
|
916 |
[QuietDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
|
1102 |
past_key_values=next_cache,
|
1103 |
hidden_states=all_hidden_states,
|
1104 |
attentions=all_self_attns,
|
1105 |
+
logits=self.lm_head(hidden_states),
|
1106 |
)
|
1107 |
|
1108 |
|
|
|
1216 |
)
|
1217 |
|
1218 |
hidden_states = outputs.last_hidden_state
|
1219 |
+
base_logits = outputs.logits # Use the logits from the model output
|
1220 |
|
1221 |
thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
|
1222 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
|
|
1225 |
mixing_input = torch.cat([hidden_states, thought_hidden_states], dim=-1)
|
1226 |
mixing_weights = self.mixing_head(mixing_input).squeeze(-1) # (batch_size, seq_length)
|
1227 |
mixed_logits = base_logits * (1 - mixing_weights.unsqueeze(-1)) + thought_logits * mixing_weights.unsqueeze(-1)
|
|
|
1228 |
loss = None
|
1229 |
if labels is not None:
|
1230 |
# Shift so that tokens < n predict n
|
|
|
1240 |
rewards = torch.clamp(rewards, min=0)
|
1241 |
policy_loss = self.calculate_policy_loss(thought_ids, rewards)
|
1242 |
loss = loss + policy_loss
|
1243 |
+
else:
|
1244 |
+
loss = None
|
1245 |
|
1246 |
if not return_dict:
|
1247 |
output = (mixed_logits,) + outputs[1:]
|