import json from datasets import Dataset, DatasetDict from opencompass.registry import LOAD_DATASET from .base import BaseDataset @LOAD_DATASET.register_module() class dropDataset(BaseDataset): @staticmethod def get_answers(validated_answers): answers = [] for answer_item in validated_answers: if answer_item['number']: answers.append(answer_item['number']) elif any(answer_item['date'][i] for i in ['day', 'month', 'year']): d = [answer_item['date'][i] for i in ['day', 'month', 'year']] answers.append(' '.join(d).strip()) else: for span in answer_item['spans']: answers.append(span) answers = list(set(answers)) return answers @staticmethod def load(path, only_number=True): with open(path, 'r', encoding='utf-8') as f: lines = json.load(f) dataset_list = [] for line in lines.values(): for qa_pair in line['qa_pairs']: validated_answers = qa_pair['validated_answers'] if only_number and not any(i['number'] for i in validated_answers): continue item = { 'prompt': line['passage'], 'question': qa_pair['question'], 'answers': dropDataset.get_answers(validated_answers), } dataset_list.append(item) dataset_list = Dataset.from_list(dataset_list) return DatasetDict({'validation': dataset_list})