|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Train a simple convnet on the MNIST dataset.""" |
|
from __future__ import print_function |
|
|
|
from absl import app as absl_app |
|
from absl import flags |
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.keras.compat import keras |
|
from tensorflow_model_optimization.python.core.sparsity.keras import prune |
|
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks |
|
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule |
|
|
|
|
|
PolynomialDecay = pruning_schedule.PolynomialDecay |
|
l = keras.layers |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
batch_size = 128 |
|
num_classes = 10 |
|
epochs = 12 |
|
|
|
flags.DEFINE_string('output_dir', '/tmp/mnist_train/', |
|
'Output directory to hold tensorboard events') |
|
|
|
|
|
def build_sequential_model(input_shape): |
|
return keras.Sequential([ |
|
l.Conv2D( |
|
32, 5, padding='same', activation='relu', input_shape=input_shape |
|
), |
|
l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
|
l.BatchNormalization(), |
|
l.Conv2D(64, 5, padding='same', activation='relu'), |
|
l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
|
l.Flatten(), |
|
l.Dense(1024, activation='relu'), |
|
l.Dropout(0.4), |
|
l.Dense(num_classes, activation='softmax'), |
|
]) |
|
|
|
|
|
def build_functional_model(input_shape): |
|
inp = keras.Input(shape=input_shape) |
|
x = l.Conv2D(32, 5, padding='same', activation='relu')(inp) |
|
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
|
x = l.BatchNormalization()(x) |
|
x = l.Conv2D(64, 5, padding='same', activation='relu')(x) |
|
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) |
|
x = l.Flatten()(x) |
|
x = l.Dense(1024, activation='relu')(x) |
|
x = l.Dropout(0.4)(x) |
|
out = l.Dense(num_classes, activation='softmax')(x) |
|
|
|
return keras.models.Model([inp], [out]) |
|
|
|
|
|
def build_layerwise_model(input_shape, **pruning_params): |
|
return keras.Sequential([ |
|
prune.prune_low_magnitude( |
|
l.Conv2D(32, 5, padding='same', activation='relu'), |
|
input_shape=input_shape, |
|
**pruning_params |
|
), |
|
l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
|
l.BatchNormalization(), |
|
prune.prune_low_magnitude( |
|
l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params |
|
), |
|
l.MaxPooling2D((2, 2), (2, 2), padding='same'), |
|
l.Flatten(), |
|
prune.prune_low_magnitude( |
|
l.Dense(1024, activation='relu'), **pruning_params |
|
), |
|
l.Dropout(0.4), |
|
prune.prune_low_magnitude( |
|
l.Dense(num_classes, activation='softmax'), **pruning_params |
|
), |
|
]) |
|
|
|
|
|
def train_and_save(models, x_train, y_train, x_test, y_test): |
|
for model in models: |
|
model.compile( |
|
loss=keras.losses.categorical_crossentropy, |
|
optimizer='adam', |
|
metrics=['accuracy'], |
|
) |
|
|
|
|
|
model.summary() |
|
|
|
|
|
|
|
callbacks = [ |
|
pruning_callbacks.UpdatePruningStep(), |
|
pruning_callbacks.PruningSummaries(log_dir=FLAGS.output_dir) |
|
] |
|
|
|
model.fit( |
|
x_train, |
|
y_train, |
|
batch_size=batch_size, |
|
epochs=epochs, |
|
verbose=1, |
|
callbacks=callbacks, |
|
validation_data=(x_test, y_test)) |
|
score = model.evaluate(x_test, y_test, verbose=0) |
|
print('Test loss:', score[0]) |
|
print('Test accuracy:', score[1]) |
|
|
|
|
|
saved_model_dir = '/tmp/saved_model' |
|
print('Saving model to: ', saved_model_dir) |
|
keras.models.save_model(model, saved_model_dir, save_format='tf') |
|
print('Loading model from: ', saved_model_dir) |
|
loaded_model = keras.models.load_model(saved_model_dir) |
|
|
|
score = loaded_model.evaluate(x_test, y_test, verbose=0) |
|
print('Test loss:', score[0]) |
|
print('Test accuracy:', score[1]) |
|
|
|
|
|
def main(unused_argv): |
|
|
|
img_rows, img_cols = 28, 28 |
|
|
|
|
|
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() |
|
|
|
if keras.backend.image_data_format() == 'channels_first': |
|
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) |
|
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) |
|
input_shape = (1, img_rows, img_cols) |
|
else: |
|
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) |
|
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) |
|
input_shape = (img_rows, img_cols, 1) |
|
|
|
x_train = x_train.astype('float32') |
|
x_test = x_test.astype('float32') |
|
x_train /= 255 |
|
x_test /= 255 |
|
print('x_train shape:', x_train.shape) |
|
print(x_train.shape[0], 'train samples') |
|
print(x_test.shape[0], 'test samples') |
|
|
|
|
|
y_train = keras.utils.to_categorical(y_train, num_classes) |
|
y_test = keras.utils.to_categorical(y_test, num_classes) |
|
|
|
pruning_params = { |
|
'pruning_schedule': |
|
PolynomialDecay( |
|
initial_sparsity=0.1, |
|
final_sparsity=0.75, |
|
begin_step=1000, |
|
end_step=5000, |
|
frequency=100) |
|
} |
|
|
|
layerwise_model = build_layerwise_model(input_shape, **pruning_params) |
|
sequential_model = build_sequential_model(input_shape) |
|
sequential_model = prune.prune_low_magnitude( |
|
sequential_model, **pruning_params) |
|
functional_model = build_functional_model(input_shape) |
|
functional_model = prune.prune_low_magnitude( |
|
functional_model, **pruning_params) |
|
|
|
models = [layerwise_model, sequential_model, functional_model] |
|
train_and_save(models, x_train, y_train, x_test, y_test) |
|
|
|
|
|
if __name__ == '__main__': |
|
absl_app.run(main) |
|
|