alfiannajih commited on
Commit
a31e205
1 Parent(s): 0308adf

Update g_retriever.py

Browse files
Files changed (1) hide show
  1. g_retriever.py +1 -1
g_retriever.py CHANGED
@@ -63,7 +63,7 @@ class GRetrieverModel(LlamaForCausalLM):
63
  )
64
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
 
66
- if inputs.shape != torch.Size([1, 1]):
67
  # embed bos prompt
68
  bos_embeds = self.get_input_embeddings()(torch.tensor(
69
  self.config.bos_id,
 
63
  )
64
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
 
66
+ if (inputs==-1).any():
67
  # embed bos prompt
68
  bos_embeds = self.get_input_embeddings()(torch.tensor(
69
  self.config.bos_id,