|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import itertools |
|
|
|
from absl.testing import parameterized |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import clipping |
|
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils |
|
|
|
|
|
if tf.executing_eagerly(): |
|
tf.compat.v1.disable_eager_execution() |
|
|
|
|
|
class ClipByNormEncodingStageTest(test_utils.BaseEncodingStageTest): |
|
|
|
def default_encoding_stage(self): |
|
"""See base class.""" |
|
return clipping.ClipByNormEncodingStage(1.0) |
|
|
|
def default_input(self): |
|
"""See base class.""" |
|
return tf.random.normal([20]) |
|
|
|
@property |
|
def is_lossless(self): |
|
"""See base class.""" |
|
return False |
|
|
|
def common_asserts_for_test_data(self, data): |
|
"""See base class.""" |
|
encoded_x = data.encoded_x[ |
|
clipping.ClipByNormEncodingStage.ENCODED_VALUES_KEY] |
|
|
|
self.assertAllEqual(data.x.shape, encoded_x.shape) |
|
|
|
self.assertAllEqual(encoded_x, data.decoded_x) |
|
|
|
def test_clipping_effective(self): |
|
stage = clipping.ClipByNormEncodingStage(1.0) |
|
test_data = self.run_one_to_many_encode_decode( |
|
stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0])) |
|
self.common_asserts_for_test_data(test_data) |
|
self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x) |
|
|
|
self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x) |
|
|
|
def test_clipping_large_norm_identity(self): |
|
stage = clipping.ClipByNormEncodingStage(1000.0) |
|
test_data = self.run_one_to_many_encode_decode( |
|
stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0])) |
|
self.common_asserts_for_test_data(test_data) |
|
|
|
self.assertAllEqual(test_data.x, test_data.decoded_x) |
|
|
|
@parameterized.parameters(([2,],), ([2, 3],), ([2, 3, 4],)) |
|
def test_different_shapes(self, shape): |
|
stage = clipping.ClipByNormEncodingStage(1.0) |
|
test_data = self.run_one_to_many_encode_decode( |
|
stage, lambda: tf.random.uniform(shape) + 1.0) |
|
self.common_asserts_for_test_data(test_data) |
|
self.assertAllClose(1.0, np.linalg.norm(test_data.decoded_x)) |
|
|
|
@parameterized.parameters( |
|
itertools.product([tf.float32, tf.float64], [tf.float32, tf.float64])) |
|
def test_input_types(self, x_dtype, clip_norm_dtype): |
|
|
|
stage = clipping.ClipByNormEncodingStage( |
|
tf.constant(1.0, clip_norm_dtype)) |
|
x = tf.constant([1.0, 1.0, 1.0, 1.0], dtype=x_dtype) |
|
encode_params, decode_params = stage.get_params() |
|
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
|
decode_params) |
|
test_data = test_utils.TestData(x, encoded_x, decoded_x) |
|
test_data = self.evaluate_test_data(test_data) |
|
|
|
self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x) |
|
|
|
self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x) |
|
|
|
|
|
class ClipByValueEncodingStageTest(test_utils.BaseEncodingStageTest): |
|
|
|
def default_encoding_stage(self): |
|
"""See base class.""" |
|
return clipping.ClipByValueEncodingStage(-1.0, 1.0) |
|
|
|
def default_input(self): |
|
"""See base class.""" |
|
return tf.random.normal([20]) |
|
|
|
@property |
|
def is_lossless(self): |
|
"""See base class.""" |
|
return False |
|
|
|
def common_asserts_for_test_data(self, data): |
|
"""See base class.""" |
|
encoded_x = data.encoded_x[ |
|
clipping.ClipByValueEncodingStage.ENCODED_VALUES_KEY] |
|
|
|
self.assertAllEqual(data.x.shape, encoded_x.shape) |
|
|
|
self.assertAllEqual(encoded_x, data.decoded_x) |
|
|
|
def test_clipping_effective(self): |
|
stage = clipping.ClipByValueEncodingStage(-1.0, 1.0) |
|
test_data = self.run_one_to_many_encode_decode( |
|
stage, lambda: tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0])) |
|
self.common_asserts_for_test_data(test_data) |
|
self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x) |
|
self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x) |
|
|
|
def test_clipping_large_min_max_identity(self): |
|
stage = clipping.ClipByValueEncodingStage(-1000.0, 1000.0) |
|
test_data = self.run_one_to_many_encode_decode(stage, self.default_input) |
|
self.common_asserts_for_test_data(test_data) |
|
|
|
self.assertAllEqual(test_data.x, test_data.decoded_x) |
|
|
|
@parameterized.parameters(([2,],), ([2, 3],), ([2, 3, 4],)) |
|
def test_different_shapes(self, shape): |
|
stage = clipping.ClipByValueEncodingStage(-1.0, 1.0) |
|
test_data = self.run_one_to_many_encode_decode( |
|
stage, lambda: tf.random.normal(shape)) |
|
self.common_asserts_for_test_data(test_data) |
|
self.assertGreaterEqual(1.0, np.amax(test_data.decoded_x)) |
|
self.assertLessEqual(-1.0, np.amin(test_data.decoded_x)) |
|
|
|
@parameterized.parameters( |
|
itertools.product([tf.float32, tf.float64], [tf.float32, tf.float64], |
|
[tf.float32, tf.float64])) |
|
def test_input_types(self, x_dtype, clip_value_min_dtype, |
|
clip_value_max_dtype): |
|
|
|
stage = clipping.ClipByValueEncodingStage( |
|
tf.constant(-1.0, clip_value_min_dtype), |
|
tf.constant(1.0, clip_value_max_dtype)) |
|
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=x_dtype) |
|
encode_params, decode_params = stage.get_params() |
|
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, |
|
decode_params) |
|
test_data = test_utils.TestData(x, encoded_x, decoded_x) |
|
test_data = self.evaluate_test_data(test_data) |
|
|
|
self.common_asserts_for_test_data(test_data) |
|
self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x) |
|
self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|