Feat (script): Added option to validate on MLPerf validation set & to load a pre-quantized checkpoint.
34b0078
""" | |
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | |
SPDX-License-Identifier: MIT | |
""" | |
import argparse | |
import copy | |
from datetime import datetime | |
import json | |
import os | |
import time | |
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint | |
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat | |
from brevitas.quant.scaled_int import Int8ActPerTensorFloat | |
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat | |
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d, LoRACompatibleQuantLinear | |
from diffusers import DiffusionPipeline | |
from diffusers.models.attention_processor import Attention | |
from diffusers.models.attention_processor import AttnProcessor | |
import pandas as pd | |
import torch | |
from torch import nn | |
from tqdm import tqdm | |
import brevitas.nn as qnn | |
from brevitas.graph.base import ModuleToModuleByClass | |
from brevitas.graph.calibrate import bias_correction_mode | |
from brevitas.graph.calibrate import calibration_mode | |
from brevitas.graph.equalize import activation_equalization_mode | |
from brevitas.graph.quantize import layerwise_quantize | |
from brevitas.inject.enum import StatsOp | |
from brevitas.nn.equalized_layer import EqualizedModule | |
from brevitas.utils.torch_utils import KwargsForwardHook | |
import brevitas.config as config | |
from brevitas_examples.common.parse_utils import add_bool_arg | |
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params | |
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention | |
from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid | |
TEST_SEED = 123456 | |
torch.manual_seed(TEST_SEED) | |
class WeightQuant(ShiftedUint8WeightPerChannelFloat): | |
narrow_range = False | |
scaling_min_val = 1e-4 | |
quantize_zero_point = True | |
scaling_impl_type = 'parameter_from_stats' | |
zero_point_impl = ParameterFromStatsFromParameterZeroPoint | |
class InputQuant(Int8ActPerTensorFloat): | |
scaling_stats_op = StatsOp.MAX | |
class OutputQuant(Fp8e4m3FNUZActPerTensorFloat): | |
scaling_stats_op = StatsOp.MAX | |
NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] | |
def load_calib_prompts(calib_data_path, sep="\t"): | |
df = pd.read_csv(calib_data_path, sep=sep) | |
lst = df["caption"].tolist() | |
return lst | |
def run_val_inference( | |
pipe, | |
prompts, | |
guidance_scale, | |
total_steps, | |
test_latents=None): | |
with torch.no_grad(): | |
for prompt in tqdm(prompts): | |
# We don't want to generate any image, so we return only the latent encoding pre VAE | |
pipe( | |
prompt, | |
negative_prompt=NEGATIVE_PROMPTS[0], | |
latents=test_latents, | |
output_type='latent', | |
guidance_scale=guidance_scale, | |
num_inference_steps=total_steps) | |
def main(args): | |
dtype = getattr(torch, args.dtype) | |
calibration_prompts = load_calib_prompts(args.calibration_prompt_path) | |
assert args.calibration_prompts <= len(calibration_prompts) , f"--calibration-prompts must be <= {len(calibration_prompts)}" | |
calibration_prompts = calibration_prompts[:args.calibration_prompts] | |
latents = torch.load(args.path_to_latents).to(torch.float16) | |
# Create output dir. Move to tmp if None | |
ts = datetime.fromtimestamp(time.time()) | |
str_ts = ts.strftime("%Y%m%d_%H%M%S") | |
output_dir = os.path.join(args.output_path, f'{str_ts}') | |
os.mkdir(output_dir) | |
print(f"Saving results in {output_dir}") | |
# Dump args to json | |
with open(os.path.join(output_dir, 'args.json'), 'w') as fp: | |
json.dump(vars(args), fp) | |
# Load model from float checkpoint | |
print(f"Loading model from {args.model}...") | |
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) | |
print(f"Model loaded from {args.model}.") | |
# Move model to target device | |
print(f"Moving model to {args.device}...") | |
pipe = pipe.to(args.device) | |
# Enable attention slicing | |
if args.attention_slicing: | |
pipe.enable_attention_slicing() | |
# Extract list of layers to avoid | |
blacklist = [] | |
for name, _ in pipe.unet.named_modules(): | |
if 'time_emb' in name: | |
blacklist.append(name.split('.')[-1]) | |
print(f"Blacklisted layers: {blacklist}") | |
# Make sure there all LoRA layers are fused first, otherwise raise an error | |
for m in pipe.unet.modules(): | |
if hasattr(m, 'lora_layer') and m.lora_layer is not None: | |
raise RuntimeError("LoRA layers should be fused in before calling into quantization.") | |
pipe.set_progress_bar_config(disable=True) | |
if args.load_checkpoint is not None: | |
with load_quant_model_mode(pipe.unet): | |
pipe = pipe.to('cpu') | |
print(f"Loading checkpoint: {args.load_checkpoint}... ", end="") | |
pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu')) | |
print(f"Checkpoint loaded!") | |
pipe = pipe.to(args.device) | |
if args.load_checkpoint is not None: | |
# Don't run full activation equalization if we're loading a quantized checkpoint | |
num_ae_prompts = 2 | |
else: | |
num_ae_prompts = len(calibration_prompts) | |
with activation_equalization_mode( | |
pipe.unet, | |
alpha=args.act_eq_alpha, | |
layerwise=True, | |
blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None, | |
add_mul_node=True): | |
# Workaround to expose `in_features` attribute from the Hook Wrapper | |
for m in pipe.unet.modules(): | |
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): | |
m.in_features = m.module.in_features | |
total_steps = args.calibration_steps | |
run_val_inference( | |
pipe, | |
calibration_prompts[:num_ae_prompts], | |
total_steps=total_steps, | |
test_latents=latents, | |
guidance_scale=args.guidance_scale) | |
# Workaround to expose `in_features` attribute from the EqualizedModule Wrapper | |
for m in pipe.unet.modules(): | |
if isinstance(m, EqualizedModule) and hasattr(m.layer, 'in_features'): | |
m.in_features = m.layer.in_features | |
quant_layer_kwargs = { | |
'input_quant': InputQuant, 'weight_quant': WeightQuant, 'dtype': dtype, 'device': args.device, 'input_dtype': dtype, 'input_device': args.device} | |
quant_linear_kwargs = copy.deepcopy(quant_layer_kwargs) | |
if args.quantize_sdp: | |
output_quant = OutputQuant | |
rewriter = ModuleToModuleByClass( | |
Attention, | |
QuantAttention, | |
softmax_output_quant=output_quant, | |
query_dim=lambda module: module.to_q.in_features, | |
dim_head=lambda module: int(1 / (module.scale ** 2)), | |
processor=AttnProcessor(), | |
is_equalized=True) | |
config.IGNORE_MISSING_KEYS = True | |
pipe.unet = rewriter.apply(pipe.unet) | |
config.IGNORE_MISSING_KEYS = False | |
pipe.unet = pipe.unet.to(args.device) | |
pipe.unet = pipe.unet.to(dtype) | |
# quant_kwargs = layer_map[torch.nn.Linear][1] | |
what_to_quantize = ['to_q', 'to_k', 'to_v'] | |
quant_linear_kwargs['output_quant'] = lambda module, name: output_quant if any(ending in name for ending in what_to_quantize) else None | |
quant_linear_kwargs['output_dtype'] = dtype | |
quant_linear_kwargs['output_device'] = args.device | |
layer_map = { | |
nn.Linear: (qnn.QuantLinear, quant_linear_kwargs), | |
nn.Conv2d: (qnn.QuantConv2d, quant_layer_kwargs), | |
'diffusers.models.lora.LoRACompatibleLinear': | |
(LoRACompatibleQuantLinear, quant_linear_kwargs), | |
'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_layer_kwargs)} | |
pipe.unet = layerwise_quantize( | |
model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist) | |
print("Model quantization applied.") | |
pipe.set_progress_bar_config(disable=True) | |
if args.load_checkpoint is None: | |
print("Applying activation calibration") | |
with torch.no_grad(), calibration_mode(pipe.unet): | |
run_val_inference( | |
pipe, | |
calibration_prompts, | |
total_steps=args.calibration_steps, | |
test_latents=latents, | |
guidance_scale=args.guidance_scale) | |
print("Applying bias correction") | |
with torch.no_grad(), bias_correction_mode(pipe.unet): | |
run_val_inference( | |
pipe, | |
calibration_prompts, | |
total_steps=args.calibration_steps, | |
test_latents=latents, | |
guidance_scale=args.guidance_scale) | |
if args.checkpoint_name is not None: | |
torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name)) | |
# Perform inference | |
if args.validation_prompts > 0: | |
print(f"Computing validation accuracy") | |
compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.validation_prompts, output_dir) | |
if args.export_target: | |
pipe.unet.to('cpu').to(dtype) | |
export_quant_params(pipe, output_dir) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Stable Diffusion quantization') | |
parser.add_argument( | |
'-m', | |
'--model', | |
type=str, | |
default=None, | |
help='Path or name of the model.') | |
parser.add_argument( | |
'-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.') | |
parser.add_argument( | |
'--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt') | |
parser.add_argument( | |
'--calibration-prompts', | |
type=int, | |
default=500, | |
help='Number of prompts to use for calibration. Default: %(default)s') | |
parser.add_argument( | |
'--validation-prompts', | |
type=int, | |
default=0, | |
help='Number of prompt to use for validation. Default: %(default)s') | |
parser.add_argument( | |
'--path-to-coco', | |
type=str, | |
default=None, | |
help= | |
'Path to MLPerf compliant Coco dataset. Required when the --validation-prompts > 0 flag is set. Default: None' | |
) | |
parser.add_argument( | |
'--checkpoint-name', | |
type=str, | |
default=None, | |
help= | |
'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.' | |
) | |
parser.add_argument( | |
'--load-checkpoint', | |
type=str, | |
default=None, | |
help='Path to checkpoint to load. If provided, PTQ techniques are skipped.') | |
parser.add_argument( | |
'--path-to-latents', | |
type=str, | |
required=True, | |
help= | |
'Path to pre-defined latents.') | |
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.') | |
parser.add_argument( | |
'--calibration-steps', type=float, default=8, help='Steps used during calibration') | |
add_bool_arg( | |
parser, | |
'output-path', | |
str_true=True, | |
default='.', | |
help='Path where to generate output folder.') | |
parser.add_argument( | |
'--dtype', | |
default='float16', | |
choices=['float32', 'float16', 'bfloat16'], | |
help='Model Dtype, choices are float32, float16, bfloat16. Default: float16') | |
add_bool_arg( | |
parser, | |
'attention-slicing', | |
default=False, | |
help='Enable attention slicing. Default: Disabled') | |
add_bool_arg( | |
parser, | |
'export-target', | |
default=True, | |
help='Export flow.') | |
parser.add_argument( | |
'--act-eq-alpha', | |
type=float, | |
default=0.9, | |
help='Alpha for activation equalization. Default: 0.9') | |
add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled') | |
add_bool_arg( | |
parser, | |
'exclude-blacklist-act-eq', | |
default=False, | |
help='Exclude unquantized layers from activation equalization. Default: Disabled') | |
args = parser.parse_args() | |
print("Args: " + str(vars(args))) | |
main(args) | |