tianweidut commited on
Commit
9a7c3dc
1 Parent(s): e7fddd3

fix `load_in_8bit=True` issue

Browse files

```python
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("pretrained/baichuan2-7b/base,, load_in_8bit=True, device_map="auto", trust_remote_code=True)
[2023-09-06 16:03:27,691] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
2023-09-06 16:03:28.999241: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-06 16:03:30.314803: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/liutianwei/.conda/envs/starwhale/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 488, in from_pretrained
return model_class.from_pretrained(
File "/home/liutianwei/.cache/huggingface/modules/transformers_modules/base/modeling_baichuan.py", line 779, in from_pretrained
return super(BaichuanForCausalLM, cls).from_pretrained(
File "/home/liutianwei/.conda/envs/starwhale/lib/python3.9/site-packages/transformers/modeling_utils.py", line 2700, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/home/liutianwei/.cache/huggingface/modules/transformers_modules/base/modeling_baichuan.py", line 638, in __init__
and config.quantization_config["load_in_4bit"]
TypeError: 'BitsAndBytesConfig' object is not subscriptable
>>>
```

Files changed (1) hide show
  1. modeling_baichuan.py +1 -1
modeling_baichuan.py CHANGED
@@ -528,7 +528,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
528
  self.model = BaichuanModel(config)
529
 
530
  self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
531
- if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
532
  try:
533
  from .quantizer import quantize_offline, init_model_weight_int4
534
  except ImportError:
 
528
  self.model = BaichuanModel(config)
529
 
530
  self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
531
+ if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
532
  try:
533
  from .quantizer import quantize_offline, init_model_weight_int4
534
  except ImportError: