Crystalcareai commited on
Commit
f593b43
1 Parent(s): 4e9663c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +40 -6
modeling_quiet.py CHANGED
@@ -1246,6 +1246,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1246
 
1247
  self.policy_loss_beta = 1e6
1248
  self.embedding_scale = 1e2
 
 
 
1249
  self.reinforce_temperature = 3
1250
  self.base_loss_beta = 1
1251
  self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
@@ -1626,16 +1629,20 @@ class QuietForCausalLM(QuietPreTrainedModel):
1626
  sample_probs_history = []
1627
  action_loglikelihoods_list = []
1628
 
 
 
 
1629
  if self.use_end_thought_token or self.use_start_thought_token:
1630
  if not self.use_reparam_for_thought_embeddings:
1631
- start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale
1632
- end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale
1633
  else:
1634
- start_embedding = self.start_embedding * self.embedding_scale
1635
- end_embedding = self.end_embedding * self.embedding_scale
1636
  base_embeddings = self.model.embed_tokens.weight
1637
  if self.train_only_thinking_embedding:
1638
  base_embeddings = base_embeddings.detach()
 
1639
  # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1640
  fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1641
  for ahead_idx in range(fwd_iters):
@@ -1900,9 +1907,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1900
  contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
1901
  contains_thought = contains_start or contains_end
1902
 
 
1903
  if not contains_thought:
1904
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1905
- inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
1906
  else:
1907
  thought_id = self.start_token_id if contains_start else self.end_token_id
1908
  cur_thought_embedding = start_embedding if contains_start else end_embedding
@@ -1915,7 +1923,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1915
  sampled_end = inputs_embeds.clone().detach()
1916
  else:
1917
  inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1918
- inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1919
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1920
 
1921
  # Predict the usefulness of thinking at each token position
@@ -2127,6 +2135,32 @@ class QuietForCausalLM(QuietPreTrainedModel):
2127
  hidden_states=outputs.hidden_states,
2128
  attentions=outputs.attentions,
2129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2130
 
2131
 
2132
  def prepare_inputs_for_generation(
 
1246
 
1247
  self.policy_loss_beta = 1e6
1248
  self.embedding_scale = 1e2
1249
+ self.temperature = nn.Parameter(torch.tensor(1.0))
1250
+ self.max_temperature = config.max_temperature
1251
+ self.complexity_factor = config.complexity_factor
1252
  self.reinforce_temperature = 3
1253
  self.base_loss_beta = 1
1254
  self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
 
1629
  sample_probs_history = []
1630
  action_loglikelihoods_list = []
1631
 
1632
+ complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
1633
+ temperature = self.temperature * complexity_scores.unsqueeze(-1)
1634
+
1635
  if self.use_end_thought_token or self.use_start_thought_token:
1636
  if not self.use_reparam_for_thought_embeddings:
1637
+ start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
1638
+ end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
1639
  else:
1640
+ start_embedding = self.start_embedding * self.embedding_scale * temperature
1641
+ end_embedding = self.end_embedding * self.embedding_scale * temperature
1642
  base_embeddings = self.model.embed_tokens.weight
1643
  if self.train_only_thinking_embedding:
1644
  base_embeddings = base_embeddings.detach()
1645
+
1646
  # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1647
  fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1648
  for ahead_idx in range(fwd_iters):
 
1907
  contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
1908
  contains_thought = contains_start or contains_end
1909
 
1910
+
1911
  if not contains_thought:
1912
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1913
+ inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype) * temperature)
1914
  else:
1915
  thought_id = self.start_token_id if contains_start else self.end_token_id
1916
  cur_thought_embedding = start_embedding if contains_start else end_embedding
 
1923
  sampled_end = inputs_embeds.clone().detach()
1924
  else:
1925
  inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1926
+ inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1927
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1928
 
1929
  # Predict the usefulness of thinking at each token position
 
2135
  hidden_states=outputs.hidden_states,
2136
  attentions=outputs.attentions,
2137
  )
2138
+
2139
+ def compute_complexity_scores(self, input_ids, attention_mask):
2140
+ # Compute complexity scores based on input sequence characteristics
2141
+ # Example: Normalize sequence lengths and consider the presence of rare tokens
2142
+ seq_lengths = torch.sum(attention_mask, dim=-1)
2143
+ max_length = torch.max(seq_lengths)
2144
+ length_scores = seq_lengths / max_length
2145
+
2146
+ # Compute the proportion of rare tokens in each sequence
2147
+ rare_token_ids = self.get_rare_token_ids()
2148
+ rare_token_mask = torch.isin(input_ids, rare_token_ids)
2149
+ rare_token_counts = torch.sum(rare_token_mask, dim=-1)
2150
+ rare_token_scores = rare_token_counts / seq_lengths
2151
+
2152
+ # Combine length scores and rare token scores
2153
+ complexity_scores = self.complexity_factor * length_scores + (1 - self.complexity_factor) * rare_token_scores
2154
+ return complexity_scores
2155
+
2156
+ def get_rare_token_ids(self):
2157
+ # Get the IDs of rare tokens based on a predefined frequency threshold
2158
+ frequency_threshold = 1e-4
2159
+ token_counts = torch.bincount(self.model.embed_tokens.weight.argmax(dim=-1))
2160
+ total_tokens = torch.sum(token_counts)
2161
+ rare_token_mask = token_counts / total_tokens < frequency_threshold
2162
+ rare_token_ids = torch.nonzero(rare_token_mask).squeeze(-1)
2163
+ return rare_token_ids
2164
 
2165
 
2166
  def prepare_inputs_for_generation(