|
--- |
|
tags: |
|
- fp8 |
|
- vllm |
|
--- |
|
|
|
# Mixtral-8x7B-Instruct-v0.1-FP8 |
|
|
|
## Model Overview |
|
Mixtral-8x7B-Instruct-v0.1 quantized to FP8 weights and activations, ready for inference with vLLM >= 0.5.0. |
|
|
|
## Usage and Creation |
|
Produced using [AutoFP8 with calibration samples from ultrachat](https://github.com/neuralmagic/AutoFP8/blob/147fa4d9e1a90ef8a93f96fc7d9c33056ddc017a/example_dataset.py). |
|
|
|
Quantized using the script below: |
|
|
|
Command: |
|
```bash |
|
python quantize.py --model-id mistralai/Mixtral-8x7B-Instruct-v0.1 --save-dir Mixtral-8x7B-Instruct-v0.1-FP8 --num-samples 512 |
|
``` |
|
|
|
Script: |
|
```python |
|
import argparse |
|
import gc |
|
import re |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.functional as F |
|
import transformers |
|
from datasets import load_dataset |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
# HACK: override the dtype_byte_size function in transformers to support float8 types |
|
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 |
|
def new_dtype_byte_size(dtype): |
|
if dtype == torch.bool: |
|
return 1 / 8 |
|
bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) |
|
if bit_search is None: |
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") |
|
bit_size = int(bit_search.groups()[0]) |
|
return bit_size // 8 |
|
|
|
|
|
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size |
|
|
|
|
|
def cleanup_memory(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: |
|
"""Quantize a tensor using per-tensor static scaling factor. |
|
|
|
Args: |
|
tensor: The input tensor. |
|
""" |
|
finfo = torch.finfo(torch.float8_e4m3fn) |
|
# Calculate the scale as dtype max divided by absmax. |
|
# Since .abs() creates a new tensor, we use aminmax to get |
|
# the min and max first and then calculate the absmax. |
|
if tensor.numel() == 0: |
|
# Deal with empty tensors (triggered by empty MoE experts) |
|
min_val, max_val = ( |
|
torch.tensor(0.0, dtype=tensor.dtype), |
|
torch.tensor(1.0, dtype=tensor.dtype), |
|
) |
|
else: |
|
min_val, max_val = tensor.aminmax() |
|
amax = min_val.abs().max(max_val.abs()) |
|
scale = finfo.max / amax.clamp(min=1e-12) |
|
# scale and clamp the tensor to bring it to |
|
# the representative range of float8 data type |
|
# (as default cast is unsaturated) |
|
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) |
|
# Return both float8 data and the inverse scale (as float), |
|
# as both required as inputs to torch._scaled_mm |
|
qweight = qweight.to(torch.float8_e4m3fn) |
|
scale = scale.float().reciprocal() |
|
return qweight, scale |
|
|
|
|
|
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): |
|
cuda_compute_capability = torch.cuda.get_device_capability() |
|
if cuda_compute_capability >= (9, 0): |
|
output, _ = torch._scaled_mm( |
|
A, |
|
B.t(), |
|
out_dtype=out_dtype, |
|
scale_a=A_scale, |
|
scale_b=B_scale, |
|
bias=bias, |
|
) |
|
else: |
|
output = torch.nn.functional.linear( |
|
A.to(out_dtype) * A_scale, |
|
B.to(out_dtype) * B_scale.to(out_dtype), |
|
bias=bias, |
|
) |
|
return output |
|
|
|
|
|
class FP8StaticLinearQuantizer(torch.nn.Module): |
|
def __init__(self, qweight, weight_scale): |
|
super().__init__() |
|
self.weight = torch.nn.Parameter(qweight, requires_grad=False) |
|
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) |
|
self.act_scale = None |
|
|
|
def forward(self, x): |
|
# Dynamically quantize |
|
qinput, x_act_scale = per_tensor_quantize(x) |
|
|
|
# Update scale if needed. |
|
if self.act_scale is None: |
|
self.act_scale = torch.nn.Parameter(x_act_scale) |
|
elif x_act_scale > self.act_scale: |
|
self.act_scale = torch.nn.Parameter(x_act_scale) |
|
|
|
# Pass quantized to next layer so it has realistic data. |
|
output = fp8_gemm( |
|
A=qinput, |
|
A_scale=self.act_scale, |
|
B=self.weight, |
|
B_scale=self.weight_scale, |
|
bias=None, |
|
out_dtype=x.dtype, |
|
) |
|
return output |
|
|
|
|
|
class FP8StaticLinear(torch.nn.Module): |
|
def __init__(self, qweight, weight_scale, act_scale=0.0): |
|
super().__init__() |
|
self.weight = torch.nn.Parameter(qweight, requires_grad=False) |
|
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) |
|
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False) |
|
|
|
def per_tensor_quantize( |
|
self, tensor: torch.Tensor, inv_scale: float |
|
) -> torch.Tensor: |
|
# Scale and clamp the tensor to bring it to |
|
# the representative range of float8 data type |
|
# (as default cast is unsaturated) |
|
finfo = torch.finfo(torch.float8_e4m3fn) |
|
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) |
|
return qweight.to(torch.float8_e4m3fn) |
|
|
|
def forward(self, x): |
|
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale) |
|
output = fp8_gemm( |
|
A=qinput, |
|
A_scale=self.act_scale, |
|
B=self.weight, |
|
B_scale=self.weight_scale, |
|
bias=None, |
|
out_dtype=x.dtype, |
|
) |
|
return output |
|
|
|
|
|
class FP8DynamicLinear(torch.nn.Module): |
|
def __init__(self, qweight, scale): |
|
super().__init__() |
|
self.weight = torch.nn.Parameter(qweight, requires_grad=False) |
|
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False) |
|
|
|
def forward(self, x): |
|
qinput, x_scale = per_tensor_quantize(x) |
|
output = fp8_gemm( |
|
A=qinput, |
|
A_scale=x_scale, |
|
B=self.weight, |
|
B_scale=self.weight_scale, |
|
bias=None, |
|
out_dtype=x.dtype, |
|
) |
|
return output |
|
|
|
|
|
def replace_module(model, name, new_module): |
|
if "." in name: |
|
parent_name = name.rsplit(".", 1)[0] |
|
child_name = name[len(parent_name) + 1 :] |
|
parent = model.model.get_submodule(parent_name) |
|
else: |
|
parent_name = "" |
|
parent = model.model |
|
child_name = name |
|
setattr(parent, child_name, new_module) |
|
|
|
|
|
def quantize_weights(model): |
|
for name, linear in model.model.named_modules(): |
|
if "gate" in name or not isinstance(linear, torch.nn.Linear): |
|
continue |
|
quant_weight, quant_scale = per_tensor_quantize(linear.weight) |
|
quant_linear = FP8DynamicLinear(quant_weight, quant_scale) |
|
replace_module(model, name, quant_linear) |
|
del linear |
|
cleanup_memory() |
|
|
|
|
|
def quantize_activations(model, calibration_tokens): |
|
# Replace layers with quantizer. |
|
for name, dynamic_quant_linear in model.model.named_modules(): |
|
if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear): |
|
continue |
|
quantizer = FP8StaticLinearQuantizer( |
|
dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale |
|
) |
|
replace_module(model, name, quantizer) |
|
del dynamic_quant_linear |
|
cleanup_memory() |
|
|
|
# Calibration. |
|
for row_idx in range(calibration_tokens.shape[0]): |
|
_ = model(calibration_tokens[row_idx].reshape(1, -1)) |
|
|
|
# Replace quantizer with StaticLayer. |
|
for name, quantizer in model.model.named_modules(): |
|
if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer): |
|
continue |
|
static_proj = FP8StaticLinear( |
|
quantizer.weight, quantizer.weight_scale, quantizer.act_scale |
|
) |
|
replace_module(model, name, static_proj) |
|
del quantizer |
|
cleanup_memory() |
|
|
|
|
|
def save_quantized_model(model, activation_scheme, save_dir): |
|
print(f"Saving the model to {save_dir}") |
|
static_q_dict = { |
|
"quantization_config": { |
|
"quant_method": "fp8", |
|
"activation_scheme": activation_scheme, |
|
} |
|
} |
|
model.config.update(static_q_dict) |
|
model.save_pretrained(save_dir) |
|
tokenizer.save_pretrained(save_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model-id", type=str) |
|
parser.add_argument("--save-dir", type=str) |
|
parser.add_argument( |
|
"--activation-scheme", type=str, default="static", choices=["static", "dynamic"] |
|
) |
|
parser.add_argument("--num-samples", type=int, default=512) |
|
parser.add_argument("--max-seq-len", type=int, default=512) |
|
args = parser.parse_args() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_id) |
|
sample_input_tokens = tokenizer.apply_chat_template( |
|
[{"role": "user", "content": "What is your name?"}], |
|
add_generation_prompt=True, |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") |
|
ds = ds.shuffle(seed=42).select(range(args.num_samples)) |
|
ds = ds.map( |
|
lambda batch: { |
|
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False) |
|
} |
|
) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
calibration_tokens = tokenizer( |
|
ds["text"], |
|
return_tensors="pt", |
|
truncation=True, |
|
padding="max_length", |
|
max_length=args.max_seq_len, |
|
add_special_tokens=False, |
|
).input_ids.to("cuda") |
|
print("Calibration tokens:", calibration_tokens.shape) |
|
|
|
# Load and test the model |
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_id, torch_dtype="auto", device_map="auto" |
|
) |
|
print(model) |
|
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) |
|
print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n") |
|
|
|
# Quantize weights. |
|
quantize_weights(model) |
|
print(model) |
|
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) |
|
print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n") |
|
|
|
if args.activation_scheme in "dynamic": |
|
print("Exporting model with static weights and dynamic activations") |
|
save_quantized_model(model, args.activation_scheme, args.save_dir) |
|
else: |
|
assert args.activation_scheme in "static" |
|
# Quantize activations. |
|
quantize_activations(model, calibration_tokens=calibration_tokens) |
|
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) |
|
print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n") |
|
|
|
print("Exporting model with static weights and static activations") |
|
save_quantized_model(model, args.activation_scheme, args.save_dir) |
|
``` |
|
|
|
## Evaluation |
|
|
|
### Open LLM Leaderboard evaluation scores |
|
| | Mixtral-8x7B-Instruct-v0.1 | Mixtral-8x7B-Instruct-v0.1-FP8<br>(this model) | |
|
| :------------------: | :----------------------: | :------------------------------------------------: | |
|
| arc-c<br>25-shot | 71.50 | 70.05 | |
|
| hellaswag<br>10-shot | 87.53 | 86.30 | |
|
| mmlu<br>5-shot | 70.33 | 68.81 | |
|
| truthfulqa<br>0-shot | 64.79 | 63.69 | |
|
| winogrande<br>5-shot | 82.40 | 81.69 | |
|
| gsm8k<br>5-shot | 64.36 | 59.82 | |
|
| **Average<br>Accuracy** | **73.48** | **71.72** | |
|
| **Recovery** | **100%** | **97.60%** | |
|
|
|
|
|
|
|
|
|
|
|
|