mpt-7b-8k-instruct / curriculum_learning_callback.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
2cc518e verified
raw
history blame
3.07 kB
"""Enable curriculum learning by resuming with a different dataset.
This callback is currently experimental. The API may change without warning in
the future.
"""
import logging
from typing import Any, Dict
from streaming import StreamingDataset
from torch.utils.data import DataLoader
from .interfaces import CallbackWithConfig
from .warnings import experimental_class
log = logging.getLogger(__name__)
@experimental_class('CurriculumLearning callback')
class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Args:
dataset_index (int): The index of the dataset currently being used.
current_dataset_config (Dict): The configuration of the dataset currently
being used.
"""
def __init__(self, dataset_index: int, train_config: Dict):
self.dataset_index = dataset_index
self.saved_dataset_index = 0
self.all_dataset_configs = []
self.current_dataset_state = {}
self.current_dataset_config = train_config['dataloader']
def before_load(self, state: State, logger: Logger):
del logger
train_loader = state.train_dataloader
if not isinstance(train_loader, DataLoader):
raise ValueError(f'CurriculumLearning callback can only be used with a train ', f'dataloader of type DataLoader, but got {type(train_loader)}.')
dataset = train_loader.dataset
if not isinstance(dataset, StreamingDataset):
raise ValueError(f'CurriculumLearning callback only supports StreamingDataset ', f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}')
assert isinstance(dataset, StreamingDataset)
self.current_dataset_state = dataset.state_dict(num_samples=0, from_beginning=False)
def after_load(self, state: State, logger: Logger):
del logger
train_loader = state._train_dataloader
assert isinstance(train_loader, DataLoader), 'CurriculumLearning callback requires a DataLoader.'
dataset = train_loader.dataset
assert isinstance(dataset, StreamingDataset), 'CurriculumLearning callback requires a StreamingDataset.'
if self.saved_dataset_index < self.dataset_index:
if self.current_dataset_state['epoch'] < 0:
self.current_dataset_state['epoch'] = 0
dataset.load_state_dict(self.current_dataset_state)
state.timestamp = state.timestamp.to_next_epoch()
self.all_dataset_configs.append(self.current_dataset_config)
elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0:
self.all_dataset_configs.append(self.current_dataset_config)
def state_dict(self):
return {'dataset_index': self.dataset_index, 'all_dataset_configs': self.all_dataset_configs}
def load_state_dict(self, state: Dict[str, Any]):
self.saved_dataset_index = state.get('dataset_index', 0)
self.all_dataset_configs = state.get('all_dataset_configs', [])