Mixtral_ether / mnist_e2e_sparsity2x4.py
jeduardogruiz's picture
Upload 22 files
516a027 verified
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=missing-docstring,protected-access
"""Train a simple convnet on the MNIST dataset with sparsity 2x4.
It is based on mnist_e2e.py
"""
from __future__ import print_function
from absl import app as absl_app
import tensorflow as tf
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.sparsity.keras import prune
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
ConstantSparsity = pruning_schedule.ConstantSparsity
l = keras.layers
tf.random.set_seed(42)
batch_size = 128
num_classes = 10
epochs = 1
PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense)
def check_model_sparsity_2x4(model):
for layer in model.layers:
if isinstance(layer, pruning_wrapper.PruneLowMagnitude) and isinstance(
layer.layer, PRUNABLE_2x4_LAYERS):
for weight in layer.layer.get_prunable_weights():
if not pruning_utils.is_pruned_m_by_n(weight):
return False
return True
def build_layerwise_model(input_shape, **pruning_params):
return keras.Sequential([
prune.prune_low_magnitude(
l.Conv2D(
32, 5, padding='same', activation='relu', input_shape=input_shape
),
**pruning_params
),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
prune.prune_low_magnitude(
l.Conv2D(64, 5, padding='same'), **pruning_params
),
l.BatchNormalization(),
l.ReLU(),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Flatten(),
prune.prune_low_magnitude(
l.Dense(1024, activation='relu'), **pruning_params
),
l.Dropout(0.4),
l.Dense(num_classes, activation='softmax'),
])
def train(model, x_train, y_train, x_test, y_test):
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'],
)
model.run_eagerly = True
# Print the model summary.
model.summary()
# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
pruning_callbacks.UpdatePruningStep(),
pruning_callbacks.PruningSummaries(log_dir='/tmp/logs')
]
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# Check sparsity 2x4 type before stripping pruning
is_pruned_2x4 = check_model_sparsity_2x4(model)
print('Pass the check for sparsity 2x4: ', is_pruned_2x4)
model = prune.strip_pruning(model)
return model
def main(unused_argv):
##############################################################################
# Prepare training and testing data
##############################################################################
(x_train, y_train), (
x_test,
y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data()
##############################################################################
# Train a model with sparsity 2x4.
##############################################################################
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, begin_step=0, frequency=100),
'sparsity_m_by_n': (2, 4),
}
model = build_layerwise_model(input_shape, **pruning_params)
pruned_model = train(model, x_train, y_train, x_test, y_test)
# Write a model that has been pruned with 2x4 sparsity.
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
tflite_model = converter.convert()
tflite_model_path = '/tmp/mnist_2x4.tflite'
print('model is saved to {}'.format(tflite_model_path))
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
print('evaluate pruned model: ')
print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model))
# the accuracy of 2:4 pruning model is 0.9866
# the accuracy of unstructured model with 50% is 0.9863
if __name__ == '__main__':
absl_app.run(main)