|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Registry responsible for built-in keras classes.""" |
|
|
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.keras.compat import keras |
|
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops |
|
from tensorflow_model_optimization.python.core.quantization.keras import quantizers |
|
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import ( |
|
default_8bit_quantize_registry,) |
|
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import ( |
|
default_8bit_quantizers,) |
|
|
|
|
|
layers = keras.layers |
|
|
|
|
|
class _PrunePreserveInfo(object): |
|
"""PrunePreserveInfo.""" |
|
|
|
def __init__(self, weight_attrs, quantize_config_attrs): |
|
"""Initializes PrunePreserveInfo. |
|
|
|
Args: |
|
weight_attrs: list of sparsity preservable weight attributes of layer. |
|
quantize_config_attrs: list of quantization configuration class name. |
|
""" |
|
self.weight_attrs = weight_attrs |
|
self.quantize_config_attrs = quantize_config_attrs |
|
|
|
|
|
class PrunePreserveQuantizeRegistry(): |
|
"""PrunePreserveQuantizeRegistry responsible for built-in keras layers.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
_LAYERS_CONFIG_MAP = { |
|
layers.Conv2D: |
|
_PrunePreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']), |
|
layers.Dense: |
|
_PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']), |
|
|
|
|
|
|
|
layers.DepthwiseConv2D: |
|
_PrunePreserveInfo(['depthwise_kernel'], ['Default8BitQuantizeConfig']), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
_DISABLE_PRUNE_PRESERVE = frozenset({ |
|
layers.DepthwiseConv2D, |
|
}) |
|
|
|
def __init__(self): |
|
|
|
self._config_quantizer_map = { |
|
'Default8BitQuantizeConfig': |
|
PrunePreserveDefault8BitWeightsQuantizer(), |
|
'Default8BitConvQuantizeConfig': |
|
PrunePreserveDefault8BitConvWeightsQuantizer(), |
|
} |
|
|
|
@classmethod |
|
def _no_trainable_weights(cls, layer): |
|
"""Returns whether this layer has trainable weights. |
|
|
|
Args: |
|
layer: The layer to check for trainable weights. |
|
|
|
Returns: |
|
True/False whether the layer has trainable weights. |
|
""" |
|
return not layer.trainable_weights |
|
|
|
@classmethod |
|
def _disable_prune_preserve(cls, layer): |
|
"""Returns whether disable this layer for prune preserve. |
|
|
|
Args: |
|
layer: The layer to check for disable. |
|
|
|
Returns: |
|
True/False whether disable this layer for prune preserve. |
|
""" |
|
|
|
return layer.__class__ in cls._DISABLE_PRUNE_PRESERVE |
|
|
|
@classmethod |
|
def supports(cls, layer): |
|
"""Returns whether the registry supports this layer type. |
|
|
|
Args: |
|
layer: The layer to check for support. |
|
|
|
Returns: |
|
True/False whether the layer type is supported. |
|
""" |
|
|
|
|
|
|
|
if cls._no_trainable_weights(layer): |
|
return True |
|
|
|
if layer.__class__ in cls._LAYERS_CONFIG_MAP: |
|
return True |
|
|
|
return False |
|
|
|
@classmethod |
|
def _weight_names(cls, layer): |
|
"""Gets the weight names.""" |
|
if cls._no_trainable_weights(layer): |
|
return [] |
|
|
|
return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs |
|
|
|
@classmethod |
|
def get_sparsity_preservable_weights(cls, layer): |
|
"""Gets sparsity preservable weights from keras layer. |
|
|
|
Args: |
|
layer: instance of keras layer |
|
|
|
Returns: |
|
List of sparsity preservable weights |
|
""" |
|
return [getattr(layer, weight) for weight in cls._weight_names(layer)] |
|
|
|
@classmethod |
|
def get_suppport_quantize_config_names(cls, layer): |
|
"""Gets class name of supported quantize config for layer. |
|
|
|
Args: |
|
layer: instance of keras layer |
|
|
|
Returns: |
|
List of supported quantize config class name. |
|
""" |
|
|
|
|
|
if cls._no_trainable_weights(layer): |
|
return [] |
|
|
|
return cls._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs |
|
|
|
def apply_sparsity_preserve_quantize_config(self, layer, quantize_config): |
|
"""Applies weights sparsity preservation. |
|
|
|
Args: |
|
layer: The layer to check for support. |
|
quantize_config: quantization config to check for support, |
|
apply sparsity preservation to pruned weights |
|
Raises: |
|
ValueError when layer is supported does not have quantization config. |
|
Returns: |
|
Returns quantize_config with addon sparsity preserve weight_quantizer. |
|
""" |
|
if self.supports(layer): |
|
if (self._no_trainable_weights(layer) or |
|
self._disable_prune_preserve(layer)): |
|
return quantize_config |
|
if (quantize_config.__class__.__name__ |
|
in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs): |
|
quantize_config.weight_quantizer = self._config_quantizer_map[ |
|
quantize_config.__class__.__name__] |
|
else: |
|
raise ValueError('Configuration {} is not supported for Layer {}.' |
|
.format(str(quantize_config.__class__.__name__), |
|
str(layer.__class__.__name__))) |
|
else: |
|
raise ValueError('Layer {} is not supported.'.format( |
|
str(layer.__class__.__name__))) |
|
|
|
return quantize_config |
|
|
|
|
|
class Default8bitPrunePreserveQuantizeRegistry(PrunePreserveQuantizeRegistry): |
|
"""Default 8 bit PrunePreserveQuantizeRegistry.""" |
|
|
|
def get_quantize_config(self, layer): |
|
"""Returns the quantization config with addon sparsity. |
|
|
|
Args: |
|
layer: input layer to return quantize config for. |
|
|
|
Returns: |
|
Returns the quantization config with sparsity preserve weight_quantizer. |
|
""" |
|
quantize_config = (default_8bit_quantize_registry |
|
.Default8BitQuantizeRegistry() |
|
.get_quantize_config(layer)) |
|
prune_aware_quantize_config = self.apply_sparsity_preserve_quantize_config( |
|
layer, quantize_config) |
|
|
|
return prune_aware_quantize_config |
|
|
|
|
|
class PrunePreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer): |
|
"""Quantize weights while preserve sparsity.""" |
|
|
|
def __init__(self, num_bits, per_axis, symmetric, narrow_range): |
|
"""Initializes PrunePreserveDefaultWeightsQuantizer. |
|
|
|
Args: |
|
num_bits: Number of bits for quantization |
|
per_axis: Whether to apply per_axis quantization. The last dimension is |
|
used as the axis. |
|
symmetric: If true, use symmetric quantization limits instead of training |
|
the minimum and maximum of each quantization range separately. |
|
narrow_range: In case of 8 bits, narrow_range nudges the quantized range |
|
to be [-127, 127] instead of [-128, 127]. This ensures symmetric range |
|
has 0 as the centre. |
|
""" |
|
quantizers.LastValueQuantizer.__init__(self, num_bits, per_axis, symmetric, |
|
narrow_range) |
|
|
|
def _build_sparsity_mask(self, name, layer): |
|
weights = getattr(layer.layer, name) |
|
sparsity_mask = tf.math.divide_no_nan(weights, weights) |
|
|
|
return {'sparsity_mask': sparsity_mask} |
|
|
|
def build(self, tensor_shape, name, layer): |
|
"""Constructs mask to preserve weights sparsity. |
|
|
|
Args: |
|
tensor_shape: Shape of weights which needs to be quantized. |
|
name: Name of weights in layer. |
|
layer: quantization wrapped keras layer. |
|
|
|
Returns: |
|
Dictionary of constructed sparsity mask and |
|
quantization params, the dictionary will be passed |
|
to __call__ function. |
|
""" |
|
result = self._build_sparsity_mask(name, layer) |
|
result.update( |
|
super(PrunePreserveDefaultWeightsQuantizer, |
|
self).build(tensor_shape, name, layer)) |
|
return result |
|
|
|
def __call__(self, inputs, training, weights, **kwargs): |
|
"""Applies sparsity preserved quantization to the input tensor. |
|
|
|
Args: |
|
inputs: Input tensor (layer's weights) to be quantized. |
|
training: Whether the graph is currently training. |
|
weights: Dictionary of weights (params) the quantizer can use to |
|
quantize the tensor (layer's weights). This contains the weights |
|
created in the `build` function. |
|
**kwargs: Additional variables which may be passed to the quantizer. |
|
|
|
Returns: |
|
quantized tensor. |
|
""" |
|
|
|
prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask']) |
|
|
|
return quant_ops.LastValueQuantize( |
|
prune_preserve_inputs, |
|
weights['min_var'], |
|
weights['max_var'], |
|
is_training=training, |
|
num_bits=self.num_bits, |
|
per_channel=self.per_axis, |
|
symmetric=self.symmetric, |
|
narrow_range=self.narrow_range, |
|
) |
|
|
|
|
|
class PrunePreserveDefault8BitWeightsQuantizer( |
|
PrunePreserveDefaultWeightsQuantizer): |
|
"""PrunePreserveWeightsQuantizer for default 8bit weights.""" |
|
|
|
def __init__(self): |
|
super(PrunePreserveDefault8BitWeightsQuantizer, |
|
self).__init__(num_bits=8, |
|
per_axis=False, |
|
symmetric=True, |
|
narrow_range=True) |
|
|
|
|
|
class PrunePreserveDefault8BitConvWeightsQuantizer( |
|
PrunePreserveDefaultWeightsQuantizer, |
|
default_8bit_quantizers.Default8BitConvWeightsQuantizer,): |
|
"""PrunePreserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights.""" |
|
|
|
|
|
def __init__(self): |
|
|
|
default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self) |
|
|
|
def build(self, tensor_shape, name, layer): |
|
result = PrunePreserveDefaultWeightsQuantizer._build_sparsity_mask( |
|
self, name, layer) |
|
result.update( |
|
default_8bit_quantizers.Default8BitConvWeightsQuantizer.build( |
|
self, tensor_shape, name, layer)) |
|
return result |
|
|