|
|
|
import argparse |
|
import glob |
|
import hashlib |
|
import json |
|
import os |
|
import re |
|
from multiprocessing import Pool |
|
from typing import List, Union |
|
|
|
from mmengine.config import Config, ConfigDict |
|
|
|
|
|
|
|
|
|
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: |
|
"""Get the hash of the prompt configuration. |
|
|
|
Args: |
|
dataset_cfg (ConfigDict or list[ConfigDict]): The dataset |
|
configuration. |
|
|
|
Returns: |
|
str: The hash of the prompt configuration. |
|
""" |
|
if isinstance(dataset_cfg, list): |
|
if len(dataset_cfg) == 1: |
|
dataset_cfg = dataset_cfg[0] |
|
else: |
|
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg]) |
|
hash_object = hashlib.sha256(hashes.encode()) |
|
return hash_object.hexdigest() |
|
|
|
if 'infer_cfg' not in dataset_cfg: |
|
dataset_cfg.pop('abbr', '') |
|
dataset_cfg.pop('path', '') |
|
d_json = json.dumps(dataset_cfg.to_dict(), sort_keys=True) |
|
hash_object = hashlib.sha256(d_json.encode()) |
|
return hash_object.hexdigest() |
|
|
|
if 'reader_cfg' in dataset_cfg.infer_cfg: |
|
|
|
reader_cfg = dict(type='DatasetReader', |
|
input_columns=dataset_cfg.reader_cfg.input_columns, |
|
output_column=dataset_cfg.reader_cfg.output_column) |
|
dataset_cfg.infer_cfg.reader = reader_cfg |
|
if 'train_split' in dataset_cfg.infer_cfg.reader_cfg: |
|
dataset_cfg.infer_cfg.retriever[ |
|
'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][ |
|
'train_split'] |
|
if 'test_split' in dataset_cfg.infer_cfg.reader_cfg: |
|
dataset_cfg.infer_cfg.retriever[ |
|
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split |
|
for k, v in dataset_cfg.infer_cfg.items(): |
|
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] |
|
|
|
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever: |
|
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list') |
|
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list |
|
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) |
|
hash_object = hashlib.sha256(d_json.encode()) |
|
return hash_object.hexdigest() |
|
|
|
|
|
|
|
|
|
def get_hash(path): |
|
cfg = Config.fromfile(path) |
|
for k in cfg.keys(): |
|
if k.endswith('_datasets'): |
|
return get_prompt_hash(cfg[k])[:6] |
|
print(f'Could not find *_datasets in {path}') |
|
return None |
|
|
|
|
|
def check_and_rename(filepath): |
|
base_name = os.path.basename(filepath) |
|
match = re.match(r'(.*)_(gen|ppl|ll)_(.*).py', base_name) |
|
if match: |
|
dataset, mode, old_hash = match.groups() |
|
new_hash = get_hash(filepath) |
|
if not new_hash: |
|
return None, None |
|
if old_hash != new_hash: |
|
new_name = f'{dataset}_{mode}_{new_hash}.py' |
|
new_file = os.path.join(os.path.dirname(filepath), new_name) |
|
print(f'Rename {filepath} to {new_file}') |
|
return filepath, new_file |
|
return None, None |
|
|
|
|
|
def update_imports(data): |
|
python_file, name_pairs = data |
|
for filepath, new_file in name_pairs: |
|
old_name = os.path.basename(filepath)[:-3] |
|
new_name = os.path.basename(new_file)[:-3] |
|
if not os.path.exists(python_file): |
|
return |
|
with open(python_file, 'r') as file: |
|
filedata = file.read() |
|
|
|
new_data = filedata.replace(old_name, new_name) |
|
if filedata != new_data: |
|
with open(python_file, 'w') as file: |
|
file.write(new_data) |
|
|
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('python_files', nargs='*') |
|
args = parser.parse_args() |
|
|
|
root_folder = 'configs/datasets' |
|
if args.python_files: |
|
python_files = [ |
|
i for i in args.python_files if i.startswith(root_folder) |
|
] |
|
else: |
|
python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True) |
|
|
|
|
|
with Pool(16) as p: |
|
name_pairs = p.map(check_and_rename, python_files) |
|
name_pairs = [pair for pair in name_pairs if pair[0] is not None] |
|
if not name_pairs: |
|
return |
|
with Pool(16) as p: |
|
p.starmap(os.rename, name_pairs) |
|
root_folder = 'configs' |
|
python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True) |
|
update_data = [(python_file, name_pairs) for python_file in python_files] |
|
with Pool(16) as p: |
|
p.map(update_imports, update_data) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|