Spaces:
Running
on
T4
Running
on
T4
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Generic utilities | |
""" | |
import inspect | |
import tempfile | |
from collections import OrderedDict, UserDict | |
from collections.abc import MutableMapping | |
from contextlib import ExitStack, contextmanager | |
from dataclasses import fields, is_dataclass | |
from enum import Enum | |
from typing import Any, ContextManager, List, Tuple | |
import numpy as np | |
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy | |
if is_flax_available(): | |
import jax.numpy as jnp | |
class cached_property(property): | |
""" | |
Descriptor that mimics @property but caches output in member variable. | |
From tensorflow_datasets | |
Built-in in functools from Python 3.8. | |
""" | |
def __get__(self, obj, objtype=None): | |
# See docs.python.org/3/howto/descriptor.html#properties | |
if obj is None: | |
return self | |
if self.fget is None: | |
raise AttributeError("unreadable attribute") | |
attr = "__cached_" + self.fget.__name__ | |
cached = getattr(obj, attr, None) | |
if cached is None: | |
cached = self.fget(obj) | |
setattr(obj, attr, cached) | |
return cached | |
# vendored from distutils.util | |
def strtobool(val): | |
"""Convert a string representation of truth to true (1) or false (0). | |
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. | |
Raises ValueError if 'val' is anything else. | |
""" | |
val = val.lower() | |
if val in {"y", "yes", "t", "true", "on", "1"}: | |
return 1 | |
if val in {"n", "no", "f", "false", "off", "0"}: | |
return 0 | |
raise ValueError(f"invalid truth value {val!r}") | |
def infer_framework_from_repr(x): | |
""" | |
Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the | |
frameworks in a smart order, without the need to import the frameworks). | |
""" | |
representation = str(type(x)) | |
if representation.startswith("<class 'torch."): | |
return "pt" | |
elif representation.startswith("<class 'tensorflow."): | |
return "tf" | |
elif representation.startswith("<class 'jax"): | |
return "jax" | |
elif representation.startswith("<class 'numpy."): | |
return "np" | |
def _get_frameworks_and_test_func(x): | |
""" | |
Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework | |
we can guess from the repr first, then Numpy, then the others. | |
""" | |
framework_to_test = { | |
"pt": is_torch_tensor, | |
"tf": is_tf_tensor, | |
"jax": is_jax_tensor, | |
"np": is_numpy_array, | |
} | |
preferred_framework = infer_framework_from_repr(x) | |
# We will test this one first, then numpy, then the others. | |
frameworks = [] if preferred_framework is None else [preferred_framework] | |
if preferred_framework != "np": | |
frameworks.append("np") | |
frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]]) | |
return {f: framework_to_test[f] for f in frameworks} | |
def is_tensor(x): | |
""" | |
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray` in the order | |
defined by `infer_framework_from_repr` | |
""" | |
# This gives us a smart order to test the frameworks with the corresponding tests. | |
framework_to_test_func = _get_frameworks_and_test_func(x) | |
for test_func in framework_to_test_func.values(): | |
if test_func(x): | |
return True | |
# Tracers | |
if is_torch_fx_proxy(x): | |
return True | |
if is_flax_available(): | |
from jax.core import Tracer | |
if isinstance(x, Tracer): | |
return True | |
return False | |
def _is_numpy(x): | |
return isinstance(x, np.ndarray) | |
def is_numpy_array(x): | |
""" | |
Tests if `x` is a numpy array or not. | |
""" | |
return _is_numpy(x) | |
def _is_torch(x): | |
import torch | |
return isinstance(x, torch.Tensor) | |
def is_torch_tensor(x): | |
""" | |
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. | |
""" | |
return False if not is_torch_available() else _is_torch(x) | |
def _is_torch_device(x): | |
import torch | |
return isinstance(x, torch.device) | |
def is_torch_device(x): | |
""" | |
Tests if `x` is a torch device or not. Safe to call even if torch is not installed. | |
""" | |
return False if not is_torch_available() else _is_torch_device(x) | |
def _is_torch_dtype(x): | |
import torch | |
if isinstance(x, str): | |
if hasattr(torch, x): | |
x = getattr(torch, x) | |
else: | |
return False | |
return isinstance(x, torch.dtype) | |
def is_torch_dtype(x): | |
""" | |
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed. | |
""" | |
return False if not is_torch_available() else _is_torch_dtype(x) | |
def _is_tensorflow(x): | |
import tensorflow as tf | |
return isinstance(x, tf.Tensor) | |
def is_tf_tensor(x): | |
""" | |
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed. | |
""" | |
return False if not is_tf_available() else _is_tensorflow(x) | |
def _is_tf_symbolic_tensor(x): | |
import tensorflow as tf | |
# the `is_symbolic_tensor` predicate is only available starting with TF 2.14 | |
if hasattr(tf, "is_symbolic_tensor"): | |
return tf.is_symbolic_tensor(x) | |
return type(x) == tf.Tensor | |
def is_tf_symbolic_tensor(x): | |
""" | |
Tests if `x` is a tensorflow symbolic tensor or not (ie. not eager). Safe to call even if tensorflow is not | |
installed. | |
""" | |
return False if not is_tf_available() else _is_tf_symbolic_tensor(x) | |
def _is_jax(x): | |
import jax.numpy as jnp # noqa: F811 | |
return isinstance(x, jnp.ndarray) | |
def is_jax_tensor(x): | |
""" | |
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed. | |
""" | |
return False if not is_flax_available() else _is_jax(x) | |
def to_py_obj(obj): | |
""" | |
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list. | |
""" | |
framework_to_py_obj = { | |
"pt": lambda obj: obj.detach().cpu().tolist(), | |
"tf": lambda obj: obj.numpy().tolist(), | |
"jax": lambda obj: np.asarray(obj).tolist(), | |
"np": lambda obj: obj.tolist(), | |
} | |
if isinstance(obj, (dict, UserDict)): | |
return {k: to_py_obj(v) for k, v in obj.items()} | |
elif isinstance(obj, (list, tuple)): | |
return [to_py_obj(o) for o in obj] | |
# This gives us a smart order to test the frameworks with the corresponding tests. | |
framework_to_test_func = _get_frameworks_and_test_func(obj) | |
for framework, test_func in framework_to_test_func.items(): | |
if test_func(obj): | |
return framework_to_py_obj[framework](obj) | |
# tolist also works on 0d np arrays | |
if isinstance(obj, np.number): | |
return obj.tolist() | |
else: | |
return obj | |
def to_numpy(obj): | |
""" | |
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array. | |
""" | |
framework_to_numpy = { | |
"pt": lambda obj: obj.detach().cpu().numpy(), | |
"tf": lambda obj: obj.numpy(), | |
"jax": lambda obj: np.asarray(obj), | |
"np": lambda obj: obj, | |
} | |
if isinstance(obj, (dict, UserDict)): | |
return {k: to_numpy(v) for k, v in obj.items()} | |
elif isinstance(obj, (list, tuple)): | |
return np.array(obj) | |
# This gives us a smart order to test the frameworks with the corresponding tests. | |
framework_to_test_func = _get_frameworks_and_test_func(obj) | |
for framework, test_func in framework_to_test_func.items(): | |
if test_func(obj): | |
return framework_to_numpy[framework](obj) | |
return obj | |
class ModelOutput(OrderedDict): | |
""" | |
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a | |
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular | |
python dictionary. | |
<Tip warning={true}> | |
You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple | |
before. | |
</Tip> | |
""" | |
def __init_subclass__(cls) -> None: | |
"""Register subclasses as pytree nodes. | |
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with | |
`static_graph=True` with modules that output `ModelOutput` subclasses. | |
""" | |
if is_torch_available(): | |
import torch.utils._pytree | |
torch.utils._pytree._register_pytree_node( | |
cls, | |
torch.utils._pytree._dict_flatten, | |
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), | |
) | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# Subclasses of ModelOutput must use the @dataclass decorator | |
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__ | |
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed | |
# Just need to check that the current class is not ModelOutput | |
is_modeloutput_subclass = self.__class__ != ModelOutput | |
if is_modeloutput_subclass and not is_dataclass(self): | |
raise TypeError( | |
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss." | |
" This is a subclass of ModelOutput and so must use the @dataclass decorator." | |
) | |
def __post_init__(self): | |
"""Check the ModelOutput dataclass. | |
Only occurs if @dataclass decorator has been used. | |
""" | |
class_fields = fields(self) | |
# Safety and consistency checks | |
if not len(class_fields): | |
raise ValueError(f"{self.__class__.__name__} has no fields.") | |
if not all(field.default is None for field in class_fields[1:]): | |
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") | |
first_field = getattr(self, class_fields[0].name) | |
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) | |
if other_fields_are_none and not is_tensor(first_field): | |
if isinstance(first_field, dict): | |
iterator = first_field.items() | |
first_field_iterator = True | |
else: | |
try: | |
iterator = iter(first_field) | |
first_field_iterator = True | |
except TypeError: | |
first_field_iterator = False | |
# if we provided an iterator as first field and the iterator is a (key, value) iterator | |
# set the associated fields | |
if first_field_iterator: | |
for idx, element in enumerate(iterator): | |
if ( | |
not isinstance(element, (list, tuple)) | |
or not len(element) == 2 | |
or not isinstance(element[0], str) | |
): | |
if idx == 0: | |
# If we do not have an iterator of key/values, set it as attribute | |
self[class_fields[0].name] = first_field | |
else: | |
# If we have a mixed iterator, raise an error | |
raise ValueError( | |
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." | |
) | |
break | |
setattr(self, element[0], element[1]) | |
if element[1] is not None: | |
self[element[0]] = element[1] | |
elif first_field is not None: | |
self[class_fields[0].name] = first_field | |
else: | |
for field in class_fields: | |
v = getattr(self, field.name) | |
if v is not None: | |
self[field.name] = v | |
def __delitem__(self, *args, **kwargs): | |
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") | |
def setdefault(self, *args, **kwargs): | |
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") | |
def pop(self, *args, **kwargs): | |
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") | |
def update(self, *args, **kwargs): | |
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") | |
def __getitem__(self, k): | |
if isinstance(k, str): | |
inner_dict = dict(self.items()) | |
return inner_dict[k] | |
else: | |
return self.to_tuple()[k] | |
def __setattr__(self, name, value): | |
if name in self.keys() and value is not None: | |
# Don't call self.__setitem__ to avoid recursion errors | |
super().__setitem__(name, value) | |
super().__setattr__(name, value) | |
def __setitem__(self, key, value): | |
# Will raise a KeyException if needed | |
super().__setitem__(key, value) | |
# Don't call self.__setattr__ to avoid recursion errors | |
super().__setattr__(key, value) | |
def __reduce__(self): | |
if not is_dataclass(self): | |
return super().__reduce__() | |
callable, _args, *remaining = super().__reduce__() | |
args = tuple(getattr(self, field.name) for field in fields(self)) | |
return callable, args, *remaining | |
def to_tuple(self) -> Tuple[Any]: | |
""" | |
Convert self to a tuple containing all the attributes/keys that are not `None`. | |
""" | |
return tuple(self[k] for k in self.keys()) | |
class ExplicitEnum(str, Enum): | |
""" | |
Enum with more explicit error message for missing values. | |
""" | |
def _missing_(cls, value): | |
raise ValueError( | |
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" | |
) | |
class PaddingStrategy(ExplicitEnum): | |
""" | |
Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an | |
IDE. | |
""" | |
LONGEST = "longest" | |
MAX_LENGTH = "max_length" | |
DO_NOT_PAD = "do_not_pad" | |
class TensorType(ExplicitEnum): | |
""" | |
Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for | |
tab-completion in an IDE. | |
""" | |
PYTORCH = "pt" | |
TENSORFLOW = "tf" | |
NUMPY = "np" | |
JAX = "jax" | |
class ContextManagers: | |
""" | |
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` | |
in the `fastcore` library. | |
""" | |
def __init__(self, context_managers: List[ContextManager]): | |
self.context_managers = context_managers | |
self.stack = ExitStack() | |
def __enter__(self): | |
for context_manager in self.context_managers: | |
self.stack.enter_context(context_manager) | |
def __exit__(self, *args, **kwargs): | |
self.stack.__exit__(*args, **kwargs) | |
def can_return_loss(model_class): | |
""" | |
Check if a given model can return loss. | |
Args: | |
model_class (`type`): The class of the model. | |
""" | |
framework = infer_framework(model_class) | |
if framework == "tf": | |
signature = inspect.signature(model_class.call) # TensorFlow models | |
elif framework == "pt": | |
signature = inspect.signature(model_class.forward) # PyTorch models | |
else: | |
signature = inspect.signature(model_class.__call__) # Flax models | |
for p in signature.parameters: | |
if p == "return_loss" and signature.parameters[p].default is True: | |
return True | |
return False | |
def find_labels(model_class): | |
""" | |
Find the labels used by a given model. | |
Args: | |
model_class (`type`): The class of the model. | |
""" | |
model_name = model_class.__name__ | |
framework = infer_framework(model_class) | |
if framework == "tf": | |
signature = inspect.signature(model_class.call) # TensorFlow models | |
elif framework == "pt": | |
signature = inspect.signature(model_class.forward) # PyTorch models | |
else: | |
signature = inspect.signature(model_class.__call__) # Flax models | |
if "QuestionAnswering" in model_name: | |
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")] | |
else: | |
return [p for p in signature.parameters if "label" in p] | |
def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."): | |
"""Flatten a nested dict into a single level dict.""" | |
def _flatten_dict(d, parent_key="", delimiter="."): | |
for k, v in d.items(): | |
key = str(parent_key) + delimiter + str(k) if parent_key else k | |
if v and isinstance(v, MutableMapping): | |
yield from flatten_dict(v, key, delimiter=delimiter).items() | |
else: | |
yield key, v | |
return dict(_flatten_dict(d, parent_key, delimiter)) | |
def working_or_temp_dir(working_dir, use_temp_dir: bool = False): | |
if use_temp_dir: | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
yield tmp_dir | |
else: | |
yield working_dir | |
def transpose(array, axes=None): | |
""" | |
Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy | |
arrays. | |
""" | |
if is_numpy_array(array): | |
return np.transpose(array, axes=axes) | |
elif is_torch_tensor(array): | |
return array.T if axes is None else array.permute(*axes) | |
elif is_tf_tensor(array): | |
import tensorflow as tf | |
return tf.transpose(array, perm=axes) | |
elif is_jax_tensor(array): | |
return jnp.transpose(array, axes=axes) | |
else: | |
raise ValueError(f"Type not supported for transpose: {type(array)}.") | |
def reshape(array, newshape): | |
""" | |
Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy | |
arrays. | |
""" | |
if is_numpy_array(array): | |
return np.reshape(array, newshape) | |
elif is_torch_tensor(array): | |
return array.reshape(*newshape) | |
elif is_tf_tensor(array): | |
import tensorflow as tf | |
return tf.reshape(array, newshape) | |
elif is_jax_tensor(array): | |
return jnp.reshape(array, newshape) | |
else: | |
raise ValueError(f"Type not supported for reshape: {type(array)}.") | |
def squeeze(array, axis=None): | |
""" | |
Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy | |
arrays. | |
""" | |
if is_numpy_array(array): | |
return np.squeeze(array, axis=axis) | |
elif is_torch_tensor(array): | |
return array.squeeze() if axis is None else array.squeeze(dim=axis) | |
elif is_tf_tensor(array): | |
import tensorflow as tf | |
return tf.squeeze(array, axis=axis) | |
elif is_jax_tensor(array): | |
return jnp.squeeze(array, axis=axis) | |
else: | |
raise ValueError(f"Type not supported for squeeze: {type(array)}.") | |
def expand_dims(array, axis): | |
""" | |
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy | |
arrays. | |
""" | |
if is_numpy_array(array): | |
return np.expand_dims(array, axis) | |
elif is_torch_tensor(array): | |
return array.unsqueeze(dim=axis) | |
elif is_tf_tensor(array): | |
import tensorflow as tf | |
return tf.expand_dims(array, axis=axis) | |
elif is_jax_tensor(array): | |
return jnp.expand_dims(array, axis=axis) | |
else: | |
raise ValueError(f"Type not supported for expand_dims: {type(array)}.") | |
def tensor_size(array): | |
""" | |
Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays. | |
""" | |
if is_numpy_array(array): | |
return np.size(array) | |
elif is_torch_tensor(array): | |
return array.numel() | |
elif is_tf_tensor(array): | |
import tensorflow as tf | |
return tf.size(array) | |
elif is_jax_tensor(array): | |
return array.size | |
else: | |
raise ValueError(f"Type not supported for expand_dims: {type(array)}.") | |
def add_model_info_to_auto_map(auto_map, repo_id): | |
""" | |
Adds the information of the repo_id to a given auto map. | |
""" | |
for key, value in auto_map.items(): | |
if isinstance(value, (tuple, list)): | |
auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value] | |
elif value is not None and "--" not in value: | |
auto_map[key] = f"{repo_id}--{value}" | |
return auto_map | |
def infer_framework(model_class): | |
""" | |
Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant | |
classes are imported or available. | |
""" | |
for base_class in inspect.getmro(model_class): | |
module = base_class.__module__ | |
name = base_class.__name__ | |
if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel": | |
return "tf" | |
elif module.startswith("torch") or name == "PreTrainedModel": | |
return "pt" | |
elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel": | |
return "flax" | |
else: | |
raise TypeError(f"Could not infer framework from class {model_class}.") | |