File size: 12,394 Bytes
ecec5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72eb84b
ecec5b7
 
 
 
72eb84b
ecec5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72eb84b
 
ecec5b7
 
 
 
 
 
 
72eb84b
ecec5b7
 
 
 
 
 
 
7f81513
 
 
 
ecec5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72eb84b
 
 
 
 
 
 
 
 
 
 
 
 
 
ecec5b7
 
 
 
 
 
 
 
 
 
 
 
 
72eb84b
ecec5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72eb84b
ecec5b7
 
 
 
 
 
 
 
72eb84b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecec5b7
 
72eb84b
ecec5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72eb84b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecec5b7
 
 
 
 
 
 
 
72eb84b
ecec5b7
 
72eb84b
 
 
 
 
ecec5b7
72eb84b
ecec5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
SPDX-License-Identifier: MIT
"""

import argparse
import copy
from datetime import datetime
import json
import os
import time

from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d, LoRACompatibleQuantLinear
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.attention_processor import AttnProcessor
import pandas as pd
import torch
from torch import nn
from tqdm import tqdm
import brevitas.nn as qnn

from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.inject.enum import StatsOp
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.utils.torch_utils import KwargsForwardHook
import brevitas.config as config

from brevitas_examples.common.parse_utils import add_bool_arg
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid

TEST_SEED = 123456
torch.manual_seed(TEST_SEED)

class WeightQuant(ShiftedUint8WeightPerChannelFloat):
    narrow_range = False
    scaling_min_val = 1e-4
    quantize_zero_point = True
    scaling_impl_type = 'parameter_from_stats'
    zero_point_impl = ParameterFromStatsFromParameterZeroPoint

class InputQuant(Int8ActPerTensorFloat):
    scaling_stats_op = StatsOp.MAX

class OutputQuant(Fp8e4m3FNUZActPerTensorFloat):
    scaling_stats_op = StatsOp.MAX

NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"]

def load_calib_prompts(calib_data_path, sep="\t"):
    df = pd.read_csv(calib_data_path, sep=sep)
    lst = df["caption"].tolist()
    return lst

def run_val_inference(
        pipe,
        prompts,
        guidance_scale,
        total_steps,
        test_latents=None):
    with torch.no_grad():
        for prompt in tqdm(prompts):
            # We don't want to generate any image, so we return only the latent encoding pre VAE
            pipe(
                prompt,
                negative_prompt=NEGATIVE_PROMPTS[0],
                latents=test_latents,
                output_type='latent',
                guidance_scale=guidance_scale,
                num_inference_steps=total_steps)


def main(args):

    dtype = getattr(torch, args.dtype)

    calibration_prompts = load_calib_prompts(args.calibration_prompt_path)
    assert args.calibration_prompts <= len(calibration_prompts) , f"--calibration-prompts must be <= {len(calibration_prompts)}"
    calibration_prompts = calibration_prompts[:args.calibration_prompts]
    latents = torch.load(args.path_to_latents).to(torch.float16)

    # Create output dir. Move to tmp if None
    ts = datetime.fromtimestamp(time.time())
    str_ts = ts.strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(args.output_path, f'{str_ts}')
    os.mkdir(output_dir)
    print(f"Saving results in {output_dir}")

    # Dump args to json
    with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
        json.dump(vars(args), fp)

    # Load model from float checkpoint
    print(f"Loading model from {args.model}...")
    variant = 'fp16' if dtype == torch.float16 else None
    pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe.vae.config.force_upcast=True
    print(f"Model loaded from {args.model}.")

    # Move model to target device
    print(f"Moving model to {args.device}...")
    pipe = pipe.to(args.device)

    # Enable attention slicing
    if args.attention_slicing:
        pipe.enable_attention_slicing()

    # Extract list of layers to avoid
    blacklist = []
    for name, _ in pipe.unet.named_modules():
        if 'time_emb' in name:
            blacklist.append(name.split('.')[-1])
    print(f"Blacklisted layers: {blacklist}")

    # Make sure there all LoRA layers are fused first, otherwise raise an error
    for m in pipe.unet.modules():
        if hasattr(m, 'lora_layer') and m.lora_layer is not None:
            raise RuntimeError("LoRA layers should be fused in before calling into quantization.")

    pipe.set_progress_bar_config(disable=True)

    if args.load_checkpoint is not None:
        with load_quant_model_mode(pipe.unet):
            pipe = pipe.to('cpu')
            print(f"Loading checkpoint: {args.load_checkpoint}... ", end="")
            pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu'))
            print(f"Checkpoint loaded!")
        pipe = pipe.to(args.device)

    if args.load_checkpoint is not None:
        # Don't run full activation equalization if we're loading a quantized checkpoint
        num_ae_prompts = 2
    else:
        num_ae_prompts = len(calibration_prompts)
    with activation_equalization_mode(
            pipe.unet,
            alpha=args.act_eq_alpha,
            layerwise=True,
            blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None,
            add_mul_node=True):
        # Workaround to expose `in_features` attribute from the Hook Wrapper
        for m in pipe.unet.modules():
            if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
                m.in_features = m.module.in_features
        total_steps = args.calibration_steps
        run_val_inference(
            pipe,
            calibration_prompts[:num_ae_prompts],
            total_steps=total_steps,
            test_latents=latents,
            guidance_scale=args.guidance_scale)

    # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper
    for m in pipe.unet.modules():
        if isinstance(m, EqualizedModule) and hasattr(m.layer, 'in_features'):
            m.in_features = m.layer.in_features


    quant_layer_kwargs = {
    'input_quant': InputQuant, 'weight_quant': WeightQuant, 'dtype': dtype, 'device': args.device, 'input_dtype': dtype, 'input_device': args.device}
    quant_linear_kwargs = copy.deepcopy(quant_layer_kwargs)
    if args.quantize_sdp:
        output_quant = OutputQuant
        rewriter = ModuleToModuleByClass(
            Attention,
            QuantAttention,
            softmax_output_quant=output_quant,
            query_dim=lambda module: module.to_q.in_features,
            dim_head=lambda module: int(1 / (module.scale ** 2)),
            processor=AttnProcessor(),
        is_equalized=True)
        config.IGNORE_MISSING_KEYS = True
        pipe.unet = rewriter.apply(pipe.unet)
        config.IGNORE_MISSING_KEYS = False
        pipe.unet = pipe.unet.to(args.device)
        pipe.unet = pipe.unet.to(dtype)
        # quant_kwargs = layer_map[torch.nn.Linear][1]
        what_to_quantize = ['to_q', 'to_k', 'to_v']
        quant_linear_kwargs['output_quant'] = lambda module, name: output_quant if any(ending in name for ending in what_to_quantize) else None
        quant_linear_kwargs['output_dtype'] = dtype
        quant_linear_kwargs['output_device'] = args.device

    layer_map = {
    nn.Linear: (qnn.QuantLinear, quant_linear_kwargs),
    nn.Conv2d: (qnn.QuantConv2d, quant_layer_kwargs),
    'diffusers.models.lora.LoRACompatibleLinear':
        (LoRACompatibleQuantLinear, quant_linear_kwargs),
    'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_layer_kwargs)}

    pipe.unet = layerwise_quantize(
        model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist)
    print("Model quantization applied.")

    pipe.set_progress_bar_config(disable=True)

    if args.load_checkpoint is None:
        print("Applying activation calibration")
        with torch.no_grad(), calibration_mode(pipe.unet):
            run_val_inference(
                pipe,
                calibration_prompts,
                total_steps=args.calibration_steps,
                test_latents=latents,
                guidance_scale=args.guidance_scale)

        print("Applying bias correction")
        with torch.no_grad(), bias_correction_mode(pipe.unet):
            run_val_inference(
                pipe,
                calibration_prompts,
                total_steps=args.calibration_steps,
                test_latents=latents,
                guidance_scale=args.guidance_scale)

        if args.checkpoint_name is not None:
            torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name))

    # Perform inference
    if args.validation_prompts > 0:
        print(f"Computing validation accuracy")
        compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.validation_prompts, output_dir)

    if args.export_target:
        pipe.unet.to('cpu').to(dtype)
        export_quant_params(pipe, output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Stable Diffusion quantization')
    parser.add_argument(
        '-m',
        '--model',
        type=str,
        default=None,
        help='Path or name of the model.')
    parser.add_argument(
        '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.')
    parser.add_argument(
        '--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt')
    parser.add_argument(
        '--calibration-prompts',
        type=int,
        default=500,
        help='Number of prompts to use for calibration. Default: %(default)s')
    parser.add_argument(
        '--validation-prompts',
        type=int,
        default=0,
        help='Number of prompt to use for validation. Default: %(default)s')
    parser.add_argument(
        '--path-to-coco',
        type=str,
        default=None,
        help=
        'Path to MLPerf compliant Coco dataset. Required when the --validation-prompts > 0 flag is set. Default: None'
    )
    parser.add_argument(
        '--checkpoint-name',
        type=str,
        default=None,
        help=
        'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
    )
    parser.add_argument(
        '--load-checkpoint',
        type=str,
        default=None,
        help='Path to checkpoint to load. If provided, PTQ techniques are skipped.')
    parser.add_argument(
        '--path-to-latents',
        type=str,
        required=True,
        help=
        'Path to pre-defined latents.')
    parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
    parser.add_argument(
        '--calibration-steps', type=float, default=8, help='Steps used during calibration')
    add_bool_arg(
        parser,
        'output-path',
        str_true=True,
        default='.',
        help='Path where to generate output folder.')
    parser.add_argument(
        '--dtype',
        default='float16',
        choices=['float32', 'float16', 'bfloat16'],
        help='Model Dtype, choices are float32, float16, bfloat16. Default: float16')
    add_bool_arg(
        parser,
        'attention-slicing',
        default=False,
        help='Enable attention slicing. Default: Disabled')
    add_bool_arg(
        parser,
        'export-target',
        default=True,
        help='Export flow.')
    parser.add_argument(
        '--act-eq-alpha',
        type=float,
        default=0.9,
        help='Alpha for activation equalization. Default: 0.9')
    add_bool_arg(parser, 'quantize-sdp', default=False, help='Quantize SDP. Default: Disabled')
    add_bool_arg(
        parser,
        'exclude-blacklist-act-eq',
        default=False,
        help='Exclude unquantized layers from activation equalization. Default: Disabled')
    args = parser.parse_args()
    print("Args: " + str(vars(args)))
    main(args)