alfiannajih commited on
Commit
ca0a58b
1 Parent(s): e7a28b6

Update g_retriever.py

Browse files
Files changed (1) hide show
  1. g_retriever.py +5 -1
g_retriever.py CHANGED
@@ -65,7 +65,11 @@ class GRetrieverModel(LlamaForCausalLM):
65
 
66
  if inputs.shape != torch.Size([1, 1]):
67
  # embed bos prompt
68
- bos_embeds = self.get_input_embeddings()(torch.tensor(self.config.bos_id.to(self.model.device)))
 
 
 
 
69
 
70
  # encode graph
71
  graph_embeds = self.encode_graphs(graph)
 
65
 
66
  if inputs.shape != torch.Size([1, 1]):
67
  # embed bos prompt
68
+ bos_embeds = self.get_input_embeddings()(torch.tensor(
69
+ self.config.bos_id,
70
+ dtype=self.model.dtype,
71
+ device=self.model.device
72
+ ))
73
 
74
  # encode graph
75
  graph_embeds = self.encode_graphs(graph)