nickfraser commited on
Commit
34b0078
1 Parent(s): 62e0e7b

Feat (script): Added option to validate on MLPerf validation set & to load a pre-quantized checkpoint.

Browse files
Files changed (1) hide show
  1. quant_sdxl/quant_sdxl.py +62 -24
quant_sdxl/quant_sdxl.py CHANGED
@@ -32,11 +32,12 @@ from brevitas.graph.quantize import layerwise_quantize
32
  from brevitas.inject.enum import StatsOp
33
  from brevitas.nn.equalized_layer import EqualizedModule
34
  from brevitas.utils.torch_utils import KwargsForwardHook
 
35
 
36
  from brevitas_examples.common.parse_utils import add_bool_arg
37
  from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
38
  from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
39
- import brevitas.config as config
40
 
41
  TEST_SEED = 123456
42
  torch.manual_seed(TEST_SEED)
@@ -125,6 +126,20 @@ def main(args):
125
  raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
126
 
127
  pipe.set_progress_bar_config(disable=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  with activation_equalization_mode(
129
  pipe.unet,
130
  alpha=args.act_eq_alpha,
@@ -138,7 +153,7 @@ def main(args):
138
  total_steps = args.calibration_steps
139
  run_val_inference(
140
  pipe,
141
- calibration_prompts,
142
  total_steps=total_steps,
143
  test_latents=latents,
144
  guidance_scale=args.guidance_scale)
@@ -186,26 +201,32 @@ def main(args):
186
 
187
  pipe.set_progress_bar_config(disable=True)
188
 
189
- print("Applying activation calibration")
190
- with torch.no_grad(), calibration_mode(pipe.unet):
191
- run_val_inference(
192
- pipe,
193
- calibration_prompts,
194
- total_steps=args.calibration_steps,
195
- test_latents=latents,
196
- guidance_scale=args.guidance_scale)
197
-
198
- print("Applying bias correction")
199
- with torch.no_grad(), bias_correction_mode(pipe.unet):
200
- run_val_inference(
201
- pipe,
202
- calibration_prompts,
203
- total_steps=args.calibration_steps,
204
- test_latents=latents,
205
- guidance_scale=args.guidance_scale)
206
-
207
- if args.checkpoint_name is not None:
208
- torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name))
 
 
 
 
 
 
209
 
210
  if args.export_target:
211
  pipe.unet.to('cpu').to(dtype)
@@ -229,6 +250,18 @@ if __name__ == "__main__":
229
  type=int,
230
  default=500,
231
  help='Number of prompts to use for calibration. Default: %(default)s')
 
 
 
 
 
 
 
 
 
 
 
 
232
  parser.add_argument(
233
  '--checkpoint-name',
234
  type=str,
@@ -237,11 +270,16 @@ if __name__ == "__main__":
237
  'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
238
  )
239
  parser.add_argument(
240
- '--path-to-latents',
241
  type=str,
242
  default=None,
 
 
 
 
 
243
  help=
244
- 'Load pre-defined latents. If not provided, they are generated based on an internal seed.')
245
  parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
246
  parser.add_argument(
247
  '--calibration-steps', type=float, default=8, help='Steps used during calibration')
 
32
  from brevitas.inject.enum import StatsOp
33
  from brevitas.nn.equalized_layer import EqualizedModule
34
  from brevitas.utils.torch_utils import KwargsForwardHook
35
+ import brevitas.config as config
36
 
37
  from brevitas_examples.common.parse_utils import add_bool_arg
38
  from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
39
  from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
40
+ from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid
41
 
42
  TEST_SEED = 123456
43
  torch.manual_seed(TEST_SEED)
 
126
  raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
127
 
128
  pipe.set_progress_bar_config(disable=True)
129
+
130
+ if args.load_checkpoint is not None:
131
+ with load_quant_model_mode(pipe.unet):
132
+ pipe = pipe.to('cpu')
133
+ print(f"Loading checkpoint: {args.load_checkpoint}... ", end="")
134
+ pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu'))
135
+ print(f"Checkpoint loaded!")
136
+ pipe = pipe.to(args.device)
137
+
138
+ if args.load_checkpoint is not None:
139
+ # Don't run full activation equalization if we're loading a quantized checkpoint
140
+ num_ae_prompts = 2
141
+ else:
142
+ num_ae_prompts = len(calibration_prompts)
143
  with activation_equalization_mode(
144
  pipe.unet,
145
  alpha=args.act_eq_alpha,
 
153
  total_steps = args.calibration_steps
154
  run_val_inference(
155
  pipe,
156
+ calibration_prompts[:num_ae_prompts],
157
  total_steps=total_steps,
158
  test_latents=latents,
159
  guidance_scale=args.guidance_scale)
 
201
 
202
  pipe.set_progress_bar_config(disable=True)
203
 
204
+ if args.load_checkpoint is None:
205
+ print("Applying activation calibration")
206
+ with torch.no_grad(), calibration_mode(pipe.unet):
207
+ run_val_inference(
208
+ pipe,
209
+ calibration_prompts,
210
+ total_steps=args.calibration_steps,
211
+ test_latents=latents,
212
+ guidance_scale=args.guidance_scale)
213
+
214
+ print("Applying bias correction")
215
+ with torch.no_grad(), bias_correction_mode(pipe.unet):
216
+ run_val_inference(
217
+ pipe,
218
+ calibration_prompts,
219
+ total_steps=args.calibration_steps,
220
+ test_latents=latents,
221
+ guidance_scale=args.guidance_scale)
222
+
223
+ if args.checkpoint_name is not None:
224
+ torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name))
225
+
226
+ # Perform inference
227
+ if args.validation_prompts > 0:
228
+ print(f"Computing validation accuracy")
229
+ compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.validation_prompts, output_dir)
230
 
231
  if args.export_target:
232
  pipe.unet.to('cpu').to(dtype)
 
250
  type=int,
251
  default=500,
252
  help='Number of prompts to use for calibration. Default: %(default)s')
253
+ parser.add_argument(
254
+ '--validation-prompts',
255
+ type=int,
256
+ default=0,
257
+ help='Number of prompt to use for validation. Default: %(default)s')
258
+ parser.add_argument(
259
+ '--path-to-coco',
260
+ type=str,
261
+ default=None,
262
+ help=
263
+ 'Path to MLPerf compliant Coco dataset. Required when the --validation-prompts > 0 flag is set. Default: None'
264
+ )
265
  parser.add_argument(
266
  '--checkpoint-name',
267
  type=str,
 
270
  'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
271
  )
272
  parser.add_argument(
273
+ '--load-checkpoint',
274
  type=str,
275
  default=None,
276
+ help='Path to checkpoint to load. If provided, PTQ techniques are skipped.')
277
+ parser.add_argument(
278
+ '--path-to-latents',
279
+ type=str,
280
+ required=True,
281
  help=
282
+ 'Path to pre-defined latents.')
283
  parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
284
  parser.add_argument(
285
  '--calibration-steps', type=float, default=8, help='Steps used during calibration')