Crystalcareai
commited on
Commit
•
69ca8a8
1
Parent(s):
7d04113
Update modeling_quiet.py
Browse files- modeling_quiet.py +21 -21
modeling_quiet.py
CHANGED
@@ -929,32 +929,32 @@ class QuietModel(QuietPreTrainedModel):
|
|
929 |
self.embed_tokens = value
|
930 |
|
931 |
def _generate_thoughts(self, hidden_states, max_length):
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
for _ in range(self.config.max_thoughts):
|
936 |
-
thought_id = torch.LongTensor([[self.config.start_token_id]]).to(hidden_states.device)
|
937 |
-
thought_embedding = self.embed_tokens(thought_id)
|
938 |
|
939 |
-
for _ in range(
|
940 |
-
|
941 |
-
|
942 |
-
attention_mask=None,
|
943 |
-
use_cache=True,
|
944 |
-
)
|
945 |
-
logits = outputs.logits[:, -1, :]
|
946 |
-
next_token_id = torch.argmax(logits, dim=-1)
|
947 |
|
948 |
-
|
949 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
950 |
|
951 |
-
|
952 |
-
|
|
|
|
|
|
|
953 |
|
954 |
-
|
955 |
-
|
956 |
|
957 |
-
|
958 |
|
959 |
|
960 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
|
|
929 |
self.embed_tokens = value
|
930 |
|
931 |
def _generate_thoughts(self, hidden_states, max_length):
|
932 |
+
thought_ids = []
|
933 |
+
thought_embeddings = []
|
|
|
|
|
|
|
|
|
934 |
|
935 |
+
for _ in range(self.config.max_thoughts):
|
936 |
+
thought_id = torch.LongTensor([[self.config.start_token_id]]).to(hidden_states.device)
|
937 |
+
thought_embedding = self.embed_tokens(thought_id)
|
|
|
|
|
|
|
|
|
|
|
938 |
|
939 |
+
for _ in range(max_length):
|
940 |
+
outputs = self.forward(
|
941 |
+
inputs_embeds=thought_embedding,
|
942 |
+
attention_mask=None,
|
943 |
+
use_cache=True,
|
944 |
+
)
|
945 |
+
logits = outputs.logits[:, -1, :]
|
946 |
+
next_token_id = torch.argmax(logits, dim=-1)
|
947 |
|
948 |
+
if next_token_id == self.config.end_token_id:
|
949 |
+
break
|
950 |
+
|
951 |
+
thought_id = torch.cat([thought_id, next_token_id.unsqueeze(0)], dim=-1)
|
952 |
+
thought_embedding = torch.cat([thought_embedding, self.embed_tokens(next_token_id.unsqueeze(0))], dim=1)
|
953 |
|
954 |
+
thought_ids.append(thought_id.squeeze(0))
|
955 |
+
thought_embeddings.append(thought_embedding.squeeze(0))
|
956 |
|
957 |
+
return thought_ids, thought_embeddings
|
958 |
|
959 |
|
960 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|