|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
from fairseq.modules.quantization import pq, quantization_options, scalar |
|
from omegaconf import DictConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def quantize_model_scalar(model, model_cfg: DictConfig): |
|
quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0 |
|
if quant_noise_scalar > 0: |
|
|
|
scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) |
|
return model |
|
|
|
|
|
class Quantizer(object): |
|
def __init__(self, config_path, max_epoch, max_update): |
|
try: |
|
import yaml |
|
except ImportError: |
|
raise ImportError("Please install yaml with: pip install yaml") |
|
|
|
|
|
if config_path: |
|
with open(config_path) as config_file: |
|
config = quantization_options.parse_config_yaml( |
|
yaml.safe_load(config_file) |
|
) |
|
else: |
|
config = quantization_options.parse_config_yaml({}) |
|
|
|
self.n_centroids_config = config["n_centroids"] |
|
self.block_sizes_config = config["block_sizes"] |
|
self.layers_to_quantize = config["layers_to_quantize"] |
|
|
|
|
|
|
|
|
|
num_iterations = len(self.layers_to_quantize) |
|
if max_epoch > 0: |
|
assert max_epoch % num_iterations == 0, ( |
|
"for iterative PQ, --max-epoch (={}) must be evenly divisible by " |
|
"len(layers_to_quantize) (={})".format(max_epoch, num_iterations) |
|
) |
|
self.epoch_schedule = max_epoch // num_iterations |
|
else: |
|
self.epoch_schedule = None |
|
if max_update > 0: |
|
assert max_update % num_iterations == 0, ( |
|
"for iterative PQ, --max-update (={}) must be evenly divisible by " |
|
"len(layers_to_quantize) (={})".format(max_update, num_iterations) |
|
) |
|
self.update_schedule = max_update // num_iterations |
|
else: |
|
self.update_schedule = None |
|
assert (self.epoch_schedule is not None) ^ ( |
|
self.update_schedule is not None |
|
), "for iterative PQ, cannot specify both --max-update and --max-epoch" |
|
|
|
|
|
|
|
self.quantization_step = 0 |
|
|
|
def set_trainer(self, trainer): |
|
self.trainer = trainer |
|
self.size_tracker = pq.SizeTracker(self.trainer.get_model()) |
|
|
|
def step(self): |
|
"""Move to the next stage of quantization.""" |
|
if self.quantization_step >= len(self.layers_to_quantize): |
|
|
|
|
|
|
|
return |
|
|
|
logger.info( |
|
"quantizing model (step={}; layers_to_quantize[step]={})".format( |
|
self.quantization_step, self.layers_to_quantize[self.quantization_step] |
|
) |
|
) |
|
quantized_layers = pq.quantize_model_( |
|
self.trainer.get_model(), |
|
self.size_tracker, |
|
self.layers_to_quantize, |
|
self.block_sizes_config, |
|
self.n_centroids_config, |
|
step=self.quantization_step, |
|
) |
|
logger.info("quantized layers: {}".format(quantized_layers)) |
|
logger.info(self.size_tracker) |
|
|
|
self.quantization_step += 1 |
|
|
|
|
|
self.trainer.reinitialize() |
|
|
|
def begin_epoch(self, epoch): |
|
"""Called at the beginning of each epoch (epochs start at 1).""" |
|
if ( |
|
( |
|
self.epoch_schedule is not None |
|
and epoch > 0 |
|
and (epoch - 1) % self.epoch_schedule == 0 |
|
) |
|
|
|
|
|
or self.quantization_step == 0 |
|
): |
|
self.step() |
|
|
|
def step_update(self, num_updates): |
|
"""Called at the end of each step.""" |
|
if ( |
|
self.update_schedule is not None |
|
and num_updates > 0 |
|
and num_updates % self.update_schedule == 0 |
|
): |
|
self.step() |
|
|
|
def state_dict(self): |
|
return { |
|
"n_centroids_config": self.n_centroids_config, |
|
"block_sizes_config": self.block_sizes_config, |
|
"layers_to_quantize": self.layers_to_quantize, |
|
"epoch_schedule": self.epoch_schedule, |
|
"update_schedule": self.update_schedule, |
|
"quantization_step": self.quantization_step, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.n_centroids_config = state_dict["n_centroids_config"] |
|
self.block_sizes_config = state_dict["block_sizes_config"] |
|
self.layers_to_quantize = state_dict["layers_to_quantize"] |
|
self.epoch_schedule = state_dict["epoch_schedule"] |
|
self.update_schedule = state_dict["update_schedule"] |
|
self.quantization_step = state_dict["quantization_step"] |
|
|