Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- fp8
|
4 |
+
---
|
5 |
+
Quantized using the script below:
|
6 |
+
|
7 |
+
Command:
|
8 |
+
```bash
|
9 |
+
python quantize.py --model-id mistralai/Mixtral-8x7B-Instruct-v0.1 --save-dir Mixtral-8x7B-Instruct-v0.1-FP8 --num-samples 512
|
10 |
+
```
|
11 |
+
|
12 |
+
Script:
|
13 |
+
```python
|
14 |
+
import argparse
|
15 |
+
import gc
|
16 |
+
import re
|
17 |
+
from typing import Tuple
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.functional as F
|
21 |
+
import transformers
|
22 |
+
from datasets import load_dataset
|
23 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
24 |
+
|
25 |
+
|
26 |
+
# HACK: override the dtype_byte_size function in transformers to support float8 types
|
27 |
+
def new_dtype_byte_size(dtype):
|
28 |
+
if dtype == torch.bool:
|
29 |
+
return 1 / 8
|
30 |
+
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
31 |
+
if bit_search is None:
|
32 |
+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
33 |
+
bit_size = int(bit_search.groups()[0])
|
34 |
+
return bit_size // 8
|
35 |
+
|
36 |
+
|
37 |
+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
38 |
+
|
39 |
+
|
40 |
+
def cleanup_memory():
|
41 |
+
gc.collect()
|
42 |
+
torch.cuda.empty_cache()
|
43 |
+
|
44 |
+
|
45 |
+
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
46 |
+
"""Quantize a tensor using per-tensor static scaling factor.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
tensor: The input tensor.
|
50 |
+
"""
|
51 |
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
52 |
+
# Calculate the scale as dtype max divided by absmax.
|
53 |
+
# Since .abs() creates a new tensor, we use aminmax to get
|
54 |
+
# the min and max first and then calculate the absmax.
|
55 |
+
if tensor.numel() == 0:
|
56 |
+
# Deal with empty tensors (triggered by empty MoE experts)
|
57 |
+
min_val, max_val = (
|
58 |
+
torch.tensor(0.0, dtype=tensor.dtype),
|
59 |
+
torch.tensor(1.0, dtype=tensor.dtype),
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
min_val, max_val = tensor.aminmax()
|
63 |
+
amax = min_val.abs().max(max_val.abs())
|
64 |
+
scale = finfo.max / amax.clamp(min=1e-12)
|
65 |
+
# scale and clamp the tensor to bring it to
|
66 |
+
# the representative range of float8 data type
|
67 |
+
# (as default cast is unsaturated)
|
68 |
+
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
69 |
+
# Return both float8 data and the inverse scale (as float),
|
70 |
+
# as both required as inputs to torch._scaled_mm
|
71 |
+
qweight = qweight.to(torch.float8_e4m3fn)
|
72 |
+
scale = scale.float().reciprocal()
|
73 |
+
return qweight, scale
|
74 |
+
|
75 |
+
|
76 |
+
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
|
77 |
+
cuda_compute_capability = torch.cuda.get_device_capability()
|
78 |
+
if cuda_compute_capability >= (9, 0):
|
79 |
+
output, _ = torch._scaled_mm(
|
80 |
+
A,
|
81 |
+
B.t(),
|
82 |
+
out_dtype=out_dtype,
|
83 |
+
scale_a=A_scale,
|
84 |
+
scale_b=B_scale,
|
85 |
+
bias=bias,
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
output = torch.nn.functional.linear(
|
89 |
+
A.to(out_dtype) * A_scale,
|
90 |
+
B.to(out_dtype) * B_scale.to(out_dtype),
|
91 |
+
bias=bias,
|
92 |
+
)
|
93 |
+
return output
|
94 |
+
|
95 |
+
|
96 |
+
class FP8StaticLinearQuantizer(torch.nn.Module):
|
97 |
+
def __init__(self, qweight, weight_scale):
|
98 |
+
super().__init__()
|
99 |
+
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
100 |
+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
101 |
+
self.act_scale = None
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
# Dynamically quantize
|
105 |
+
qinput, x_act_scale = per_tensor_quantize(x)
|
106 |
+
|
107 |
+
# Update scale if needed.
|
108 |
+
if self.act_scale is None:
|
109 |
+
self.act_scale = torch.nn.Parameter(x_act_scale)
|
110 |
+
elif x_act_scale > self.act_scale:
|
111 |
+
self.act_scale = torch.nn.Parameter(x_act_scale)
|
112 |
+
|
113 |
+
# Pass quantized to next layer so it has realistic data.
|
114 |
+
output = fp8_gemm(
|
115 |
+
A=qinput,
|
116 |
+
A_scale=self.act_scale,
|
117 |
+
B=self.weight,
|
118 |
+
B_scale=self.weight_scale,
|
119 |
+
bias=None,
|
120 |
+
out_dtype=x.dtype,
|
121 |
+
)
|
122 |
+
return output
|
123 |
+
|
124 |
+
|
125 |
+
class FP8StaticLinear(torch.nn.Module):
|
126 |
+
def __init__(self, qweight, weight_scale, act_scale=0.0):
|
127 |
+
super().__init__()
|
128 |
+
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
129 |
+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
130 |
+
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
|
131 |
+
|
132 |
+
def per_tensor_quantize(
|
133 |
+
self, tensor: torch.Tensor, inv_scale: float
|
134 |
+
) -> torch.Tensor:
|
135 |
+
# Scale and clamp the tensor to bring it to
|
136 |
+
# the representative range of float8 data type
|
137 |
+
# (as default cast is unsaturated)
|
138 |
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
139 |
+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
140 |
+
return qweight.to(torch.float8_e4m3fn)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
|
144 |
+
output = fp8_gemm(
|
145 |
+
A=qinput,
|
146 |
+
A_scale=self.act_scale,
|
147 |
+
B=self.weight,
|
148 |
+
B_scale=self.weight_scale,
|
149 |
+
bias=None,
|
150 |
+
out_dtype=x.dtype,
|
151 |
+
)
|
152 |
+
return output
|
153 |
+
|
154 |
+
|
155 |
+
class FP8DynamicLinear(torch.nn.Module):
|
156 |
+
def __init__(self, qweight, scale):
|
157 |
+
super().__init__()
|
158 |
+
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
159 |
+
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
qinput, x_scale = per_tensor_quantize(x)
|
163 |
+
output = fp8_gemm(
|
164 |
+
A=qinput,
|
165 |
+
A_scale=x_scale,
|
166 |
+
B=self.weight,
|
167 |
+
B_scale=self.weight_scale,
|
168 |
+
bias=None,
|
169 |
+
out_dtype=x.dtype,
|
170 |
+
)
|
171 |
+
return output
|
172 |
+
|
173 |
+
|
174 |
+
def replace_module(model, name, new_module):
|
175 |
+
if "." in name:
|
176 |
+
parent_name = name.rsplit(".", 1)[0]
|
177 |
+
child_name = name[len(parent_name) + 1 :]
|
178 |
+
parent = model.model.get_submodule(parent_name)
|
179 |
+
else:
|
180 |
+
parent_name = ""
|
181 |
+
parent = model.model
|
182 |
+
child_name = name
|
183 |
+
setattr(parent, child_name, new_module)
|
184 |
+
|
185 |
+
|
186 |
+
def quantize_weights(model):
|
187 |
+
for name, linear in model.model.named_modules():
|
188 |
+
if "gate" in name or not isinstance(linear, torch.nn.Linear):
|
189 |
+
continue
|
190 |
+
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
|
191 |
+
quant_linear = FP8DynamicLinear(quant_weight, quant_scale)
|
192 |
+
replace_module(model, name, quant_linear)
|
193 |
+
del linear
|
194 |
+
cleanup_memory()
|
195 |
+
|
196 |
+
|
197 |
+
def quantize_activations(model, calibration_tokens):
|
198 |
+
# Replace layers with quantizer.
|
199 |
+
for name, dynamic_quant_linear in model.model.named_modules():
|
200 |
+
if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear):
|
201 |
+
continue
|
202 |
+
quantizer = FP8StaticLinearQuantizer(
|
203 |
+
dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale
|
204 |
+
)
|
205 |
+
replace_module(model, name, quantizer)
|
206 |
+
del dynamic_quant_linear
|
207 |
+
cleanup_memory()
|
208 |
+
|
209 |
+
# Calibration.
|
210 |
+
for row_idx in range(calibration_tokens.shape[0]):
|
211 |
+
_ = model(calibration_tokens[row_idx].reshape(1, -1))
|
212 |
+
|
213 |
+
# Replace quantizer with StaticLayer.
|
214 |
+
for name, quantizer in model.model.named_modules():
|
215 |
+
if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer):
|
216 |
+
continue
|
217 |
+
static_proj = FP8StaticLinear(
|
218 |
+
quantizer.weight, quantizer.weight_scale, quantizer.act_scale
|
219 |
+
)
|
220 |
+
replace_module(model, name, static_proj)
|
221 |
+
del quantizer
|
222 |
+
cleanup_memory()
|
223 |
+
|
224 |
+
|
225 |
+
def save_quantized_model(model, activation_scheme, save_dir):
|
226 |
+
print(f"Saving the model to {save_dir}")
|
227 |
+
static_q_dict = {
|
228 |
+
"quantization_config": {
|
229 |
+
"quant_method": "fp8",
|
230 |
+
"activation_scheme": activation_scheme,
|
231 |
+
}
|
232 |
+
}
|
233 |
+
model.config.update(static_q_dict)
|
234 |
+
model.save_pretrained(save_dir)
|
235 |
+
tokenizer.save_pretrained(save_dir)
|
236 |
+
|
237 |
+
|
238 |
+
if __name__ == "__main__":
|
239 |
+
parser = argparse.ArgumentParser()
|
240 |
+
parser.add_argument("--model-id", type=str)
|
241 |
+
parser.add_argument("--save-dir", type=str)
|
242 |
+
parser.add_argument(
|
243 |
+
"--activation-scheme", type=str, default="static", choices=["static", "dynamic"]
|
244 |
+
)
|
245 |
+
parser.add_argument("--num-samples", type=int, default=512)
|
246 |
+
parser.add_argument("--max-seq-len", type=int, default=512)
|
247 |
+
args = parser.parse_args()
|
248 |
+
|
249 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
|
250 |
+
sample_input_tokens = tokenizer.apply_chat_template(
|
251 |
+
[{"role": "user", "content": "What is your name?"}],
|
252 |
+
add_generation_prompt=True,
|
253 |
+
return_tensors="pt",
|
254 |
+
).to("cuda")
|
255 |
+
|
256 |
+
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
257 |
+
ds = ds.shuffle(seed=42).select(range(args.num_samples))
|
258 |
+
ds = ds.map(
|
259 |
+
lambda batch: {
|
260 |
+
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
|
261 |
+
}
|
262 |
+
)
|
263 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
264 |
+
calibration_tokens = tokenizer(
|
265 |
+
ds["text"],
|
266 |
+
return_tensors="pt",
|
267 |
+
truncation=True,
|
268 |
+
padding="max_length",
|
269 |
+
max_length=args.max_seq_len,
|
270 |
+
add_special_tokens=False,
|
271 |
+
).input_ids.to("cuda")
|
272 |
+
print("Calibration tokens:", calibration_tokens.shape)
|
273 |
+
|
274 |
+
# Load and test the model
|
275 |
+
model = AutoModelForCausalLM.from_pretrained(
|
276 |
+
args.model_id, torch_dtype="auto", device_map="auto"
|
277 |
+
)
|
278 |
+
print(model)
|
279 |
+
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
|
280 |
+
print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n")
|
281 |
+
|
282 |
+
# Quantize weights.
|
283 |
+
quantize_weights(model)
|
284 |
+
print(model)
|
285 |
+
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
|
286 |
+
print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
|
287 |
+
|
288 |
+
if args.activation_scheme in "dynamic":
|
289 |
+
print("Exporting model with static weights and dynamic activations")
|
290 |
+
save_quantized_model(model, args.activation_scheme, args.save_dir)
|
291 |
+
else:
|
292 |
+
assert args.activation_scheme in "static"
|
293 |
+
# Quantize activations.
|
294 |
+
quantize_activations(model, calibration_tokens=calibration_tokens)
|
295 |
+
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
|
296 |
+
print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
|
297 |
+
|
298 |
+
print("Exporting model with static weights and static activations")
|
299 |
+
save_quantized_model(model, args.activation_scheme, args.save_dir)
|
300 |
+
```
|