File size: 5,233 Bytes
516a027 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=missing-docstring,protected-access
"""Train a simple convnet on the MNIST dataset with sparsity 2x4.
It is based on mnist_e2e.py
"""
from __future__ import print_function
from absl import app as absl_app
import tensorflow as tf
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.sparsity.keras import prune
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
ConstantSparsity = pruning_schedule.ConstantSparsity
l = keras.layers
tf.random.set_seed(42)
batch_size = 128
num_classes = 10
epochs = 1
PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense)
def check_model_sparsity_2x4(model):
for layer in model.layers:
if isinstance(layer, pruning_wrapper.PruneLowMagnitude) and isinstance(
layer.layer, PRUNABLE_2x4_LAYERS):
for weight in layer.layer.get_prunable_weights():
if not pruning_utils.is_pruned_m_by_n(weight):
return False
return True
def build_layerwise_model(input_shape, **pruning_params):
return keras.Sequential([
prune.prune_low_magnitude(
l.Conv2D(
32, 5, padding='same', activation='relu', input_shape=input_shape
),
**pruning_params
),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
prune.prune_low_magnitude(
l.Conv2D(64, 5, padding='same'), **pruning_params
),
l.BatchNormalization(),
l.ReLU(),
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
l.Flatten(),
prune.prune_low_magnitude(
l.Dense(1024, activation='relu'), **pruning_params
),
l.Dropout(0.4),
l.Dense(num_classes, activation='softmax'),
])
def train(model, x_train, y_train, x_test, y_test):
model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'],
)
model.run_eagerly = True
# Print the model summary.
model.summary()
# Add a pruning step callback to peg the pruning step to the optimizer's
# step. Also add a callback to add pruning summaries to tensorboard
callbacks = [
pruning_callbacks.UpdatePruningStep(),
pruning_callbacks.PruningSummaries(log_dir='/tmp/logs')
]
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
callbacks=callbacks,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# Check sparsity 2x4 type before stripping pruning
is_pruned_2x4 = check_model_sparsity_2x4(model)
print('Pass the check for sparsity 2x4: ', is_pruned_2x4)
model = prune.strip_pruning(model)
return model
def main(unused_argv):
##############################################################################
# Prepare training and testing data
##############################################################################
(x_train, y_train), (
x_test,
y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data()
##############################################################################
# Train a model with sparsity 2x4.
##############################################################################
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, begin_step=0, frequency=100),
'sparsity_m_by_n': (2, 4),
}
model = build_layerwise_model(input_shape, **pruning_params)
pruned_model = train(model, x_train, y_train, x_test, y_test)
# Write a model that has been pruned with 2x4 sparsity.
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model)
tflite_model = converter.convert()
tflite_model_path = '/tmp/mnist_2x4.tflite'
print('model is saved to {}'.format(tflite_model_path))
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
print('evaluate pruned model: ')
print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model))
# the accuracy of 2:4 pruning model is 0.9866
# the accuracy of unstructured model with 50% is 0.9863
if __name__ == '__main__':
absl_app.run(main)
|