|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Utility that checks the custom inits of Transformers are well-defined: Transformers uses init files that delay the |
|
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would |
|
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with |
|
delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the |
|
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. The goal of this |
|
script is to check the objects defined in both halves are the same. |
|
|
|
This also checks the main init properly references all submodules, even if it doesn't import anything from them: every |
|
submodule should be defined as a key of `_import_structure`, with an empty list as value potentially, or the submodule |
|
won't be importable. |
|
|
|
Use from the root of the repo with: |
|
|
|
```bash |
|
python utils/check_inits.py |
|
``` |
|
|
|
for a check that will error in case of inconsistencies (used by `make repo-consistency`). |
|
|
|
There is no auto-fix possible here sadly :-( |
|
""" |
|
|
|
import collections |
|
import os |
|
import re |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
|
|
PATH_TO_TRANSFORMERS = "src/transformers" |
|
|
|
|
|
|
|
_re_backend = re.compile(r"is\_([a-z_]*)_available()") |
|
|
|
_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}") |
|
|
|
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') |
|
|
|
_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)") |
|
|
|
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') |
|
|
|
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]") |
|
|
|
_re_quote_object = re.compile(r'^\s+"([^"]+)",') |
|
|
|
_re_between_brackets = re.compile(r"^\s+\[([^\]]+)\]") |
|
|
|
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") |
|
|
|
_re_try = re.compile(r"^\s*try:") |
|
|
|
_re_else = re.compile(r"^\s*else:") |
|
|
|
|
|
def find_backend(line: str) -> Optional[str]: |
|
""" |
|
Find one (or multiple) backend in a code line of the init. |
|
|
|
Args: |
|
line (`str`): A code line of the main init. |
|
|
|
Returns: |
|
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line |
|
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so |
|
`xxx_and_yyy` for instance). |
|
""" |
|
if _re_test_backend.search(line) is None: |
|
return None |
|
backends = [b[0] for b in _re_backend.findall(line)] |
|
backends.sort() |
|
return "_and_".join(backends) |
|
|
|
|
|
def parse_init(init_file) -> Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]: |
|
""" |
|
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects |
|
defined. |
|
|
|
Args: |
|
init_file (`str`): Path to the init file to inspect. |
|
|
|
Returns: |
|
`Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]`: A tuple of two dictionaries mapping backends to list of |
|
imported objects, one for the `_import_structure` part of the init and one for the `TYPE_CHECKING` part of the |
|
init. Returns `None` if the init is not a custom init. |
|
""" |
|
with open(init_file, "r", encoding="utf-8", newline="\n") as f: |
|
lines = f.readlines() |
|
|
|
|
|
line_index = 0 |
|
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"): |
|
line_index += 1 |
|
|
|
|
|
if line_index >= len(lines): |
|
return None |
|
|
|
|
|
objects = [] |
|
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None: |
|
line = lines[line_index] |
|
|
|
if _re_one_line_import_struct.search(line): |
|
content = _re_one_line_import_struct.search(line).groups()[0] |
|
imports = re.findall(r"\[([^\]]+)\]", content) |
|
for imp in imports: |
|
objects.extend([obj[1:-1] for obj in imp.split(", ")]) |
|
line_index += 1 |
|
continue |
|
single_line_import_search = _re_import_struct_key_value.search(line) |
|
if single_line_import_search is not None: |
|
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0] |
|
objects.extend(imports) |
|
elif line.startswith(" " * 8 + '"'): |
|
objects.append(line[9:-3]) |
|
line_index += 1 |
|
|
|
|
|
import_dict_objects = {"none": objects} |
|
|
|
|
|
while not lines[line_index].startswith("if TYPE_CHECKING"): |
|
|
|
backend = find_backend(lines[line_index]) |
|
|
|
if _re_try.search(lines[line_index - 1]) is None: |
|
backend = None |
|
|
|
if backend is not None: |
|
line_index += 1 |
|
|
|
|
|
while _re_else.search(lines[line_index]) is None: |
|
line_index += 1 |
|
|
|
line_index += 1 |
|
|
|
objects = [] |
|
|
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): |
|
line = lines[line_index] |
|
if _re_import_struct_add_one.search(line) is not None: |
|
objects.append(_re_import_struct_add_one.search(line).groups()[0]) |
|
elif _re_import_struct_add_many.search(line) is not None: |
|
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ") |
|
imports = [obj[1:-1] for obj in imports if len(obj) > 0] |
|
objects.extend(imports) |
|
elif _re_between_brackets.search(line) is not None: |
|
imports = _re_between_brackets.search(line).groups()[0].split(", ") |
|
imports = [obj[1:-1] for obj in imports if len(obj) > 0] |
|
objects.extend(imports) |
|
elif _re_quote_object.search(line) is not None: |
|
objects.append(_re_quote_object.search(line).groups()[0]) |
|
elif line.startswith(" " * 8 + '"'): |
|
objects.append(line[9:-3]) |
|
elif line.startswith(" " * 12 + '"'): |
|
objects.append(line[13:-3]) |
|
line_index += 1 |
|
|
|
import_dict_objects[backend] = objects |
|
else: |
|
line_index += 1 |
|
|
|
|
|
objects = [] |
|
while ( |
|
line_index < len(lines) |
|
and find_backend(lines[line_index]) is None |
|
and not lines[line_index].startswith("else") |
|
): |
|
line = lines[line_index] |
|
single_line_import_search = _re_import.search(line) |
|
if single_line_import_search is not None: |
|
objects.extend(single_line_import_search.groups()[0].split(", ")) |
|
elif line.startswith(" " * 8): |
|
objects.append(line[8:-2]) |
|
line_index += 1 |
|
|
|
type_hint_objects = {"none": objects} |
|
|
|
|
|
while line_index < len(lines): |
|
|
|
backend = find_backend(lines[line_index]) |
|
|
|
if _re_try.search(lines[line_index - 1]) is None: |
|
backend = None |
|
|
|
if backend is not None: |
|
line_index += 1 |
|
|
|
|
|
while _re_else.search(lines[line_index]) is None: |
|
line_index += 1 |
|
|
|
line_index += 1 |
|
|
|
objects = [] |
|
|
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): |
|
line = lines[line_index] |
|
single_line_import_search = _re_import.search(line) |
|
if single_line_import_search is not None: |
|
objects.extend(single_line_import_search.groups()[0].split(", ")) |
|
elif line.startswith(" " * 12): |
|
objects.append(line[12:-2]) |
|
line_index += 1 |
|
|
|
type_hint_objects[backend] = objects |
|
else: |
|
line_index += 1 |
|
|
|
return import_dict_objects, type_hint_objects |
|
|
|
|
|
def analyze_results(import_dict_objects: Dict[str, List[str]], type_hint_objects: Dict[str, List[str]]) -> List[str]: |
|
""" |
|
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init. |
|
|
|
Args: |
|
import_dict_objects (`Dict[str, List[str]]`): |
|
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to |
|
list of imported objects. |
|
type_hint_objects (`Dict[str, List[str]]`): |
|
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to |
|
list of imported objects. |
|
|
|
Returns: |
|
`List[str]`: The list of errors corresponding to mismatches. |
|
""" |
|
|
|
def find_duplicates(seq): |
|
return [k for k, v in collections.Counter(seq).items() if v > 1] |
|
|
|
|
|
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()): |
|
return ["Both sides of the init do not have the same backends!"] |
|
|
|
errors = [] |
|
|
|
for key in import_dict_objects.keys(): |
|
|
|
duplicate_imports = find_duplicates(import_dict_objects[key]) |
|
if duplicate_imports: |
|
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}") |
|
duplicate_type_hints = find_duplicates(type_hint_objects[key]) |
|
if duplicate_type_hints: |
|
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}") |
|
|
|
|
|
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])): |
|
name = "base imports" if key == "none" else f"{key} backend" |
|
errors.append(f"Differences for {name}:") |
|
for a in type_hint_objects[key]: |
|
if a not in import_dict_objects[key]: |
|
errors.append(f" {a} in TYPE_HINT but not in _import_structure.") |
|
for a in import_dict_objects[key]: |
|
if a not in type_hint_objects[key]: |
|
errors.append(f" {a} in _import_structure but not in TYPE_HINT.") |
|
return errors |
|
|
|
|
|
def check_all_inits(): |
|
""" |
|
Check all inits in the transformers repo and raise an error if at least one does not define the same objects in |
|
both halves. |
|
""" |
|
failures = [] |
|
for root, _, files in os.walk(PATH_TO_TRANSFORMERS): |
|
if "__init__.py" in files: |
|
fname = os.path.join(root, "__init__.py") |
|
objects = parse_init(fname) |
|
if objects is not None: |
|
errors = analyze_results(*objects) |
|
if len(errors) > 0: |
|
errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}" |
|
failures.append("\n".join(errors)) |
|
if len(failures) > 0: |
|
raise ValueError("\n\n".join(failures)) |
|
|
|
|
|
def get_transformers_submodules() -> List[str]: |
|
""" |
|
Returns the list of Transformers submodules. |
|
""" |
|
submodules = [] |
|
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS): |
|
for folder in directories: |
|
|
|
if folder.startswith("_"): |
|
directories.remove(folder) |
|
continue |
|
|
|
if len(list((Path(path) / folder).glob("*.py"))) == 0: |
|
continue |
|
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS)) |
|
submodule = short_path.replace(os.path.sep, ".") |
|
submodules.append(submodule) |
|
for fname in files: |
|
if fname == "__init__.py": |
|
continue |
|
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS)) |
|
submodule = short_path.replace(".py", "").replace(os.path.sep, ".") |
|
if len(submodule.split(".")) == 1: |
|
submodules.append(submodule) |
|
return submodules |
|
|
|
|
|
IGNORE_SUBMODULES = [ |
|
"convert_pytorch_checkpoint_to_tf2", |
|
"modeling_flax_pytorch_utils", |
|
"models.esm.openfold_utils", |
|
"modeling_attn_mask_utils", |
|
"safetensors_conversion", |
|
] |
|
|
|
|
|
def check_submodules(): |
|
""" |
|
Check all submodules of Transformers are properly registered in the main init. Error otherwise. |
|
""" |
|
|
|
from transformers.utils import direct_transformers_import |
|
|
|
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS) |
|
|
|
import_structure_keys = set(transformers._import_structure.keys()) |
|
|
|
|
|
|
|
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r") as f: |
|
init_content = f.read() |
|
import_structure_keys.update(set(re.findall(r"import_structure\[\"([^\"]*)\"\]", init_content))) |
|
|
|
module_not_registered = [ |
|
module |
|
for module in get_transformers_submodules() |
|
if module not in IGNORE_SUBMODULES and module not in import_structure_keys |
|
] |
|
|
|
if len(module_not_registered) > 0: |
|
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered) |
|
raise ValueError( |
|
"The following submodules are not properly registed in the main init of Transformers:\n" |
|
f"{list_of_modules}\n" |
|
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
check_all_inits() |
|
check_submodules() |
|
|