duzx16 commited on
Commit
7fabe56
1 Parent(s): efb7a1e

Fix use_cache=False

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +5 -2
modeling_chatglm.py CHANGED
@@ -897,6 +897,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
897
  past_key_values: Optional[torch.Tensor] = None,
898
  attention_mask: Optional[torch.Tensor] = None,
899
  position_ids: Optional[torch.Tensor] = None,
 
900
  is_first_forward: bool = True,
901
  **kwargs
902
  ) -> dict:
@@ -904,7 +905,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
904
  if position_ids is None:
905
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
906
  if not is_first_forward:
907
- if self.config.use_cache:
908
  position_ids = position_ids[..., -1:]
909
  input_ids = input_ids[:, -1:]
910
  return {
@@ -912,7 +913,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
912
  "past_key_values": past_key_values,
913
  "position_ids": position_ids,
914
  "attention_mask": attention_mask,
915
- "return_last_logit": True
 
916
  }
917
 
918
  def forward(
@@ -1089,6 +1091,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1089
  generation_config = self.generation_config
1090
  generation_config = copy.deepcopy(generation_config)
1091
  model_kwargs = generation_config.update(**kwargs)
 
1092
  bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1093
 
1094
  if isinstance(eos_token_id, int):
 
897
  past_key_values: Optional[torch.Tensor] = None,
898
  attention_mask: Optional[torch.Tensor] = None,
899
  position_ids: Optional[torch.Tensor] = None,
900
+ use_cache: Optional[bool] = None,
901
  is_first_forward: bool = True,
902
  **kwargs
903
  ) -> dict:
 
905
  if position_ids is None:
906
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
907
  if not is_first_forward:
908
+ if past_key_values is not None:
909
  position_ids = position_ids[..., -1:]
910
  input_ids = input_ids[:, -1:]
911
  return {
 
913
  "past_key_values": past_key_values,
914
  "position_ids": position_ids,
915
  "attention_mask": attention_mask,
916
+ "return_last_logit": True,
917
+ "use_cache": use_cache
918
  }
919
 
920
  def forward(
 
1091
  generation_config = self.generation_config
1092
  generation_config = copy.deepcopy(generation_config)
1093
  model_kwargs = generation_config.update(**kwargs)
1094
+ model_kwargs["use_cache"] = generation_config.use_cache
1095
  bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1096
 
1097
  if isinstance(eos_token_id, int):