|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Train a simple convnet on the MNIST dataset with sparsity 2x4. |
|
|
|
It is based on mnist_e2e.py |
|
""" |
|
from __future__ import print_function |
|
|
|
from absl import app as absl_app |
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils |
|
from tensorflow_model_optimization.python.core.keras.compat import keras |
|
from tensorflow_model_optimization.python.core.sparsity.keras import prune |
|
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks |
|
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule |
|
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils |
|
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper |
|
|
|
|
|
ConstantSparsity = pruning_schedule.ConstantSparsity |
|
l = keras.layers |
|
|
|
tf.random.set_seed(42) |
|
|
|
batch_size = 128 |
|
num_classes = 10 |
|
epochs = 1 |
|
|
|
PRUNABLE_2x4_LAYERS = (keras.layers.Conv2D, keras.layers.Dense) |
|
|
|
|
|
def check_model_sparsity_2x4(model): |
|
for layer in model.layers: |
|
if isinstance(layer, pruning_wrapper.PruneLowMagnitude) and isinstance( |
|
layer.layer, PRUNABLE_2x4_LAYERS): |
|
for weight in layer.layer.get_prunable_weights(): |
|
if not pruning_utils.is_pruned_m_by_n(weight): |
|
return False |
|
return True |
|
|
|
|
|
def build_layerwise_model(input_shape, **pruning_params): |
|
return keras.Sequential([ |
|
prune.prune_low_magnitude( |
|
l.Conv2D( |
|
32, 5, padding='same', activation='relu', input_shape=input_shape |
|
), |
|
**pruning_params |
|
), |
|
l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
|
prune.prune_low_magnitude( |
|
l.Conv2D(64, 5, padding='same'), **pruning_params |
|
), |
|
l.BatchNormalization(), |
|
l.ReLU(), |
|
l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
|
l.Flatten(), |
|
prune.prune_low_magnitude( |
|
l.Dense(1024, activation='relu'), **pruning_params |
|
), |
|
l.Dropout(0.4), |
|
l.Dense(num_classes, activation='softmax'), |
|
]) |
|
|
|
|
|
def train(model, x_train, y_train, x_test, y_test): |
|
model.compile( |
|
loss=keras.losses.categorical_crossentropy, |
|
optimizer='adam', |
|
metrics=['accuracy'], |
|
) |
|
model.run_eagerly = True |
|
|
|
|
|
model.summary() |
|
|
|
|
|
|
|
callbacks = [ |
|
pruning_callbacks.UpdatePruningStep(), |
|
pruning_callbacks.PruningSummaries(log_dir='/tmp/logs') |
|
] |
|
|
|
model.fit( |
|
x_train, |
|
y_train, |
|
batch_size=batch_size, |
|
epochs=epochs, |
|
verbose=1, |
|
callbacks=callbacks, |
|
validation_data=(x_test, y_test)) |
|
score = model.evaluate(x_test, y_test, verbose=0) |
|
print('Test loss:', score[0]) |
|
print('Test accuracy:', score[1]) |
|
|
|
|
|
is_pruned_2x4 = check_model_sparsity_2x4(model) |
|
print('Pass the check for sparsity 2x4: ', is_pruned_2x4) |
|
|
|
model = prune.strip_pruning(model) |
|
return model |
|
|
|
|
|
def main(unused_argv): |
|
|
|
|
|
|
|
(x_train, y_train), ( |
|
x_test, |
|
y_test), input_shape = keras_test_utils.get_preprocessed_mnist_data() |
|
|
|
|
|
|
|
|
|
pruning_params = { |
|
'pruning_schedule': ConstantSparsity(0.5, begin_step=0, frequency=100), |
|
'sparsity_m_by_n': (2, 4), |
|
} |
|
|
|
model = build_layerwise_model(input_shape, **pruning_params) |
|
pruned_model = train(model, x_train, y_train, x_test, y_test) |
|
|
|
|
|
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_model) |
|
tflite_model = converter.convert() |
|
|
|
tflite_model_path = '/tmp/mnist_2x4.tflite' |
|
print('model is saved to {}'.format(tflite_model_path)) |
|
with open(tflite_model_path, 'wb') as f: |
|
f.write(tflite_model) |
|
|
|
print('evaluate pruned model: ') |
|
print(keras_test_utils.eval_mnist_tflite(model_content=tflite_model)) |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
absl_app.run(main) |
|
|