Crystalcareai commited on
Commit
69ca8a8
1 Parent(s): 7d04113

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +21 -21
modeling_quiet.py CHANGED
@@ -929,32 +929,32 @@ class QuietModel(QuietPreTrainedModel):
929
  self.embed_tokens = value
930
 
931
  def _generate_thoughts(self, hidden_states, max_length):
932
- thought_ids = []
933
- thought_embeddings = []
934
-
935
- for _ in range(self.config.max_thoughts):
936
- thought_id = torch.LongTensor([[self.config.start_token_id]]).to(hidden_states.device)
937
- thought_embedding = self.embed_tokens(thought_id)
938
 
939
- for _ in range(max_length):
940
- outputs = self.forward(
941
- inputs_embeds=thought_embedding,
942
- attention_mask=None,
943
- use_cache=True,
944
- )
945
- logits = outputs.logits[:, -1, :]
946
- next_token_id = torch.argmax(logits, dim=-1)
947
 
948
- if next_token_id == self.config.end_token_id:
949
- break
 
 
 
 
 
 
950
 
951
- thought_id = torch.cat([thought_id, next_token_id.unsqueeze(0)], dim=-1)
952
- thought_embedding = torch.cat([thought_embedding, self.embed_tokens(next_token_id.unsqueeze(0))], dim=1)
 
 
 
953
 
954
- thought_ids.append(thought_id.squeeze(0))
955
- thought_embeddings.append(thought_embedding.squeeze(0))
956
 
957
- return thought_ids, thought_embeddings
958
 
959
 
960
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
 
929
  self.embed_tokens = value
930
 
931
  def _generate_thoughts(self, hidden_states, max_length):
932
+ thought_ids = []
933
+ thought_embeddings = []
 
 
 
 
934
 
935
+ for _ in range(self.config.max_thoughts):
936
+ thought_id = torch.LongTensor([[self.config.start_token_id]]).to(hidden_states.device)
937
+ thought_embedding = self.embed_tokens(thought_id)
 
 
 
 
 
938
 
939
+ for _ in range(max_length):
940
+ outputs = self.forward(
941
+ inputs_embeds=thought_embedding,
942
+ attention_mask=None,
943
+ use_cache=True,
944
+ )
945
+ logits = outputs.logits[:, -1, :]
946
+ next_token_id = torch.argmax(logits, dim=-1)
947
 
948
+ if next_token_id == self.config.end_token_id:
949
+ break
950
+
951
+ thought_id = torch.cat([thought_id, next_token_id.unsqueeze(0)], dim=-1)
952
+ thought_embedding = torch.cat([thought_embedding, self.embed_tokens(next_token_id.unsqueeze(0))], dim=1)
953
 
954
+ thought_ids.append(thought_id.squeeze(0))
955
+ thought_embeddings.append(thought_embedding.squeeze(0))
956
 
957
+ return thought_ids, thought_embeddings
958
 
959
 
960
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)