Spaces:
Sleeping
Sleeping
File size: 15,006 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py
from itertools import chain
from pathlib import Path
import pickle
from typing import Any, List, Union
import subprocess
import mmap
from multiprocessing.shared_memory import SharedMemory
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
from pytorch_lightning import LightningDataModule
from src.datamodules.datasets.lm_dataset import LMDataset
from src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler
from src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler
from src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY
from src.utils.utils import get_logger
logger = get_logger()
# https://github.com/numpy/numpy/issues/18294
class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array
def __new__(cls, input_array, shm=None):
obj = np.asarray(input_array).view(cls)
obj.shm = shm
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.shm = getattr(obj, 'shm', None)
class LMDataModule(LightningDataModule):
def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024,
cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True,
detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,
shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,
fast_forward_epochs=None, fast_forward_batches=None,
use_shmem=True):
super().__init__()
self.dataset_name = dataset_name
self.dataset_config_name = dataset_config_name
self.tokenizer_name = tokenizer_name
self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser()
self.max_length = max_length
self.val_ratio = val_ratio
self.val_split_seed = val_split_seed
self.val_only = val_only
self.add_eos = add_eos
self.detokenize = detokenize
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
if fault_tolerant:
assert self.shuffle
self.fault_tolerant = fault_tolerant
if ddp:
assert fault_tolerant
self.ddp = ddp
self.fast_forward_epochs = fast_forward_epochs
self.fast_forward_batches = fast_forward_batches
if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:
assert ddp and fault_tolerant
self.use_shmem = use_shmem
if self.use_shmem:
assert cache_dir is not None
def prepare_data(self):
if self.cache_dir is None: # Just download the dataset
load_dataset(self.dataset_name, self.dataset_config_name)
else: # Process the dataset and save it
self.process_dataset()
def setup(self, stage=None):
if stage == 'test' and hasattr(self, 'dataset_test'):
return
concat_ids, self.tokenizer = self.process_dataset()
self.vocab_size = len(self.tokenizer)
# Create all splits
self.dataset_train, self.dataset_val, self.dataset_test = [
LMDataset(concat_ids[split], seq_len=self.max_length)
for split in ['train', 'validation', 'test']
]
def process_dataset(self):
cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name
if cache_dir is not None:
if cache_dir.is_dir():
return self._load_from_cache(cache_dir)
raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name)
# https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py
if 'validation' not in raw_datasets:
assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets"
raw_datasets = raw_datasets["train"].train_test_split(
test_size=self.val_ratio, seed=self.val_split_seed,
shuffle=True # Otherwise test will be at the end of the dataset
)
raw_datasets['validation'] = raw_datasets['test']
if self.val_only: # Should only be used for evaluation, not for training
raw_datasets['train'] = raw_datasets['validation']
# [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse
# (GPT2-small val ppl after 10 epochs ~22 -> ~25)
# However, it's useful for zero-shot transfer from Openwebtext,
# as after detokenization it's closer to Openwebtext's format.
# https://github.com/stanford-crfm/mistral/issues/12
if self.detokenize:
if self.dataset_name in DATASET_TOKENIZATION_REGISTRY:
detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name]
raw_datasets = raw_datasets.map(
lambda example: {'text': detokenizer(example['text'])},
num_proc=max(self.num_workers, 1),
desc='Running detokenizer on dataset'
)
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
# [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends
# with '\n', and there are no other '\n' in the examples.
# assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t])
# Add EOS token to the end of the text if the text is not empty
# https://github.com/stanford-crfm/mistral/issues/91
# https://github.com/stanford-crfm/mistral/pull/98
if self.add_eos:
add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq
add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs]
tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name]))
else:
tokenize = lambda example: tokenizer(example[text_column_name])
# tokenized_datasets = raw_datasets.map(
# tokenize,
# batched=True,
# num_proc=max(self.num_workers, 1),
# remove_columns=column_names,
# desc="Running tokenizer on dataset",
# )
dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32
def tokenize_concat(examples):
# We just need 'input_ids', not 'attention_mask' (since it's all 1)
input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype)
# Need to return a list since we're doing batched processing
return {'input_ids': [input_ids], 'len': [len(input_ids)]}
tokenized_datasets = raw_datasets.map(
tokenize_concat,
batched=True,
num_proc=max(self.num_workers, 1),
remove_columns=column_names,
desc="Running tokenizer on dataset",
)
if self.use_shmem:
# Concatenate all input_ids into an array in shared memory
def write_ids_to_shm(example, shm_name, array_len):
shm = SharedMemory(name=shm_name)
shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)
start_idx = example['len_offset'] - len(example['input_ids'])
shm_arr[start_idx:example['len_offset']] = example['input_ids']
shm.close()
concat_ids = {}
for name, ds in tokenized_datasets.items():
tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))
array_len = tokenized_datasets[name][-1]['len_offset']
shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize)
shm_name = shm.name
tokenized_datasets[name].map(
write_ids_to_shm,
fn_kwargs={'shm_name': shm_name, 'array_len': array_len},
batched=False,
num_proc=max(self.num_workers, 1),
desc="Concatenating examples",
)
shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)
# We need to keep a reference to the shared memory, otherwise it gets garbage-collected
# when it goes out of scope, and that memory is gone.
# https://github.com/numpy/numpy/issues/18294
concat_ids[name] = SHMArray(shm_arr, shm=shm)
else:
# Use disk
concat_ids = {}
assert cache_dir is not None
cache_dir.mkdir(parents=True, exist_ok=True)
def write_ids_to_disk(example, filename):
with open(filename, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0)
start_idx = example['len_offset'] - len(example['input_ids'])
array_len = len(example['input_ids'])
arr = np.ndarray((array_len,), dtype=dtype, buffer=mm,
offset=np.dtype(dtype).itemsize * start_idx)
arr[:] = example['input_ids']
mm.flush()
for name, ds in tokenized_datasets.items():
tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))
array_len = tokenized_datasets[name][-1]['len_offset']
filename = cache_dir / f'{name}.bin'
# Need to create the file with this specific size first
# https://ostechnix.com/create-files-certain-size-linux/
subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize),
str(filename)], check=True)
tokenized_datasets[name].map(
write_ids_to_disk,
fn_kwargs={'filename': filename},
batched=False,
num_proc=max(self.num_workers, 1),
desc="Concatenating examples",
)
concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,))
if cache_dir is not None:
self._save_to_cache(concat_ids, tokenizer, cache_dir)
if not self.use_shmem:
for name in concat_ids:
Path(cache_dir / f'{name}.bin').unlink()
return concat_ids, tokenizer
def _save_to_cache(self, concat_ids, tokenizer, cache_dir):
cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f'Saving to cache at {str(cache_dir)}')
for k, v in concat_ids.items():
np.save(cache_dir / f'{k}.npy', v)
with open(cache_dir / 'tokenizer.pkl', 'wb') as f:
pickle.dump(tokenizer, f)
def _load_from_cache(self, cache_dir):
assert cache_dir.is_dir()
logger.info(f'Load from cache at {str(cache_dir)}')
concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r')
for split in ['train', 'validation', 'test']}
with open(cache_dir / 'tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
return concat_ids, tokenizer
@property
def _cache_dir_name(self):
return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}'
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.shuffle and self.fault_tolerant:
shuffle = False
sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp
else RandomFaultTolerantSampler(self.dataset_train))
# TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now
# We assume that it's being resumed with the same number of GPUs
if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:
sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': self.fast_forward_batches * self.batch_size
})
else:
shuffle = self.shuffle
sampler = None
return self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, sampler=sampler)
def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval)
def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval)
def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,
sampler=None) -> DataLoader:
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=1, # Data is already in memory, we don't need many workers
shuffle=shuffle,
sampler=sampler,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
# persistent_workers=True
)
def load_state_dict(self, checkpoint):
if self.fault_tolerant:
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
# TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
# behind, so we're using the optimizer's progress. This is set correctly in seq.py.
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
# At this point the train loader hasn't been constructed yet
|