|
"""Simple Dataset Reader.""" |
|
|
|
import random |
|
from typing import Dict, List, Optional, Union |
|
|
|
import torch |
|
from datasets import Dataset, DatasetDict |
|
from transformers import AutoTokenizer |
|
|
|
from opencompass.openicl.icl_prompt_template import PromptTemplate |
|
from opencompass.registry import ICL_DATASET_READERS |
|
from opencompass.utils.types import (_check_dataset, _check_str, |
|
_check_type_list) |
|
|
|
|
|
@ICL_DATASET_READERS.register_module() |
|
class DatasetReader: |
|
"""In-conext Learning Dataset Reader Class Generate an DatasetReader |
|
instance through 'dataset'. |
|
|
|
Attributes: |
|
dataset (:obj:`Dataset` or :obj:`DatasetDict`): The dataset to be read. |
|
input_columns (:obj:`List[str]` or :obj:`str`): A list of column names |
|
(a string of column name) in the dataset that represent(s) the |
|
input field. |
|
output_column (:obj:`str`): A column name in the dataset that |
|
represents the prediction field. |
|
input_template (:obj:`PromptTemplate`, optional): An instance of the |
|
:obj:`PromptTemplate` class, used to format the input field |
|
content during the retrieval process. (in some retrieval methods) |
|
output_template (:obj:`PromptTemplate`, optional): An instance of the |
|
:obj:`PromptTemplate` class, used to format the output field |
|
content during the retrieval process. (in some learnable retrieval |
|
methods) |
|
train_split (str): The name of the training split. Defaults to 'train'. |
|
train_range (int or float or str, optional): The size of the partial |
|
training dataset to load. |
|
If None, the entire training dataset will be loaded. |
|
If int or float, the random partial dataset will be loaded with the |
|
specified size. |
|
If str, the partial dataset will be loaded with the |
|
specified index list (e.g. "[:100]" for the first 100 examples, |
|
"[100:200]" for the second 100 examples, etc.). Defaults to None. |
|
test_split (str): The name of the test split. Defaults to 'test'. |
|
test_range (int or float or str, optional): The size of the partial |
|
test dataset to load. |
|
If None, the entire test dataset will be loaded. |
|
If int or float, the random partial dataset will be loaded with the |
|
specified size. |
|
If str, the partial dataset will be loaded with the |
|
specified index list (e.g. "[:100]" for the first 100 examples, |
|
"[100:200]" for the second 100 examples, etc.). Defaults to None. |
|
""" |
|
dataset = None |
|
input_template = None |
|
output_template = None |
|
|
|
def __init__(self, |
|
dataset: Union[Dataset, DatasetDict, str], |
|
input_columns: Union[List[str], str], |
|
output_column: Optional[str], |
|
input_template: Optional[PromptTemplate] = None, |
|
output_template: Optional[PromptTemplate] = None, |
|
train_split: str = 'train', |
|
train_range: Optional[Union[int, float, str]] = None, |
|
test_split: str = 'test', |
|
test_range: Optional[Union[int, float, str]] = None) -> None: |
|
self.input_columns = _check_type_list(input_columns, [List, str]) |
|
if isinstance(self.input_columns, str): |
|
self.input_columns = self.input_columns.split() |
|
self.output_column = None |
|
if output_column: |
|
self.output_column = _check_str(output_column) |
|
|
|
train_range = _check_type_list(train_range, [None, int, float, str]) |
|
test_range = _check_type_list(test_range, [None, int, float, str]) |
|
|
|
if input_template is not None: |
|
self.input_template = PromptTemplate._check_prompt_template( |
|
input_template) |
|
if output_template is not None: |
|
self.output_template = PromptTemplate._check_prompt_template( |
|
output_template) |
|
|
|
self.dataset = _check_dataset(dataset) |
|
if isinstance(self.dataset, Dataset): |
|
self.dataset = DatasetDict({ |
|
'train': self.dataset, |
|
'test': self.dataset |
|
}) |
|
|
|
|
|
for origin_split, mapped_split, split_range in [[ |
|
train_split, 'train', train_range |
|
], [test_split, 'test', test_range]]: |
|
self.dataset[mapped_split] = load_partial_dataset( |
|
self.dataset[origin_split], size=split_range) |
|
|
|
def generate_input_field_prompt(self, entry: Dict) -> str: |
|
"""Generate a prompt for the input field based on the provided |
|
:obj:`entry` data. |
|
|
|
Args: |
|
entry (:obj:`Dict`): A piece of data to be used for generating the |
|
prompt. |
|
|
|
Returns: |
|
:obj:`str`: The generated prompt. |
|
""" |
|
prompt = None |
|
if self.input_template is None: |
|
prompt = ' '.join([str(entry[ctx]) for ctx in self.input_columns]) |
|
else: |
|
prompt = self.input_template.generate_item(entry) |
|
return prompt |
|
|
|
def generate_input_field_corpus(self, |
|
dataset: Union[Dataset, DatasetDict], |
|
split: Optional[str] = None) -> List[str]: |
|
"""Generate corpus for input field. |
|
|
|
Args: |
|
dataset (:obj:`Dataset` or :obj:`DatasetDict`): A |
|
:obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` |
|
instance. |
|
split (:obj:`str`, optional): The split of the dataset to use. If |
|
:obj:`None`, the entire dataset will be used. Defaults to |
|
``None``. |
|
|
|
Returns: |
|
:obj:`List[str]`: A list of generated input field prompts. |
|
""" |
|
if split is not None: |
|
dataset = dataset[split] |
|
corpus = [] |
|
for entry in dataset: |
|
corpus.append(self.generate_input_field_prompt(entry)) |
|
return corpus |
|
|
|
def generate_output_field_prompt(self, entry: Dict) -> str: |
|
"""Generate a prompt for the output field based on the provided |
|
:obj:`entry` data. |
|
|
|
Args: |
|
entry (:obj:`Dict`): A piece of data to be used for generating the |
|
prompt. |
|
|
|
Returns: |
|
:obj:`str`: The generated prompt. |
|
""" |
|
prompt = None |
|
if self.output_template is None: |
|
prompt = str(entry[self.output_column]) |
|
else: |
|
prompt = self.output_template.generate_item(entry) |
|
return prompt |
|
|
|
def generate_output_field_corpus(self, |
|
dataset: Union[Dataset, DatasetDict], |
|
split: Optional[str] = None) -> List[str]: |
|
"""Generate corpus for output field. |
|
|
|
Args: |
|
dataset (:obj:`Dataset` or :obj:`DatasetDict`): A |
|
:obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` |
|
instance. |
|
split (:obj:`str`, optional): The split of the dataset to use. |
|
If :obj:`None`, the entire dataset will be used. Defaults to |
|
``None``. |
|
|
|
Returns: |
|
:obj:`List[str]`: A list of generated output field prompts. |
|
""" |
|
if split is not None: |
|
dataset = dataset[split] |
|
corpus = [] |
|
for entry in dataset: |
|
corpus.append(self.generate_output_field_prompt(entry)) |
|
return corpus |
|
|
|
def generate_input_output_field_prompt(self, entry: Dict) -> str: |
|
"""Generate a prompt for the input-output field based on the |
|
provided:obj:`entry` data. |
|
|
|
Args: |
|
entry (:obj:`Dict`): A piece of data to be used for generating the |
|
prompt. |
|
|
|
Returns: |
|
:obj:`str`: The generated prompt. |
|
""" |
|
prompt = None |
|
if self.input_output_template is None: |
|
prompt = ' '.join([entry[ctx] for ctx in self.input_columns] + |
|
[str(entry[self.output_column])]) |
|
else: |
|
prompt = self.input_output_template.generate_item(entry) |
|
return prompt |
|
|
|
def _check_dataset_reader(obj) -> 'DatasetReader': |
|
if isinstance(obj, DatasetReader): |
|
return obj |
|
else: |
|
raise TypeError(f'Expected a DatasetReader object, but got {obj}') |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def __repr__(self): |
|
return (f'DatasetReader({{\n dataset: {self.dataset},' |
|
f'\n input_columns: {self.input_columns},\n' |
|
f' output_columns: {self.output_column}\n}})') |
|
|
|
|
|
def load_partial_dataset( |
|
dataset: Dataset, |
|
size: Optional[Union[int, float, str]] = None) -> Dataset: |
|
"""Load a partial dataset. |
|
|
|
Args: |
|
dataset (Dataset): A :obj:`datasets.Dataset` instance. |
|
size (int or float or (int, int), optional): The size of the partial |
|
dataset to load. If None, the entire dataset will be loaded. |
|
If int or float, the random partial dataset will be loaded with the |
|
specified size. If str, the partial dataset will be loaded with the |
|
specified index list (e.g. "[:100]" for the first 100 examples, |
|
"[100:200]" for the second 100 examples, etc.). Defaults to None. |
|
""" |
|
total_size = len(dataset) |
|
index_list = list(range(total_size)) |
|
if isinstance(size, (int, float)): |
|
if size >= total_size or size <= 0: |
|
return dataset |
|
if size > 0 and size < 1: |
|
size = int(size * total_size) |
|
rand = random.Random(x=size) |
|
rand.shuffle(index_list) |
|
dataset = dataset.select(index_list[:size]) |
|
elif isinstance(size, str): |
|
dataset = dataset.select(eval(f'index_list{size}')) |
|
return dataset |
|
|
|
|
|
class DatasetEncoder(torch.utils.data.Dataset): |
|
|
|
def __init__(self, |
|
datalist: List, |
|
model_name=None, |
|
tokenizer=None) -> None: |
|
self.datalist = datalist |
|
if model_name is None and tokenizer is None: |
|
raise ValueError('model_name and tokenizer could not both be None') |
|
if tokenizer is not None: |
|
self.tokenizer = tokenizer |
|
else: |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
self.tokenizer.padding_side = 'left' |
|
self.encode_dataset = [] |
|
self.init_dataset() |
|
self.datalist_length = len(self.encode_dataset) |
|
|
|
def init_dataset(self): |
|
for idx, data in enumerate(self.datalist): |
|
tokenized_data = self.tokenizer.encode_plus(data, |
|
truncation=True, |
|
return_tensors='pt', |
|
verbose=False) |
|
self.encode_dataset.append({ |
|
'input_ids': |
|
tokenized_data.input_ids[0], |
|
'attention_mask': |
|
tokenized_data.attention_mask[0], |
|
'metadata': { |
|
'id': idx, |
|
'len': len(tokenized_data.input_ids[0]), |
|
'text': data |
|
} |
|
}) |
|
|
|
def __len__(self): |
|
return self.datalist_length |
|
|
|
def __getitem__(self, idx): |
|
return self.encode_dataset[idx] |
|
|