pharaouk commited on
Commit
b141b1e
1 Parent(s): 2061e79

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +6 -1
modeling_llama.py CHANGED
@@ -1117,7 +1117,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1117
  def detect_shutdown_token(self, input_ids):
1118
  if torch.any(input_ids == self.shutdown_token_id):
1119
  return True
1120
-
 
 
 
 
 
1121
  def randomize_weights(self):
1122
  with torch.no_grad():
1123
  for param in self.parameters():
 
1117
  def detect_shutdown_token(self, input_ids):
1118
  if torch.any(input_ids == self.shutdown_token_id):
1119
  return True
1120
+ def detect_shutdown_token(self, input_ids):
1121
+ shutdown_token_tensor = torch.tensor(self.shutdown_token_id, device=input_ids.device, dtype=input_ids.dtype)
1122
+ if torch.any(input_ids == shutdown_token_tensor):
1123
+ return True
1124
+ return False
1125
+
1126
  def randomize_weights(self):
1127
  with torch.no_grad():
1128
  for param in self.parameters():