alfiannajih commited on
Commit
1ddc9ba
1 Parent(s): c37f43f

Update g_retriever.py

Browse files
Files changed (1) hide show
  1. g_retriever.py +6 -2
g_retriever.py CHANGED
@@ -33,10 +33,14 @@ class GRetrieverModel(LlamaForCausalLM):
33
  ).to(self.model.dtype).to(self.model.device)
34
 
35
  def encode_graphs(self, graph):
36
- n_embeds, _ = self.graph_encoder(graph.x.to(self.model.dtype), graph.edge_index.long(), graph.edge_attr.to(self.model.dtype))
 
 
 
 
37
 
38
  # mean pooling
39
- g_embeds = global_mean_pool(n_embeds, graph.batch)
40
 
41
  return g_embeds
42
 
 
33
  ).to(self.model.dtype).to(self.model.device)
34
 
35
  def encode_graphs(self, graph):
36
+ n_embeds, _ = self.graph_encoder(
37
+ graph.x.to(self.model.dtype),
38
+ graph.edge_index.long(),
39
+ graph.edge_attr.to(self.model.dtype)
40
+ )
41
 
42
  # mean pooling
43
+ g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
 
45
  return g_embeds
46