TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
692 Bytes
from abc import abstractstaticmethod
from typing import Dict, Optional, Union
from datasets import Dataset, DatasetDict
from opencompass.openicl import DatasetReader
class BaseDataset:
def __init__(self, reader_cfg: Optional[Dict] = {}, **kwargs):
self.dataset = self.load(**kwargs)
self._init_reader(**reader_cfg)
def _init_reader(self, **kwargs):
self.reader = DatasetReader(self.dataset, **kwargs)
@property
def train(self):
return self.reader.dataset['train']
@property
def test(self):
return self.reader.dataset['test']
@abstractstaticmethod
def load(**kwargs) -> Union[Dataset, DatasetDict]:
pass