Spaces:
Runtime error
Runtime error
from tempfile import TemporaryDirectory | |
from unittest import TestCase | |
from unittest.mock import MagicMock, patch | |
from transformers import AutoModel, TFAutoModel | |
from transformers.onnx import FeaturesManager | |
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch | |
class DetermineFrameworkTest(TestCase): | |
""" | |
Test `FeaturesManager.determine_framework` | |
""" | |
def setUp(self): | |
self.test_model = SMALL_MODEL_IDENTIFIER | |
self.framework_pt = "pt" | |
self.framework_tf = "tf" | |
def _setup_pt_ckpt(self, save_dir): | |
model_pt = AutoModel.from_pretrained(self.test_model) | |
model_pt.save_pretrained(save_dir) | |
def _setup_tf_ckpt(self, save_dir): | |
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True) | |
model_tf.save_pretrained(save_dir) | |
def test_framework_provided(self): | |
""" | |
Ensure the that the provided framework is returned. | |
""" | |
mock_framework = "mock_framework" | |
# Framework provided - return whatever the user provides | |
result = FeaturesManager.determine_framework(self.test_model, mock_framework) | |
self.assertEqual(result, mock_framework) | |
# Local checkpoint and framework provided - return provided framework | |
# PyTorch checkpoint | |
with TemporaryDirectory() as local_pt_ckpt: | |
self._setup_pt_ckpt(local_pt_ckpt) | |
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework) | |
self.assertEqual(result, mock_framework) | |
# TensorFlow checkpoint | |
with TemporaryDirectory() as local_tf_ckpt: | |
self._setup_tf_ckpt(local_tf_ckpt) | |
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework) | |
self.assertEqual(result, mock_framework) | |
def test_checkpoint_provided(self): | |
""" | |
Ensure that the determined framework is the one used for the local checkpoint. | |
For the functionality to execute, local checkpoints are provided but framework is not. | |
""" | |
# PyTorch checkpoint | |
with TemporaryDirectory() as local_pt_ckpt: | |
self._setup_pt_ckpt(local_pt_ckpt) | |
result = FeaturesManager.determine_framework(local_pt_ckpt) | |
self.assertEqual(result, self.framework_pt) | |
# TensorFlow checkpoint | |
with TemporaryDirectory() as local_tf_ckpt: | |
self._setup_tf_ckpt(local_tf_ckpt) | |
result = FeaturesManager.determine_framework(local_tf_ckpt) | |
self.assertEqual(result, self.framework_tf) | |
# Invalid local checkpoint | |
with TemporaryDirectory() as local_invalid_ckpt: | |
with self.assertRaises(FileNotFoundError): | |
result = FeaturesManager.determine_framework(local_invalid_ckpt) | |
def test_from_environment(self): | |
""" | |
Ensure that the determined framework is the one available in the environment. | |
For the functionality to execute, framework and local checkpoints are not provided. | |
""" | |
# Framework not provided, hub model is used (no local checkpoint directory) | |
# TensorFlow not in environment -> use PyTorch | |
mock_tf_available = MagicMock(return_value=False) | |
with patch("transformers.onnx.features.is_tf_available", mock_tf_available): | |
result = FeaturesManager.determine_framework(self.test_model) | |
self.assertEqual(result, self.framework_pt) | |
# PyTorch not in environment -> use TensorFlow | |
mock_torch_available = MagicMock(return_value=False) | |
with patch("transformers.onnx.features.is_torch_available", mock_torch_available): | |
result = FeaturesManager.determine_framework(self.test_model) | |
self.assertEqual(result, self.framework_tf) | |
# Both in environment -> use PyTorch | |
mock_tf_available = MagicMock(return_value=True) | |
mock_torch_available = MagicMock(return_value=True) | |
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | |
"transformers.onnx.features.is_torch_available", mock_torch_available | |
): | |
result = FeaturesManager.determine_framework(self.test_model) | |
self.assertEqual(result, self.framework_pt) | |
# Both not in environment -> raise error | |
mock_tf_available = MagicMock(return_value=False) | |
mock_torch_available = MagicMock(return_value=False) | |
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch( | |
"transformers.onnx.features.is_torch_available", mock_torch_available | |
): | |
with self.assertRaises(EnvironmentError): | |
result = FeaturesManager.determine_framework(self.test_model) | |