# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Integration tests for CQAT, PCQAT cases.""" from absl.testing import parameterized import numpy as np import tensorflow as tf from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster from tensorflow_model_optimization.python.core.keras.compat import keras from tensorflow_model_optimization.python.core.quantization.keras import quantize from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve import ( default_8bit_cluster_preserve_quantize_scheme,) from tensorflow_model_optimization.python.core.quantization.keras.collab_opts.cluster_preserve.cluster_utils import ( strip_clustering_cqat,) layers = keras.layers class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): super(ClusterPreserveIntegrationTest, self).setUp() self.cluster_params = { 'number_of_clusters': 4, 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR } def compile_and_fit(self, model): """Here we compile and fit the model.""" model.compile( loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'], ) model.fit( np.random.rand(20, 10), keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), batch_size=20, ) def _get_number_of_unique_weights(self, stripped_model, layer_nr, weight_name): layer = stripped_model.layers[layer_nr] if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): for weight_item in layer.trainable_weights: if weight_name in weight_item.name: weight = weight_item else: weight = getattr(layer, weight_name) weights_as_list = weight.numpy().flatten() nr_of_unique_weights = len(set(weights_as_list)) return nr_of_unique_weights def _get_sparsity(self, model): sparsity_list = [] for layer in model.layers: for weights in layer.trainable_weights: if 'kernel' in weights.name: np_weights = keras.backend.get_value(weights) sparsity = 1.0 - np.count_nonzero(np_weights) / float( np_weights.size) sparsity_list.append(sparsity) return sparsity_list def _get_clustered_model(self, preserve_sparsity): """Cluster the (sparse) model and return clustered_model.""" tf.random.set_seed(1) original_model = keras.Sequential([ layers.Dense(5, activation='softmax', input_shape=(10,)), layers.Flatten(), ]) # Manually set sparsity in the Dense layer if preserve_sparsity is on if preserve_sparsity: first_layer_weights = original_model.layers[0].get_weights() first_layer_weights[0][:][0:2] = 0.0 original_model.layers[0].set_weights(first_layer_weights) # Start the sparsity-aware clustering clustering_params = { 'number_of_clusters': 4, 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR, 'preserve_sparsity': True } clustered_model = experimental_cluster.cluster_weights( original_model, **clustering_params) return clustered_model def _get_conv_model(self, nr_of_channels, data_format=None, kernel_size=(3, 3)): """Returns functional model with Conv2D layer.""" inp = keras.layers.Input(shape=(32, 32), batch_size=100) shape = (1, 32, 32) if data_format == 'channels_first' else (32, 32, 1) x = keras.layers.Reshape(shape)(inp) x = keras.layers.Conv2D( filters=nr_of_channels, kernel_size=kernel_size, data_format=data_format, activation='relu', )(x) x = keras.layers.MaxPool2D(2, 2)(x) out = keras.layers.Flatten()(x) model = keras.Model(inputs=inp, outputs=out) return model def _compile_and_fit_conv_model(self, model, nr_epochs=1): """Compile and fit conv model from _get_conv_model.""" x_train = np.random.uniform(size=(500, 32, 32)) y_train = np.random.randint(low=0, high=1024, size=(500,)) model.compile( optimizer=keras.optimizers.Adam(learning_rate=1e-4), loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy')], ) model.fit(x_train, y_train, epochs=nr_epochs, batch_size=100, verbose=1) return model def _get_conv_clustered_model(self, nr_of_channels, nr_of_clusters, data_format, preserve_sparsity, kernel_size=(3, 3)): """Returns clustered per channel model with Conv2D layer.""" tf.random.set_seed(42) model = self._get_conv_model(nr_of_channels, data_format, kernel_size) if preserve_sparsity: # Make the convolutional layer sparse by nullifying half of weights assert model.layers[2].name == 'conv2d' conv_layer_weights = model.layers[2].get_weights() shape = conv_layer_weights[0].shape conv_layer_weights_flatten = conv_layer_weights[0].flatten() nr_elems = len(conv_layer_weights_flatten) conv_layer_weights_flatten[0:1 + nr_elems // 2] = 0.0 pruned_conv_layer_weights = tf.reshape(conv_layer_weights_flatten, shape) conv_layer_weights[0] = pruned_conv_layer_weights model.layers[2].set_weights(conv_layer_weights) clustering_params = { 'number_of_clusters': nr_of_clusters, 'cluster_centroids_init': cluster_config.CentroidInitialization.KMEANS_PLUS_PLUS, 'cluster_per_channel': True, 'preserve_sparsity': preserve_sparsity } clustered_model = experimental_cluster.cluster_weights(model, **clustering_params) clustered_model = self._compile_and_fit_conv_model(clustered_model) # Returns un-stripped model return clustered_model def _pcqat_training(self, preserve_sparsity, quant_aware_annotate_model): """PCQAT training on the input model.""" quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme(preserve_sparsity)) self.compile_and_fit(quant_aware_model) stripped_pcqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of clustered_model and pcqat_model # layer 0 is the quantize_layer num_of_unique_weights_pcqat = self._get_number_of_unique_weights( stripped_pcqat_model, 1, 'kernel') sparsity_pcqat = self._get_sparsity(stripped_pcqat_model) return sparsity_pcqat, num_of_unique_weights_pcqat def testEndToEndClusterPreserve(self): """Runs CQAT end to end and whole model is quantized.""" original_model = keras.Sequential( [layers.Dense(5, activation='softmax', input_shape=(10,))] ) clustered_model = cluster.cluster_weights( original_model, **self.cluster_params) self.compile_and_fit(clustered_model) clustered_model = cluster.strip_clustering(clustered_model) num_of_unique_weights_clustering = self._get_number_of_unique_weights( clustered_model, 0, 'kernel') quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model)) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme()) self.compile_and_fit(quant_aware_model) stripped_cqat_model = strip_clustering_cqat(quant_aware_model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, 1, 'kernel') self.assertAllEqual(num_of_unique_weights_clustering, num_of_unique_weights_cqat) def testEndToEndClusterPreservePerLayer(self): """Runs CQAT end to end and model is quantized per layers.""" original_model = keras.Sequential([ layers.Dense(5, activation='relu', input_shape=(10,)), layers.Dense(5, activation='softmax', input_shape=(10,)), ]) clustered_model = cluster.cluster_weights( original_model, **self.cluster_params) self.compile_and_fit(clustered_model) clustered_model = cluster.strip_clustering(clustered_model) num_of_unique_weights_clustering = self._get_number_of_unique_weights( clustered_model, 1, 'kernel') def apply_quantization_to_dense(layer): if isinstance(layer, keras.layers.Dense): return quantize.quantize_annotate_layer(layer) return layer quant_aware_annotate_model = keras.models.clone_model( clustered_model, clone_function=apply_quantization_to_dense, ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme()) self.compile_and_fit(quant_aware_model) stripped_cqat_model = strip_clustering_cqat( quant_aware_model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, 2, 'kernel') self.assertAllEqual(num_of_unique_weights_clustering, num_of_unique_weights_cqat) def testEndToEndClusterPreserveOneLayer(self): """Runs CQAT end to end and model is quantized only for a single layer.""" original_model = keras.Sequential([ layers.Dense(5, activation='relu', input_shape=(10,)), layers.Dense(5, activation='softmax', input_shape=(10,), name='qat'), ]) clustered_model = cluster.cluster_weights( original_model, **self.cluster_params) self.compile_and_fit(clustered_model) clustered_model = cluster.strip_clustering(clustered_model) num_of_unique_weights_clustering = self._get_number_of_unique_weights( clustered_model, 1, 'kernel') def apply_quantization_to_dense(layer): if isinstance(layer, keras.layers.Dense): if layer.name == 'qat': return quantize.quantize_annotate_layer(layer) return layer quant_aware_annotate_model = keras.models.clone_model( clustered_model, clone_function=apply_quantization_to_dense, ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme()) self.compile_and_fit(quant_aware_model) stripped_cqat_model = strip_clustering_cqat( quant_aware_model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, 1, 'kernel') self.assertAllEqual(num_of_unique_weights_clustering, num_of_unique_weights_cqat) def testEndToEndPruneClusterPreserveQAT(self): """Runs PCQAT end to end when we quantize the whole model.""" preserve_sparsity = True clustered_model = self._get_clustered_model(preserve_sparsity) # Save the kernel weights first_layer_weights = clustered_model.layers[0].weights[1] stripped_model_before_tuning = cluster.strip_clustering( clustered_model) nr_of_unique_weights_before = self._get_number_of_unique_weights( stripped_model_before_tuning, 0, 'kernel') self.compile_and_fit(clustered_model) stripped_model_clustered = cluster.strip_clustering(clustered_model) weights_after_tuning = stripped_model_clustered.layers[0].kernel nr_of_unique_weights_after = self._get_number_of_unique_weights( stripped_model_clustered, 0, 'kernel') # Check after sparsity-aware clustering, despite zero centroid can drift, # the final number of unique weights remains the same self.assertEqual(nr_of_unique_weights_before, nr_of_unique_weights_after) # Check that the zero weights stayed the same before and after tuning. # There might be new weights that become zeros but sparsity-aware # clustering preserves the original zero weights in the original positions # of the weight array self.assertTrue( np.array_equal(first_layer_weights[:][0:2], weights_after_tuning[:][0:2])) # Check sparsity before the input of PCQAT sparsity_pruning = self._get_sparsity(stripped_model_clustered) # PCQAT: when the preserve_sparsity flag is True, the PCQAT should work quant_aware_annotate_model = ( quantize.quantize_annotate_model(stripped_model_clustered) ) # When preserve_sparsity is True in PCQAT, the final sparsity of # the layer stays the same or larger than that of the input layer preserve_sparsity = True sparsity_pcqat, unique_weights_pcqat = self._pcqat_training( preserve_sparsity, quant_aware_annotate_model) self.assertAllGreaterEqual(np.array(sparsity_pcqat), sparsity_pruning[0]) self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat) def testEndToEndClusterPreserveQATClusteredPerChannel( self, data_format='channels_last'): """Runs CQAT end to end for the model that is clustered per channel.""" nr_of_channels = 12 nr_of_clusters = 4 clustered_model = self._get_conv_clustered_model( nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=False) stripped_model = cluster.strip_clustering(clustered_model) # Save the kernel weights conv2d_layer = stripped_model.layers[2] self.assertEqual(conv2d_layer.name, 'conv2d') # should be nr_of_channels * nr_of_clusters nr_unique_weights = -1 for weight in conv2d_layer.weights: if 'kernel' in weight.name: nr_unique_weights = len(np.unique(weight.numpy())) self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels) quant_aware_annotate_model = ( quantize.quantize_annotate_model(stripped_model) ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme()) # Lets train for more epochs to have a chance to scatter clusters model = self._compile_and_fit_conv_model(quant_aware_model, 3) stripped_cqat_model = strip_clustering_cqat(model) # Check the unique weights of a certain layer of # clustered_model and pcqat_model layer_nr = 3 num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, layer_nr, 'kernel') self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights) # We need to do tighter check: we check that the number of unique # weights per channel is less than the given nr_of_channels layer = stripped_cqat_model.layers[layer_nr] weight_to_check = None if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): for weight_item in layer.trainable_weights: if 'kernel' in weight_item.name: weight_to_check = weight_item assert weight_to_check is not None for i in range(nr_of_channels): nr_unique_weights_per_channel = len( np.unique(weight_to_check[:, :, :, i])) assert nr_unique_weights_per_channel == nr_of_clusters def testEndToEndPCQATClusteredPerChannel(self, data_format='channels_last'): """Runs PCQAT end to end for the model that is clustered per channel.""" nr_of_channels = 12 nr_of_clusters = 4 clustered_model = self._get_conv_clustered_model( nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=True) stripped_model = cluster.strip_clustering(clustered_model) # Save the kernel weights conv2d_layer = stripped_model.layers[2] self.assertEqual(conv2d_layer.name, 'conv2d') # should be nr_of_channels * nr_of_clusters nr_unique_weights = -1 for weight in conv2d_layer.weights: if 'kernel' in weight.name: nr_unique_weights = len(np.unique(weight.numpy())) self.assertLessEqual(nr_unique_weights, nr_of_clusters*nr_of_channels) # get sparsity before PCQAT training # we expect that only one value will be returned control_sparsity = self._get_sparsity(stripped_model) self.assertGreater(control_sparsity[0], 0.5) quant_aware_annotate_model = ( quantize.quantize_annotate_model(stripped_model) ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme()) # Lets train for more epochs to have a chance to scatter clusters model = self._compile_and_fit_conv_model(quant_aware_model, 3) stripped_cqat_model = strip_clustering_cqat(model) # Check the unique weights of a certain layer of # clustered_model and cqat_model layer_nr = 3 num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, layer_nr, 'kernel') self.assertLessEqual(num_of_unique_weights_cqat, nr_unique_weights) # We need to do tighter check: we check that the number of unique # weights per channel is less than the given nr_of_channels layer = stripped_cqat_model.layers[layer_nr] weight_to_check = None if isinstance(layer, quantize.quantize_wrapper.QuantizeWrapper): for weight_item in layer.trainable_weights: if 'kernel' in weight_item.name: weight_to_check = weight_item assert weight_to_check is not None for i in range(nr_of_channels): nr_unique_weights_per_channel = len( np.unique(weight_to_check[:, :, :, i])) assert nr_unique_weights_per_channel == nr_of_clusters cqat_sparsity = self._get_sparsity(stripped_cqat_model) self.assertLessEqual(cqat_sparsity[0], control_sparsity[0]) def testEndToEndPCQATClusteredPerChannelConv2d1x1(self, data_format='channels_last' ): """Runs PCQAT for model containing a 1x1 Conv2D. (with insufficient number of weights per channel). Args: data_format: Format of input data. """ nr_of_channels = 12 nr_of_clusters = 4 # Ensure a warning is given to the user that # clustering is not implemented for this layer with self.assertWarnsRegex(Warning, r'Layer conv2d does not have enough weights'): clustered_model = self._get_conv_clustered_model( nr_of_channels, nr_of_clusters, data_format, preserve_sparsity=True, kernel_size=(1, 1)) stripped_model = cluster.strip_clustering(clustered_model) # Save the kernel weights conv2d_layer = stripped_model.layers[2] self.assertEqual(conv2d_layer.name, 'conv2d') for weight in conv2d_layer.weights: if 'kernel' in weight.name: # Original number of unique weights nr_original_weights = len(np.unique(weight.numpy())) self.assertLess(nr_original_weights, nr_of_channels * nr_of_clusters) # Demonstrate unmodified test layer has less weights # than requested clusters for channel in range(nr_of_channels): channel_weights = ( weight[:, channel, :, :] if data_format == 'channels_first' else weight[:, :, :, channel]) nr_channel_weights = len(channel_weights) self.assertGreater(nr_channel_weights, 0) self.assertLessEqual(nr_channel_weights, nr_of_clusters) # get sparsity before PCQAT training # we expect that only one value will be returned control_sparsity = self._get_sparsity(stripped_model) self.assertGreater(control_sparsity[0], 0.5) quant_aware_annotate_model = ( quantize.quantize_annotate_model(stripped_model)) with self.assertWarnsRegex( Warning, r'No clustering performed on layer quant_conv2d'): quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True)) # Lets train for more epochs to have a chance to scatter clusters model = self._compile_and_fit_conv_model(quant_aware_model, 3) stripped_cqat_model = strip_clustering_cqat(model) # Check the unique weights of a certain layer of # clustered_model and cqat_model, ensuring unchanged layer_nr = 3 num_of_unique_weights_cqat = self._get_number_of_unique_weights( stripped_cqat_model, layer_nr, 'kernel') self.assertEqual(num_of_unique_weights_cqat, nr_original_weights) cqat_sparsity = self._get_sparsity(stripped_cqat_model) self.assertLessEqual(cqat_sparsity[0], control_sparsity[0]) def testPassingNonPrunedModelToPCQAT(self): """Runs PCQAT as CQAT if the input model is not pruned.""" preserve_sparsity = False clustered_model = self._get_clustered_model(preserve_sparsity) clustered_model = cluster.strip_clustering(clustered_model) nr_of_unique_weights_after = self._get_number_of_unique_weights( clustered_model, 0, 'kernel') # Check after plain clustering, if there are no zero weights, # PCQAT falls back to CQAT quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model) ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme(True)) self.compile_and_fit(quant_aware_model) stripped_pcqat_model = strip_clustering_cqat( quant_aware_model) # Check the unique weights of clustered_model and pcqat_model num_of_unique_weights_pcqat = self._get_number_of_unique_weights( stripped_pcqat_model, 1, 'kernel') self.assertAllEqual(nr_of_unique_weights_after, num_of_unique_weights_pcqat) @parameterized.parameters((0.), (2.)) def testPassingModelWithUniformWeightsToPCQAT(self, uniform_weights): """If pruned_clustered_model has uniform weights, it won't break PCQAT.""" preserve_sparsity = True original_model = keras.Sequential([ layers.Dense(5, activation='softmax', input_shape=(10,)), layers.Flatten(), ]) # Manually set all weights to the same value in the Dense layer first_layer_weights = original_model.layers[0].get_weights() first_layer_weights[0][:] = uniform_weights original_model.layers[0].set_weights(first_layer_weights) # Start the sparsity-aware clustering clustering_params = { 'number_of_clusters': 4, 'cluster_centroids_init': cluster_config.CentroidInitialization.LINEAR, 'preserve_sparsity': True } clustered_model = experimental_cluster.cluster_weights( original_model, **clustering_params) clustered_model = cluster.strip_clustering(clustered_model) nr_of_unique_weights_after = self._get_number_of_unique_weights( clustered_model, 0, 'kernel') sparsity_pruning = self._get_sparsity(clustered_model) quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model) ) sparsity_pcqat, unique_weights_pcqat = self._pcqat_training( preserve_sparsity, quant_aware_annotate_model) self.assertAllGreaterEqual(np.array(sparsity_pcqat), sparsity_pruning[0]) self.assertAllEqual(nr_of_unique_weights_after, unique_weights_pcqat) def testTrainableWeightsBehaveCorrectlyDuringPCQAT(self): """PCQAT zero centroid masks stay the same and trainable variables are updating between epochs.""" preserve_sparsity = True clustered_model = self._get_clustered_model(preserve_sparsity) clustered_model = cluster.strip_clustering(clustered_model) # Apply PCQAT quant_aware_annotate_model = ( quantize.quantize_annotate_model(clustered_model) ) quant_aware_model = quantize.quantize_apply( quant_aware_annotate_model, scheme=default_8bit_cluster_preserve_quantize_scheme .Default8BitClusterPreserveQuantizeScheme(True)) quant_aware_model.compile( loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'], ) class CheckCentroidsAndTrainableVarsCallback(keras.callbacks.Callback): """Check the updates of trainable variables and centroid masks.""" def on_epoch_begin(self, batch, logs=None): # Check cluster centroids have the zero in the right position vars_dictionary = self.model.layers[1]._weight_vars[0][2] self.centroid_mask = vars_dictionary['centroids_mask'] self.zero_centroid_index_begin = np.where( self.centroid_mask == 0)[0] # Check trainable weights before training self.layer_kernel = ( self.model.layers[1].weights[3].numpy() ) self.original_weight = vars_dictionary['ori_weights_vars_tf'].numpy() self.centroids = vars_dictionary['cluster_centroids_tf'].numpy() def on_epoch_end(self, batch, logs=None): # Check the index of the zero centroids are not changed after training vars_dictionary = self.model.layers[1]._weight_vars[0][2] self.zero_centroid_index_end = np.where( vars_dictionary['centroids_mask'] == 0)[0] assert np.array_equal( self.zero_centroid_index_begin, self.zero_centroid_index_end ) # Check trainable variables after training are updated assert not np.array_equal( self.layer_kernel, self.model.layers[1].weights[3].numpy() ) assert not np.array_equal( self.original_weight, vars_dictionary['ori_weights_vars_tf'].numpy() ) assert not np.array_equal( self.centroids, vars_dictionary['cluster_centroids_tf'].numpy() ) # Use many epochs to verify layer's kernel weights are updating because # they can stay the same after being trained using only the first batch # of data for instance quant_aware_model.fit( np.random.rand(20, 10), keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5), steps_per_epoch=5, epochs=3, callbacks=[CheckCentroidsAndTrainableVarsCallback()], ) if __name__ == '__main__': tf.test.main()