File size: 3,015 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import csv
import json
import os.path as osp
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class CEvalDataset(BaseDataset):
@staticmethod
def load(path: str, name: str):
dataset = {}
for split in ['dev', 'val', 'test']:
filename = osp.join(path, split, f'{name}_{split}.csv')
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
item = dict(zip(header, row))
item.setdefault('explanation', '')
item.setdefault('answer', '')
dataset.setdefault(split, []).append(item)
dataset = {i: Dataset.from_list(dataset[i]) for i in dataset}
return DatasetDict(dataset)
class CEvalDatasetClean(BaseDataset):
# load the contamination annotations of CEval from
# https://github.com/liyucheng09/Contamination_Detector
@staticmethod
def load_contamination_annotations(path, split='val'):
import requests
assert split == 'val', 'Now we only have annotations for val set'
annotation_cache_path = osp.join(
path, split, 'ceval_contamination_annotations.json')
if osp.exists(annotation_cache_path):
with open(annotation_cache_path, 'r') as f:
annotations = json.load(f)
return annotations
link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc/ceval_annotations.json' # noqa
annotations = json.loads(requests.get(link_of_annotations).text)
with open(annotation_cache_path, 'w') as f:
json.dump(annotations, f)
return annotations
@staticmethod
def load(path: str, name: str):
dataset = {}
for split in ['dev', 'val', 'test']:
if split == 'val':
annotations = CEvalDatasetClean.load_contamination_annotations(
path, split)
filename = osp.join(path, split, f'{name}_{split}.csv')
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f)
header = next(reader)
for row_index, row in enumerate(reader):
item = dict(zip(header, row))
item.setdefault('explanation', '')
item.setdefault('answer', '')
if split == 'val':
row_id = f'{name}-{row_index}'
if row_id in annotations:
item['is_clean'] = annotations[row_id][0]
else:
item['is_clean'] = 'not labeled'
dataset.setdefault(split, []).append(item)
dataset = {i: Dataset.from_list(dataset[i]) for i in dataset}
return DatasetDict(dataset)
|