alfiannajih commited on
Commit
7793e2b
1 Parent(s): 06fa6b1

Update g_retriever.py

Browse files
Files changed (1) hide show
  1. g_retriever.py +2 -2
g_retriever.py CHANGED
@@ -42,7 +42,7 @@ class GRetrieverModel(LlamaForCausalLM):
42
  # mean pooling
43
  g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
 
45
- return g_embeds.to(self.model.device)
46
 
47
  @wraps(LlamaForCausalLM.forward)
48
  def forward(
@@ -76,7 +76,7 @@ class GRetrieverModel(LlamaForCausalLM):
76
 
77
  # encode graph
78
  graph_embeds = self.encode_graphs(graph)
79
- graph_embeds = self.projector(graph_embeds)
80
 
81
  # prepare for reserved ids (bos+graph)
82
  non_tokenized_ids = (inputs == -1).nonzero()
 
42
  # mean pooling
43
  g_embeds = global_mean_pool(n_embeds, graph.batch.to(n_embeds.device))
44
 
45
+ return g_embeds
46
 
47
  @wraps(LlamaForCausalLM.forward)
48
  def forward(
 
76
 
77
  # encode graph
78
  graph_embeds = self.encode_graphs(graph)
79
+ graph_embeds = self.projector(graph_embeds).to(self.model.device)
80
 
81
  # prepare for reserved ids (bos+graph)
82
  non_tokenized_ids = (inputs == -1).nonzero()