TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
1.64 kB
import json
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class dropDataset(BaseDataset):
@staticmethod
def get_answers(validated_answers):
answers = []
for answer_item in validated_answers:
if answer_item['number']:
answers.append(answer_item['number'])
elif any(answer_item['date'][i] for i in ['day', 'month', 'year']):
d = [answer_item['date'][i] for i in ['day', 'month', 'year']]
answers.append(' '.join(d).strip())
else:
for span in answer_item['spans']:
answers.append(span)
answers = list(set(answers))
return answers
@staticmethod
def load(path, only_number=True):
with open(path, 'r', encoding='utf-8') as f:
lines = json.load(f)
dataset_list = []
for line in lines.values():
for qa_pair in line['qa_pairs']:
validated_answers = qa_pair['validated_answers']
if only_number and not any(i['number']
for i in validated_answers):
continue
item = {
'prompt': line['passage'],
'question': qa_pair['question'],
'answers': dropDataset.get_answers(validated_answers),
}
dataset_list.append(item)
dataset_list = Dataset.from_list(dataset_list)
return DatasetDict({'validation': dataset_list})