|
""" |
|
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. |
|
SPDX-License-Identifier: MIT |
|
""" |
|
|
|
import argparse |
|
import copy |
|
from datetime import datetime |
|
import json |
|
import os |
|
import time |
|
|
|
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint |
|
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat |
|
from brevitas.quant.scaled_int import Int8ActPerTensorFloat |
|
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat |
|
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d, LoRACompatibleQuantLinear |
|
from diffusers import DiffusionPipeline |
|
from diffusers.models.attention_processor import Attention |
|
from diffusers.models.attention_processor import AttnProcessor |
|
import pandas as pd |
|
import torch |
|
from torch import nn |
|
from tqdm import tqdm |
|
import brevitas.nn as qnn |
|
|
|
from brevitas.graph.base import ModuleToModuleByClass |
|
from brevitas.graph.calibrate import bias_correction_mode |
|
from brevitas.graph.calibrate import calibration_mode |
|
from brevitas.graph.equalize import activation_equalization_mode |
|
from brevitas.graph.quantize import layerwise_quantize |
|
from brevitas.inject.enum import StatsOp |
|
from brevitas.nn.equalized_layer import EqualizedModule |
|
from brevitas.utils.torch_utils import KwargsForwardHook |
|
|
|
from brevitas_examples.common.parse_utils import add_bool_arg |
|
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params |
|
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention |
|
import brevitas.config as config |
|
|
|
TEST_SEED = 123456 |
|
torch.manual_seed(TEST_SEED) |
|
|
|
class WeightQuant(ShiftedUint8WeightPerChannelFloat): |
|
narrow_range = False |
|
scaling_min_val = 1e-4 |
|
quantize_zero_point = True |
|
scaling_impl_type = 'parameter_from_stats' |
|
zero_point_impl = ParameterFromStatsFromParameterZeroPoint |
|
|
|
class InputQuant(Int8ActPerTensorFloat): |
|
scaling_stats_op = StatsOp.MAX |
|
|
|
class OutputQuant(Fp8e4m3FNUZActPerTensorFloat): |
|
scaling_stats_op = StatsOp.MAX |
|
|
|
NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] |
|
|
|
def load_calib_prompts(calib_data_path, sep="\t"): |
|
df = pd.read_csv(calib_data_path, sep=sep) |
|
lst = df["caption"].tolist() |
|
return lst |
|
|
|
def run_val_inference( |
|
pipe, |
|
prompts, |
|
guidance_scale, |
|
total_steps, |
|
test_latents=None): |
|
with torch.no_grad(): |
|
for prompt in tqdm(prompts): |
|
|
|
pipe( |
|
prompt, |
|
negative_prompt=NEGATIVE_PROMPTS[0], |
|
latents=test_latents, |
|
output_type='latent', |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=total_steps) |
|
|
|
|
|
def main(args): |
|
|
|
dtype = getattr(torch, args.dtype) |
|
|
|
calibration_prompts = load_calib_prompts(args.calibration_prompt_path) |
|
latents = torch.load(args.path_to_latents).to(torch.float16) |
|
|
|
|
|
ts = datetime.fromtimestamp(time.time()) |
|
str_ts = ts.strftime("%Y%m%d_%H%M%S") |
|
output_dir = os.path.join(args.output_path, f'{str_ts}') |
|
os.mkdir(output_dir) |
|
|
|
|
|
with open(os.path.join(output_dir, 'args.json'), 'w') as fp: |
|
json.dump(vars(args), fp) |
|
|
|
|
|
print(f"Loading model from {args.model}...") |
|
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) |
|
print(f"Model loaded from {args.model}.") |
|
|
|
|
|
print(f"Moving model to {args.device}...") |
|
pipe = pipe.to(args.device) |
|
|
|
|
|
if args.attention_slicing: |
|
pipe.enable_attention_slicing() |
|
|
|
|
|
blacklist = [] |
|
for name, _ in pipe.unet.named_modules(): |
|
if 'time_emb' in name: |
|
blacklist.append(name.split('.')[-1]) |
|
print(f"Blacklisted layers: {blacklist}") |
|
|
|
|
|
for m in pipe.unet.modules(): |
|
if hasattr(m, 'lora_layer') and m.lora_layer is not None: |
|
raise RuntimeError("LoRA layers should be fused in before calling into quantization.") |
|
|
|
pipe.set_progress_bar_config(disable=True) |
|
with activation_equalization_mode( |
|
pipe.unet, |
|
alpha=args.act_eq_alpha, |
|
layerwise=True, |
|
blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None, |
|
add_mul_node=True): |
|
|
|
for m in pipe.unet.modules(): |
|
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): |
|
m.in_features = m.module.in_features |
|
total_steps = args.calibration_steps |
|
run_val_inference( |
|
pipe, |
|
calibration_prompts, |
|
total_steps=total_steps, |
|
test_latents=latents, |
|
guidance_scale=args.guidance_scale) |
|
|
|
|
|
for m in pipe.unet.modules(): |
|
if isinstance(m, EqualizedModule) and hasattr(m.layer, 'in_features'): |
|
m.in_features = m.layer.in_features |
|
|
|
|
|
quant_layer_kwargs = { |
|
'input_quant': InputQuant, 'weight_quant': WeightQuant, 'dtype': dtype, 'device': args.device, 'input_dtype': dtype, 'input_device': args.device} |
|
quant_linear_kwargs = copy.deepcopy(quant_layer_kwargs) |
|
if args.quantize_sdp: |
|
output_quant = OutputQuant |
|
rewriter = ModuleToModuleByClass( |
|
Attention, |
|
QuantAttention, |
|
softmax_output_quant=output_quant, |
|
query_dim=lambda module: module.to_q.in_features, |
|
dim_head=lambda module: int(1 / (module.scale ** 2)), |
|
processor=AttnProcessor(), |
|
is_equalized=True) |
|
config.IGNORE_MISSING_KEYS = True |
|
pipe.unet = rewriter.apply(pipe.unet) |
|
config.IGNORE_MISSING_KEYS = False |
|
pipe.unet = pipe.unet.to(args.device) |
|
pipe.unet = pipe.unet.to(dtype) |
|
|
|
what_to_quantize = ['to_q', 'to_k', 'to_v'] |
|
quant_linear_kwargs['output_quant'] = lambda module, name: output_quant if any(ending in name for ending in what_to_quantize) else None |
|
quant_linear_kwargs['output_dtype'] = dtype |
|
quant_linear_kwargs['output_device'] = args.device |
|
|
|
layer_map = { |
|
nn.Linear: (qnn.QuantLinear, quant_linear_kwargs), |
|
nn.Conv2d: (qnn.QuantConv2d, quant_layer_kwargs), |
|
'diffusers.models.lora.LoRACompatibleLinear': |
|
(LoRACompatibleQuantLinear, quant_layer_kwargs), |
|
'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_layer_kwargs)} |
|
|
|
pipe.unet = layerwise_quantize( |
|
model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist) |
|
print("Model quantization applied.") |
|
|
|
pipe.set_progress_bar_config(disable=True) |
|
|
|
print("Applying activation calibration") |
|
with torch.no_grad(), calibration_mode(pipe.unet): |
|
run_val_inference( |
|
pipe, |
|
calibration_prompts, |
|
total_steps=args.calibration_steps, |
|
test_latents=latents, |
|
guidance_scale=args.guidance_scale) |
|
|
|
print("Applying bias correction") |
|
with torch.no_grad(), bias_correction_mode(pipe.unet): |
|
run_val_inference( |
|
pipe, |
|
calibration_prompts, |
|
total_steps=args.calibration_steps, |
|
test_latents=latents, |
|
guidance_scale=args.guidance_scale) |
|
|
|
if args.checkpoint_name is not None: |
|
torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name)) |
|
|
|
if args.export_target: |
|
pipe.unet.to('cpu').to(dtype) |
|
export_quant_params(pipe, output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Stable Diffusion quantization') |
|
parser.add_argument( |
|
'-m', |
|
'--model', |
|
type=str, |
|
default=None, |
|
help='Path or name of the model.') |
|
parser.add_argument( |
|
'-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.') |
|
parser.add_argument( |
|
'--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt') |
|
parser.add_argument( |
|
'--checkpoint-name', |
|
type=str, |
|
default=None, |
|
help= |
|
'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.' |
|
) |
|
parser.add_argument( |
|
'--path-to-latents', |
|
type=str, |
|
default=None, |
|
help= |
|
'Load pre-defined latents. If not provided, they are generated based on an internal seed.') |
|
|
|
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.') |
|
parser.add_argument( |
|
'--calibration-steps', type=float, default=8, help='Steps used during calibration') |
|
add_bool_arg( |
|
parser, |
|
'output-path', |
|
str_true=True, |
|
default='.', |
|
help='Path where to generate output folder.') |
|
parser.add_argument( |
|
'--dtype', |
|
default='float16', |
|
choices=['float32', 'float16', 'bfloat16'], |
|
help='Model Dtype, choices are float32, float16, bfloat16. Default: float16') |
|
add_bool_arg( |
|
parser, |
|
'attention-slicing', |
|
default=False, |
|
help='Enable attention slicing. Default: Disabled') |
|
add_bool_arg( |
|
parser, |
|
'export-target', |
|
default=True, |
|
help='Export flow.') |
|
parser.add_argument( |
|
'--act-eq-alpha', |
|
type=float, |
|
default=0.9, |
|
help='Alpha for activation equalization. Default: 0.9') |
|
add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled') |
|
add_bool_arg( |
|
parser, |
|
'exclude-blacklist-act-eq', |
|
default=False, |
|
help='Exclude unquantized layers from activation equalization. Default: Disabled') |
|
args = parser.parse_args() |
|
print("Args: " + str(vars(args))) |
|
main(args) |
|
|