duzx16
commited on
Commit
•
7fabe56
1
Parent(s):
efb7a1e
Fix use_cache=False
Browse files- modeling_chatglm.py +5 -2
modeling_chatglm.py
CHANGED
@@ -897,6 +897,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
897 |
past_key_values: Optional[torch.Tensor] = None,
|
898 |
attention_mask: Optional[torch.Tensor] = None,
|
899 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
900 |
is_first_forward: bool = True,
|
901 |
**kwargs
|
902 |
) -> dict:
|
@@ -904,7 +905,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
904 |
if position_ids is None:
|
905 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
906 |
if not is_first_forward:
|
907 |
-
if
|
908 |
position_ids = position_ids[..., -1:]
|
909 |
input_ids = input_ids[:, -1:]
|
910 |
return {
|
@@ -912,7 +913,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
912 |
"past_key_values": past_key_values,
|
913 |
"position_ids": position_ids,
|
914 |
"attention_mask": attention_mask,
|
915 |
-
"return_last_logit": True
|
|
|
916 |
}
|
917 |
|
918 |
def forward(
|
@@ -1089,6 +1091,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1089 |
generation_config = self.generation_config
|
1090 |
generation_config = copy.deepcopy(generation_config)
|
1091 |
model_kwargs = generation_config.update(**kwargs)
|
|
|
1092 |
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1093 |
|
1094 |
if isinstance(eos_token_id, int):
|
|
|
897 |
past_key_values: Optional[torch.Tensor] = None,
|
898 |
attention_mask: Optional[torch.Tensor] = None,
|
899 |
position_ids: Optional[torch.Tensor] = None,
|
900 |
+
use_cache: Optional[bool] = None,
|
901 |
is_first_forward: bool = True,
|
902 |
**kwargs
|
903 |
) -> dict:
|
|
|
905 |
if position_ids is None:
|
906 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
907 |
if not is_first_forward:
|
908 |
+
if past_key_values is not None:
|
909 |
position_ids = position_ids[..., -1:]
|
910 |
input_ids = input_ids[:, -1:]
|
911 |
return {
|
|
|
913 |
"past_key_values": past_key_values,
|
914 |
"position_ids": position_ids,
|
915 |
"attention_mask": attention_mask,
|
916 |
+
"return_last_logit": True,
|
917 |
+
"use_cache": use_cache
|
918 |
}
|
919 |
|
920 |
def forward(
|
|
|
1091 |
generation_config = self.generation_config
|
1092 |
generation_config = copy.deepcopy(generation_config)
|
1093 |
model_kwargs = generation_config.update(**kwargs)
|
1094 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
1095 |
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1096 |
|
1097 |
if isinstance(eos_token_id, int):
|