File size: 5,024 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/env python3
import argparse
import glob
import hashlib
import json
import os
import re
from multiprocessing import Pool
from typing import List, Union
from mmengine.config import Config, ConfigDict
# from opencompass.utils import get_prompt_hash
# copied from opencompass.utils.get_prompt_hash, for easy use in ci
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
"""Get the hash of the prompt configuration.
Args:
dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
configuration.
Returns:
str: The hash of the prompt configuration.
"""
if isinstance(dataset_cfg, list):
if len(dataset_cfg) == 1:
dataset_cfg = dataset_cfg[0]
else:
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
hash_object = hashlib.sha256(hashes.encode())
return hash_object.hexdigest()
# for custom datasets
if 'infer_cfg' not in dataset_cfg:
dataset_cfg.pop('abbr', '')
dataset_cfg.pop('path', '')
d_json = json.dumps(dataset_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()
# for regular datasets
if 'reader_cfg' in dataset_cfg.infer_cfg:
# new config
reader_cfg = dict(type='DatasetReader',
input_columns=dataset_cfg.reader_cfg.input_columns,
output_column=dataset_cfg.reader_cfg.output_column)
dataset_cfg.infer_cfg.reader = reader_cfg
if 'train_split' in dataset_cfg.infer_cfg.reader_cfg:
dataset_cfg.infer_cfg.retriever[
'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][
'train_split']
if 'test_split' in dataset_cfg.infer_cfg.reader_cfg:
dataset_cfg.infer_cfg.retriever[
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()
# Assuming get_hash is a function that computes the hash of a file
# from get_hash import get_hash
def get_hash(path):
cfg = Config.fromfile(path)
for k in cfg.keys():
if k.endswith('_datasets'):
return get_prompt_hash(cfg[k])[:6]
print(f'Could not find *_datasets in {path}')
return None
def check_and_rename(filepath):
base_name = os.path.basename(filepath)
match = re.match(r'(.*)_(gen|ppl|ll)_(.*).py', base_name)
if match:
dataset, mode, old_hash = match.groups()
new_hash = get_hash(filepath)
if not new_hash:
return None, None
if old_hash != new_hash:
new_name = f'{dataset}_{mode}_{new_hash}.py'
new_file = os.path.join(os.path.dirname(filepath), new_name)
print(f'Rename {filepath} to {new_file}')
return filepath, new_file
return None, None
def update_imports(data):
python_file, name_pairs = data
for filepath, new_file in name_pairs:
old_name = os.path.basename(filepath)[:-3]
new_name = os.path.basename(new_file)[:-3]
if not os.path.exists(python_file):
return
with open(python_file, 'r') as file:
filedata = file.read()
# Replace the old name with new name
new_data = filedata.replace(old_name, new_name)
if filedata != new_data:
with open(python_file, 'w') as file:
file.write(new_data)
# print(f"Updated imports in {python_file}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('python_files', nargs='*')
args = parser.parse_args()
root_folder = 'configs/datasets'
if args.python_files:
python_files = [
i for i in args.python_files if i.startswith(root_folder)
]
else:
python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
# Use multiprocessing to speed up the check and rename process
with Pool(16) as p:
name_pairs = p.map(check_and_rename, python_files)
name_pairs = [pair for pair in name_pairs if pair[0] is not None]
if not name_pairs:
return
with Pool(16) as p:
p.starmap(os.rename, name_pairs)
root_folder = 'configs'
python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
update_data = [(python_file, name_pairs) for python_file in python_files]
with Pool(16) as p:
p.map(update_imports, update_data)
if __name__ == '__main__':
main()
|