# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # pylint: disable=missing-docstring,protected-access """Train a simple convnet on the MNIST dataset with sparsity 2x4. It is based on mnist_e2e.py """ from __future__ import print_function from absl import app as absl_app import tensorflow as tf from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils from tensorflow_model_optimization.python.core.keras.compat import keras from tensorflow_model_optimization.python.core.sparsity.keras import prune from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper ConstantSparsity = pruning_schedule.ConstantSparsity l = keras.layers tf.random.set_seed(42) batch_size = 128 num_classes = 10 epochs = 1 PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense) def check_model_sparsity_2x4(model): for layer in model.layers: if isinstance(layer, pruning_wrapper.PruneLowMagnitude) and isinstance( layer.layer, PRUNABLE_2x4_LAYERS): for weight in layer.layer.get_prunable_weights(): if not pruning_utils.is_pruned_m_by_n(weight): return False return True def build_layerwise_model(input_shape, **pruning_params): return keras.Sequential([ prune.prune_low_magnitude( l.Conv2D( 32, 5, padding='same', activation='relu', input_shape=input_shape ), **pruning_params ), l.MaxPooling2D((2, 2), (2, 2), padding='same'), prune.prune_low_magnitude( l.Conv2D(64, 5, padding='same'), **pruning_params ), l.BatchNormalization(), l.ReLU(), l.MaxPooling2D((2, 2), (2, 2), padding='same'), l.Flatten(), prune.prune_low_magnitude( l.Dense(1024, activation='relu'), **pruning_params ), l.Dropout(0.4), l.Dense(num_classes, activation='softmax'), ]) def train(model, x_train, y_train, x_test, y_test): model.compile( loss=keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'], ) model.run_eagerly = True # Print the model summary. model.summary() # Add a pruning step callback to peg the pruning step to the optimizer's # step. Also add a callback to add pruning summaries to tensorboard callbacks = [ pruning_callbacks.UpdatePruningStep(), pruning_callbacks.PruningSummaries(log_dir='/tmp/logs') ] model.fit( x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=callbacks, validation_data=(x_test, y_test)) score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) # Check sparsity 2x4 type before stripping pruning is_pruned_2x4 = check_model_sparsity_2x4(model) print('Pass the check for sparsity 2x4: ', is_pruned_2x4) model = prune.strip_pruning(model) return model def main(unused_argv): ############################################################################## # Prepare training and testing data ############################################################################## (x_train, y_train), ( x_test, y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data() ############################################################################## # Train a model with sparsity 2x4. ############################################################################## pruning_params = { 'pruning_schedule': ConstantSparsity(0.5, begin_step=0, frequency=100), 'sparsity_m_by_n': (2, 4), } model = build_layerwise_model(input_shape, **pruning_params) pruned_model = train(model, x_train, y_train, x_test, y_test) # Write a model that has been pruned with 2x4 sparsity. converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model) tflite_model = converter.convert() tflite_model_path = '/tmp/mnist_2x4.tflite' print('model is saved to {}'.format(tflite_model_path)) with open(tflite_model_path, 'wb') as f: f.write(tflite_model) print('evaluate pruned model: ') print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model)) # the accuracy of 2:4 pruning model is 0.9866 # the accuracy of unstructured model with 50% is 0.9863 if __name__ == '__main__': absl_app.run(main)