Spaces:
Running
Running
from dataclasses import dataclass, field | |
import os | |
from os.path import isdir, isfile | |
from pathlib import Path | |
import sys | |
from transformers import AutoTokenizer | |
class GptqConfig: | |
ckpt: str = field( | |
default=None, | |
metadata={ | |
"help": "Load quantized model. The path to the local GPTQ checkpoint." | |
}, | |
) | |
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) | |
groupsize: int = field( | |
default=-1, | |
metadata={"help": "Groupsize to use for quantization; default uses full row."}, | |
) | |
act_order: bool = field( | |
default=True, | |
metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, | |
) | |
def load_gptq_quantized(model_name, gptq_config: GptqConfig): | |
print("Loading GPTQ quantized model...") | |
try: | |
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | |
module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") | |
sys.path.insert(0, module_path) | |
from llama import load_quant | |
except ImportError as e: | |
print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") | |
print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") | |
sys.exit(-1) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
# only `fastest-inference-4bit` branch cares about `act_order` | |
if gptq_config.act_order: | |
model = load_quant( | |
model_name, | |
find_gptq_ckpt(gptq_config), | |
gptq_config.wbits, | |
gptq_config.groupsize, | |
act_order=gptq_config.act_order, | |
) | |
else: | |
# other branches | |
model = load_quant( | |
model_name, | |
find_gptq_ckpt(gptq_config), | |
gptq_config.wbits, | |
gptq_config.groupsize, | |
) | |
return model, tokenizer | |
def find_gptq_ckpt(gptq_config: GptqConfig): | |
if Path(gptq_config.ckpt).is_file(): | |
return gptq_config.ckpt | |
for ext in ["*.pt", "*.safetensors"]: | |
matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) | |
if len(matched_result) > 0: | |
return str(matched_result[-1]) | |
print("Error: gptq checkpoint not found") | |
sys.exit(1) | |