File size: 5,507 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from fairseq.modules.quantization import pq, quantization_options, scalar
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
def quantize_model_scalar(model, model_cfg: DictConfig):
quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0
if quant_noise_scalar > 0:
# quantize_model edits the model in place
scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)
return model
class Quantizer(object):
def __init__(self, config_path, max_epoch, max_update):
try:
import yaml
except ImportError:
raise ImportError("Please install yaml with: pip install yaml")
# parse config
if config_path:
with open(config_path) as config_file:
config = quantization_options.parse_config_yaml(
yaml.safe_load(config_file)
)
else:
config = quantization_options.parse_config_yaml({})
self.n_centroids_config = config["n_centroids"]
self.block_sizes_config = config["block_sizes"]
self.layers_to_quantize = config["layers_to_quantize"]
# We assume that training will run for a fixed number of epochs
# (or updates) and that we should train for equal durations
# between iterations of PQ.
num_iterations = len(self.layers_to_quantize)
if max_epoch > 0:
assert max_epoch % num_iterations == 0, (
"for iterative PQ, --max-epoch (={}) must be evenly divisible by "
"len(layers_to_quantize) (={})".format(max_epoch, num_iterations)
)
self.epoch_schedule = max_epoch // num_iterations
else:
self.epoch_schedule = None
if max_update > 0:
assert max_update % num_iterations == 0, (
"for iterative PQ, --max-update (={}) must be evenly divisible by "
"len(layers_to_quantize) (={})".format(max_update, num_iterations)
)
self.update_schedule = max_update // num_iterations
else:
self.update_schedule = None
assert (self.epoch_schedule is not None) ^ (
self.update_schedule is not None
), "for iterative PQ, cannot specify both --max-update and --max-epoch"
# 0 is a special value for quantization step, which will force
# the first call to begin_epoch() to call step()
self.quantization_step = 0
def set_trainer(self, trainer):
self.trainer = trainer
self.size_tracker = pq.SizeTracker(self.trainer.get_model())
def step(self):
"""Move to the next stage of quantization."""
if self.quantization_step >= len(self.layers_to_quantize):
# Maybe we just finished the last training step or we loaded
# a checkpoint for an iterative PQ model which previously
# finished training. Either way, don't quantize again.
return
logger.info(
"quantizing model (step={}; layers_to_quantize[step]={})".format(
self.quantization_step, self.layers_to_quantize[self.quantization_step]
)
)
quantized_layers = pq.quantize_model_(
self.trainer.get_model(),
self.size_tracker,
self.layers_to_quantize,
self.block_sizes_config,
self.n_centroids_config,
step=self.quantization_step,
)
logger.info("quantized layers: {}".format(quantized_layers))
logger.info(self.size_tracker)
self.quantization_step += 1
# reintialize the Trainer since model parameters have changed
self.trainer.reinitialize()
def begin_epoch(self, epoch):
"""Called at the beginning of each epoch (epochs start at 1)."""
if (
(
self.epoch_schedule is not None
and epoch > 0
and (epoch - 1) % self.epoch_schedule == 0
)
# we always step once in the beginning, even if using
# update-based quantization
or self.quantization_step == 0
):
self.step()
def step_update(self, num_updates):
"""Called at the end of each step."""
if (
self.update_schedule is not None
and num_updates > 0
and num_updates % self.update_schedule == 0
):
self.step()
def state_dict(self):
return {
"n_centroids_config": self.n_centroids_config,
"block_sizes_config": self.block_sizes_config,
"layers_to_quantize": self.layers_to_quantize,
"epoch_schedule": self.epoch_schedule,
"update_schedule": self.update_schedule,
"quantization_step": self.quantization_step,
}
def load_state_dict(self, state_dict):
self.n_centroids_config = state_dict["n_centroids_config"]
self.block_sizes_config = state_dict["block_sizes_config"]
self.layers_to_quantize = state_dict["layers_to_quantize"]
self.epoch_schedule = state_dict["epoch_schedule"]
self.update_schedule = state_dict["update_schedule"]
self.quantization_step = state_dict["quantization_step"]
|