|
import ast |
|
import importlib |
|
import os |
|
from typing import Optional, Sequence |
|
|
|
class DeleteSpecificNodes(ast.NodeTransformer): |
|
|
|
def __init__(self, nodes_to_remove: list[ast.AST]): |
|
self.nodes_to_remove = nodes_to_remove |
|
|
|
def visit(self, node: ast.AST) -> Optional[ast.AST]: |
|
if node in self.nodes_to_remove: |
|
return None |
|
return super().visit(node) |
|
|
|
def convert_to_relative_import(module_name: str, original_parent_module_name: Optional[str]) -> str: |
|
parts = module_name.split('.') |
|
if parts[-1] == original_parent_module_name: |
|
return '.' |
|
return '.' + parts[-1] |
|
|
|
def find_module_file(module_name: str) -> str: |
|
if not module_name: |
|
raise ValueError(f'Invalid input: module_name={module_name!r}') |
|
module = importlib.import_module(module_name) |
|
module_file = module.__file__ |
|
if module_file is None: |
|
raise ValueError(f'Could not find file for module: {module_name}') |
|
return module_file |
|
|
|
def _flatten_import(node: ast.ImportFrom, flatten_imports_prefix: Sequence[str]) -> bool: |
|
"""Returns True if import should be flattened. |
|
|
|
Checks whether the node starts the same as any of the imports in |
|
flatten_imports_prefix. |
|
""" |
|
for import_prefix in flatten_imports_prefix: |
|
if node.module is not None and node.module.startswith(import_prefix): |
|
return True |
|
return False |
|
|
|
def _remove_import(node: ast.ImportFrom, remove_imports_prefix: Sequence[str]) -> bool: |
|
"""Returns True if import should be removed. |
|
|
|
Checks whether the node starts the same as any of the imports in |
|
remove_imports_prefix. |
|
""" |
|
for import_prefix in remove_imports_prefix: |
|
if node.module is not None and node.module.startswith(import_prefix): |
|
return True |
|
return False |
|
|
|
def process_file(file_path: str, folder_path: str, flatten_imports_prefix: Sequence[str], remove_imports_prefix: Sequence[str]) -> list[str]: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
source = f.read() |
|
parent_module_name = None |
|
if os.path.basename(file_path) == '__init__.py': |
|
parent_module_name = os.path.basename(os.path.dirname(file_path)) |
|
tree = ast.parse(source) |
|
new_files_to_process = [] |
|
nodes_to_remove = [] |
|
for node in ast.walk(tree): |
|
if isinstance(node, ast.ImportFrom) and node.module is not None and _remove_import(node, remove_imports_prefix): |
|
nodes_to_remove.append(node) |
|
elif isinstance(node, ast.ImportFrom) and node.module is not None and _flatten_import(node, flatten_imports_prefix): |
|
module_path = find_module_file(node.module) |
|
node.module = convert_to_relative_import(node.module, parent_module_name) |
|
new_files_to_process.append(module_path) |
|
elif isinstance(node, ast.ClassDef) and node.name.startswith('Composer'): |
|
nodes_to_remove.append(node) |
|
elif isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) and (node.targets[0].id == '__all__'): |
|
nodes_to_remove.append(node) |
|
transformer = DeleteSpecificNodes(nodes_to_remove) |
|
new_tree = transformer.visit(tree) |
|
new_filename = os.path.basename(file_path) |
|
if new_filename == '__init__.py': |
|
new_filename = file_path.split('/')[-2] + '.py' |
|
new_file_path = os.path.join(folder_path, new_filename) |
|
with open(new_file_path, 'w', encoding='utf-8') as f: |
|
assert new_tree is not None |
|
f.write(ast.unparse(new_tree)) |
|
return new_files_to_process |
|
|
|
def edit_files_for_hf_compatibility(folder: str, flatten_imports_prefix: Sequence[str]=('llmfoundry',), remove_imports_prefix: Sequence[str]=('composer', 'omegaconf', 'llmfoundry.metrics')) -> None: |
|
"""Edit files to be compatible with Hugging Face Hub. |
|
|
|
Args: |
|
folder (str): The folder to process. |
|
flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',). |
|
remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening. |
|
Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics'). |
|
""" |
|
files_to_process = [os.path.join(folder, filename) for filename in os.listdir(folder) if filename.endswith('.py')] |
|
files_processed_and_queued = set(files_to_process) |
|
while len(files_to_process) > 0: |
|
to_process = files_to_process.pop() |
|
if os.path.isfile(to_process) and to_process.endswith('.py'): |
|
to_add = process_file(to_process, folder, flatten_imports_prefix, remove_imports_prefix) |
|
for file in to_add: |
|
if file not in files_processed_and_queued: |
|
files_to_process.append(file) |
|
files_processed_and_queued.add(file) |