Mixtral_ether / cluster_preserve_quantize_registry.py
jeduardogruiz's picture
Upload 22 files
516a027 verified
raw
history blame
19.2 kB
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry responsible for built-in keras classes."""
import logging
import warnings
import tensorflow as tf
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
layers = keras.layers
K = keras.backend
CLUSTER_CENTROIDS = 'cluster_centroids_tf'
PULLING_INDICES = 'pulling_indices_tf'
ORIGINAL_WEIGHTS = 'ori_weights_vars_tf'
WEIGHT_NAME = 'weight_name'
CLUSTERING_IMPL = 'clst_impl'
CENTROIDS_MASK = 'centroids_mask'
SPARSITY_MASK = 'sparsity_mask'
def get_unique(t):
"""Get unique values and lookup index from N-D tensor.
Args:
t: tensor
Returns:
centroids (unique values), lookup index (same shape as input tensor)
Example:
t:
([[1.0, 2.0],
[2.0, 3.0],
[3.0, 3.0],
[1.0, 2.0]]
)
centroids(unique values):
([1.0, 2.0, 3.0])
output final index:
([[0, 1],
[1, 2],
[2, 2],
[0, 1]]
)
"""
t_flatten = tf.reshape(t, shape=(-1,))
uniques, index = tf.unique(t_flatten)
return uniques, tf.reshape(index, shape=tf.shape(t))
def get_centroids(layer, weight, data_format):
"""Gets centroid infos from the weights of a layer.
Args:
layer: The Keras layer from which the weight belong.
weight: The weight tensor to get the centroids info from.
data_format: string to indicate format: "channels_first" or "channels_last".
Returns:
A 4-tuple of centroids (unique values), number of centroids, lookup index,
whether to cluster per channel (boolean).
"""
cluster_per_channel = layer.layer and isinstance(
layer.layer, keras.layers.Conv2D
)
if not cluster_per_channel:
centroids, index = get_unique(weight)
return centroids, tf.size(centroids), index, False
# In case of cluster_per_channel we need to extract
# unique values (centroids) for each channel.
num_channels = weight.shape[1 if data_format == 'channels_first' else -1]
channel_centroids = []
channel_indices = []
num_centroids = []
for channel in range(num_channels):
channel_weights = weight[:, :, :, channel]
centroids, indices = get_unique(channel_weights)
channel_centroids.append(centroids)
channel_indices.append(indices)
num_centroids.append(tf.size(centroids))
max_centroid = max(num_centroids)
max_diff = max_centroid - min(num_centroids)
if max_diff > 1:
centroids, index = get_unique(weight)
return centroids, tf.size(centroids), index, False
for i, centroid in enumerate(channel_centroids):
if num_centroids[i] != max_centroid:
one_padding = tf.ones([max_centroid - num_centroids[i]])
channel_centroids[i] = tf.concat([centroid, one_padding], 0)
centroids = tf.convert_to_tensor(channel_centroids)
lookup = tf.convert_to_tensor(channel_indices)
lookup = tf.transpose(
lookup,
perm=(1, 0, 2, 3) if data_format == 'channels_first' else (1, 2, 3, 0))
return centroids, max_centroid, lookup, True
class _ClusterPreserveInfo(object):
"""ClusterPreserveInfo."""
def __init__(self, weight_attrs, quantize_config_attrs):
"""ClusterPreserveInfo.
Args:
weight_attrs: list of cluster preservable weight attributes of layer.
quantize_config_attrs: list of quantization configuration class name.
"""
self.weight_attrs = weight_attrs
self.quantize_config_attrs = quantize_config_attrs
class ClusterPreserveQuantizeRegistry(object):
"""ClusterPreserveQuantizeRegistry is for built-in keras layers."""
# The keys represent built-in keras layers; the first values represent the
# the variables within the layers which hold the kernel weights, second
# values represent the class name of quantization configuration for layers.
# This decide the weights of layers with quantization configurations are
# cluster preservable.
_LAYERS_CONFIG_MAP = {
layers.Conv2D:
_ClusterPreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
layers.Dense:
_ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
# DepthwiseConv2D is supported with 8bit qat, but not with
# clustering, thus for DepthwiseConv2D CQAT,
# preserving clustered weights is disabled.
layers.DepthwiseConv2D:
_ClusterPreserveInfo(['depthwise_kernel'],
['Default8BitQuantizeConfig']),
# layers that are supported with clustering, but not yet with qat
# layers.Conv1D:
# _ClusterPreserveInfo(['kernel'], []),
# layers.Conv2DTranspose:
# _ClusterPreserveInfo(['kernel'], []),
# layers.Conv3D:
# _ClusterPreserveInfo(['kernel'], []),
# layers.Conv3DTranspose:
# _ClusterPreserveInfo(['kernel'], []),
# layers.LocallyConnected1D:
# _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
# layers.LocallyConnected2D:
# _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
# SeparableConv need verify from 8bit qat
# layers.SeparableConv1D:
# _ClusterPreserveInfo(['pointwise_kernel'],
# ['Default8BitConvQuantizeConfig']),
# layers.SeparableConv2D:
# _ClusterPreserveInfo(['pointwise_kernel'],
# ['Default8BitConvQuantizeConfig']),
# Embedding need verify from 8bit qat
# layers.Embedding: _ClusterPreserveInfo(['embeddings'], []),
}
_DISABLE_CLUSTER_PRESERVE = frozenset({
layers.DepthwiseConv2D,
})
def __init__(self, preserve_sparsity):
self._config_quantizer_map = {
'Default8BitQuantizeConfig':
ClusterPreserveDefault8BitWeightsQuantizer(preserve_sparsity),
'Default8BitConvQuantizeConfig':
ClusterPreserveDefault8BitConvWeightsQuantizer(preserve_sparsity),
}
@classmethod
def _no_trainable_weights(cls, layer):
"""Returns whether this layer has trainable weights.
Args:
layer: The layer to check for trainable weights.
Returns:
True/False whether the layer has trainable weights.
"""
return not layer.trainable_weights
@classmethod
def _disable_cluster_preserve(cls, layer):
"""Returns whether to disable this layer for preserving clusters.
Args:
layer: The layer to check for disabling.
Returns:
True/False whether disabling this layer for preserving clusters.
"""
return layer.__class__ in cls._DISABLE_CLUSTER_PRESERVE
@classmethod
def supports(cls, layer):
"""Returns whether the registry supports this layer type.
Args:
layer: The layer to check for support.
Returns:
True/False whether the layer type is supported.
"""
# layers without trainable weights are consider supported,
# e.g., ReLU, Softmax, and AveragePooling2D.
if cls._no_trainable_weights(layer):
return True
if layer.__class__ in cls._LAYERS_CONFIG_MAP:
return True
return False
@classmethod
def _weight_names(cls, layer):
if cls._no_trainable_weights(layer):
return []
return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs
def apply_cluster_preserve_quantize_config(self, layer, quantize_config):
"""Applies cluster-preserve weight quantizer.
Args:
layer: The layer to check for support.
quantize_config: quantization config for supporting cluster preservation
on clustered weights
Returns:
The quantize_config with addon cluster preserve weight_quantizer.
"""
if not self.supports(layer):
raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')
# Example: ReLU, Softmax, and AveragePooling2D (without trainable weights)
# DepthwiseConv2D (cluster_preserve is disabled)
if self._no_trainable_weights(layer) or self._disable_cluster_preserve(
layer):
return quantize_config
# Example: Conv2D, Dense layers
if quantize_config.__class__.__name__ in self._LAYERS_CONFIG_MAP[
layer.__class__].quantize_config_attrs:
quantize_config.weight_quantizer = self._config_quantizer_map[
quantize_config.__class__.__name__]
else:
raise ValueError('Configuration ' +
str(quantize_config.__class__.__name__) +
' is not supported for Layer ' + str(layer.__class__) +
'.')
return quantize_config
class Default8bitClusterPreserveQuantizeRegistry(
ClusterPreserveQuantizeRegistry):
"""Default 8 bit ClusterPreserveQuantizeRegistry."""
def get_quantize_config(self, layer):
"""Returns the quantization config with weight_quantizer for a given layer.
Args:
layer: input layer to return quantize config for.
Returns:
Returns the quantization config for cluster preserve weight_quantizer.
"""
quantize_config = (default_8bit_quantize_registry.
Default8BitQuantizeRegistry().
get_quantize_config(layer))
cluster_aware_quantize_config = super(
Default8bitClusterPreserveQuantizeRegistry,
self).apply_cluster_preserve_quantize_config(layer, quantize_config)
return cluster_aware_quantize_config
class ClusterPreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
"""Quantize weights while preserving clusters."""
def __init__(
self, num_bits, per_axis, symmetric, narrow_range, preserve_sparsity):
"""ClusterPreserveDefaultWeightsQuantizer.
Args:
num_bits: Number of bits for quantization
per_axis: Whether to apply per_axis quantization. The last dimension is
used as the axis.
symmetric: If true, use symmetric quantization limits instead of training
the minimum and maximum of each quantization range separately.
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
to be [-127, 127] instead of [-128, 127]. This ensures symmetric
range has 0 as the centre.
preserve_sparsity: Whether to apply prune-cluster-preserving quantization
aware training.
"""
super(ClusterPreserveDefaultWeightsQuantizer, self).__init__(
num_bits=num_bits,
per_axis=per_axis,
symmetric=symmetric,
narrow_range=narrow_range,
)
self.preserve_sparsity = preserve_sparsity
def _build_clusters(self, name, layer):
"""Extracts the cluster centroids and cluster indices.
Extracts cluster centroids and cluster indices from the pretrained
clustered model when the input layer is clustered.
Args:
name: Name of weights in layer.
layer: Quantization wrapped keras layer.
Returns:
A dictionary of the initial values of the
cluster centroids, cluster indices, original weights,
the pretrained flag for marking the first training
epoch, and weight name.
"""
result = {}
weights = getattr(layer.layer, name)
if self.preserve_sparsity and not tf.reduce_any(weights == 0):
self.preserve_sparsity = False
logging.warning(
'Input layer does not contain zero weights, so apply CQAT instead.')
centroids_mask = None
# Detects whether layer is convolutional and is clustered per channel
data_format = getattr(layer.layer, 'data_format', None)
centroids, num_centroids, lookup, cluster_per_channel = get_centroids(
layer, weights, data_format)
if self.preserve_sparsity:
sparsity_mask = tf.math.divide_no_nan(weights, weights)
zero_idx = tf.argmin(tf.abs(centroids), axis=-1)
centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids)
result = {SPARSITY_MASK: sparsity_mask}
# Prepare clustering variables for the Keras graph when clusters
# exist, assuming we do not use number_of_clusters larger than 1024
if num_centroids > 1024:
warnings.warn(f'No clustering performed on layer {layer.name}.\n'
f'Too many centroids to cluster.')
return result
# If not enough clusters, we do not preserve clustering
elif num_centroids <= 1:
warnings.warn(f'No clustering performed on layer {layer.name}.\n'
f'Perhaps too many clusters requested for this layer?')
return result
else:
clst_centroids_tf = layer.add_weight(
CLUSTER_CENTROIDS,
shape=centroids.shape,
initializer=keras.initializers.Constant(
value=K.batch_get_value([centroids])[0]
),
dtype=centroids.dtype,
trainable=True,
)
ori_weights_tf = layer.add_weight(
ORIGINAL_WEIGHTS,
shape=weights.shape,
initializer=keras.initializers.Constant(
value=K.batch_get_value([weights])[0]
),
dtype=weights.dtype,
trainable=True,
)
# Get clustering implementation according to layer type
clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
).get_clustering_impl(
layer.layer, name, cluster_per_channel=cluster_per_channel)
clustering_impl = clustering_impl_cls(
clst_centroids_tf, cluster_config.GradientAggregation.SUM,
data_format)
pulling_indices = tf.dtypes.cast(
clustering_impl.get_pulling_indices(ori_weights_tf),
lookup.dtype
)
pulling_indices_tf = layer.add_weight(
PULLING_INDICES,
shape=lookup.shape,
initializer=keras.initializers.Constant(
value=K.batch_get_value([pulling_indices])[0]
),
dtype=lookup.dtype,
trainable=False,
)
result_clst = {
CLUSTER_CENTROIDS: clst_centroids_tf,
PULLING_INDICES: pulling_indices_tf,
ORIGINAL_WEIGHTS: ori_weights_tf,
WEIGHT_NAME: name,
CLUSTERING_IMPL: clustering_impl,
CENTROIDS_MASK: centroids_mask,
}
result.update(result_clst)
return result
def build(self, tensor_shape, name, layer):
"""Build (P)CQAT wrapper.
When preserve_sparsity is true and the input is clustered.
Args:
tensor_shape: Shape of weights which needs to be quantized.
name: Name of weights in layer.
layer: Quantization wrapped keras layer.
Returns:
Dictionary of centroids, indices and
quantization params, the dictionary will be passed
to __call__ function.
"""
# To get all the initial values from pretrained clustered model
result = self._build_clusters(name, layer)
# Result can have clustering nodes, then this is CQAT
# Result can have both clustering nodes and sparsity mask, then
# this will be PCQAT
result.update(
super(ClusterPreserveDefaultWeightsQuantizer,
self).build(tensor_shape, name, layer))
return result
def __call__(self, inputs, training, weights, **kwargs):
"""Apply cluster preserved quantization to the input tensor.
Args:
inputs: Input tensor (layer's weights) to be quantized.
training: Whether the graph is currently training.
weights: Dictionary of weights (params) the quantizer can use to
quantize the tensor (layer's weights). This contains the weights
created in the `build` function.
**kwargs: Additional variables which may be passed to the quantizer.
Returns:
quantized tensor.
"""
if training:
if CLUSTER_CENTROIDS in weights:
if self.preserve_sparsity:
weights[ORIGINAL_WEIGHTS].assign(
tf.multiply(weights[ORIGINAL_WEIGHTS],
weights[SPARSITY_MASK]))
weights[CLUSTERING_IMPL].cluster_centroids.assign(
weights[CLUSTERING_IMPL].
cluster_centroids * weights[CENTROIDS_MASK]
)
weights[CLUSTER_CENTROIDS].assign(
weights[CLUSTERING_IMPL].cluster_centroids
)
# Insert clustering variables
weights[PULLING_INDICES].assign(tf.dtypes.cast(
weights[CLUSTERING_IMPL].get_pulling_indices(
weights[ORIGINAL_WEIGHTS]),
weights[PULLING_INDICES].dtype
))
output = weights[CLUSTERING_IMPL].get_clustered_weight(
weights[PULLING_INDICES], weights[ORIGINAL_WEIGHTS])
inputs.assign(output)
else:
if self.preserve_sparsity:
inputs = tf.multiply(inputs, weights[SPARSITY_MASK])
output = inputs
else:
output = inputs
return quant_ops.LastValueQuantize(
output,
weights['min_var'],
weights['max_var'],
is_training=training,
num_bits=self.num_bits,
per_channel=self.per_axis,
symmetric=self.symmetric,
narrow_range=self.narrow_range
)
class ClusterPreserveDefault8BitWeightsQuantizer(
ClusterPreserveDefaultWeightsQuantizer):
"""ClusterPreserveWeightsQuantizer for default 8bit weights."""
def __init__(self, preserve_sparsity):
super(ClusterPreserveDefault8BitWeightsQuantizer,
self).__init__(num_bits=8,
per_axis=False,
symmetric=True,
narrow_range=True,
preserve_sparsity=preserve_sparsity)
self.preserve_sparsity = preserve_sparsity
class ClusterPreserveDefault8BitConvWeightsQuantizer(
ClusterPreserveDefaultWeightsQuantizer,
default_8bit_quantizers.Default8BitConvWeightsQuantizer):
"""ClusterPreserveWeightsQuantizer for default 8bit Conv2D weights."""
def __init__(self, preserve_sparsity): # pylint: disable=super-init-not-called
default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
self.preserve_sparsity = preserve_sparsity
def build(self, tensor_shape, name, layer):
result = ClusterPreserveDefaultWeightsQuantizer._build_clusters(
self, name, layer)
result.update(
default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
self, tensor_shape, name, layer))
return result