SinanAkkoyun
commited on
Commit
•
56224bb
1
Parent(s):
cd9d96f
test
Browse files- quantize_alpaca.py +178 -0
quantize_alpaca.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
8 |
+
from datasets import Dataset
|
9 |
+
from transformers import AutoTokenizer, TextGenerationPipeline
|
10 |
+
|
11 |
+
|
12 |
+
def load_data(data_path, tokenizer, n_samples):
|
13 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
14 |
+
raw_data = json.load(f)
|
15 |
+
|
16 |
+
raw_data = random.sample(raw_data, k=min(n_samples, len(raw_data)))
|
17 |
+
|
18 |
+
def dummy_gen():
|
19 |
+
return raw_data
|
20 |
+
|
21 |
+
def tokenize(examples):
|
22 |
+
instructions = examples["instruction"]
|
23 |
+
inputs = examples["input"]
|
24 |
+
outputs = examples["output"]
|
25 |
+
|
26 |
+
prompts = []
|
27 |
+
texts = []
|
28 |
+
input_ids = []
|
29 |
+
attention_mask = []
|
30 |
+
for istr, inp, opt in zip(instructions, inputs, outputs):
|
31 |
+
if inp:
|
32 |
+
prompt = f"### User:\n{istr}\n\n### Input:\n{inp}\n\nResponse:\n"
|
33 |
+
text = prompt + opt
|
34 |
+
else:
|
35 |
+
prompt = f"### User:\n{istr}\n\nResponse:\n"
|
36 |
+
text = prompt + opt
|
37 |
+
if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length:
|
38 |
+
continue
|
39 |
+
|
40 |
+
tokenized_data = tokenizer(text)
|
41 |
+
|
42 |
+
input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
|
43 |
+
attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
|
44 |
+
prompts.append(prompt)
|
45 |
+
texts.append(text)
|
46 |
+
|
47 |
+
return {
|
48 |
+
"input_ids": input_ids,
|
49 |
+
"attention_mask": attention_mask,
|
50 |
+
"prompt": prompts
|
51 |
+
}
|
52 |
+
|
53 |
+
dataset = Dataset.from_generator(dummy_gen)
|
54 |
+
|
55 |
+
dataset = dataset.map(
|
56 |
+
tokenize,
|
57 |
+
batched=True,
|
58 |
+
batch_size=len(dataset),
|
59 |
+
num_proc=1,
|
60 |
+
keep_in_memory=True,
|
61 |
+
load_from_cache_file=False,
|
62 |
+
remove_columns=["instruction", "input"]
|
63 |
+
)
|
64 |
+
|
65 |
+
dataset = dataset.to_list()
|
66 |
+
|
67 |
+
for sample in dataset:
|
68 |
+
sample["input_ids"] = torch.LongTensor(sample["input_ids"])
|
69 |
+
sample["attention_mask"] = torch.LongTensor(sample["attention_mask"])
|
70 |
+
|
71 |
+
return dataset
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
parser = ArgumentParser()
|
76 |
+
parser.add_argument("--pretrained_model_dir", type=str)
|
77 |
+
parser.add_argument("--quantized_model_dir", type=str, default=None)
|
78 |
+
parser.add_argument("--bits", type=int, default=4, choices=[2, 3, 4, 8])
|
79 |
+
parser.add_argument("--group_size", type=int, default=128, help="group size, -1 means no grouping or full rank")
|
80 |
+
parser.add_argument("--desc_act", action="store_true", help="whether to quantize with desc_act")
|
81 |
+
parser.add_argument("--num_samples", type=int, default=128, help="how many samples will be used to quantize model")
|
82 |
+
parser.add_argument("--save_and_reload", action="store_true", help="whether save quantized model to disk and reload back")
|
83 |
+
parser.add_argument("--fast_tokenizer", action="store_true", help="whether use fast tokenizer")
|
84 |
+
parser.add_argument("--use_triton", action="store_true", help="whether use triton to speedup at inference")
|
85 |
+
parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="max memory used to load model per gpu")
|
86 |
+
parser.add_argument("--cpu_max_memory", type=int, default=None, help="max memory used to offload model to cpu")
|
87 |
+
parser.add_argument("--quant_batch_size", type=int, default=1, help="examples batch size for quantization")
|
88 |
+
parser.add_argument("--trust_remote_code", action="store_true", help="whether to trust remote code when loading model")
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
max_memory = dict()
|
92 |
+
if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
|
93 |
+
if torch.cuda.is_available():
|
94 |
+
max_memory.update(
|
95 |
+
{i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
|
96 |
+
)
|
97 |
+
if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
|
98 |
+
max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
|
99 |
+
if not max_memory:
|
100 |
+
max_memory = None
|
101 |
+
|
102 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
103 |
+
args.pretrained_model_dir,
|
104 |
+
use_fast=args.fast_tokenizer,
|
105 |
+
trust_remote_code=args.trust_remote_code
|
106 |
+
)
|
107 |
+
model = AutoGPTQForCausalLM.from_pretrained(
|
108 |
+
args.pretrained_model_dir,
|
109 |
+
quantize_config=BaseQuantizeConfig(bits=args.bits, group_size=args.group_size, desc_act=args.desc_act),
|
110 |
+
max_memory=max_memory,
|
111 |
+
trust_remote_code=args.trust_remote_code
|
112 |
+
)
|
113 |
+
|
114 |
+
examples = load_data("dataset/alpaca_data_cleaned.json", tokenizer, args.num_samples)
|
115 |
+
examples_for_quant = [
|
116 |
+
{"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
|
117 |
+
for example in examples
|
118 |
+
]
|
119 |
+
|
120 |
+
start = time.time()
|
121 |
+
model.quantize(
|
122 |
+
examples_for_quant,
|
123 |
+
batch_size=args.quant_batch_size,
|
124 |
+
use_triton=args.use_triton,
|
125 |
+
autotune_warmup_after_quantized=args.use_triton,
|
126 |
+
)
|
127 |
+
end = time.time()
|
128 |
+
print(f"quantization took: {end - start: .4f}s")
|
129 |
+
|
130 |
+
if not args.quantized_model_dir:
|
131 |
+
args.quantized_model_dir = args.pretrained_model_dir
|
132 |
+
|
133 |
+
if args.save_and_reload:
|
134 |
+
model.save_quantized(args.quantized_model_dir, use_safetensors=True)
|
135 |
+
del model
|
136 |
+
if torch.cuda.is_available():
|
137 |
+
torch.cuda.empty_cache()
|
138 |
+
model = AutoGPTQForCausalLM.from_quantized(
|
139 |
+
args.quantized_model_dir,
|
140 |
+
device="cuda:0",
|
141 |
+
use_triton=args.use_triton,
|
142 |
+
max_memory=max_memory,
|
143 |
+
inject_fused_mlp=True,
|
144 |
+
inject_fused_attention=True,
|
145 |
+
trust_remote_code=args.trust_remote_code
|
146 |
+
)
|
147 |
+
|
148 |
+
pipeline_init_kwargs = {"model": model, "tokenizer": tokenizer}
|
149 |
+
if not max_memory:
|
150 |
+
pipeline_init_kwargs["device"] = "cuda:0"
|
151 |
+
pipeline = TextGenerationPipeline(**pipeline_init_kwargs)
|
152 |
+
for example in random.sample(examples, k=min(4, len(examples))):
|
153 |
+
print(f"prompt: {example['prompt']}")
|
154 |
+
print("-" * 42)
|
155 |
+
print(f"golden: {example['output']}")
|
156 |
+
print("-" * 42)
|
157 |
+
start = time.time()
|
158 |
+
generated_text = pipeline(
|
159 |
+
example['prompt'],
|
160 |
+
return_full_text=False,
|
161 |
+
num_beams=1,
|
162 |
+
max_length=len(example["input_ids"]) + 128 # use this instead of max_new_token to disable UserWarning when integrate with logging
|
163 |
+
)[0]['generated_text']
|
164 |
+
end = time.time()
|
165 |
+
print(f"quant: {generated_text}")
|
166 |
+
num_new_tokens = len(tokenizer(generated_text)["input_ids"])
|
167 |
+
print(f"generate {num_new_tokens} tokens using {end-start: .4f}s, {num_new_tokens / (end - start)} tokens/s.")
|
168 |
+
print("=" * 42)
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
import logging
|
173 |
+
|
174 |
+
logging.basicConfig(
|
175 |
+
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
176 |
+
)
|
177 |
+
|
178 |
+
main()
|