alfiannajih commited on
Commit
687afc0
1 Parent(s): a31e205

Update g_retriever.py

Browse files
Files changed (1) hide show
  1. g_retriever.py +3 -1
g_retriever.py CHANGED
@@ -80,10 +80,12 @@ class GRetrieverModel(LlamaForCausalLM):
80
 
81
  # embed inputs
82
  inputs[non_tokenized_shape] = self.config.pad_token_id
83
- inputs_embeds = self.get_input_embeddings()(inputs)
84
  non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
85
 
86
  # replace reserved ids with bos+graph
 
 
87
  inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
88
 
89
  else:
 
80
 
81
  # embed inputs
82
  inputs[non_tokenized_shape] = self.config.pad_token_id
83
+ temp_inputs_embeds = self.get_input_embeddings()(inputs)
84
  non_tokenized_embeds = torch.cat([bos_embeds.repeat(len(inputs), 1, 1), graph_embeds.unsqueeze(1)], dim=1)
85
 
86
  # replace reserved ids with bos+graph
87
+ temp_inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
88
+ inputs_embeds = temp_inputs_embeds.clone()
89
  inputs_embeds[non_tokenized_shape] = non_tokenized_embeds.view(len(non_tokenized_ids), -1)
90
 
91
  else: