import logging import tempfile from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np import torch from transformers import PreTrainedTokenizerBase log = logging.getLogger(__name__) class BinPackCollator: """Utility collator for packing to reduce padding.""" def __init__(self, collator: Callable, target_batch_size: int, max_seq_len: int, pad_token_id: int, padding_side: Literal['left', 'right'], max_leftover_bins_to_keep: Optional[int]=None): self.base_collator = collator self.out_size = int(target_batch_size) self.max_seq_len = int(max_seq_len) self.pad_token_id = int(pad_token_id) self.padding_side = padding_side if self.out_size <= 0: raise ValueError(f'target_batch_size={target_batch_size!r} must be >0.') if self.max_seq_len <= 0: raise ValueError(f'max_seq_len={max_seq_len!r} must be >0.') if self.pad_token_id < 0: raise ValueError(f'pad_token_id={pad_token_id!r} must be >=0.') if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0: raise ValueError(f'max_leftover_bins_to_keep={max_leftover_bins_to_keep!r} must be >=0 or None.') self.max_leftover_bins_to_keep = max_leftover_bins_to_keep self.n_packed_tokens = 0 self.n_total_tokens = 0 self.n_packed_examples = 0 self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = [] @property def waste(self) -> float: return 1 - self.n_packed_tokens / self.n_total_tokens @property def efficiency(self) -> float: return self.n_packed_tokens / (self.max_seq_len * self.n_packed_examples) def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: batch = self.base_collator(examples) return self.pack(batch) def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: assert 'attention_mask' in batch assert 'input_ids' in batch for key in batch.keys(): assert key in ['input_ids', 'labels', 'attention_mask', 'sequence_id'] sizes, trimmed_examples = _trim_batch(batch) return self._pack_trimmed_examples(trimmed_examples, sizes) def _pack_trimmed_examples(self, trimmed_examples: List[Dict[str, torch.Tensor]], sizes: List[int]) -> Dict[str, torch.Tensor]: """Packs trimmed examples into fixed-size bins and repads them. Args: trimmed_examples (List[Dict[str, torch.Tensor]]): A list of trimmed examples. sizes (List[int]): The sizes of the trimmed examples. Returns: Dict[str, torch.Tensor]: A batch of repadded examples ready for processing """ packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing(sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, max_bin_size=self.max_seq_len, existing_bins=self._leftover_bins) self.n_packed_tokens += n_packed_tokens self.n_total_tokens += n_total_tokens self.n_packed_examples += self.out_size self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep] batch = _repad(packed_examples, max_seq_len=self.max_seq_len, pad_token_id=self.pad_token_id, padding_side=self.padding_side) return batch def _trim_batch(batch: Dict[str, torch.Tensor]) -> Tuple[List[int], List[Dict[str, torch.Tensor]]]: """Trims padding off all examples in batch. Args: batch (Dict[str, torch.Tensor]): Batch of padded data with tensors as values. Returns: A tuple with unpadded lengths of examples and a list of each trimmed example from the batch. """ sizes, trimmed_examples = ([], []) for idx in range(batch['attention_mask'].shape[0]): size, trimmed_example = _extract_trim_batch_idx(batch, idx) sizes.append(size) trimmed_examples.append(trimmed_example) return (sizes, trimmed_examples) def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: example = {k: v[idx] for k, v in batch.items()} keep = example['attention_mask'] == 1 size = int(keep.sum()) trim_example = {k: v[keep] for k, v in example.items()} trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids']) return (size, trim_example) def _combine_in_place(example: Dict[str, torch.Tensor], add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if 'labels' in add_on: add_on['labels'][0] = -100 for k in example.keys(): if k == 'sequence_id': example[k] = torch.cat([example[k], add_on[k] + 1 + torch.max(example[k])]) else: example[k] = torch.cat([example[k], add_on[k]]) return example def _first_fit_bin_packing(sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[str, torch.Tensor]]]]: bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins]) sizes_and_examples = [(size, example) for size, example in zip(sizes, examples)] sorted_sizes_and_examples = sorted(sizes_and_examples, key=lambda x: x[0], reverse=True) required_num_examples = max(0, num_bins - len(bins)) num_examples = len(sizes) if num_examples < required_num_examples: for size, example in sorted_sizes_and_examples: bins.append((size, example)) total_bin_sizes = sum([bin_size for bin_size, _ in bins]) total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes total_example_sizes = sum(sizes) if total_new_bin_sizes != total_example_sizes: raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.') sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) bin_sizes, packed_examples = ([], []) for bin_size, packed_example in sorted_bins: bin_sizes.append(bin_size) packed_examples.append(packed_example) return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]) for i, (size, example) in enumerate(sorted_sizes_and_examples): required_num_examples = max(0, num_bins - len(bins)) n_remaining = num_examples - i assert n_remaining >= required_num_examples if n_remaining == required_num_examples: bins.append((size, example)) continue added = False for bidx in range(len(bins)): if bins[bidx][0] + size <= max_bin_size: bin_size, packed_example = bins.pop(bidx) bin_size = bin_size + size packed_example = _combine_in_place(packed_example, example) bins.append((bin_size, packed_example)) added = True break if not added: bins.append((size, example)) total_bin_sizes = sum([bin_size for bin_size, _ in bins]) total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes total_example_sizes = sum(sizes) if total_new_bin_sizes != total_example_sizes: raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.') sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True) bin_sizes, packed_examples = ([], []) for bin_size, packed_example in sorted_bins: bin_sizes.append(bin_size) packed_examples.append(packed_example) return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]) def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]: def pad_tensor(tensor: torch.Tensor, pad_value: int): if len(tensor) == max_seq_len: return tensor t = torch.full((max_seq_len,), pad_value, dtype=tensor.dtype, device=tensor.device) if padding_side == 'left': t[-len(tensor):] = tensor elif padding_side == 'right': t[:len(tensor)] = tensor else: raise ValueError(f'Unknown padding_side={padding_side!r}') return t pad_vals = {'input_ids': pad_token_id, 'labels': -100, 'attention_mask': 0, 'sequence_id': -1} keys = packed_examples[0].keys() batch = {} for key in keys: batch[key] = torch.stack([pad_tensor(example[key], pad_vals[key]) for example in packed_examples]) return batch def auto_packing_ratio(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, num_packing_ratios: int=20) -> float: """Find a packing ratio that minimizes padding with zero waste. By packing examples, we can increase training efficiency, training on more data with less batches. However, in practice, the selected packing_ratio may produce some waste because profiling is done on only a subset of the dataset. We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive. When a packing_ratio with non-zero waste is found, we stop and select the previous ratio, which has zero waste. Args: dataloader_cfg (DictConfig): The dataloader configuration for profiling. tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. device_batch_size (int): The size of the batches (number of examples) per device. num_packing_ratio (int): The number of packing ratios to try. Returns: A packing ratio that minimizes padding while maintaining zero waste. """ rng_state = reproducibility.get_rng_state() reproducibility.seed_all(0) max_seq_len = dataloader_cfg.dataset.max_seq_len if max_seq_len <= 100: return 1 min_ratio = 1 max_ratio = max_seq_len / 100 profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, max_ratio, num_packing_ratios, device_batch_size) packing_ratio = 1 for packing_ratio_candidate, _, waste in profiling_results: if waste is None or waste > 0: break packing_ratio = packing_ratio_candidate if dist.is_available() and dist.is_initialized(): device = get_device(None) packing_ratio_tensor = device.tensor_to_device(torch.tensor(packing_ratio)) dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN') packing_ratio = packing_ratio_tensor.item() reproducibility.load_rng_state(rng_state) return packing_ratio def profile_packing(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, min_ratio: float, max_ratio: float, num_packing_ratios: int, device_batch_size: int) -> Iterable[Tuple[float, Optional[float], Optional[float]]]: """Generator function that profiles example packing across packing ratios. Args: dataloader_cfg (DictConfig): The dataloader configuration for profiling. tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling. min_ratio (float): Smallest packing_ratio to test. Must be >=1. max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`. num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try. device_batch_size (int): The size of the batches (number of examples) per device. Returns: An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio. """ import copy from .dataloader import build_dataloader max_seq_len = dataloader_cfg.dataset.get('max_seq_len') max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', None) dataloader_cfg = copy.deepcopy(dataloader_cfg) dataloader_cfg.dataset.packing_ratio = 1.0 dataloader_cfg.drop_last = False dataloader_cfg.num_workers = 0 dataloader_cfg.prefetch_factor = None dataloader_cfg.persistent_workers = False if dataloader_cfg.dataset.get('remote') is not None: dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name packing_ratios, raw_batch_sizes = ([], []) for packing_ratio in np.linspace(min_ratio, max_ratio, num_packing_ratios, endpoint=True): packing_ratio = np.round(10 * packing_ratio) / 10 raw_batch_size = int(packing_ratio * device_batch_size) if raw_batch_size not in raw_batch_sizes: packing_ratios.append(packing_ratio) raw_batch_sizes.append(raw_batch_size) n_profile_examples = max(raw_batch_sizes) * 100 train_dataspec = build_dataloader(dataloader_cfg, tokenizer, n_profile_examples) train_dataloader = train_dataspec.dataloader big_batch = next(iter(train_dataloader)) sizes, trimmed_examples = _trim_batch(big_batch) def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: trimmed_examples_copy = [te.copy() for te in trimmed_examples] packer = BinPackCollator(collator=lambda x: x, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=0, padding_side='left', max_leftover_bins_to_keep=max_leftovers_to_keep) for idx in range(0, len(trimmed_examples_copy), raw_batch_size): batch = trimmed_examples_copy[idx:idx + raw_batch_size] if len(batch) < device_batch_size: continue packer._pack_trimmed_examples(batch, sizes[idx:idx + raw_batch_size]) if packer.n_packed_examples == 0: log.debug('No examples packed during profiling. Dataset is smaller than device batch size.') return (None, None) padding_percent = 100 * (1 - packer.efficiency) waste_percent = 100 * packer.waste return (padding_percent, waste_percent) for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): padding, waste = profile(raw_batch_size) yield (packing_ratio, padding, waste)