Crystalcareai commited on
Commit
f2459a7
1 Parent(s): 6f470a7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +5 -3
modeling_quiet.py CHANGED
@@ -910,7 +910,7 @@ class QuietModel(QuietPreTrainedModel):
910
  super().__init__(config)
911
  self.padding_idx = config.pad_token_id
912
  self.vocab_size = config.vocab_size
913
-
914
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
915
  self.layers = nn.ModuleList(
916
  [QuietDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -1102,6 +1102,7 @@ class QuietModel(QuietPreTrainedModel):
1102
  past_key_values=next_cache,
1103
  hidden_states=all_hidden_states,
1104
  attentions=all_self_attns,
 
1105
  )
1106
 
1107
 
@@ -1215,7 +1216,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1215
  )
1216
 
1217
  hidden_states = outputs.last_hidden_state
1218
- base_logits = self.lm_head(hidden_states)
1219
 
1220
  thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
1221
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
@@ -1224,7 +1225,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1224
  mixing_input = torch.cat([hidden_states, thought_hidden_states], dim=-1)
1225
  mixing_weights = self.mixing_head(mixing_input).squeeze(-1) # (batch_size, seq_length)
1226
  mixed_logits = base_logits * (1 - mixing_weights.unsqueeze(-1)) + thought_logits * mixing_weights.unsqueeze(-1)
1227
-
1228
  loss = None
1229
  if labels is not None:
1230
  # Shift so that tokens < n predict n
@@ -1240,6 +1240,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1240
  rewards = torch.clamp(rewards, min=0)
1241
  policy_loss = self.calculate_policy_loss(thought_ids, rewards)
1242
  loss = loss + policy_loss
 
 
1243
 
1244
  if not return_dict:
1245
  output = (mixed_logits,) + outputs[1:]
 
910
  super().__init__(config)
911
  self.padding_idx = config.pad_token_id
912
  self.vocab_size = config.vocab_size
913
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
914
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
915
  self.layers = nn.ModuleList(
916
  [QuietDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
1102
  past_key_values=next_cache,
1103
  hidden_states=all_hidden_states,
1104
  attentions=all_self_attns,
1105
+ logits=self.lm_head(hidden_states),
1106
  )
1107
 
1108
 
 
1216
  )
1217
 
1218
  hidden_states = outputs.last_hidden_state
1219
+ base_logits = outputs.logits # Use the logits from the model output
1220
 
1221
  thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
1222
  thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
 
1225
  mixing_input = torch.cat([hidden_states, thought_hidden_states], dim=-1)
1226
  mixing_weights = self.mixing_head(mixing_input).squeeze(-1) # (batch_size, seq_length)
1227
  mixed_logits = base_logits * (1 - mixing_weights.unsqueeze(-1)) + thought_logits * mixing_weights.unsqueeze(-1)
 
1228
  loss = None
1229
  if labels is not None:
1230
  # Shift so that tokens < n predict n
 
1240
  rewards = torch.clamp(rewards, min=0)
1241
  policy_loss = self.calculate_policy_loss(thought_ids, rewards)
1242
  loss = loss + policy_loss
1243
+ else:
1244
+ loss = None
1245
 
1246
  if not return_dict:
1247
  output = (mixed_logits,) + outputs[1:]