File size: 5,024 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
import argparse
import glob
import hashlib
import json
import os
import re
from multiprocessing import Pool
from typing import List, Union

from mmengine.config import Config, ConfigDict


# from opencompass.utils import get_prompt_hash
# copied from opencompass.utils.get_prompt_hash, for easy use in ci
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
    """Get the hash of the prompt configuration.

    Args:
        dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
            configuration.

    Returns:
        str: The hash of the prompt configuration.
    """
    if isinstance(dataset_cfg, list):
        if len(dataset_cfg) == 1:
            dataset_cfg = dataset_cfg[0]
        else:
            hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
            hash_object = hashlib.sha256(hashes.encode())
            return hash_object.hexdigest()
    # for custom datasets
    if 'infer_cfg' not in dataset_cfg:
        dataset_cfg.pop('abbr', '')
        dataset_cfg.pop('path', '')
        d_json = json.dumps(dataset_cfg.to_dict(), sort_keys=True)
        hash_object = hashlib.sha256(d_json.encode())
        return hash_object.hexdigest()
    # for regular datasets
    if 'reader_cfg' in dataset_cfg.infer_cfg:
        # new config
        reader_cfg = dict(type='DatasetReader',
                          input_columns=dataset_cfg.reader_cfg.input_columns,
                          output_column=dataset_cfg.reader_cfg.output_column)
        dataset_cfg.infer_cfg.reader = reader_cfg
        if 'train_split' in dataset_cfg.infer_cfg.reader_cfg:
            dataset_cfg.infer_cfg.retriever[
                'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][
                    'train_split']
        if 'test_split' in dataset_cfg.infer_cfg.reader_cfg:
            dataset_cfg.infer_cfg.retriever[
                'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
        for k, v in dataset_cfg.infer_cfg.items():
            dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
    # A compromise for the hash consistency
    if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
        fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
        dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
    d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
    hash_object = hashlib.sha256(d_json.encode())
    return hash_object.hexdigest()


# Assuming get_hash is a function that computes the hash of a file
# from get_hash import get_hash
def get_hash(path):
    cfg = Config.fromfile(path)
    for k in cfg.keys():
        if k.endswith('_datasets'):
            return get_prompt_hash(cfg[k])[:6]
    print(f'Could not find *_datasets in {path}')
    return None


def check_and_rename(filepath):
    base_name = os.path.basename(filepath)
    match = re.match(r'(.*)_(gen|ppl|ll)_(.*).py', base_name)
    if match:
        dataset, mode, old_hash = match.groups()
        new_hash = get_hash(filepath)
        if not new_hash:
            return None, None
        if old_hash != new_hash:
            new_name = f'{dataset}_{mode}_{new_hash}.py'
            new_file = os.path.join(os.path.dirname(filepath), new_name)
            print(f'Rename {filepath} to {new_file}')
            return filepath, new_file
    return None, None


def update_imports(data):
    python_file, name_pairs = data
    for filepath, new_file in name_pairs:
        old_name = os.path.basename(filepath)[:-3]
        new_name = os.path.basename(new_file)[:-3]
        if not os.path.exists(python_file):
            return
        with open(python_file, 'r') as file:
            filedata = file.read()
        # Replace the old name with new name
        new_data = filedata.replace(old_name, new_name)
        if filedata != new_data:
            with open(python_file, 'w') as file:
                file.write(new_data)
            # print(f"Updated imports in {python_file}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('python_files', nargs='*')
    args = parser.parse_args()

    root_folder = 'configs/datasets'
    if args.python_files:
        python_files = [
            i for i in args.python_files if i.startswith(root_folder)
        ]
    else:
        python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)

    # Use multiprocessing to speed up the check and rename process
    with Pool(16) as p:
        name_pairs = p.map(check_and_rename, python_files)
    name_pairs = [pair for pair in name_pairs if pair[0] is not None]
    if not name_pairs:
        return
    with Pool(16) as p:
        p.starmap(os.rename, name_pairs)
    root_folder = 'configs'
    python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
    update_data = [(python_file, name_pairs) for python_file in python_files]
    with Pool(16) as p:
        p.map(update_imports, update_data)


if __name__ == '__main__':
    main()