SinanAkkoyun commited on
Commit
56224bb
1 Parent(s): cd9d96f
Files changed (1) hide show
  1. quantize_alpaca.py +178 -0
quantize_alpaca.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import time
4
+ from argparse import ArgumentParser
5
+
6
+ import torch
7
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
8
+ from datasets import Dataset
9
+ from transformers import AutoTokenizer, TextGenerationPipeline
10
+
11
+
12
+ def load_data(data_path, tokenizer, n_samples):
13
+ with open(data_path, "r", encoding="utf-8") as f:
14
+ raw_data = json.load(f)
15
+
16
+ raw_data = random.sample(raw_data, k=min(n_samples, len(raw_data)))
17
+
18
+ def dummy_gen():
19
+ return raw_data
20
+
21
+ def tokenize(examples):
22
+ instructions = examples["instruction"]
23
+ inputs = examples["input"]
24
+ outputs = examples["output"]
25
+
26
+ prompts = []
27
+ texts = []
28
+ input_ids = []
29
+ attention_mask = []
30
+ for istr, inp, opt in zip(instructions, inputs, outputs):
31
+ if inp:
32
+ prompt = f"### User:\n{istr}\n\n### Input:\n{inp}\n\nResponse:\n"
33
+ text = prompt + opt
34
+ else:
35
+ prompt = f"### User:\n{istr}\n\nResponse:\n"
36
+ text = prompt + opt
37
+ if len(tokenizer(prompt)["input_ids"]) >= tokenizer.model_max_length:
38
+ continue
39
+
40
+ tokenized_data = tokenizer(text)
41
+
42
+ input_ids.append(tokenized_data["input_ids"][: tokenizer.model_max_length])
43
+ attention_mask.append(tokenized_data["attention_mask"][: tokenizer.model_max_length])
44
+ prompts.append(prompt)
45
+ texts.append(text)
46
+
47
+ return {
48
+ "input_ids": input_ids,
49
+ "attention_mask": attention_mask,
50
+ "prompt": prompts
51
+ }
52
+
53
+ dataset = Dataset.from_generator(dummy_gen)
54
+
55
+ dataset = dataset.map(
56
+ tokenize,
57
+ batched=True,
58
+ batch_size=len(dataset),
59
+ num_proc=1,
60
+ keep_in_memory=True,
61
+ load_from_cache_file=False,
62
+ remove_columns=["instruction", "input"]
63
+ )
64
+
65
+ dataset = dataset.to_list()
66
+
67
+ for sample in dataset:
68
+ sample["input_ids"] = torch.LongTensor(sample["input_ids"])
69
+ sample["attention_mask"] = torch.LongTensor(sample["attention_mask"])
70
+
71
+ return dataset
72
+
73
+
74
+ def main():
75
+ parser = ArgumentParser()
76
+ parser.add_argument("--pretrained_model_dir", type=str)
77
+ parser.add_argument("--quantized_model_dir", type=str, default=None)
78
+ parser.add_argument("--bits", type=int, default=4, choices=[2, 3, 4, 8])
79
+ parser.add_argument("--group_size", type=int, default=128, help="group size, -1 means no grouping or full rank")
80
+ parser.add_argument("--desc_act", action="store_true", help="whether to quantize with desc_act")
81
+ parser.add_argument("--num_samples", type=int, default=128, help="how many samples will be used to quantize model")
82
+ parser.add_argument("--save_and_reload", action="store_true", help="whether save quantized model to disk and reload back")
83
+ parser.add_argument("--fast_tokenizer", action="store_true", help="whether use fast tokenizer")
84
+ parser.add_argument("--use_triton", action="store_true", help="whether use triton to speedup at inference")
85
+ parser.add_argument("--per_gpu_max_memory", type=int, default=None, help="max memory used to load model per gpu")
86
+ parser.add_argument("--cpu_max_memory", type=int, default=None, help="max memory used to offload model to cpu")
87
+ parser.add_argument("--quant_batch_size", type=int, default=1, help="examples batch size for quantization")
88
+ parser.add_argument("--trust_remote_code", action="store_true", help="whether to trust remote code when loading model")
89
+ args = parser.parse_args()
90
+
91
+ max_memory = dict()
92
+ if args.per_gpu_max_memory is not None and args.per_gpu_max_memory > 0:
93
+ if torch.cuda.is_available():
94
+ max_memory.update(
95
+ {i: f"{args.per_gpu_max_memory}GIB" for i in range(torch.cuda.device_count())}
96
+ )
97
+ if args.cpu_max_memory is not None and args.cpu_max_memory > 0 and max_memory:
98
+ max_memory["cpu"] = f"{args.cpu_max_memory}GIB"
99
+ if not max_memory:
100
+ max_memory = None
101
+
102
+ tokenizer = AutoTokenizer.from_pretrained(
103
+ args.pretrained_model_dir,
104
+ use_fast=args.fast_tokenizer,
105
+ trust_remote_code=args.trust_remote_code
106
+ )
107
+ model = AutoGPTQForCausalLM.from_pretrained(
108
+ args.pretrained_model_dir,
109
+ quantize_config=BaseQuantizeConfig(bits=args.bits, group_size=args.group_size, desc_act=args.desc_act),
110
+ max_memory=max_memory,
111
+ trust_remote_code=args.trust_remote_code
112
+ )
113
+
114
+ examples = load_data("dataset/alpaca_data_cleaned.json", tokenizer, args.num_samples)
115
+ examples_for_quant = [
116
+ {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
117
+ for example in examples
118
+ ]
119
+
120
+ start = time.time()
121
+ model.quantize(
122
+ examples_for_quant,
123
+ batch_size=args.quant_batch_size,
124
+ use_triton=args.use_triton,
125
+ autotune_warmup_after_quantized=args.use_triton,
126
+ )
127
+ end = time.time()
128
+ print(f"quantization took: {end - start: .4f}s")
129
+
130
+ if not args.quantized_model_dir:
131
+ args.quantized_model_dir = args.pretrained_model_dir
132
+
133
+ if args.save_and_reload:
134
+ model.save_quantized(args.quantized_model_dir, use_safetensors=True)
135
+ del model
136
+ if torch.cuda.is_available():
137
+ torch.cuda.empty_cache()
138
+ model = AutoGPTQForCausalLM.from_quantized(
139
+ args.quantized_model_dir,
140
+ device="cuda:0",
141
+ use_triton=args.use_triton,
142
+ max_memory=max_memory,
143
+ inject_fused_mlp=True,
144
+ inject_fused_attention=True,
145
+ trust_remote_code=args.trust_remote_code
146
+ )
147
+
148
+ pipeline_init_kwargs = {"model": model, "tokenizer": tokenizer}
149
+ if not max_memory:
150
+ pipeline_init_kwargs["device"] = "cuda:0"
151
+ pipeline = TextGenerationPipeline(**pipeline_init_kwargs)
152
+ for example in random.sample(examples, k=min(4, len(examples))):
153
+ print(f"prompt: {example['prompt']}")
154
+ print("-" * 42)
155
+ print(f"golden: {example['output']}")
156
+ print("-" * 42)
157
+ start = time.time()
158
+ generated_text = pipeline(
159
+ example['prompt'],
160
+ return_full_text=False,
161
+ num_beams=1,
162
+ max_length=len(example["input_ids"]) + 128 # use this instead of max_new_token to disable UserWarning when integrate with logging
163
+ )[0]['generated_text']
164
+ end = time.time()
165
+ print(f"quant: {generated_text}")
166
+ num_new_tokens = len(tokenizer(generated_text)["input_ids"])
167
+ print(f"generate {num_new_tokens} tokens using {end-start: .4f}s, {num_new_tokens / (end - start)} tokens/s.")
168
+ print("=" * 42)
169
+
170
+
171
+ if __name__ == "__main__":
172
+ import logging
173
+
174
+ logging.basicConfig(
175
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
176
+ )
177
+
178
+ main()