|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Registry responsible for built-in keras classes.""" |
|
|
|
import logging |
|
import warnings |
|
|
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config |
|
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry |
|
from tensorflow_model_optimization.python.core.keras.compat import keras |
|
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops |
|
from tensorflow_model_optimization.python.core.quantization.keras import quantizers |
|
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry |
|
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers |
|
|
|
|
|
layers = keras.layers |
|
K = keras.backend |
|
|
|
CLUSTER_CENTROIDS = 'cluster_centroids_tf' |
|
PULLING_INDICES = 'pulling_indices_tf' |
|
ORIGINAL_WEIGHTS = 'ori_weights_vars_tf' |
|
WEIGHT_NAME = 'weight_name' |
|
CLUSTERING_IMPL = 'clst_impl' |
|
CENTROIDS_MASK = 'centroids_mask' |
|
SPARSITY_MASK = 'sparsity_mask' |
|
|
|
|
|
def get_unique(t): |
|
"""Get unique values and lookup index from N-D tensor. |
|
|
|
Args: |
|
t: tensor |
|
Returns: |
|
centroids (unique values), lookup index (same shape as input tensor) |
|
Example: |
|
t: |
|
([[1.0, 2.0], |
|
[2.0, 3.0], |
|
[3.0, 3.0], |
|
[1.0, 2.0]] |
|
) |
|
centroids(unique values): |
|
([1.0, 2.0, 3.0]) |
|
output final index: |
|
([[0, 1], |
|
[1, 2], |
|
[2, 2], |
|
[0, 1]] |
|
) |
|
""" |
|
t_flatten = tf.reshape(t, shape=(-1,)) |
|
uniques, index = tf.unique(t_flatten) |
|
return uniques, tf.reshape(index, shape=tf.shape(t)) |
|
|
|
|
|
def get_centroids(layer, weight, data_format): |
|
"""Gets centroid infos from the weights of a layer. |
|
|
|
Args: |
|
layer: The Keras layer from which the weight belong. |
|
weight: The weight tensor to get the centroids info from. |
|
data_format: string to indicate format: "channels_first" or "channels_last". |
|
Returns: |
|
A 4-tuple of centroids (unique values), number of centroids, lookup index, |
|
whether to cluster per channel (boolean). |
|
""" |
|
cluster_per_channel = layer.layer and isinstance( |
|
layer.layer, keras.layers.Conv2D |
|
) |
|
|
|
if not cluster_per_channel: |
|
centroids, index = get_unique(weight) |
|
return centroids, tf.size(centroids), index, False |
|
|
|
|
|
|
|
num_channels = weight.shape[1 if data_format == 'channels_first' else -1] |
|
channel_centroids = [] |
|
channel_indices = [] |
|
num_centroids = [] |
|
|
|
for channel in range(num_channels): |
|
channel_weights = weight[:, :, :, channel] |
|
centroids, indices = get_unique(channel_weights) |
|
|
|
channel_centroids.append(centroids) |
|
channel_indices.append(indices) |
|
num_centroids.append(tf.size(centroids)) |
|
|
|
max_centroid = max(num_centroids) |
|
max_diff = max_centroid - min(num_centroids) |
|
|
|
if max_diff > 1: |
|
centroids, index = get_unique(weight) |
|
return centroids, tf.size(centroids), index, False |
|
|
|
for i, centroid in enumerate(channel_centroids): |
|
if num_centroids[i] != max_centroid: |
|
one_padding = tf.ones([max_centroid - num_centroids[i]]) |
|
channel_centroids[i] = tf.concat([centroid, one_padding], 0) |
|
|
|
centroids = tf.convert_to_tensor(channel_centroids) |
|
lookup = tf.convert_to_tensor(channel_indices) |
|
|
|
lookup = tf.transpose( |
|
lookup, |
|
perm=(1, 0, 2, 3) if data_format == 'channels_first' else (1, 2, 3, 0)) |
|
|
|
return centroids, max_centroid, lookup, True |
|
|
|
|
|
class _ClusterPreserveInfo(object): |
|
"""ClusterPreserveInfo.""" |
|
|
|
def __init__(self, weight_attrs, quantize_config_attrs): |
|
"""ClusterPreserveInfo. |
|
|
|
Args: |
|
weight_attrs: list of cluster preservable weight attributes of layer. |
|
quantize_config_attrs: list of quantization configuration class name. |
|
""" |
|
self.weight_attrs = weight_attrs |
|
self.quantize_config_attrs = quantize_config_attrs |
|
|
|
|
|
class ClusterPreserveQuantizeRegistry(object): |
|
"""ClusterPreserveQuantizeRegistry is for built-in keras layers.""" |
|
|
|
|
|
|
|
|
|
|
|
_LAYERS_CONFIG_MAP = { |
|
layers.Conv2D: |
|
_ClusterPreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']), |
|
layers.Dense: |
|
_ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']), |
|
|
|
|
|
|
|
|
|
layers.DepthwiseConv2D: |
|
_ClusterPreserveInfo(['depthwise_kernel'], |
|
['Default8BitQuantizeConfig']), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
_DISABLE_CLUSTER_PRESERVE = frozenset({ |
|
layers.DepthwiseConv2D, |
|
}) |
|
|
|
def __init__(self, preserve_sparsity): |
|
self._config_quantizer_map = { |
|
'Default8BitQuantizeConfig': |
|
ClusterPreserveDefault8BitWeightsQuantizer(preserve_sparsity), |
|
'Default8BitConvQuantizeConfig': |
|
ClusterPreserveDefault8BitConvWeightsQuantizer(preserve_sparsity), |
|
} |
|
|
|
@classmethod |
|
def _no_trainable_weights(cls, layer): |
|
"""Returns whether this layer has trainable weights. |
|
|
|
Args: |
|
layer: The layer to check for trainable weights. |
|
Returns: |
|
True/False whether the layer has trainable weights. |
|
""" |
|
return not layer.trainable_weights |
|
|
|
@classmethod |
|
def _disable_cluster_preserve(cls, layer): |
|
"""Returns whether to disable this layer for preserving clusters. |
|
|
|
Args: |
|
layer: The layer to check for disabling. |
|
Returns: |
|
True/False whether disabling this layer for preserving clusters. |
|
""" |
|
return layer.__class__ in cls._DISABLE_CLUSTER_PRESERVE |
|
|
|
@classmethod |
|
def supports(cls, layer): |
|
"""Returns whether the registry supports this layer type. |
|
|
|
Args: |
|
layer: The layer to check for support. |
|
Returns: |
|
True/False whether the layer type is supported. |
|
""" |
|
|
|
|
|
if cls._no_trainable_weights(layer): |
|
return True |
|
|
|
if layer.__class__ in cls._LAYERS_CONFIG_MAP: |
|
return True |
|
|
|
return False |
|
|
|
@classmethod |
|
def _weight_names(cls, layer): |
|
|
|
if cls._no_trainable_weights(layer): |
|
return [] |
|
|
|
return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs |
|
|
|
def apply_cluster_preserve_quantize_config(self, layer, quantize_config): |
|
"""Applies cluster-preserve weight quantizer. |
|
|
|
Args: |
|
layer: The layer to check for support. |
|
quantize_config: quantization config for supporting cluster preservation |
|
on clustered weights |
|
Returns: |
|
The quantize_config with addon cluster preserve weight_quantizer. |
|
""" |
|
if not self.supports(layer): |
|
raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.') |
|
|
|
|
|
|
|
if self._no_trainable_weights(layer) or self._disable_cluster_preserve( |
|
layer): |
|
return quantize_config |
|
|
|
|
|
if quantize_config.__class__.__name__ in self._LAYERS_CONFIG_MAP[ |
|
layer.__class__].quantize_config_attrs: |
|
quantize_config.weight_quantizer = self._config_quantizer_map[ |
|
quantize_config.__class__.__name__] |
|
else: |
|
raise ValueError('Configuration ' + |
|
str(quantize_config.__class__.__name__) + |
|
' is not supported for Layer ' + str(layer.__class__) + |
|
'.') |
|
|
|
return quantize_config |
|
|
|
|
|
class Default8bitClusterPreserveQuantizeRegistry( |
|
ClusterPreserveQuantizeRegistry): |
|
"""Default 8 bit ClusterPreserveQuantizeRegistry.""" |
|
|
|
def get_quantize_config(self, layer): |
|
"""Returns the quantization config with weight_quantizer for a given layer. |
|
|
|
Args: |
|
layer: input layer to return quantize config for. |
|
Returns: |
|
Returns the quantization config for cluster preserve weight_quantizer. |
|
""" |
|
quantize_config = (default_8bit_quantize_registry. |
|
Default8BitQuantizeRegistry(). |
|
get_quantize_config(layer)) |
|
cluster_aware_quantize_config = super( |
|
Default8bitClusterPreserveQuantizeRegistry, |
|
self).apply_cluster_preserve_quantize_config(layer, quantize_config) |
|
|
|
return cluster_aware_quantize_config |
|
|
|
|
|
class ClusterPreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer): |
|
"""Quantize weights while preserving clusters.""" |
|
|
|
def __init__( |
|
self, num_bits, per_axis, symmetric, narrow_range, preserve_sparsity): |
|
"""ClusterPreserveDefaultWeightsQuantizer. |
|
|
|
Args: |
|
num_bits: Number of bits for quantization |
|
per_axis: Whether to apply per_axis quantization. The last dimension is |
|
used as the axis. |
|
symmetric: If true, use symmetric quantization limits instead of training |
|
the minimum and maximum of each quantization range separately. |
|
narrow_range: In case of 8 bits, narrow_range nudges the quantized range |
|
to be [-127, 127] instead of [-128, 127]. This ensures symmetric |
|
range has 0 as the centre. |
|
preserve_sparsity: Whether to apply prune-cluster-preserving quantization |
|
aware training. |
|
""" |
|
super(ClusterPreserveDefaultWeightsQuantizer, self).__init__( |
|
num_bits=num_bits, |
|
per_axis=per_axis, |
|
symmetric=symmetric, |
|
narrow_range=narrow_range, |
|
) |
|
self.preserve_sparsity = preserve_sparsity |
|
|
|
def _build_clusters(self, name, layer): |
|
"""Extracts the cluster centroids and cluster indices. |
|
|
|
Extracts cluster centroids and cluster indices from the pretrained |
|
clustered model when the input layer is clustered. |
|
|
|
Args: |
|
name: Name of weights in layer. |
|
layer: Quantization wrapped keras layer. |
|
Returns: |
|
A dictionary of the initial values of the |
|
cluster centroids, cluster indices, original weights, |
|
the pretrained flag for marking the first training |
|
epoch, and weight name. |
|
""" |
|
result = {} |
|
weights = getattr(layer.layer, name) |
|
if self.preserve_sparsity and not tf.reduce_any(weights == 0): |
|
self.preserve_sparsity = False |
|
logging.warning( |
|
'Input layer does not contain zero weights, so apply CQAT instead.') |
|
centroids_mask = None |
|
|
|
|
|
data_format = getattr(layer.layer, 'data_format', None) |
|
centroids, num_centroids, lookup, cluster_per_channel = get_centroids( |
|
layer, weights, data_format) |
|
|
|
if self.preserve_sparsity: |
|
sparsity_mask = tf.math.divide_no_nan(weights, weights) |
|
zero_idx = tf.argmin(tf.abs(centroids), axis=-1) |
|
centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids) |
|
result = {SPARSITY_MASK: sparsity_mask} |
|
|
|
|
|
|
|
if num_centroids > 1024: |
|
warnings.warn(f'No clustering performed on layer {layer.name}.\n' |
|
f'Too many centroids to cluster.') |
|
return result |
|
|
|
elif num_centroids <= 1: |
|
warnings.warn(f'No clustering performed on layer {layer.name}.\n' |
|
f'Perhaps too many clusters requested for this layer?') |
|
return result |
|
else: |
|
clst_centroids_tf = layer.add_weight( |
|
CLUSTER_CENTROIDS, |
|
shape=centroids.shape, |
|
initializer=keras.initializers.Constant( |
|
value=K.batch_get_value([centroids])[0] |
|
), |
|
dtype=centroids.dtype, |
|
trainable=True, |
|
) |
|
|
|
ori_weights_tf = layer.add_weight( |
|
ORIGINAL_WEIGHTS, |
|
shape=weights.shape, |
|
initializer=keras.initializers.Constant( |
|
value=K.batch_get_value([weights])[0] |
|
), |
|
dtype=weights.dtype, |
|
trainable=True, |
|
) |
|
|
|
|
|
clustering_impl_cls = clustering_registry.ClusteringLookupRegistry( |
|
).get_clustering_impl( |
|
layer.layer, name, cluster_per_channel=cluster_per_channel) |
|
clustering_impl = clustering_impl_cls( |
|
clst_centroids_tf, cluster_config.GradientAggregation.SUM, |
|
data_format) |
|
|
|
pulling_indices = tf.dtypes.cast( |
|
clustering_impl.get_pulling_indices(ori_weights_tf), |
|
lookup.dtype |
|
) |
|
|
|
pulling_indices_tf = layer.add_weight( |
|
PULLING_INDICES, |
|
shape=lookup.shape, |
|
initializer=keras.initializers.Constant( |
|
value=K.batch_get_value([pulling_indices])[0] |
|
), |
|
dtype=lookup.dtype, |
|
trainable=False, |
|
) |
|
|
|
result_clst = { |
|
CLUSTER_CENTROIDS: clst_centroids_tf, |
|
PULLING_INDICES: pulling_indices_tf, |
|
ORIGINAL_WEIGHTS: ori_weights_tf, |
|
WEIGHT_NAME: name, |
|
CLUSTERING_IMPL: clustering_impl, |
|
CENTROIDS_MASK: centroids_mask, |
|
} |
|
result.update(result_clst) |
|
return result |
|
|
|
def build(self, tensor_shape, name, layer): |
|
"""Build (P)CQAT wrapper. |
|
|
|
When preserve_sparsity is true and the input is clustered. |
|
|
|
Args: |
|
tensor_shape: Shape of weights which needs to be quantized. |
|
name: Name of weights in layer. |
|
layer: Quantization wrapped keras layer. |
|
Returns: |
|
Dictionary of centroids, indices and |
|
quantization params, the dictionary will be passed |
|
to __call__ function. |
|
""" |
|
|
|
result = self._build_clusters(name, layer) |
|
|
|
|
|
|
|
result.update( |
|
super(ClusterPreserveDefaultWeightsQuantizer, |
|
self).build(tensor_shape, name, layer)) |
|
|
|
return result |
|
|
|
def __call__(self, inputs, training, weights, **kwargs): |
|
"""Apply cluster preserved quantization to the input tensor. |
|
|
|
Args: |
|
inputs: Input tensor (layer's weights) to be quantized. |
|
training: Whether the graph is currently training. |
|
weights: Dictionary of weights (params) the quantizer can use to |
|
quantize the tensor (layer's weights). This contains the weights |
|
created in the `build` function. |
|
**kwargs: Additional variables which may be passed to the quantizer. |
|
Returns: |
|
quantized tensor. |
|
""" |
|
if training: |
|
if CLUSTER_CENTROIDS in weights: |
|
if self.preserve_sparsity: |
|
weights[ORIGINAL_WEIGHTS].assign( |
|
tf.multiply(weights[ORIGINAL_WEIGHTS], |
|
weights[SPARSITY_MASK])) |
|
weights[CLUSTERING_IMPL].cluster_centroids.assign( |
|
weights[CLUSTERING_IMPL]. |
|
cluster_centroids * weights[CENTROIDS_MASK] |
|
) |
|
weights[CLUSTER_CENTROIDS].assign( |
|
weights[CLUSTERING_IMPL].cluster_centroids |
|
) |
|
|
|
weights[PULLING_INDICES].assign(tf.dtypes.cast( |
|
weights[CLUSTERING_IMPL].get_pulling_indices( |
|
weights[ORIGINAL_WEIGHTS]), |
|
weights[PULLING_INDICES].dtype |
|
)) |
|
|
|
output = weights[CLUSTERING_IMPL].get_clustered_weight( |
|
weights[PULLING_INDICES], weights[ORIGINAL_WEIGHTS]) |
|
inputs.assign(output) |
|
else: |
|
if self.preserve_sparsity: |
|
inputs = tf.multiply(inputs, weights[SPARSITY_MASK]) |
|
output = inputs |
|
else: |
|
output = inputs |
|
|
|
return quant_ops.LastValueQuantize( |
|
output, |
|
weights['min_var'], |
|
weights['max_var'], |
|
is_training=training, |
|
num_bits=self.num_bits, |
|
per_channel=self.per_axis, |
|
symmetric=self.symmetric, |
|
narrow_range=self.narrow_range |
|
) |
|
|
|
|
|
class ClusterPreserveDefault8BitWeightsQuantizer( |
|
ClusterPreserveDefaultWeightsQuantizer): |
|
"""ClusterPreserveWeightsQuantizer for default 8bit weights.""" |
|
|
|
def __init__(self, preserve_sparsity): |
|
super(ClusterPreserveDefault8BitWeightsQuantizer, |
|
self).__init__(num_bits=8, |
|
per_axis=False, |
|
symmetric=True, |
|
narrow_range=True, |
|
preserve_sparsity=preserve_sparsity) |
|
self.preserve_sparsity = preserve_sparsity |
|
|
|
|
|
class ClusterPreserveDefault8BitConvWeightsQuantizer( |
|
ClusterPreserveDefaultWeightsQuantizer, |
|
default_8bit_quantizers.Default8BitConvWeightsQuantizer): |
|
"""ClusterPreserveWeightsQuantizer for default 8bit Conv2D weights.""" |
|
|
|
def __init__(self, preserve_sparsity): |
|
default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self) |
|
self.preserve_sparsity = preserve_sparsity |
|
|
|
def build(self, tensor_shape, name, layer): |
|
result = ClusterPreserveDefaultWeightsQuantizer._build_clusters( |
|
self, name, layer) |
|
result.update( |
|
default_8bit_quantizers.Default8BitConvWeightsQuantizer.build( |
|
self, tensor_shape, name, layer)) |
|
return result |
|
|