compatibility with new transformers

#60
Files changed (1) hide show
  1. modeling_chatglm.py +17 -3
modeling_chatglm.py CHANGED
@@ -14,6 +14,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
14
  from torch.nn.utils import skip_init
15
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
  from copy import deepcopy
 
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
@@ -45,6 +46,9 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
  ]
47
 
 
 
 
48
 
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
@@ -872,9 +876,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
872
  standardize_cache_format: bool = False,
873
  ) -> Dict[str, Any]:
874
  # update past_key_values
875
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
876
- outputs, standardize_cache_format=standardize_cache_format
877
- )
 
 
 
 
 
 
 
 
 
 
878
 
879
  # update attention mask
880
  if "attention_mask" in model_kwargs:
 
14
  from torch.nn.utils import skip_init
15
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
  from copy import deepcopy
17
+ import transformers
18
 
19
  from transformers.modeling_outputs import (
20
  BaseModelOutputWithPast,
 
46
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
47
  ]
48
 
49
+ is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
50
+ is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
51
+
52
 
53
  def default_init(cls, *args, **kwargs):
54
  return cls(*args, **kwargs)
 
876
  standardize_cache_format: bool = False,
877
  ) -> Dict[str, Any]:
878
  # update past_key_values
879
+ if is_transformers_4_44_or_higher:
880
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
881
+ outputs
882
+ )[1]
883
+ elif is_transformers_4_42_or_higher:
884
+ # update past_key_values
885
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
886
+ outputs, standardize_cache_format=standardize_cache_format
887
+ )[1]
888
+ else:
889
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
890
+ outputs, standardize_cache_format=standardize_cache_format
891
+ )
892
 
893
  # update attention mask
894
  if "attention_mask" in model_kwargs: