|
from datasets import DatasetDict, load_dataset |
|
|
|
from opencompass.registry import LOAD_DATASET |
|
|
|
from .base import BaseDataset |
|
|
|
|
|
@LOAD_DATASET.register_module() |
|
class CivilCommentsDataset(BaseDataset): |
|
|
|
@staticmethod |
|
def load(**kwargs): |
|
train_dataset = load_dataset(**kwargs, split='train') |
|
test_dataset = load_dataset(**kwargs, split='test') |
|
|
|
def pre_process(example): |
|
example['label'] = int(example['toxicity'] >= 0.5) |
|
example['choices'] = ['no', 'yes'] |
|
return example |
|
|
|
def remove_columns(dataset): |
|
return dataset.remove_columns([ |
|
'severe_toxicity', 'obscene', 'threat', 'insult', |
|
'identity_attack', 'sexual_explicit' |
|
]) |
|
|
|
train_dataset = remove_columns(train_dataset) |
|
test_dataset = remove_columns(test_dataset) |
|
test_dataset = test_dataset.shuffle(seed=42) |
|
test_dataset = test_dataset.select(list(range(10000))) |
|
test_dataset = test_dataset.map(pre_process) |
|
|
|
return DatasetDict({ |
|
'train': train_dataset, |
|
'test': test_dataset, |
|
}) |
|
|