# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py from itertools import chain from pathlib import Path import pickle from typing import Any, List, Union import subprocess import mmap from multiprocessing.shared_memory import SharedMemory import numpy as np import torch from torch.utils.data.dataloader import DataLoader, Dataset from transformers import AutoTokenizer from datasets import load_dataset from pytorch_lightning import LightningDataModule from src.datamodules.datasets.lm_dataset import LMDataset from src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler from src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler from src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY from src.utils.utils import get_logger logger = get_logger() # https://github.com/numpy/numpy/issues/18294 class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array def __new__(cls, input_array, shm=None): obj = np.asarray(input_array).view(cls) obj.shm = shm return obj def __array_finalize__(self, obj): if obj is None: return self.shm = getattr(obj, 'shm', None) class LMDataModule(LightningDataModule): def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024, cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, fast_forward_epochs=None, fast_forward_batches=None, use_shmem=True): super().__init__() self.dataset_name = dataset_name self.dataset_config_name = dataset_config_name self.tokenizer_name = tokenizer_name self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser() self.max_length = max_length self.val_ratio = val_ratio self.val_split_seed = val_split_seed self.val_only = val_only self.add_eos = add_eos self.detokenize = detokenize self.batch_size = batch_size self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size self.num_workers = num_workers self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last if fault_tolerant: assert self.shuffle self.fault_tolerant = fault_tolerant if ddp: assert fault_tolerant self.ddp = ddp self.fast_forward_epochs = fast_forward_epochs self.fast_forward_batches = fast_forward_batches if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: assert ddp and fault_tolerant self.use_shmem = use_shmem if self.use_shmem: assert cache_dir is not None def prepare_data(self): if self.cache_dir is None: # Just download the dataset load_dataset(self.dataset_name, self.dataset_config_name) else: # Process the dataset and save it self.process_dataset() def setup(self, stage=None): if stage == 'test' and hasattr(self, 'dataset_test'): return concat_ids, self.tokenizer = self.process_dataset() self.vocab_size = len(self.tokenizer) # Create all splits self.dataset_train, self.dataset_val, self.dataset_test = [ LMDataset(concat_ids[split], seq_len=self.max_length) for split in ['train', 'validation', 'test'] ] def process_dataset(self): cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name if cache_dir is not None: if cache_dir.is_dir(): return self._load_from_cache(cache_dir) raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name) # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py if 'validation' not in raw_datasets: assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets" raw_datasets = raw_datasets["train"].train_test_split( test_size=self.val_ratio, seed=self.val_split_seed, shuffle=True # Otherwise test will be at the end of the dataset ) raw_datasets['validation'] = raw_datasets['test'] if self.val_only: # Should only be used for evaluation, not for training raw_datasets['train'] = raw_datasets['validation'] # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse # (GPT2-small val ppl after 10 epochs ~22 -> ~25) # However, it's useful for zero-shot transfer from Openwebtext, # as after detokenization it's closer to Openwebtext's format. # https://github.com/stanford-crfm/mistral/issues/12 if self.detokenize: if self.dataset_name in DATASET_TOKENIZATION_REGISTRY: detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name] raw_datasets = raw_datasets.map( lambda example: {'text': detokenizer(example['text'])}, num_proc=max(self.num_workers, 1), desc='Running detokenizer on dataset' ) tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True) # Preprocessing the datasets. # First we tokenize all the texts. column_names = raw_datasets["train"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends # with '\n', and there are no other '\n' in the examples. # assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t]) # Add EOS token to the end of the text if the text is not empty # https://github.com/stanford-crfm/mistral/issues/91 # https://github.com/stanford-crfm/mistral/pull/98 if self.add_eos: add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name])) else: tokenize = lambda example: tokenizer(example[text_column_name]) # tokenized_datasets = raw_datasets.map( # tokenize, # batched=True, # num_proc=max(self.num_workers, 1), # remove_columns=column_names, # desc="Running tokenizer on dataset", # ) dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 def tokenize_concat(examples): # We just need 'input_ids', not 'attention_mask' (since it's all 1) input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype) # Need to return a list since we're doing batched processing return {'input_ids': [input_ids], 'len': [len(input_ids)]} tokenized_datasets = raw_datasets.map( tokenize_concat, batched=True, num_proc=max(self.num_workers, 1), remove_columns=column_names, desc="Running tokenizer on dataset", ) if self.use_shmem: # Concatenate all input_ids into an array in shared memory def write_ids_to_shm(example, shm_name, array_len): shm = SharedMemory(name=shm_name) shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) start_idx = example['len_offset'] - len(example['input_ids']) shm_arr[start_idx:example['len_offset']] = example['input_ids'] shm.close() concat_ids = {} for name, ds in tokenized_datasets.items(): tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) array_len = tokenized_datasets[name][-1]['len_offset'] shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize) shm_name = shm.name tokenized_datasets[name].map( write_ids_to_shm, fn_kwargs={'shm_name': shm_name, 'array_len': array_len}, batched=False, num_proc=max(self.num_workers, 1), desc="Concatenating examples", ) shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) # We need to keep a reference to the shared memory, otherwise it gets garbage-collected # when it goes out of scope, and that memory is gone. # https://github.com/numpy/numpy/issues/18294 concat_ids[name] = SHMArray(shm_arr, shm=shm) else: # Use disk concat_ids = {} assert cache_dir is not None cache_dir.mkdir(parents=True, exist_ok=True) def write_ids_to_disk(example, filename): with open(filename, 'r+b') as f: mm = mmap.mmap(f.fileno(), 0) start_idx = example['len_offset'] - len(example['input_ids']) array_len = len(example['input_ids']) arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, offset=np.dtype(dtype).itemsize * start_idx) arr[:] = example['input_ids'] mm.flush() for name, ds in tokenized_datasets.items(): tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) array_len = tokenized_datasets[name][-1]['len_offset'] filename = cache_dir / f'{name}.bin' # Need to create the file with this specific size first # https://ostechnix.com/create-files-certain-size-linux/ subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize), str(filename)], check=True) tokenized_datasets[name].map( write_ids_to_disk, fn_kwargs={'filename': filename}, batched=False, num_proc=max(self.num_workers, 1), desc="Concatenating examples", ) concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,)) if cache_dir is not None: self._save_to_cache(concat_ids, tokenizer, cache_dir) if not self.use_shmem: for name in concat_ids: Path(cache_dir / f'{name}.bin').unlink() return concat_ids, tokenizer def _save_to_cache(self, concat_ids, tokenizer, cache_dir): cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f'Saving to cache at {str(cache_dir)}') for k, v in concat_ids.items(): np.save(cache_dir / f'{k}.npy', v) with open(cache_dir / 'tokenizer.pkl', 'wb') as f: pickle.dump(tokenizer, f) def _load_from_cache(self, cache_dir): assert cache_dir.is_dir() logger.info(f'Load from cache at {str(cache_dir)}') concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r') for split in ['train', 'validation', 'test']} with open(cache_dir / 'tokenizer.pkl', 'rb') as f: tokenizer = pickle.load(f) return concat_ids, tokenizer @property def _cache_dir_name(self): return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}' def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: """ The train dataloader """ if self.shuffle and self.fault_tolerant: shuffle = False sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp else RandomFaultTolerantSampler(self.dataset_train)) # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now # We assume that it's being resumed with the same number of GPUs if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None: sampler.load_state_dict({ 'epoch': self.fast_forward_epochs, 'counter': self.fast_forward_batches * self.batch_size }) else: shuffle = self.shuffle sampler = None return self._data_loader(self.dataset_train, batch_size=self.batch_size, shuffle=shuffle, sampler=sampler) def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The val dataloader """ return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The test dataloader """ return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, sampler=None) -> DataLoader: return DataLoader( dataset, batch_size=batch_size, num_workers=1, # Data is already in memory, we don't need many workers shuffle=shuffle, sampler=sampler, drop_last=self.drop_last, pin_memory=self.pin_memory, # persistent_workers=True ) def load_state_dict(self, checkpoint): if self.fault_tolerant: self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed'] # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration # behind, so we're using the optimizer's progress. This is set correctly in seq.py. self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] # At this point the train loader hasn't been constructed yet