lolcats / src /dataloaders /alpaca_clean.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
"""
Alpaca training dataloaders
We adopt the original prompt template; goes something like:
```
Below is an instruction that describes a task.
Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
{response}
```
See `PROMPT_DICT` for more.
"""
from functools import partial
from os.path import join
from datasets import load_metric, load_dataset
from .utils import (
get_lm_loader, get_seq2seq_loader,
convert_to_hf_dataset,
get_tokenizer_from_config,
download_scrolls_metric as download_metric
)
from .utils.packing import ConcatDataset
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}
def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
preprocess_config: dict, **loader_kwargs: any):
"""
Shared function to load dataset from experiment config
-> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml
"""
# Misc. setup
cache_dir = dataset_config['cache_dir']
input_len = dataset_config['chunk_size']
concat_data = dataset_config['concat_data']
tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
tokenizer_name = tokenizer_name.split('/')[-1]
# save_path = join(cache_dir, f'{name}_{tokenizer_name}')
# Setup tokenizer
tokenizer = get_tokenizer_from_config(pretrained_model_config)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
tokenizer.padding_side = 'left' # for decoder-only generation
# Get initial data
ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs']
dataset = load_dataset(
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}
)
if dataset_config['name'] == 'samsum': # hack
dataset = dataset.rename_column('dialogue', 'input')
dataset = dataset.rename_column('summary', 'output')
_instruction = 'Summarize this dialogue.'
for split in dataset.keys():
dataset[split] = dataset[split].add_column(
'instruction', [_instruction] * len(dataset[split])
)
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
dataset = train_set # hack to work with below code
else:
dataset = dataset['train']
train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir)
val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
# Convert to dicts of {input_ids, attention_mask, labels}
train_set = train_set.map(
partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
remove_columns=list(dataset.features),) # load_from_cache_file=False)
val_set = val_set.map(
partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
remove_columns=list(dataset.features),) # load_from_cache_file=False)
test_set = test_set.map(
partial(template_and_tokenize, tokenizer=tokenizer, include_label=False),
remove_columns=list(dataset.features),) # load_from_cache_file=False)
# Chunk together train and val sets
if concat_data:
train_set = ConcatDataset(train_set, chunk_size=input_len)
val_set = ConcatDataset(val_set, chunk_size=input_len)
# Get dataloaders
dataloaders = {
'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
}
# Evaluation metric
try:
metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
except Exception as e:
print(f'Error loading metric: {e}')
metric = None
# Finishing touches
for k, v in dataloaders.items(): # Make tokenizer accessible
dataloaders[k].dataset.tokenizer = tokenizer
dataloaders[k].dataset.metric = metric
return dataloaders
def template_and_tokenize(sample, tokenizer, include_label: bool = True):
"""
Format dataset context and answers into single-sequence prompts
"""
if sample.get('input', '') == '':
prompt = PROMPT_DICT["prompt_no_input"].format_map(sample)
else:
prompt = PROMPT_DICT["prompt_input"].format_map(sample)
prompt = tokenizer.encode(prompt, add_special_tokens=True)
if include_label:
answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
add_special_tokens=False)
target = None
else:
answer = []
target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
add_special_tokens=False)
input_ids = prompt + answer
attn_mask = [1] * len(input_ids)
sample = {
"input_ids": input_ids,
"attention_mask" : attn_mask,
"labels": [-100] * len(prompt) + answer if include_label else target,
}
return sample