File size: 14,362 Bytes
fdb2891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import logging
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
log = logging.getLogger(__name__)

class BinPackCollator:
    """Utility collator for packing to reduce padding."""

    def __init__(self, collator: Callable, target_batch_size: int, max_seq_len: int, pad_token_id: int, padding_side: Literal['left', 'right'], max_leftover_bins_to_keep: Optional[int]=None):
        self.base_collator = collator
        self.out_size = int(target_batch_size)
        self.max_seq_len = int(max_seq_len)
        self.pad_token_id = int(pad_token_id)
        self.padding_side = padding_side
        if self.out_size <= 0:
            raise ValueError(f'target_batch_size={target_batch_size!r} must be >0.')
        if self.max_seq_len <= 0:
            raise ValueError(f'max_seq_len={max_seq_len!r} must be >0.')
        if self.pad_token_id < 0:
            raise ValueError(f'pad_token_id={pad_token_id!r} must be >=0.')
        if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0:
            raise ValueError(f'max_leftover_bins_to_keep={max_leftover_bins_to_keep!r} must be >=0 or None.')
        self.max_leftover_bins_to_keep = max_leftover_bins_to_keep
        self.n_packed_tokens = 0
        self.n_total_tokens = 0
        self.n_packed_examples = 0
        self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []

    @property
    def waste(self) -> float:
        return 1 - self.n_packed_tokens / self.n_total_tokens

    @property
    def efficiency(self) -> float:
        return self.n_packed_tokens / (self.max_seq_len * self.n_packed_examples)

    def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        batch = self.base_collator(examples)
        return self.pack(batch)

    def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        assert 'attention_mask' in batch
        assert 'input_ids' in batch
        for key in batch.keys():
            assert key in ['input_ids', 'labels', 'attention_mask', 'sequence_id']
        sizes, trimmed_examples = _trim_batch(batch)
        return self._pack_trimmed_examples(trimmed_examples, sizes)

    def _pack_trimmed_examples(self, trimmed_examples: List[Dict[str, torch.Tensor]], sizes: List[int]) -> Dict[str, torch.Tensor]:
        """Packs trimmed examples into fixed-size bins and repads them.

        Args:
            trimmed_examples (List[Dict[str, torch.Tensor]]): A list of trimmed examples.
            sizes (List[int]): The sizes of the trimmed examples.

        Returns:
            Dict[str, torch.Tensor]: A batch of repadded examples ready for processing
        """
        packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing(sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, max_bin_size=self.max_seq_len, existing_bins=self._leftover_bins)
        self.n_packed_tokens += n_packed_tokens
        self.n_total_tokens += n_total_tokens
        self.n_packed_examples += self.out_size
        self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
        batch = _repad(packed_examples, max_seq_len=self.max_seq_len, pad_token_id=self.pad_token_id, padding_side=self.padding_side)
        return batch

def _trim_batch(batch: Dict[str, torch.Tensor]) -> Tuple[List[int], List[Dict[str, torch.Tensor]]]:
    """Trims padding off all examples in batch.

    Args:
        batch (Dict[str, torch.Tensor]): Batch of padded data with tensors as values.

    Returns:
        A tuple with unpadded lengths of examples and a list of each trimmed example from the batch.
    """
    sizes, trimmed_examples = ([], [])
    for idx in range(batch['attention_mask'].shape[0]):
        size, trimmed_example = _extract_trim_batch_idx(batch, idx)
        sizes.append(size)
        trimmed_examples.append(trimmed_example)
    return (sizes, trimmed_examples)

def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
    example = {k: v[idx] for k, v in batch.items()}
    keep = example['attention_mask'] == 1
    size = int(keep.sum())
    trim_example = {k: v[keep] for k, v in example.items()}
    trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])
    return (size, trim_example)

def _combine_in_place(example: Dict[str, torch.Tensor], add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if 'labels' in add_on:
        add_on['labels'][0] = -100
    for k in example.keys():
        if k == 'sequence_id':
            example[k] = torch.cat([example[k], add_on[k] + 1 + torch.max(example[k])])
        else:
            example[k] = torch.cat([example[k], add_on[k]])
    return example

def _first_fit_bin_packing(sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[str, torch.Tensor]]]]:
    bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins
    starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])
    sizes_and_examples = [(size, example) for size, example in zip(sizes, examples)]
    sorted_sizes_and_examples = sorted(sizes_and_examples, key=lambda x: x[0], reverse=True)
    required_num_examples = max(0, num_bins - len(bins))
    num_examples = len(sizes)
    if num_examples < required_num_examples:
        for size, example in sorted_sizes_and_examples:
            bins.append((size, example))
        total_bin_sizes = sum([bin_size for bin_size, _ in bins])
        total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
        total_example_sizes = sum(sizes)
        if total_new_bin_sizes != total_example_sizes:
            raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
        sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
        bin_sizes, packed_examples = ([], [])
        for bin_size, packed_example in sorted_bins:
            bin_sizes.append(bin_size)
            packed_examples.append(packed_example)
        return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])
    for i, (size, example) in enumerate(sorted_sizes_and_examples):
        required_num_examples = max(0, num_bins - len(bins))
        n_remaining = num_examples - i
        assert n_remaining >= required_num_examples
        if n_remaining == required_num_examples:
            bins.append((size, example))
            continue
        added = False
        for bidx in range(len(bins)):
            if bins[bidx][0] + size <= max_bin_size:
                bin_size, packed_example = bins.pop(bidx)
                bin_size = bin_size + size
                packed_example = _combine_in_place(packed_example, example)
                bins.append((bin_size, packed_example))
                added = True
                break
        if not added:
            bins.append((size, example))
    total_bin_sizes = sum([bin_size for bin_size, _ in bins])
    total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
    total_example_sizes = sum(sizes)
    if total_new_bin_sizes != total_example_sizes:
        raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
    sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
    bin_sizes, packed_examples = ([], [])
    for bin_size, packed_example in sorted_bins:
        bin_sizes.append(bin_size)
        packed_examples.append(packed_example)
    return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])

