alfiannajih
commited on
Commit
•
ca0a58b
1
Parent(s):
e7a28b6
Update g_retriever.py
Browse files- g_retriever.py +5 -1
g_retriever.py
CHANGED
@@ -65,7 +65,11 @@ class GRetrieverModel(LlamaForCausalLM):
|
|
65 |
|
66 |
if inputs.shape != torch.Size([1, 1]):
|
67 |
# embed bos prompt
|
68 |
-
bos_embeds = self.get_input_embeddings()(torch.tensor(
|
|
|
|
|
|
|
|
|
69 |
|
70 |
# encode graph
|
71 |
graph_embeds = self.encode_graphs(graph)
|
|
|
65 |
|
66 |
if inputs.shape != torch.Size([1, 1]):
|
67 |
# embed bos prompt
|
68 |
+
bos_embeds = self.get_input_embeddings()(torch.tensor(
|
69 |
+
self.config.bos_id,
|
70 |
+
dtype=self.model.dtype,
|
71 |
+
device=self.model.device
|
72 |
+
))
|
73 |
|
74 |
# encode graph
|
75 |
graph_embeds = self.encode_graphs(graph)
|