import csv import json import os.path as osp from datasets import Dataset, DatasetDict from opencompass.registry import LOAD_DATASET from .base import BaseDataset @LOAD_DATASET.register_module() class CEvalDataset(BaseDataset): @staticmethod def load(path: str, name: str): dataset = {} for split in ['dev', 'val', 'test']: filename = osp.join(path, split, f'{name}_{split}.csv') with open(filename, encoding='utf-8') as f: reader = csv.reader(f) header = next(reader) for row in reader: item = dict(zip(header, row)) item.setdefault('explanation', '') item.setdefault('answer', '') dataset.setdefault(split, []).append(item) dataset = {i: Dataset.from_list(dataset[i]) for i in dataset} return DatasetDict(dataset) class CEvalDatasetClean(BaseDataset): # load the contamination annotations of CEval from # https://github.com/liyucheng09/Contamination_Detector @staticmethod def load_contamination_annotations(path, split='val'): import requests assert split == 'val', 'Now we only have annotations for val set' annotation_cache_path = osp.join( path, split, 'ceval_contamination_annotations.json') if osp.exists(annotation_cache_path): with open(annotation_cache_path, 'r') as f: annotations = json.load(f) return annotations link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc/ceval_annotations.json' # noqa annotations = json.loads(requests.get(link_of_annotations).text) with open(annotation_cache_path, 'w') as f: json.dump(annotations, f) return annotations @staticmethod def load(path: str, name: str): dataset = {} for split in ['dev', 'val', 'test']: if split == 'val': annotations = CEvalDatasetClean.load_contamination_annotations( path, split) filename = osp.join(path, split, f'{name}_{split}.csv') with open(filename, encoding='utf-8') as f: reader = csv.reader(f) header = next(reader) for row_index, row in enumerate(reader): item = dict(zip(header, row)) item.setdefault('explanation', '') item.setdefault('answer', '') if split == 'val': row_id = f'{name}-{row_index}' if row_id in annotations: item['is_clean'] = annotations[row_id][0] else: item['is_clean'] = 'not labeled' dataset.setdefault(split, []).append(item) dataset = {i: Dataset.from_list(dataset[i]) for i in dataset} return DatasetDict(dataset)