Mixtral_ether / misc_test.py
jeduardogruiz's picture
Upload 22 files
516a027 verified
raw
history blame
7.6 kB
# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import misc
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
if tf.executing_eagerly():
tf.compat.v1.disable_eager_execution()
class SplitBySmallValueEncodingStageTest(test_utils.BaseEncodingStageTest):
def default_encoding_stage(self):
"""See base class."""
return misc.SplitBySmallValueEncodingStage()
def default_input(self):
"""See base class."""
return tf.random.uniform([50], minval=-1.0, maxval=1.0)
@property
def is_lossless(self):
"""See base class."""
return False
def common_asserts_for_test_data(self, data):
"""See base class."""
self._assert_is_integer(
data.encoded_x[misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
def _assert_is_integer(self, indices):
"""Asserts that indices values are integers."""
assert indices.dtype == np.int32
@parameterized.parameters([tf.float32, tf.float64])
def test_input_types(self, x_dtype):
# Tests different input dtypes.
x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype)
threshold = 0.05
stage = misc.SplitBySmallValueEncodingStage(threshold=threshold)
encode_params, decode_params = stage.get_params()
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
decode_params)
test_data = test_utils.TestData(x, encoded_x, decoded_x)
test_data = self.evaluate_test_data(test_data)
self._assert_is_integer(test_data.encoded_x[
misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
# The numpy arrays must have the same dtype as the arrays from test_data.
expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype)
expected_encoded_indices = np.array([0, 1], dtype=np.int32)
expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.],
dtype=x_dtype.as_numpy_dtype)
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY],
expected_encoded_values)
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
expected_encoded_indices)
self.assertAllEqual(test_data.decoded_x, expected_decoded_x)
def test_all_zero_input_works(self):
# Tests that encoding does not blow up with all-zero input. With all-zero
# input, both of the encoded values will be empty arrays.
stage = misc.SplitBySmallValueEncodingStage()
test_data = self.run_one_to_many_encode_decode(stage,
lambda: tf.zeros([50]))
self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x)
def test_all_below_threshold_works(self):
# Tests that encoding does not blow up with all-below-threshold input. In
# this case, both of the encoded values will be empty arrays.
stage = misc.SplitBySmallValueEncodingStage(threshold=0.1)
x = tf.random.uniform([50], minval=-0.01, maxval=0.01)
encode_params, decode_params = stage.get_params()
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
decode_params)
test_data = test_utils.TestData(x, encoded_x, decoded_x)
test_data = self.evaluate_test_data(test_data)
expected_encoded_indices = np.array([], dtype=np.int32).reshape([0])
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], [])
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
expected_encoded_indices)
self.assertAllEqual(test_data.decoded_x,
np.zeros([50], dtype=x.dtype.as_numpy_dtype))
class DifferenceBetweenIntegersEncodingStageTest(
test_utils.BaseEncodingStageTest):
def default_encoding_stage(self):
"""See base class."""
return misc.DifferenceBetweenIntegersEncodingStage()
def default_input(self):
"""See base class."""
return tf.random.uniform([10], minval=0, maxval=10, dtype=tf.int64)
@property
def is_lossless(self):
"""See base class."""
return True
def common_asserts_for_test_data(self, data):
"""See base class."""
self.assertAllEqual(data.x, data.decoded_x)
@parameterized.parameters(
itertools.product([[1,], [2,], [10,]], [tf.int32, tf.int64]))
def test_with_multiple_input_shapes(self, input_dims, dtype):
def x_fn():
return tf.random.uniform(input_dims, minval=0, maxval=10, dtype=dtype)
test_data = self.run_one_to_many_encode_decode(
self.default_encoding_stage(), x_fn)
self.common_asserts_for_test_data(test_data)
def test_empty_input_static(self):
# Tests that the encoding works when the input shape is [0].
x = []
x = tf.convert_to_tensor(x, dtype=tf.int32)
assert x.shape.as_list() == [0]
stage = self.default_encoding_stage()
encode_params, decode_params = stage.get_params()
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
decode_params)
test_data = self.evaluate_test_data(
test_utils.TestData(x, encoded_x, decoded_x))
self.common_asserts_for_test_data(test_data)
def test_empty_input_dynamic(self):
# Tests that the encoding works when the input shape is [0], but not
# statically known.
y = tf.zeros((10,))
indices = tf.compat.v2.where(tf.abs(y) > 1e-8)
x = tf.gather_nd(y, indices)
x = tf.cast(x, tf.int32) # Empty tensor.
assert x.shape.as_list() == [None]
stage = self.default_encoding_stage()
encode_params, decode_params = stage.get_params()
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
decode_params)
test_data = self.evaluate_test_data(
test_utils.TestData(x, encoded_x, decoded_x))
assert test_data.x.shape == (0,)
assert test_data.encoded_x[stage.ENCODED_VALUES_KEY].shape == (0,)
assert test_data.decoded_x.shape == (0,)
@parameterized.parameters([tf.bool, tf.float32])
def test_encode_unsupported_type_raises(self, dtype):
stage = self.default_encoding_stage()
with self.assertRaisesRegexp(TypeError, 'Unsupported input type'):
self.run_one_to_many_encode_decode(
stage, lambda: tf.cast(self.default_input(), dtype))
def test_encode_unsupported_input_shape_raises(self):
x = tf.random.uniform((3, 4), maxval=10, dtype=tf.int32)
stage = self.default_encoding_stage()
params, _ = stage.get_params()
with self.assertRaisesRegexp(ValueError, 'Number of dimensions must be 1'):
stage.encode(x, params)
if __name__ == '__main__':
tf.test.main()