nickfraser
commited on
Commit
•
34b0078
1
Parent(s):
62e0e7b
Feat (script): Added option to validate on MLPerf validation set & to load a pre-quantized checkpoint.
Browse files- quant_sdxl/quant_sdxl.py +62 -24
quant_sdxl/quant_sdxl.py
CHANGED
@@ -32,11 +32,12 @@ from brevitas.graph.quantize import layerwise_quantize
|
|
32 |
from brevitas.inject.enum import StatsOp
|
33 |
from brevitas.nn.equalized_layer import EqualizedModule
|
34 |
from brevitas.utils.torch_utils import KwargsForwardHook
|
|
|
35 |
|
36 |
from brevitas_examples.common.parse_utils import add_bool_arg
|
37 |
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
|
38 |
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
|
39 |
-
|
40 |
|
41 |
TEST_SEED = 123456
|
42 |
torch.manual_seed(TEST_SEED)
|
@@ -125,6 +126,20 @@ def main(args):
|
|
125 |
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
|
126 |
|
127 |
pipe.set_progress_bar_config(disable=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
with activation_equalization_mode(
|
129 |
pipe.unet,
|
130 |
alpha=args.act_eq_alpha,
|
@@ -138,7 +153,7 @@ def main(args):
|
|
138 |
total_steps = args.calibration_steps
|
139 |
run_val_inference(
|
140 |
pipe,
|
141 |
-
calibration_prompts,
|
142 |
total_steps=total_steps,
|
143 |
test_latents=latents,
|
144 |
guidance_scale=args.guidance_scale)
|
@@ -186,26 +201,32 @@ def main(args):
|
|
186 |
|
187 |
pipe.set_progress_bar_config(disable=True)
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
if args.export_target:
|
211 |
pipe.unet.to('cpu').to(dtype)
|
@@ -229,6 +250,18 @@ if __name__ == "__main__":
|
|
229 |
type=int,
|
230 |
default=500,
|
231 |
help='Number of prompts to use for calibration. Default: %(default)s')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
parser.add_argument(
|
233 |
'--checkpoint-name',
|
234 |
type=str,
|
@@ -237,11 +270,16 @@ if __name__ == "__main__":
|
|
237 |
'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
|
238 |
)
|
239 |
parser.add_argument(
|
240 |
-
'--
|
241 |
type=str,
|
242 |
default=None,
|
|
|
|
|
|
|
|
|
|
|
243 |
help=
|
244 |
-
'
|
245 |
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
|
246 |
parser.add_argument(
|
247 |
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
|
|
|
32 |
from brevitas.inject.enum import StatsOp
|
33 |
from brevitas.nn.equalized_layer import EqualizedModule
|
34 |
from brevitas.utils.torch_utils import KwargsForwardHook
|
35 |
+
import brevitas.config as config
|
36 |
|
37 |
from brevitas_examples.common.parse_utils import add_bool_arg
|
38 |
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
|
39 |
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
|
40 |
+
from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid
|
41 |
|
42 |
TEST_SEED = 123456
|
43 |
torch.manual_seed(TEST_SEED)
|
|
|
126 |
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
|
127 |
|
128 |
pipe.set_progress_bar_config(disable=True)
|
129 |
+
|
130 |
+
if args.load_checkpoint is not None:
|
131 |
+
with load_quant_model_mode(pipe.unet):
|
132 |
+
pipe = pipe.to('cpu')
|
133 |
+
print(f"Loading checkpoint: {args.load_checkpoint}... ", end="")
|
134 |
+
pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu'))
|
135 |
+
print(f"Checkpoint loaded!")
|
136 |
+
pipe = pipe.to(args.device)
|
137 |
+
|
138 |
+
if args.load_checkpoint is not None:
|
139 |
+
# Don't run full activation equalization if we're loading a quantized checkpoint
|
140 |
+
num_ae_prompts = 2
|
141 |
+
else:
|
142 |
+
num_ae_prompts = len(calibration_prompts)
|
143 |
with activation_equalization_mode(
|
144 |
pipe.unet,
|
145 |
alpha=args.act_eq_alpha,
|
|
|
153 |
total_steps = args.calibration_steps
|
154 |
run_val_inference(
|
155 |
pipe,
|
156 |
+
calibration_prompts[:num_ae_prompts],
|
157 |
total_steps=total_steps,
|
158 |
test_latents=latents,
|
159 |
guidance_scale=args.guidance_scale)
|
|
|
201 |
|
202 |
pipe.set_progress_bar_config(disable=True)
|
203 |
|
204 |
+
if args.load_checkpoint is None:
|
205 |
+
print("Applying activation calibration")
|
206 |
+
with torch.no_grad(), calibration_mode(pipe.unet):
|
207 |
+
run_val_inference(
|
208 |
+
pipe,
|
209 |
+
calibration_prompts,
|
210 |
+
total_steps=args.calibration_steps,
|
211 |
+
test_latents=latents,
|
212 |
+
guidance_scale=args.guidance_scale)
|
213 |
+
|
214 |
+
print("Applying bias correction")
|
215 |
+
with torch.no_grad(), bias_correction_mode(pipe.unet):
|
216 |
+
run_val_inference(
|
217 |
+
pipe,
|
218 |
+
calibration_prompts,
|
219 |
+
total_steps=args.calibration_steps,
|
220 |
+
test_latents=latents,
|
221 |
+
guidance_scale=args.guidance_scale)
|
222 |
+
|
223 |
+
if args.checkpoint_name is not None:
|
224 |
+
torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name))
|
225 |
+
|
226 |
+
# Perform inference
|
227 |
+
if args.validation_prompts > 0:
|
228 |
+
print(f"Computing validation accuracy")
|
229 |
+
compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.validation_prompts, output_dir)
|
230 |
|
231 |
if args.export_target:
|
232 |
pipe.unet.to('cpu').to(dtype)
|
|
|
250 |
type=int,
|
251 |
default=500,
|
252 |
help='Number of prompts to use for calibration. Default: %(default)s')
|
253 |
+
parser.add_argument(
|
254 |
+
'--validation-prompts',
|
255 |
+
type=int,
|
256 |
+
default=0,
|
257 |
+
help='Number of prompt to use for validation. Default: %(default)s')
|
258 |
+
parser.add_argument(
|
259 |
+
'--path-to-coco',
|
260 |
+
type=str,
|
261 |
+
default=None,
|
262 |
+
help=
|
263 |
+
'Path to MLPerf compliant Coco dataset. Required when the --validation-prompts > 0 flag is set. Default: None'
|
264 |
+
)
|
265 |
parser.add_argument(
|
266 |
'--checkpoint-name',
|
267 |
type=str,
|
|
|
270 |
'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
|
271 |
)
|
272 |
parser.add_argument(
|
273 |
+
'--load-checkpoint',
|
274 |
type=str,
|
275 |
default=None,
|
276 |
+
help='Path to checkpoint to load. If provided, PTQ techniques are skipped.')
|
277 |
+
parser.add_argument(
|
278 |
+
'--path-to-latents',
|
279 |
+
type=str,
|
280 |
+
required=True,
|
281 |
help=
|
282 |
+
'Path to pre-defined latents.')
|
283 |
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
|
284 |
parser.add_argument(
|
285 |
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
|