FIRE / src /modules /gptq.py
zhangbofei
feat: change to fstchat
6dc0c9c
raw
history blame
2.27 kB
from dataclasses import dataclass, field
import os
from os.path import isdir, isfile
from pathlib import Path
import sys
from transformers import AutoTokenizer
@dataclass
class GptqConfig:
ckpt: str = field(
default=None,
metadata={
"help": "Load quantized model. The path to the local GPTQ checkpoint."
},
)
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
groupsize: int = field(
default=-1,
metadata={"help": "Groupsize to use for quantization; default uses full row."},
)
act_order: bool = field(
default=True,
metadata={"help": "Whether to apply the activation order GPTQ heuristic"},
)
def load_gptq_quantized(model_name, gptq_config: GptqConfig):
print("Loading GPTQ quantized model...")
try:
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa")
sys.path.insert(0, module_path)
from llama import load_quant
except ImportError as e:
print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
sys.exit(-1)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# only `fastest-inference-4bit` branch cares about `act_order`
if gptq_config.act_order:
model = load_quant(
model_name,
find_gptq_ckpt(gptq_config),
gptq_config.wbits,
gptq_config.groupsize,
act_order=gptq_config.act_order,
)
else:
# other branches
model = load_quant(
model_name,
find_gptq_ckpt(gptq_config),
gptq_config.wbits,
gptq_config.groupsize,
)
return model, tokenizer
def find_gptq_ckpt(gptq_config: GptqConfig):
if Path(gptq_config.ckpt).is_file():
return gptq_config.ckpt
for ext in ["*.pt", "*.safetensors"]:
matched_result = sorted(Path(gptq_config.ckpt).glob(ext))
if len(matched_result) > 0:
return str(matched_result[-1])
print("Error: gptq checkpoint not found")
sys.exit(1)