|
import logging |
|
import os |
|
from typing import Tuple, Union |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from transformers import PreTrainedTokenizerBase |
|
from .collator import Seq2SeqFinetuningCollator, validate_target_settings |
|
from .tasks import DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, dataset_constructor |
|
from .packing import BinPackCollator, auto_packing_ratio |
|
from .text_data import build_streams, get_tokens_per_batch_func |
|
from .exceptions import MissingHuggingFaceURLSplitError, NotEnoughDatasetSamplesError |
|
log = logging.getLogger(__name__) |
|
_HF_IGNORE_INDEX = -100 |
|
_DEFAULT_TARGET_RESPONSES = 'last' |
|
_DEFAULT_TARGET_PROMPTS = 'none' |
|
|
|
def build_finetuning_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec: |
|
"""Builds a finetuning dataloader for training or evaluating. |
|
|
|
The underlying dataset can be built through one of two code paths: |
|
1. As a HuggingFace dataset, via `datasets.load_dataset(...)` |
|
2. As a streaming dataset |
|
You will need to set slightly different dataset config fields depending |
|
on which you intend to use, as explained below. |
|
|
|
Args: |
|
cfg (DictConfig): An omegaconf dictionary used to configure the loader: |
|
cfg.name (str): The type of dataloader to build. Must = "finetuning". |
|
--- |
|
*** HuggingFace dataset config fields *** |
|
cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset |
|
to use. Can also be a remote http(s) directory or object store bucket |
|
containing the file {split}.jsonl in the format (prompt, response), |
|
in which case the builder will create a HuggingFace dataset. |
|
cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to |
|
pass to `datasets.load_dataset`, which can be used to load |
|
a dataset from local files. |
|
cfg.dataset.preprocessing_fn (str, optional): The name/import path of |
|
the preprocessing function to use for formatting the data examples. |
|
If ``None`` (default), the builder will use the preprocessing function |
|
registered under `hf_name` (see `tasks.py`), if one exists, |
|
otherwise it will skip preprocessing. |
|
If `preprocessing_fn` corresponds to a registered preprocessing |
|
function in `tasks.py`, the builder will use that. |
|
Otherwise, it will interpret `preprocessing_fn` as a |
|
"import.path:function_name" import path; e.g., it will call |
|
`from import.path import function_name` and use the imported |
|
function as the preprocessing function. |
|
*** Streaming dataset config fields *** |
|
cfg.dataset.remote (str, optional): Location of a MDS-formatted |
|
streaming dataset to use. Setting this will tell the builder |
|
to create a streaming dataset rather than a HuggingFace dataset. |
|
cfg.dataset.local (str, optional): Local path where remote data |
|
will be streamed to. Only valid if `cfg.dataset.remote` has |
|
also been set. |
|
*** Shared dataset configs fields *** |
|
cfg.dataset.max_seq_len (int): The maximum length of sequences |
|
in the batch. See :class:`Seq2SeqFinetuningCollator` docstring |
|
for details. |
|
cfg.dataset.decoder_only_format (bool): Whether to format the |
|
examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` |
|
docstring for details. |
|
cfg.dataset.target_responses (str): Which responses are used as training targets. |
|
Defaults to "last", meaning only the final response in multi-turn examples |
|
will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for |
|
details. |
|
cfg.dataset.target_prompts (str): Which prompts are used as training targets. |
|
Defaults to "none", meaning prompts are never used as training targets. |
|
See :class:`Seq2SeqFinetuningCollator` docstring for details. |
|
cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow |
|
the collator to trim padding. See :class:`Seq2SeqFinetuningCollator` |
|
docstring for details. Default: ``False``. |
|
cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes |
|
a collator wrapper that packs device_batch_size*packing_ratio |
|
raw examples into device_batch_size packed examples. This helps |
|
minimize padding while preserving sequence integrity. |
|
This adds `sequence_id` to the batch, which indicates which unique |
|
sequence each token belongs to. |
|
|
|
If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with |
|
zero waste is selected. |
|
In practice, this may result in > 0 waste because profiling is done on only a portion |
|
of the dataset. |
|
|
|
Note: Using this feature will not change device_batch_size but it |
|
will determine the number of raw examples consumed by the dataloader |
|
per batch. Some examples may be discarded if they do not fit when |
|
packing. |
|
Select packing_ratio **carefully** based on the dataset |
|
statistics, max_seq_len, and tolerance for discarding samples! |
|
The script `scripts/misc/profile_packing.py` can help |
|
you choose the best packing_ratio. |
|
cfg.dataset.shuffle (bool): Whether to shuffle the dataset. |
|
___ |
|
See :class:`StreamingFinetuningDataset` for info on other standard config |
|
options within `cfg.dataset` that will be passed as kwargs if |
|
using the streaming codepath. |
|
--- |
|
See :class:`DataLoader` for standard argument options to the pytorch |
|
dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc. |
|
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to |
|
prepare the data from raw text. Any missing sentinel tokens will |
|
be added by the collator. |
|
device_batch_size (int): The size of the batches (number of examples) |
|
that the dataloader will produce. |
|
|
|
Returns: |
|
A pytorch dataloader |
|
|
|
Note: |
|
You can run the script inside `scripts/misc/profile_packing.py` to quickly test the |
|
padding/waste rates for different `cfg.dataset.packing_ratio` choices, |
|
given a starting workload YAML. |
|
""" |
|
_validate_config(cfg.dataset) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
collate_fn, dataloader_batch_size = _build_collate_fn(cfg, tokenizer, device_batch_size) |
|
dataset = None |
|
sampler = None |
|
if cfg.dataset.get('remote') is not None or cfg.dataset.get('streams') is not None: |
|
streams = build_streams(cfg.dataset) |
|
dataset = dataset_constructor.build_from_streaming(tokenizer=tokenizer, streams=streams, local=cfg.dataset.get('local', None), remote=cfg.dataset.get('remote', None), split=cfg.dataset.get('split', None), download_retry=cfg.dataset.get('download_retry', 2), download_timeout=cfg.dataset.get('download_timeout', 60), validate_hash=cfg.dataset.get('validate_hash', None), keep_zip=cfg.dataset.get('keep_zip', False), epoch_size=cfg.dataset.get('epoch_size', None), predownload=cfg.dataset.get('predownload', None), cache_limit=cfg.dataset.get('cache_limit', None), partition_algo=cfg.dataset.get('partition_algo', 'relaxed'), num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, shuffle=cfg.dataset.get('shuffle', False), shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), shuffle_block_size=cfg.dataset.get('shuffle_block_size', None), sampling_method=cfg.dataset.get('sampling_method', 'balanced'), sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), max_seq_len=cfg.dataset.max_seq_len) |
|
else: |
|
dataset_name_or_path = cfg.dataset.hf_name |
|
split = cfg.dataset.get('split') |
|
if split is None: |
|
raise MissingHuggingFaceURLSplitError() |
|
backend, _, _ = parse_uri(dataset_name_or_path) |
|
if backend not in ['', None]: |
|
dataset_name_or_path = _download_remote_hf_dataset(remote_path=dataset_name_or_path, split=split) |
|
split = split.replace('-', '_') |
|
proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn') |
|
if isinstance(proto_preprocessing_fn, (dict, DictConfig)): |
|
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(dict(proto_preprocessing_fn)) |
|
else: |
|
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(proto_preprocessing_fn, dataset_name_or_path) |
|
dataset = dataset_constructor.build_from_hf(dataset_name=dataset_name_or_path, split=split, safe_load=cfg.dataset.get('safe_load', False), max_seq_len=cfg.dataset.max_seq_len, preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, target_prompts=cfg.dataset.get('target_prompts', _DEFAULT_TARGET_PROMPTS), target_responses=cfg.dataset.get('target_responses', _DEFAULT_TARGET_RESPONSES), decoder_only_format=cfg.dataset.decoder_only_format, hf_kwargs=cfg.dataset.get('hf_kwargs', {})) |
|
if cfg.drop_last: |
|
world_size = dist.get_world_size() |
|
minimum_dataset_size = world_size * dataloader_batch_size |
|
if hasattr(dataset, '__len__'): |
|
full_dataset_size = len(dataset) |
|
if full_dataset_size < minimum_dataset_size: |
|
raise NotEnoughDatasetSamplesError(dataset_name=cfg.dataset.hf_name, split=split, dataloader_batch_size=dataloader_batch_size, world_size=world_size, full_dataset_size=full_dataset_size, minimum_dataset_size=minimum_dataset_size) |
|
sampler = dist.get_sampler(dataset, drop_last=cfg.drop_last, shuffle=cfg.dataset.shuffle) |
|
assert dataset is not None |
|
dl = DataLoader(dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, drop_last=cfg.drop_last, sampler=sampler, num_workers=cfg.num_workers, pin_memory=cfg.get('pin_memory', True), prefetch_factor=cfg.get('prefetch_factor', 2), persistent_workers=cfg.get('persistent_workers', True), timeout=cfg.get('timeout', 0)) |
|
token_counting_func = get_tokens_per_batch_func() |
|
return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func) |
|
|
|
def _validate_config(dataset_cfg: DictConfig) -> None: |
|
"""Validates the dataset configuration. |
|
|
|
Makes sure that the dataset is properly configured for either |
|
a HuggingFace dataset or a streaming dataset. Must be valid for one or |
|
the other. |
|
|
|
Args: |
|
dataset_cfg (DictConfig): The dataset configuration to be validated. |
|
|
|
Raises: |
|
ValueError: If the dataset configuration does not meet the requirements. |
|
""" |
|
if dataset_cfg.get('hf_name') is not None: |
|
illegal_keys = ['local', 'remote'] |
|
discovered_illegal_keys = [] |
|
for key in illegal_keys: |
|
if dataset_cfg.get(key) is not None: |
|
discovered_illegal_keys.append('`' + key + '`') |
|
if discovered_illegal_keys: |
|
raise ValueError('The dataset config sets a value for `hf_name` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a streaming dataset, but ' + 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.') |
|
elif dataset_cfg.get('remote') is not None: |
|
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] |
|
discovered_illegal_keys = [] |
|
for key in illegal_keys: |
|
if dataset_cfg.get(key) is not None: |
|
discovered_illegal_keys.append('`' + key + '`') |
|
if discovered_illegal_keys: |
|
raise ValueError('The dataset config sets a value for `remote` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a HuggingFace dataset, but ' + 'setting `remote` instructs the dataset to build from a streaming dataset.') |
|
if dataset_cfg.get('local') is None: |
|
raise ValueError('Using a streaming dataset requires setting both `remote` and `local`, ' + 'but dataset.local is None.') |
|
elif dataset_cfg.get('streams') is not None: |
|
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] |
|
discovered_illegal_keys = [] |
|
for key in illegal_keys: |
|
if dataset_cfg.get(key) is not None: |
|
discovered_illegal_keys.append('`' + key + '`') |
|
if discovered_illegal_keys: |
|
raise ValueError('The dataset config sets a value for `streams` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Those keys are used when building from a HuggingFace dataset, but ' + 'setting `streams` instructs the dataset to build from a streaming dataset.') |
|
illegal_keys = ['remote', 'local'] |
|
discovered_illegal_keys = [] |
|
for key in illegal_keys: |
|
if dataset_cfg.get(key) is not None: |
|
discovered_illegal_keys.append('`' + key + '`') |
|
if discovered_illegal_keys: |
|
raise ValueError('The dataset config sets a value for `streams` as well as the ' + f"following keys: {', '.join(discovered_illegal_keys)}.\n" + 'Please either use single stream (set remote/local only) ' + 'or put remote/local under streams') |
|
else: |
|
raise ValueError('In the dataset config, you must set `hf_name` to use a HuggingFace ' + 'dataset, or set `remote` to use a streaming dataset, or set ' + '`streams` to use multiple streaming datasets, but all were None.') |
|
if dataset_cfg.get('max_seq_len') is None: |
|
raise ValueError('In the dataset config, you must set the `max_seq_len`') |
|
target_responses = str(dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES)).lower() |
|
target_prompts = str(dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS)).lower() |
|
decoder_only_format = dataset_cfg.decoder_only_format |
|
validate_target_settings(target_prompts, target_responses, decoder_only_format) |
|
|
|
def _download_remote_hf_dataset(remote_path: str, split: str) -> str: |
|
"""Downloads a dataset from a remote object store. |
|
|
|
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download |
|
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this |
|
dataset. |
|
|
|
The function also ensures synchronicity across multiple processes during the file download. It creates a signal |
|
file that is used to synchronize the start of the download across different processes. Once the download is |
|
completed, the function removes the signal file. |
|
|
|
Args: |
|
hf_name (str): The path of the HuggingFace dataset to download. |
|
split (str): The dataset split to download (e.g., 'train', 'validation', 'test'). |
|
|
|
Returns: |
|
A local directory path where the dataset files are stored. |
|
|
|
Raises: |
|
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions. |
|
""" |
|
hf_formatted_split = split.replace('-', '_') |
|
finetune_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_formatted_split if hf_formatted_split != 'data' else 'data_not') |
|
os.makedirs(finetune_dir, exist_ok=True) |
|
for extension in SUPPORTED_EXTENSIONS: |
|
name = f"{remote_path.strip('/')}/{split}{extension}" |
|
destination = str(os.path.abspath(os.path.join(finetune_dir, 'data', f'{hf_formatted_split}-00000-of-00001{extension}'))) |
|
signal_file_path = os.path.join(finetune_dir, f'.node_{dist.get_node_rank()}_local_rank0_completed') |
|
if dist.get_local_rank() == 0: |
|
try: |
|
get_file(path=name, destination=destination, overwrite=True) |
|
except FileNotFoundError as e: |
|
if extension == SUPPORTED_EXTENSIONS[-1]: |
|
files_searched = [f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}' for ext in SUPPORTED_EXTENSIONS] |
|
raise FileNotFoundError(f'Could not find a file with any of ' + f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + f'at {files_searched}') from e |
|
else: |
|
log.debug(f'Could not find {name}, looking for another extension') |
|
continue |
|
os.makedirs(os.path.dirname(signal_file_path), exist_ok=True) |
|
with open(signal_file_path, 'wb') as f: |
|
f.write(b'local_rank0_completed_download') |
|
with dist.local_rank_zero_download_and_wait(signal_file_path): |
|
dist.barrier() |
|
if dist.get_local_rank() == 0: |
|
os.remove(signal_file_path) |
|
dist.barrier() |
|
break |
|
return finetune_dir |
|
|
|
def _build_collate_fn(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: |
|
dataset_cfg = dataloader_cfg.dataset |
|
max_seq_len = dataset_cfg.max_seq_len |
|
collate_fn = Seq2SeqFinetuningCollator(tokenizer=tokenizer, max_seq_len=max_seq_len, decoder_only_format=dataset_cfg.decoder_only_format, target_responses=dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES), target_prompts=dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS), allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False)) |
|
packing_ratio = dataset_cfg.get('packing_ratio') |
|
max_leftover_bins_to_keep = dataset_cfg.get('max_leftover_bins_to_keep') |
|
if packing_ratio is None: |
|
if max_leftover_bins_to_keep is not None: |
|
raise ValueError('dataset.max_leftover_bins_to_keep has been defined, ' + 'but dataset.packing_ratio has not been set. Please set ' + 'the latter to turn on packing or remove the former from the config.') |
|
return (collate_fn, device_batch_size) |
|
if packing_ratio == 'auto': |
|
packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer, device_batch_size) |
|
if isinstance(packing_ratio, str): |
|
raise ValueError('dataset.packing_ratio must be a float or "auto", but it was set to ' + f'{packing_ratio}.') |
|
log.info(f'Using packing ratio {packing_ratio}') |
|
if packing_ratio == 1.0: |
|
return (collate_fn, device_batch_size) |
|
elif packing_ratio < 1.0: |
|
raise ValueError('packing_ratio must be >= 1, if supplied') |
|
if not dataset_cfg.decoder_only_format: |
|
raise NotImplementedError('On-the-fly packing is currently only supported for decoder-only formats.') |
|
collate_fn = BinPackCollator(collator=collate_fn, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=tokenizer.pad_token_id, padding_side=tokenizer.padding_side, max_leftover_bins_to_keep=max_leftover_bins_to_keep) |
|
n_examples_to_pack = int(device_batch_size * packing_ratio) |
|
return (collate_fn, n_examples_to_pack) |
|
if __name__ == '__main__': |
|
import torch |
|
from .utils import build_tokenizer |
|
cfg = om.create({'dataset': {'hf_name': 'tatsu-lab/alpaca', 'preprocessing_fn': 'llmfoundry.data.finetuning.tasks:alpaca_preprocessing_function', 'split': 'train', 'packing_ratio': 18.0, 'max_seq_len': 2048, 'decoder_only_format': True, 'allow_pad_trimming': False, 'num_canonical_nodes': 472, 'shuffle': True, 'target_responses': 'last', 'target_prompts': 'none'}, 'drop_last': False, 'num_workers': 0, 'pin_memory': False, 'prefetch_factor': None, 'persistent_workers': False, 'timeout': 0}) |
|
tokenizer_name = 'EleutherAI/gpt-neox-20b' |
|
tokenizer_kwargs = {'model_max_length': cfg.dataset.max_seq_len} |
|
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) |
|
device_batch_size = 1 |
|
dataloader = build_finetuning_dataloader(cfg, tokenizer, device_batch_size).dataloader |
|
packing = cfg.dataset.get('packing_ratio') is not None |
|
for i, batch in enumerate(dataloader): |
|
if i >= 5: |
|
break |
|
print(f'-----Batch {i}-----') |
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
print(k, v.shape) |
|
else: |
|
print(k, v) |
|
for j in range(device_batch_size): |
|
print(f'--- Sample {j} ---') |
|
if cfg.dataset.decoder_only_format: |
|
if packing: |
|
for subseq in range(int(batch['sequence_id'][j].max()) + 1): |
|
is_subseq = batch['sequence_id'][j] == subseq |
|
print('\x1b[93m{}\x1b[00m\n'.format('INPUT IDS:'), tokenizer.decode(batch['input_ids'][j, torch.logical_and(is_subseq, batch['attention_mask'][j] == 1)], skip_special_tokens=False, clean_up_tokenization_spaces=True)) |
|
print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['input_ids'][j, torch.logical_and(is_subseq, batch['labels'][j] != _HF_IGNORE_INDEX)], skip_special_tokens=False, clean_up_tokenization_spaces=True)) |
|
else: |
|
print('\x1b[93m{}\x1b[00m\n'.format('INPUT IDS:'), tokenizer.decode(batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True)) |
|
print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['input_ids'][j, batch['labels'][j] != _HF_IGNORE_INDEX], skip_special_tokens=False, clean_up_tokenization_spaces=True)) |
|
else: |
|
print('\x1b[92m{}\x1b[00m\n'.format('CONTEXT: '), tokenizer.decode(batch['input_ids'][j, batch['attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True)) |
|
print('\x1b[91m{}\x1b[00m\n'.format('TARGET: '), tokenizer.decode(batch['labels'][j, batch['decoder_attention_mask'][j] == 1], skip_special_tokens=False, clean_up_tokenization_spaces=True)) |
|
print(' ') |