AoiKazama commited on
Commit
b10cc28
1 Parent(s): ab9e7f0

Upload import_utils.py

Browse files
Files changed (1) hide show
  1. import_utils.py +1583 -0
import_utils.py ADDED
@@ -0,0 +1,1583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Import utilities: Utilities related to imports and our lazy inits.
16
+ """
17
+
18
+ import importlib.metadata
19
+ import importlib.util
20
+ import json
21
+ import os
22
+ import shutil
23
+ import subprocess
24
+ import sys
25
+ import warnings
26
+ from collections import OrderedDict
27
+ from functools import lru_cache
28
+ from itertools import chain
29
+ from types import ModuleType
30
+ from typing import Any, Tuple, Union
31
+
32
+ from packaging import version
33
+
34
+ from transformers import logging
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
41
+ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
42
+ # Check if the package spec exists and grab its version to avoid importing a local directory
43
+ package_exists = importlib.util.find_spec(pkg_name) is not None
44
+ package_version = "N/A"
45
+ if package_exists:
46
+ try:
47
+ # Primary method to get the package version
48
+ package_version = importlib.metadata.version(pkg_name)
49
+ except importlib.metadata.PackageNotFoundError:
50
+ # Fallback method: Only for "torch" and versions containing "dev"
51
+ if pkg_name == "torch":
52
+ try:
53
+ package = importlib.import_module(pkg_name)
54
+ temp_version = getattr(package, "__version__", "N/A")
55
+ # Check if the version contains "dev"
56
+ if "dev" in temp_version:
57
+ package_version = temp_version
58
+ package_exists = True
59
+ else:
60
+ package_exists = False
61
+ except ImportError:
62
+ # If the package can't be imported, it's not available
63
+ package_exists = False
64
+ else:
65
+ # For packages other than "torch", don't attempt the fallback and set as not available
66
+ package_exists = False
67
+ logger.debug(f"Detected {pkg_name} version: {package_version}")
68
+ if return_version:
69
+ return package_exists, package_version
70
+ else:
71
+ return package_exists
72
+
73
+
74
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
75
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
76
+
77
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
78
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
79
+ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
80
+
81
+ # Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0.
82
+ USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
83
+
84
+ FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
85
+
86
+ # `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it.
87
+ # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
88
+ TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
89
+
90
+ ACCELERATE_MIN_VERSION = "0.21.0"
91
+ FSDP_MIN_VERSION = "1.12.0"
92
+ XLA_FSDPV2_MIN_VERSION = "2.2.0"
93
+
94
+
95
+ _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
96
+ _apex_available = _is_package_available("apex")
97
+ _aqlm_available = _is_package_available("aqlm")
98
+ _av_available = importlib.util.find_spec("av") is not None
99
+ _bitsandbytes_available = _is_package_available("bitsandbytes")
100
+ _eetq_available = _is_package_available("eetq")
101
+ _galore_torch_available = _is_package_available("galore_torch")
102
+ _lomo_available = _is_package_available("lomo_optim")
103
+ # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
104
+ _bs4_available = importlib.util.find_spec("bs4") is not None
105
+ _coloredlogs_available = _is_package_available("coloredlogs")
106
+ # `importlib.metadata.util` doesn't work with `opencv-python-headless`.
107
+ _cv2_available = importlib.util.find_spec("cv2") is not None
108
+ _datasets_available = _is_package_available("datasets")
109
+ _decord_available = importlib.util.find_spec("decord") is not None
110
+ _detectron2_available = _is_package_available("detectron2")
111
+ # We need to check both `faiss` and `faiss-cpu`.
112
+ _faiss_available = importlib.util.find_spec("faiss") is not None
113
+ try:
114
+ _faiss_version = importlib.metadata.version("faiss")
115
+ logger.debug(f"Successfully imported faiss version {_faiss_version}")
116
+ except importlib.metadata.PackageNotFoundError:
117
+ try:
118
+ _faiss_version = importlib.metadata.version("faiss-cpu")
119
+ logger.debug(f"Successfully imported faiss version {_faiss_version}")
120
+ except importlib.metadata.PackageNotFoundError:
121
+ _faiss_available = False
122
+ _ftfy_available = _is_package_available("ftfy")
123
+ _g2p_en_available = _is_package_available("g2p_en")
124
+ _ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
125
+ _jieba_available = _is_package_available("jieba")
126
+ _jinja_available = _is_package_available("jinja2")
127
+ _kenlm_available = _is_package_available("kenlm")
128
+ _keras_nlp_available = _is_package_available("keras_nlp")
129
+ _levenshtein_available = _is_package_available("Levenshtein")
130
+ _librosa_available = _is_package_available("librosa")
131
+ _natten_available = _is_package_available("natten")
132
+ _nltk_available = _is_package_available("nltk")
133
+ _onnx_available = _is_package_available("onnx")
134
+ _openai_available = _is_package_available("openai")
135
+ _optimum_available = _is_package_available("optimum")
136
+ _auto_gptq_available = _is_package_available("auto_gptq")
137
+ # `importlib.metadata.version` doesn't work with `awq`
138
+ _auto_awq_available = importlib.util.find_spec("awq") is not None
139
+ _quanto_available = _is_package_available("quanto")
140
+ _pandas_available = _is_package_available("pandas")
141
+ _peft_available = _is_package_available("peft")
142
+ _phonemizer_available = _is_package_available("phonemizer")
143
+ _psutil_available = _is_package_available("psutil")
144
+ _py3nvml_available = _is_package_available("py3nvml")
145
+ _pyctcdecode_available = _is_package_available("pyctcdecode")
146
+ _pygments_available = _is_package_available("pygments")
147
+ _pytesseract_available = _is_package_available("pytesseract")
148
+ _pytest_available = _is_package_available("pytest")
149
+ _pytorch_quantization_available = _is_package_available("pytorch_quantization")
150
+ _rjieba_available = _is_package_available("rjieba")
151
+ _sacremoses_available = _is_package_available("sacremoses")
152
+ _safetensors_available = _is_package_available("safetensors")
153
+ _scipy_available = _is_package_available("scipy")
154
+ _sentencepiece_available = _is_package_available("sentencepiece")
155
+ _is_seqio_available = _is_package_available("seqio")
156
+ _is_gguf_available = _is_package_available("gguf")
157
+ _sklearn_available = importlib.util.find_spec("sklearn") is not None
158
+ if _sklearn_available:
159
+ try:
160
+ importlib.metadata.version("scikit-learn")
161
+ except importlib.metadata.PackageNotFoundError:
162
+ _sklearn_available = False
163
+ _smdistributed_available = importlib.util.find_spec("smdistributed") is not None
164
+ _soundfile_available = _is_package_available("soundfile")
165
+ _spacy_available = _is_package_available("spacy")
166
+ _sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True)
167
+ _tensorflow_probability_available = _is_package_available("tensorflow_probability")
168
+ _tensorflow_text_available = _is_package_available("tensorflow_text")
169
+ _tf2onnx_available = _is_package_available("tf2onnx")
170
+ _timm_available = _is_package_available("timm")
171
+ _tokenizers_available = _is_package_available("tokenizers")
172
+ _torchaudio_available = _is_package_available("torchaudio")
173
+ _torchdistx_available = _is_package_available("torchdistx")
174
+ _torchvision_available = _is_package_available("torchvision")
175
+ _mlx_available = _is_package_available("mlx")
176
+ _hqq_available = _is_package_available("hqq")
177
+
178
+
179
+ _torch_version = "N/A"
180
+ _torch_available = False
181
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
182
+ _torch_available, _torch_version = _is_package_available("torch", return_version=True)
183
+ else:
184
+ logger.info("Disabling PyTorch because USE_TF is set")
185
+ _torch_available = False
186
+
187
+
188
+ _tf_version = "N/A"
189
+ _tf_available = False
190
+ if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
191
+ _tf_available = True
192
+ else:
193
+ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
194
+ # Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below
195
+ # with tensorflow-cpu to make sure it still works!
196
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
197
+ if _tf_available:
198
+ candidates = (
199
+ "tensorflow",
200
+ "tensorflow-cpu",
201
+ "tensorflow-gpu",
202
+ "tf-nightly",
203
+ "tf-nightly-cpu",
204
+ "tf-nightly-gpu",
205
+ "tf-nightly-rocm",
206
+ "intel-tensorflow",
207
+ "intel-tensorflow-avx512",
208
+ "tensorflow-rocm",
209
+ "tensorflow-macos",
210
+ "tensorflow-aarch64",
211
+ )
212
+ _tf_version = None
213
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
214
+ for pkg in candidates:
215
+ try:
216
+ _tf_version = importlib.metadata.version(pkg)
217
+ break
218
+ except importlib.metadata.PackageNotFoundError:
219
+ pass
220
+ _tf_available = _tf_version is not None
221
+ if _tf_available:
222
+ if version.parse(_tf_version) < version.parse("2"):
223
+ logger.info(
224
+ f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
225
+ )
226
+ _tf_available = False
227
+ else:
228
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
229
+
230
+
231
+ _essentia_available = importlib.util.find_spec("essentia") is not None
232
+ try:
233
+ _essentia_version = importlib.metadata.version("essentia")
234
+ logger.debug(f"Successfully imported essentia version {_essentia_version}")
235
+ except importlib.metadata.PackageNotFoundError:
236
+ _essentia_version = False
237
+
238
+
239
+ _pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
240
+ try:
241
+ _pretty_midi_version = importlib.metadata.version("pretty_midi")
242
+ logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
243
+ except importlib.metadata.PackageNotFoundError:
244
+ _pretty_midi_available = False
245
+
246
+
247
+ ccl_version = "N/A"
248
+ _is_ccl_available = (
249
+ importlib.util.find_spec("torch_ccl") is not None
250
+ or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
251
+ )
252
+ try:
253
+ ccl_version = importlib.metadata.version("oneccl_bind_pt")
254
+ logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
255
+ except importlib.metadata.PackageNotFoundError:
256
+ _is_ccl_available = False
257
+
258
+
259
+ _flax_available = False
260
+ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
261
+ _flax_available, _flax_version = _is_package_available("flax", return_version=True)
262
+ if _flax_available:
263
+ _jax_available, _jax_version = _is_package_available("jax", return_version=True)
264
+ if _jax_available:
265
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
266
+ else:
267
+ _flax_available = _jax_available = False
268
+ _jax_version = _flax_version = "N/A"
269
+
270
+
271
+ _torch_fx_available = False
272
+ if _torch_available:
273
+ torch_version = version.parse(_torch_version)
274
+ _torch_fx_available = (torch_version.major, torch_version.minor) >= (
275
+ TORCH_FX_REQUIRED_VERSION.major,
276
+ TORCH_FX_REQUIRED_VERSION.minor,
277
+ )
278
+
279
+
280
+ _torch_xla_available = False
281
+ if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
282
+ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
283
+ if _torch_xla_available:
284
+ logger.info(f"Torch XLA version {_torch_xla_version} available.")
285
+
286
+
287
+ def is_kenlm_available():
288
+ return _kenlm_available
289
+
290
+
291
+ def is_cv2_available():
292
+ return _cv2_available
293
+
294
+
295
+ def is_torch_available():
296
+ return _torch_available
297
+
298
+
299
+ def is_torch_deterministic():
300
+ """
301
+ Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
302
+ """
303
+ import torch
304
+
305
+ if torch.get_deterministic_debug_mode() == 0:
306
+ return False
307
+ else:
308
+ return True
309
+
310
+
311
+ def is_hqq_available():
312
+ return _hqq_available
313
+
314
+
315
+ def is_pygments_available():
316
+ return _pygments_available
317
+
318
+
319
+ def get_torch_version():
320
+ return _torch_version
321
+
322
+
323
+ def is_torch_sdpa_available():
324
+ if not is_torch_available():
325
+ return False
326
+ elif _torch_version == "N/A":
327
+ return False
328
+
329
+ # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
330
+ # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
331
+ # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
332
+ # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
333
+ return version.parse(_torch_version) >= version.parse("2.1.1")
334
+
335
+
336
+ def is_torchvision_available():
337
+ return _torchvision_available
338
+
339
+
340
+ def is_galore_torch_available():
341
+ return _galore_torch_available
342
+
343
+
344
+ def is_lomo_available():
345
+ return _lomo_available
346
+
347
+
348
+ def is_pyctcdecode_available():
349
+ return _pyctcdecode_available
350
+
351
+
352
+ def is_librosa_available():
353
+ return _librosa_available
354
+
355
+
356
+ def is_essentia_available():
357
+ return _essentia_available
358
+
359
+
360
+ def is_pretty_midi_available():
361
+ return _pretty_midi_available
362
+
363
+
364
+ def is_torch_cuda_available():
365
+ if is_torch_available():
366
+ import torch
367
+
368
+ return torch.cuda.is_available()
369
+ else:
370
+ return False
371
+
372
+
373
+ def is_mamba_ssm_available():
374
+ if is_torch_available():
375
+ import torch
376
+
377
+ if not torch.cuda.is_available():
378
+ return False
379
+ else:
380
+ return _is_package_available("mamba_ssm")
381
+ return False
382
+
383
+
384
+ def is_causal_conv1d_available():
385
+ if is_torch_available():
386
+ import torch
387
+
388
+ if not torch.cuda.is_available():
389
+ return False
390
+ return _is_package_available("causal_conv1d")
391
+ return False
392
+
393
+
394
+ def is_torch_mps_available():
395
+ if is_torch_available():
396
+ import torch
397
+
398
+ if hasattr(torch.backends, "mps"):
399
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
400
+ return False
401
+
402
+
403
+ def is_torch_bf16_gpu_available():
404
+ if not is_torch_available():
405
+ return False
406
+
407
+ import torch
408
+
409
+ return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
410
+
411
+
412
+ def is_torch_bf16_cpu_available():
413
+ if not is_torch_available():
414
+ return False
415
+
416
+ import torch
417
+
418
+ try:
419
+ # multiple levels of AttributeError depending on the pytorch version so do them all in one check
420
+ _ = torch.cpu.amp.autocast
421
+ except AttributeError:
422
+ return False
423
+
424
+ return True
425
+
426
+
427
+ def is_torch_bf16_available():
428
+ # the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util
429
+ # has become ambiguous and therefore deprecated
430
+ warnings.warn(
431
+ "The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
432
+ "or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
433
+ FutureWarning,
434
+ )
435
+ return is_torch_bf16_gpu_available()
436
+
437
+
438
+ @lru_cache()
439
+ def is_torch_fp16_available_on_device(device):
440
+ if not is_torch_available():
441
+ return False
442
+
443
+ import torch
444
+
445
+ try:
446
+ x = torch.zeros(2, 2, dtype=torch.float16).to(device)
447
+ _ = x @ x
448
+
449
+ # At this moment, let's be strict of the check: check if `LayerNorm` is also supported on device, because many
450
+ # models use this layer.
451
+ batch, sentence_length, embedding_dim = 3, 4, 5
452
+ embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device)
453
+ layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device)
454
+ _ = layer_norm(embedding)
455
+
456
+ except: # noqa: E722
457
+ # TODO: more precise exception matching, if possible.
458
+ # most backends should return `RuntimeError` however this is not guaranteed.
459
+ return False
460
+
461
+ return True
462
+
463
+
464
+ @lru_cache()
465
+ def is_torch_bf16_available_on_device(device):
466
+ if not is_torch_available():
467
+ return False
468
+
469
+ import torch
470
+
471
+ if device == "cuda":
472
+ return is_torch_bf16_gpu_available()
473
+
474
+ try:
475
+ x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
476
+ _ = x @ x
477
+ except: # noqa: E722
478
+ # TODO: more precise exception matching, if possible.
479
+ # most backends should return `RuntimeError` however this is not guaranteed.
480
+ return False
481
+
482
+ return True
483
+
484
+
485
+ def is_torch_tf32_available():
486
+ if not is_torch_available():
487
+ return False
488
+
489
+ import torch
490
+
491
+ if not torch.cuda.is_available() or torch.version.cuda is None:
492
+ return False
493
+ if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
494
+ return False
495
+ if int(torch.version.cuda.split(".")[0]) < 11:
496
+ return False
497
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
498
+ return False
499
+
500
+ return True
501
+
502
+
503
+ def is_torch_fx_available():
504
+ return _torch_fx_available
505
+
506
+
507
+ def is_peft_available():
508
+ return _peft_available
509
+
510
+
511
+ def is_bs4_available():
512
+ return _bs4_available
513
+
514
+
515
+ def is_tf_available():
516
+ return _tf_available
517
+
518
+
519
+ def is_coloredlogs_available():
520
+ return _coloredlogs_available
521
+
522
+
523
+ def is_tf2onnx_available():
524
+ return _tf2onnx_available
525
+
526
+
527
+ def is_onnx_available():
528
+ return _onnx_available
529
+
530
+
531
+ def is_openai_available():
532
+ return _openai_available
533
+
534
+
535
+ def is_flax_available():
536
+ return _flax_available
537
+
538
+
539
+ def is_ftfy_available():
540
+ return _ftfy_available
541
+
542
+
543
+ def is_g2p_en_available():
544
+ return _g2p_en_available
545
+
546
+
547
+ @lru_cache()
548
+ def is_torch_tpu_available(check_device=True):
549
+ "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
550
+ warnings.warn(
551
+ "`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
552
+ "Please use the `is_torch_xla_available` instead.",
553
+ FutureWarning,
554
+ )
555
+
556
+ if not _torch_available:
557
+ return False
558
+ if importlib.util.find_spec("torch_xla") is not None:
559
+ if check_device:
560
+ # We need to check if `xla_device` can be found, will raise a RuntimeError if not
561
+ try:
562
+ import torch_xla.core.xla_model as xm
563
+
564
+ _ = xm.xla_device()
565
+ return True
566
+ except RuntimeError:
567
+ return False
568
+ return True
569
+ return False
570
+
571
+
572
+ @lru_cache
573
+ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
574
+ """
575
+ Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
576
+ the USE_TORCH_XLA to false.
577
+ """
578
+ assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."
579
+
580
+ if not _torch_xla_available:
581
+ return False
582
+
583
+ import torch_xla
584
+
585
+ if check_is_gpu:
586
+ return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
587
+ elif check_is_tpu:
588
+ return torch_xla.runtime.device_type() == "TPU"
589
+
590
+ return True
591
+
592
+
593
+ @lru_cache()
594
+ def is_torch_neuroncore_available(check_device=True):
595
+ if importlib.util.find_spec("torch_neuronx") is not None:
596
+ return is_torch_xla_available()
597
+ return False
598
+
599
+
600
+ @lru_cache()
601
+ def is_torch_npu_available(check_device=False):
602
+ "Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
603
+ if not _torch_available or importlib.util.find_spec("torch_npu") is None:
604
+ return False
605
+
606
+ import torch
607
+ import torch_npu # noqa: F401
608
+
609
+ if check_device:
610
+ try:
611
+ # Will raise a RuntimeError if no NPU is found
612
+ _ = torch.npu.device_count()
613
+ return torch.npu.is_available()
614
+ except RuntimeError:
615
+ return False
616
+ return hasattr(torch, "npu") and torch.npu.is_available()
617
+
618
+
619
+ @lru_cache()
620
+ def is_torch_mlu_available(check_device=False):
621
+ "Checks if `torch_mlu` is installed and potentially if a MLU is in the environment"
622
+ if not _torch_available or importlib.util.find_spec("torch_mlu") is None:
623
+ return False
624
+
625
+ import torch
626
+ import torch_mlu # noqa: F401
627
+
628
+ from ..dependency_versions_table import deps
629
+
630
+ deps["deepspeed"] = "deepspeed-mlu>=0.10.1"
631
+
632
+ if check_device:
633
+ try:
634
+ # Will raise a RuntimeError if no MLU is found
635
+ _ = torch.mlu.device_count()
636
+ return torch.mlu.is_available()
637
+ except RuntimeError:
638
+ return False
639
+ return hasattr(torch, "mlu") and torch.mlu.is_available()
640
+
641
+
642
+ def is_torchdynamo_available():
643
+ if not is_torch_available():
644
+ return False
645
+ try:
646
+ import torch._dynamo as dynamo # noqa: F401
647
+
648
+ return True
649
+ except Exception:
650
+ return False
651
+
652
+
653
+ def is_torch_compile_available():
654
+ if not is_torch_available():
655
+ return False
656
+
657
+ import torch
658
+
659
+ # We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against
660
+ # 2.0 but let's do it later.
661
+ return hasattr(torch, "compile")
662
+
663
+
664
+ def is_torchdynamo_compiling():
665
+ if not is_torch_available():
666
+ return False
667
+ try:
668
+ import torch._dynamo as dynamo # noqa: F401
669
+
670
+ return dynamo.is_compiling()
671
+ except Exception:
672
+ return False
673
+
674
+
675
+ def is_torch_tensorrt_fx_available():
676
+ if importlib.util.find_spec("torch_tensorrt") is None:
677
+ return False
678
+ return importlib.util.find_spec("torch_tensorrt.fx") is not None
679
+
680
+
681
+ def is_datasets_available():
682
+ return _datasets_available
683
+
684
+
685
+ def is_detectron2_available():
686
+ return _detectron2_available
687
+
688
+
689
+ def is_rjieba_available():
690
+ return _rjieba_available
691
+
692
+
693
+ def is_psutil_available():
694
+ return _psutil_available
695
+
696
+
697
+ def is_py3nvml_available():
698
+ return _py3nvml_available
699
+
700
+
701
+ def is_sacremoses_available():
702
+ return _sacremoses_available
703
+
704
+
705
+ def is_apex_available():
706
+ return _apex_available
707
+
708
+
709
+ def is_aqlm_available():
710
+ return _aqlm_available
711
+
712
+
713
+ def is_av_available():
714
+ return _av_available
715
+
716
+
717
+ def is_ninja_available():
718
+ r"""
719
+ Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
720
+ [ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
721
+ """
722
+ try:
723
+ subprocess.check_output("ninja --version".split())
724
+ except Exception:
725
+ return False
726
+ else:
727
+ return True
728
+
729
+
730
+ def is_ipex_available():
731
+ def get_major_and_minor_from_version(full_version):
732
+ return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
733
+
734
+ if not is_torch_available() or not _ipex_available:
735
+ return False
736
+
737
+ torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
738
+ ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
739
+ if torch_major_and_minor != ipex_major_and_minor:
740
+ logger.warning(
741
+ f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
742
+ f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
743
+ )
744
+ return False
745
+ return True
746
+
747
+
748
+ @lru_cache
749
+ def is_torch_xpu_available(check_device=False):
750
+ "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
751
+ if not is_ipex_available():
752
+ return False
753
+
754
+ import intel_extension_for_pytorch # noqa: F401
755
+ import torch
756
+
757
+ if check_device:
758
+ try:
759
+ # Will raise a RuntimeError if no XPU is found
760
+ _ = torch.xpu.device_count()
761
+ return torch.xpu.is_available()
762
+ except RuntimeError:
763
+ return False
764
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
765
+
766
+
767
+ def is_bitsandbytes_available():
768
+ if not is_torch_available():
769
+ return False
770
+
771
+ # bitsandbytes throws an error if cuda is not available
772
+ # let's avoid that by adding a simple check
773
+ import torch
774
+
775
+ return _bitsandbytes_available and torch.cuda.is_available()
776
+
777
+
778
+ def is_flash_attn_2_available():
779
+ if not is_torch_available():
780
+ return False
781
+
782
+ if not _is_package_available("flash_attn"):
783
+ return False
784
+
785
+ # Let's add an extra check to see if cuda is available
786
+ import torch
787
+
788
+ if not torch.cuda.is_available():
789
+ return False
790
+
791
+ if torch.version.cuda:
792
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
793
+ elif torch.version.hip:
794
+ # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
795
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
796
+ else:
797
+ return False
798
+
799
+
800
+ def is_flash_attn_greater_or_equal_2_10():
801
+ if not _is_package_available("flash_attn"):
802
+ return False
803
+
804
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
805
+
806
+
807
+ def is_torchdistx_available():
808
+ return _torchdistx_available
809
+
810
+
811
+ def is_faiss_available():
812
+ return _faiss_available
813
+
814
+
815
+ def is_scipy_available():
816
+ return _scipy_available
817
+
818
+
819
+ def is_sklearn_available():
820
+ return _sklearn_available
821
+
822
+
823
+ def is_sentencepiece_available():
824
+ return _sentencepiece_available
825
+
826
+
827
+ def is_seqio_available():
828
+ return _is_seqio_available
829
+
830
+
831
+ def is_gguf_available():
832
+ return _is_gguf_available
833
+
834
+
835
+ def is_protobuf_available():
836
+ if importlib.util.find_spec("google") is None:
837
+ return False
838
+ return importlib.util.find_spec("google.protobuf") is not None
839
+
840
+
841
+ def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
842
+ return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
843
+
844
+
845
+ def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
846
+ return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version)
847
+
848
+
849
+ def is_optimum_available():
850
+ return _optimum_available
851
+
852
+
853
+ def is_auto_awq_available():
854
+ return _auto_awq_available
855
+
856
+
857
+ def is_quanto_available():
858
+ return _quanto_available
859
+
860
+
861
+ def is_auto_gptq_available():
862
+ return _auto_gptq_available
863
+
864
+
865
+ def is_eetq_available():
866
+ return _eetq_available
867
+
868
+
869
+ def is_levenshtein_available():
870
+ return _levenshtein_available
871
+
872
+
873
+ def is_optimum_neuron_available():
874
+ return _optimum_available and _is_package_available("optimum.neuron")
875
+
876
+
877
+ def is_safetensors_available():
878
+ return _safetensors_available
879
+
880
+
881
+ def is_tokenizers_available():
882
+ return _tokenizers_available
883
+
884
+
885
+ @lru_cache
886
+ def is_vision_available():
887
+ _pil_available = importlib.util.find_spec("PIL") is not None
888
+ if _pil_available:
889
+ try:
890
+ package_version = importlib.metadata.version("Pillow")
891
+ except importlib.metadata.PackageNotFoundError:
892
+ try:
893
+ package_version = importlib.metadata.version("Pillow-SIMD")
894
+ except importlib.metadata.PackageNotFoundError:
895
+ return False
896
+ logger.debug(f"Detected PIL version {package_version}")
897
+ return _pil_available
898
+
899
+
900
+ def is_pytesseract_available():
901
+ return _pytesseract_available
902
+
903
+
904
+ def is_pytest_available():
905
+ return _pytest_available
906
+
907
+
908
+ def is_spacy_available():
909
+ return _spacy_available
910
+
911
+
912
+ def is_tensorflow_text_available():
913
+ return is_tf_available() and _tensorflow_text_available
914
+
915
+
916
+ def is_keras_nlp_available():
917
+ return is_tensorflow_text_available() and _keras_nlp_available
918
+
919
+
920
+ def is_in_notebook():
921
+ try:
922
+ # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
923
+ get_ipython = sys.modules["IPython"].get_ipython
924
+ if "IPKernelApp" not in get_ipython().config:
925
+ raise ImportError("console")
926
+ if "VSCODE_PID" in os.environ:
927
+ raise ImportError("vscode")
928
+ if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
929
+ # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook
930
+ # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel
931
+ raise ImportError("databricks")
932
+
933
+ return importlib.util.find_spec("IPython") is not None
934
+ except (AttributeError, ImportError, KeyError):
935
+ return False
936
+
937
+
938
+ def is_pytorch_quantization_available():
939
+ return _pytorch_quantization_available
940
+
941
+
942
+ def is_tensorflow_probability_available():
943
+ return _tensorflow_probability_available
944
+
945
+
946
+ def is_pandas_available():
947
+ return _pandas_available
948
+
949
+
950
+ def is_sagemaker_dp_enabled():
951
+ # Get the sagemaker specific env variable.
952
+ sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
953
+ try:
954
+ # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
955
+ sagemaker_params = json.loads(sagemaker_params)
956
+ if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
957
+ return False
958
+ except json.JSONDecodeError:
959
+ return False
960
+ # Lastly, check if the `smdistributed` module is present.
961
+ return _smdistributed_available
962
+
963
+
964
+ def is_sagemaker_mp_enabled():
965
+ # Get the sagemaker specific mp parameters from smp_options variable.
966
+ smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
967
+ try:
968
+ # Parse it and check the field "partitions" is included, it is required for model parallel.
969
+ smp_options = json.loads(smp_options)
970
+ if "partitions" not in smp_options:
971
+ return False
972
+ except json.JSONDecodeError:
973
+ return False
974
+
975
+ # Get the sagemaker specific framework parameters from mpi_options variable.
976
+ mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
977
+ try:
978
+ # Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
979
+ mpi_options = json.loads(mpi_options)
980
+ if not mpi_options.get("sagemaker_mpi_enabled", False):
981
+ return False
982
+ except json.JSONDecodeError:
983
+ return False
984
+ # Lastly, check if the `smdistributed` module is present.
985
+ return _smdistributed_available
986
+
987
+
988
+ def is_training_run_on_sagemaker():
989
+ return "SAGEMAKER_JOB_NAME" in os.environ
990
+
991
+
992
+ def is_soundfile_availble():
993
+ return _soundfile_available
994
+
995
+
996
+ def is_timm_available():
997
+ return _timm_available
998
+
999
+
1000
+ def is_natten_available():
1001
+ return _natten_available
1002
+
1003
+
1004
+ def is_nltk_available():
1005
+ return _nltk_available
1006
+
1007
+
1008
+ def is_torchaudio_available():
1009
+ return _torchaudio_available
1010
+
1011
+
1012
+ def is_speech_available():
1013
+ # For now this depends on torchaudio but the exact dependency might evolve in the future.
1014
+ return _torchaudio_available
1015
+
1016
+
1017
+ def is_phonemizer_available():
1018
+ return _phonemizer_available
1019
+
1020
+
1021
+ def torch_only_method(fn):
1022
+ def wrapper(*args, **kwargs):
1023
+ if not _torch_available:
1024
+ raise ImportError(
1025
+ "You need to install pytorch to use this method or class, "
1026
+ "or activate it with environment variables USE_TORCH=1 and USE_TF=0."
1027
+ )
1028
+ else:
1029
+ return fn(*args, **kwargs)
1030
+
1031
+ return wrapper
1032
+
1033
+
1034
+ def is_ccl_available():
1035
+ return _is_ccl_available
1036
+
1037
+
1038
+ def is_decord_available():
1039
+ return _decord_available
1040
+
1041
+
1042
+ def is_sudachi_available():
1043
+ return _sudachipy_available
1044
+
1045
+
1046
+ def get_sudachi_version():
1047
+ return _sudachipy_version
1048
+
1049
+
1050
+ def is_sudachi_projection_available():
1051
+ if not is_sudachi_available():
1052
+ return False
1053
+
1054
+ # NOTE: We require sudachipy>=0.6.8 to use projection option in sudachi_kwargs for the constructor of BertJapaneseTokenizer.
1055
+ # - `projection` option is not supported in sudachipy<0.6.8, see https://github.com/WorksApplications/sudachi.rs/issues/230
1056
+ return version.parse(_sudachipy_version) >= version.parse("0.6.8")
1057
+
1058
+
1059
+ def is_jumanpp_available():
1060
+ return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None)
1061
+
1062
+
1063
+ def is_cython_available():
1064
+ return importlib.util.find_spec("pyximport") is not None
1065
+
1066
+
1067
+ def is_jieba_available():
1068
+ return _jieba_available
1069
+
1070
+
1071
+ def is_jinja_available():
1072
+ return _jinja_available
1073
+
1074
+
1075
+ def is_mlx_available():
1076
+ return _mlx_available
1077
+
1078
+
1079
+ # docstyle-ignore
1080
+ AV_IMPORT_ERROR = """
1081
+ {0} requires the PyAv library but it was not found in your environment. You can install it with:
1082
+ ```
1083
+ pip install av
1084
+ ```
1085
+ Please note that you may need to restart your runtime after installation.
1086
+ """
1087
+
1088
+
1089
+ # docstyle-ignore
1090
+ CV2_IMPORT_ERROR = """
1091
+ {0} requires the OpenCV library but it was not found in your environment. You can install it with:
1092
+ ```
1093
+ pip install opencv-python
1094
+ ```
1095
+ Please note that you may need to restart your runtime after installation.
1096
+ """
1097
+
1098
+
1099
+ # docstyle-ignore
1100
+ DATASETS_IMPORT_ERROR = """
1101
+ {0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
1102
+ ```
1103
+ pip install datasets
1104
+ ```
1105
+ In a notebook or a colab, you can install it by executing a cell with
1106
+ ```
1107
+ !pip install datasets
1108
+ ```
1109
+ then restarting your kernel.
1110
+
1111
+ Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
1112
+ working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
1113
+ that python file if that's the case. Please note that you may need to restart your runtime after installation.
1114
+ """
1115
+
1116
+
1117
+ # docstyle-ignore
1118
+ TOKENIZERS_IMPORT_ERROR = """
1119
+ {0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
1120
+ ```
1121
+ pip install tokenizers
1122
+ ```
1123
+ In a notebook or a colab, you can install it by executing a cell with
1124
+ ```
1125
+ !pip install tokenizers
1126
+ ```
1127
+ Please note that you may need to restart your runtime after installation.
1128
+ """
1129
+
1130
+
1131
+ # docstyle-ignore
1132
+ SENTENCEPIECE_IMPORT_ERROR = """
1133
+ {0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
1134
+ installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
1135
+ that match your environment. Please note that you may need to restart your runtime after installation.
1136
+ """
1137
+
1138
+
1139
+ # docstyle-ignore
1140
+ PROTOBUF_IMPORT_ERROR = """
1141
+ {0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
1142
+ installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
1143
+ that match your environment. Please note that you may need to restart your runtime after installation.
1144
+ """
1145
+
1146
+
1147
+ # docstyle-ignore
1148
+ FAISS_IMPORT_ERROR = """
1149
+ {0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
1150
+ installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
1151
+ that match your environment. Please note that you may need to restart your runtime after installation.
1152
+ """
1153
+
1154
+
1155
+ # docstyle-ignore
1156
+ PYTORCH_IMPORT_ERROR = """
1157
+ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
1158
+ installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
1159
+ Please note that you may need to restart your runtime after installation.
1160
+ """
1161
+
1162
+
1163
+ # docstyle-ignore
1164
+ TORCHVISION_IMPORT_ERROR = """
1165
+ {0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
1166
+ installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
1167
+ Please note that you may need to restart your runtime after installation.
1168
+ """
1169
+
1170
+ # docstyle-ignore
1171
+ PYTORCH_IMPORT_ERROR_WITH_TF = """
1172
+ {0} requires the PyTorch library but it was not found in your environment.
1173
+ However, we were able to find a TensorFlow installation. TensorFlow classes begin
1174
+ with "TF", but are otherwise identically named to our PyTorch classes. This
1175
+ means that the TF equivalent of the class you tried to import would be "TF{0}".
1176
+ If you want to use TensorFlow, please use TF classes instead!
1177
+
1178
+ If you really do want to use PyTorch please go to
1179
+ https://pytorch.org/get-started/locally/ and follow the instructions that
1180
+ match your environment.
1181
+ """
1182
+
1183
+ # docstyle-ignore
1184
+ TF_IMPORT_ERROR_WITH_PYTORCH = """
1185
+ {0} requires the TensorFlow library but it was not found in your environment.
1186
+ However, we were able to find a PyTorch installation. PyTorch classes do not begin
1187
+ with "TF", but are otherwise identically named to our TF classes.
1188
+ If you want to use PyTorch, please use those classes instead!
1189
+
1190
+ If you really do want to use TensorFlow, please follow the instructions on the
1191
+ installation page https://www.tensorflow.org/install that match your environment.
1192
+ """
1193
+
1194
+ # docstyle-ignore
1195
+ BS4_IMPORT_ERROR = """
1196
+ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
1197
+ `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
1198
+ """
1199
+
1200
+
1201
+ # docstyle-ignore
1202
+ SKLEARN_IMPORT_ERROR = """
1203
+ {0} requires the scikit-learn library but it was not found in your environment. You can install it with:
1204
+ ```
1205
+ pip install -U scikit-learn
1206
+ ```
1207
+ In a notebook or a colab, you can install it by executing a cell with
1208
+ ```
1209
+ !pip install -U scikit-learn
1210
+ ```
1211
+ Please note that you may need to restart your runtime after installation.
1212
+ """
1213
+
1214
+
1215
+ # docstyle-ignore
1216
+ TENSORFLOW_IMPORT_ERROR = """
1217
+ {0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
1218
+ installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
1219
+ Please note that you may need to restart your runtime after installation.
1220
+ """
1221
+
1222
+
1223
+ # docstyle-ignore
1224
+ DETECTRON2_IMPORT_ERROR = """
1225
+ {0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
1226
+ installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
1227
+ that match your environment. Please note that you may need to restart your runtime after installation.
1228
+ """
1229
+
1230
+
1231
+ # docstyle-ignore
1232
+ FLAX_IMPORT_ERROR = """
1233
+ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
1234
+ installation page: https://github.com/google/flax and follow the ones that match your environment.
1235
+ Please note that you may need to restart your runtime after installation.
1236
+ """
1237
+
1238
+ # docstyle-ignore
1239
+ FTFY_IMPORT_ERROR = """
1240
+ {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
1241
+ installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
1242
+ that match your environment. Please note that you may need to restart your runtime after installation.
1243
+ """
1244
+
1245
+ LEVENSHTEIN_IMPORT_ERROR = """
1246
+ {0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip
1247
+ install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
1248
+ """
1249
+
1250
+ # docstyle-ignore
1251
+ G2P_EN_IMPORT_ERROR = """
1252
+ {0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
1253
+ `pip install g2p-en`. Please note that you may need to restart your runtime after installation.
1254
+ """
1255
+
1256
+ # docstyle-ignore
1257
+ PYTORCH_QUANTIZATION_IMPORT_ERROR = """
1258
+ {0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip:
1259
+ `pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`
1260
+ Please note that you may need to restart your runtime after installation.
1261
+ """
1262
+
1263
+ # docstyle-ignore
1264
+ TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
1265
+ {0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as
1266
+ explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation.
1267
+ """
1268
+
1269
+ # docstyle-ignore
1270
+ TENSORFLOW_TEXT_IMPORT_ERROR = """
1271
+ {0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
1272
+ explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
1273
+ Please note that you may need to restart your runtime after installation.
1274
+ """
1275
+
1276
+
1277
+ # docstyle-ignore
1278
+ PANDAS_IMPORT_ERROR = """
1279
+ {0} requires the pandas library but it was not found in your environment. You can install it with pip as
1280
+ explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
1281
+ Please note that you may need to restart your runtime after installation.
1282
+ """
1283
+
1284
+
1285
+ # docstyle-ignore
1286
+ PHONEMIZER_IMPORT_ERROR = """
1287
+ {0} requires the phonemizer library but it was not found in your environment. You can install it with pip:
1288
+ `pip install phonemizer`. Please note that you may need to restart your runtime after installation.
1289
+ """
1290
+
1291
+
1292
+ # docstyle-ignore
1293
+ SACREMOSES_IMPORT_ERROR = """
1294
+ {0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
1295
+ `pip install sacremoses`. Please note that you may need to restart your runtime after installation.
1296
+ """
1297
+
1298
+ # docstyle-ignore
1299
+ SCIPY_IMPORT_ERROR = """
1300
+ {0} requires the scipy library but it was not found in your environment. You can install it with pip:
1301
+ `pip install scipy`. Please note that you may need to restart your runtime after installation.
1302
+ """
1303
+
1304
+
1305
+ # docstyle-ignore
1306
+ SPEECH_IMPORT_ERROR = """
1307
+ {0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
1308
+ `pip install torchaudio`. Please note that you may need to restart your runtime after installation.
1309
+ """
1310
+
1311
+ # docstyle-ignore
1312
+ TIMM_IMPORT_ERROR = """
1313
+ {0} requires the timm library but it was not found in your environment. You can install it with pip:
1314
+ `pip install timm`. Please note that you may need to restart your runtime after installation.
1315
+ """
1316
+
1317
+ # docstyle-ignore
1318
+ NATTEN_IMPORT_ERROR = """
1319
+ {0} requires the natten library but it was not found in your environment. You can install it by referring to:
1320
+ shi-labs.com/natten . You can also install it with pip (may take longer to build):
1321
+ `pip install natten`. Please note that you may need to restart your runtime after installation.
1322
+ """
1323
+
1324
+ NUMEXPR_IMPORT_ERROR = """
1325
+ {0} requires the numexpr library but it was not found in your environment. You can install it by referring to:
1326
+ https://numexpr.readthedocs.io/en/latest/index.html.
1327
+ """
1328
+
1329
+
1330
+ # docstyle-ignore
1331
+ NLTK_IMPORT_ERROR = """
1332
+ {0} requires the NLTK library but it was not found in your environment. You can install it by referring to:
1333
+ https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation.
1334
+ """
1335
+
1336
+
1337
+ # docstyle-ignore
1338
+ VISION_IMPORT_ERROR = """
1339
+ {0} requires the PIL library but it was not found in your environment. You can install it with pip:
1340
+ `pip install pillow`. Please note that you may need to restart your runtime after installation.
1341
+ """
1342
+
1343
+
1344
+ # docstyle-ignore
1345
+ PYTESSERACT_IMPORT_ERROR = """
1346
+ {0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
1347
+ `pip install pytesseract`. Please note that you may need to restart your runtime after installation.
1348
+ """
1349
+
1350
+ # docstyle-ignore
1351
+ PYCTCDECODE_IMPORT_ERROR = """
1352
+ {0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
1353
+ `pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
1354
+ """
1355
+
1356
+ # docstyle-ignore
1357
+ ACCELERATE_IMPORT_ERROR = """
1358
+ {0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment.
1359
+ You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your
1360
+ runtime after installation.
1361
+ """
1362
+
1363
+ # docstyle-ignore
1364
+ CCL_IMPORT_ERROR = """
1365
+ {0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
1366
+ `pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
1367
+ Please note that you may need to restart your runtime after installation.
1368
+ """
1369
+
1370
+ # docstyle-ignore
1371
+ ESSENTIA_IMPORT_ERROR = """
1372
+ {0} requires essentia library. But that was not found in your environment. You can install them with pip:
1373
+ `pip install essentia==2.1b6.dev1034`
1374
+ Please note that you may need to restart your runtime after installation.
1375
+ """
1376
+
1377
+ # docstyle-ignore
1378
+ LIBROSA_IMPORT_ERROR = """
1379
+ {0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
1380
+ `pip install librosa`
1381
+ Please note that you may need to restart your runtime after installation.
1382
+ """
1383
+
1384
+ # docstyle-ignore
1385
+ PRETTY_MIDI_IMPORT_ERROR = """
1386
+ {0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
1387
+ `pip install pretty_midi`
1388
+ Please note that you may need to restart your runtime after installation.
1389
+ """
1390
+
1391
+ DECORD_IMPORT_ERROR = """
1392
+ {0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
1393
+ decord`. Please note that you may need to restart your runtime after installation.
1394
+ """
1395
+
1396
+ CYTHON_IMPORT_ERROR = """
1397
+ {0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
1398
+ Cython`. Please note that you may need to restart your runtime after installation.
1399
+ """
1400
+
1401
+ JIEBA_IMPORT_ERROR = """
1402
+ {0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install
1403
+ jieba`. Please note that you may need to restart your runtime after installation.
1404
+ """
1405
+
1406
+ PEFT_IMPORT_ERROR = """
1407
+ {0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
1408
+ peft`. Please note that you may need to restart your runtime after installation.
1409
+ """
1410
+
1411
+ JINJA_IMPORT_ERROR = """
1412
+ {0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
1413
+ jinja2`. Please note that you may need to restart your runtime after installation.
1414
+ """
1415
+
1416
+ BACKENDS_MAPPING = OrderedDict(
1417
+ [
1418
+ ("av", (is_av_available, AV_IMPORT_ERROR)),
1419
+ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
1420
+ ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
1421
+ ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
1422
+ ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
1423
+ ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
1424
+ ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
1425
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
1426
+ ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
1427
+ ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
1428
+ ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
1429
+ ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
1430
+ ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
1431
+ ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)),
1432
+ ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
1433
+ ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
1434
+ ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
1435
+ ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
1436
+ ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
1437
+ ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
1438
+ ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
1439
+ ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
1440
+ ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
1441
+ ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
1442
+ ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
1443
+ ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
1444
+ ("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
1445
+ ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
1446
+ ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
1447
+ ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
1448
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
1449
+ ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
1450
+ ("vision", (is_vision_available, VISION_IMPORT_ERROR)),
1451
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
1452
+ ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
1453
+ ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
1454
+ ("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
1455
+ ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
1456
+ ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
1457
+ ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
1458
+ ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
1459
+ ]
1460
+ )
1461
+
1462
+
1463
+ def requires_backends(obj, backends):
1464
+ if not isinstance(backends, (list, tuple)):
1465
+ backends = [backends]
1466
+
1467
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
1468
+
1469
+ # Raise an error for users who might not realize that classes without "TF" are torch-only
1470
+ if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available():
1471
+ raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
1472
+
1473
+ # Raise the inverse error for PyTorch users trying to load TF classes
1474
+ if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
1475
+ raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
1476
+
1477
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
1478
+ failed = [msg.format(name) for available, msg in checks if not available()]
1479
+ if failed:
1480
+ raise ImportError("".join(failed))
1481
+
1482
+
1483
+ class DummyObject(type):
1484
+ """
1485
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
1486
+ `requires_backend` each time a user tries to access any method of that class.
1487
+ """
1488
+
1489
+ def __getattribute__(cls, key):
1490
+ if key.startswith("_") and key != "_from_config":
1491
+ return super().__getattribute__(key)
1492
+ requires_backends(cls, cls._backends)
1493
+
1494
+
1495
+ def is_torch_fx_proxy(x):
1496
+ if is_torch_fx_available():
1497
+ import torch.fx
1498
+
1499
+ return isinstance(x, torch.fx.Proxy)
1500
+ return False
1501
+
1502
+
1503
+ class _LazyModule(ModuleType):
1504
+ """
1505
+ Module class that surfaces all objects but only performs associated imports when the objects are requested.
1506
+ """
1507
+
1508
+ # Very heavily inspired by optuna.integration._IntegrationModule
1509
+ # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
1510
+ def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
1511
+ super().__init__(name)
1512
+ self._modules = set(import_structure.keys())
1513
+ self._class_to_module = {}
1514
+ for key, values in import_structure.items():
1515
+ for value in values:
1516
+ self._class_to_module[value] = key
1517
+ # Needed for autocompletion in an IDE
1518
+ self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
1519
+ self.__file__ = module_file
1520
+ self.__spec__ = module_spec
1521
+ self.__path__ = [os.path.dirname(module_file)]
1522
+ self._objects = {} if extra_objects is None else extra_objects
1523
+ self._name = name
1524
+ self._import_structure = import_structure
1525
+
1526
+ # Needed for autocompletion in an IDE
1527
+ def __dir__(self):
1528
+ result = super().__dir__()
1529
+ # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
1530
+ # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
1531
+ for attr in self.__all__:
1532
+ if attr not in result:
1533
+ result.append(attr)
1534
+ return result
1535
+
1536
+ def __getattr__(self, name: str) -> Any:
1537
+ if name in self._objects:
1538
+ return self._objects[name]
1539
+ if name in self._modules:
1540
+ value = self._get_module(name)
1541
+ elif name in self._class_to_module.keys():
1542
+ module = self._get_module(self._class_to_module[name])
1543
+ value = getattr(module, name)
1544
+ else:
1545
+ raise AttributeError(f"module {self.__name__} has no attribute {name}")
1546
+
1547
+ setattr(self, name, value)
1548
+ return value
1549
+
1550
+ def _get_module(self, module_name: str):
1551
+ try:
1552
+ return importlib.import_module("." + module_name, self.__name__)
1553
+ except Exception as e:
1554
+ raise RuntimeError(
1555
+ f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
1556
+ f" traceback):\n{e}"
1557
+ ) from e
1558
+
1559
+ def __reduce__(self):
1560
+ return (self.__class__, (self._name, self.__file__, self._import_structure))
1561
+
1562
+
1563
+ class OptionalDependencyNotAvailable(BaseException):
1564
+ """Internally used error class for signalling an optional dependency was not found."""
1565
+
1566
+
1567
+ def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
1568
+ """Imports transformers directly
1569
+
1570
+ Args:
1571
+ path (`str`): The path to the source file
1572
+ file (`str`, optional): The file to join with the path. Defaults to "__init__.py".
1573
+
1574
+ Returns:
1575
+ `ModuleType`: The resulting imported module
1576
+ """
1577
+ name = "transformers"
1578
+ location = os.path.join(path, file)
1579
+ spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path])
1580
+ module = importlib.util.module_from_spec(spec)
1581
+ spec.loader.exec_module(module)
1582
+ module = sys.modules[name]
1583
+ return module