def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:

    def pad_tensor(tensor: torch.Tensor, pad_value: int):
        if len(tensor) == max_seq_len:
            return tensor
        t = torch.full((max_seq_len,), pad_value, dtype=tensor.dtype, device=tensor.device)
        if padding_side == 'left':
            t[-len(tensor):] = tensor
        elif padding_side == 'right':
            t[:len(tensor)] = tensor
        else:
            raise ValueError(f'Unknown padding_side={padding_side!r}')
        return t
    pad_vals = {'input_ids': pad_token_id, 'labels': -100, 'attention_mask': 0, 'sequence_id': -1}
    keys = packed_examples[0].keys()
    batch = {}
    for key in keys:
        batch[key] = torch.stack([pad_tensor(example[key], pad_vals[key]) for example in packed_examples])
    return batch

def auto_packing_ratio(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, num_packing_ratios: int=20) -> float:
    """Find a packing ratio that minimizes padding with zero waste.

    By packing examples, we can increase training efficiency, training on more data with less batches.
    However, in practice, the selected packing_ratio may produce some waste because profiling is done on only
    a subset of the dataset.

    We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to
    num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive.
    When a packing_ratio with non-zero waste is found, we stop and select the previous ratio,
    which has zero waste.

    Args:
        dataloader_cfg (DictConfig): The dataloader configuration for profiling.
        tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
        device_batch_size (int): The size of the batches (number of examples) per device.
        num_packing_ratio (int): The number of packing ratios to try.

    Returns:
        A packing ratio that minimizes padding while maintaining zero waste.
    """
    rng_state = reproducibility.get_rng_state()
    reproducibility.seed_all(0)
    max_seq_len = dataloader_cfg.dataset.max_seq_len
    if max_seq_len <= 100:
        return 1
    min_ratio = 1
    max_ratio = max_seq_len / 100
    profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, max_ratio, num_packing_ratios, device_batch_size)
    packing_ratio = 1
    for packing_ratio_candidate, _, waste in profiling_results:
        if waste is None or waste > 0:
            break
        packing_ratio = packing_ratio_candidate
    if dist.is_available() and dist.is_initialized():
        device = get_device(None)
        packing_ratio_tensor = device.tensor_to_device(torch.tensor(packing_ratio))
        dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN')
        packing_ratio = packing_ratio_tensor.item()
    reproducibility.load_rng_state(rng_state)
    return packing_ratio

def profile_packing(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, min_ratio: float, max_ratio: float, num_packing_ratios: int, device_batch_size: int) -> Iterable[Tuple[float, Optional[float], Optional[float]]]:
    """Generator function that profiles example packing across packing ratios.

    Args:
        dataloader_cfg (DictConfig): The dataloader configuration for profiling.
        tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
        min_ratio (float): Smallest packing_ratio to test. Must be >=1.
        max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`.
        num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try.
        device_batch_size (int): The size of the batches (number of examples) per device.

    Returns:
        An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio.
    """
    import copy
    from .dataloader import build_dataloader
    max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
    max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', None)
    dataloader_cfg = copy.deepcopy(dataloader_cfg)
    dataloader_cfg.dataset.packing_ratio = 1.0
    dataloader_cfg.drop_last = False
    dataloader_cfg.num_workers = 0
    dataloader_cfg.prefetch_factor = None
    dataloader_cfg.persistent_workers = False
    if dataloader_cfg.dataset.get('remote') is not None:
        dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name
    packing_ratios, raw_batch_sizes = ([], [])
    for packing_ratio in np.linspace(min_ratio, max_ratio, num_packing_ratios, endpoint=True):
        packing_ratio = np.round(10 * packing_ratio) / 10
        raw_batch_size = int(packing_ratio * device_batch_size)
        if raw_batch_size not in raw_batch_sizes:
            packing_ratios.append(packing_ratio)
            raw_batch_sizes.append(raw_batch_size)
    n_profile_examples = max(raw_batch_sizes) * 100
    train_dataspec = build_dataloader(dataloader_cfg, tokenizer, n_profile_examples)
    train_dataloader = train_dataspec.dataloader
    big_batch = next(iter(train_dataloader))
    sizes, trimmed_examples = _trim_batch(big_batch)

    def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
        trimmed_examples_copy = [te.copy() for te in trimmed_examples]
        packer = BinPackCollator(collator=lambda x: x, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=0, padding_side='left', max_leftover_bins_to_keep=max_leftovers_to_keep)
        for idx in range(0, len(trimmed_examples_copy), raw_batch_size):
            batch = trimmed_examples_copy[idx:idx + raw_batch_size]
            if len(batch) < device_batch_size:
                continue
            packer._pack_trimmed_examples(batch, sizes[idx:idx + raw_batch_size])
        if packer.n_packed_examples == 0:
            log.debug('No examples packed during profiling. Dataset is smaller than device batch size.')
            return (None, None)
        padding_percent = 100 * (1 - packer.efficiency)
        waste_percent = 100 * packer.waste
        return (padding_percent, waste_percent)
    for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
        padding, waste = profile(raw_batch_size)
        yield (packing_ratio, padding, waste)