alfiannajih
commited on
Commit
•
687afc0
1
Parent(s):
a31e205
Update g_retriever.py
Browse files- g_retriever.py +3 -1
g_retriever.py
CHANGED
@@ -80,10 +80,12 @@ class GRetrieverModel(LlamaForCausalLM):
|
|
80 |
|
81 |
# embed inputs
|
82 |
inputs[non_tokenized_shape] = self.config.pad_token_id
|
83 |
-
|
84 |
non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
|
85 |
|
86 |
# replace reserved ids with bos+graph
|
|
|
|
|
87 |
inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
|
88 |
|
89 |
else:
|
|
|
80 |
|
81 |
# embed inputs
|
82 |
inputs[non_tokenized_shape] = self.config.pad_token_id
|
83 |
+
temp_inputs_embeds = self.get_input_embeddings()(inputs)
|
84 |
non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
|
85 |
|
86 |
# replace reserved ids with bos+graph
|
87 |
+
temp_inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
|
88 |
+
inputs_embeds = temp_inputs_embeds.clone()
|
89 |
inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
|
90 |
|
91 |
else:
|