|
"""Enable curriculum learning by resuming with a different dataset. |
|
|
|
This callback is currently experimental. The API may change without warning in |
|
the future. |
|
""" |
|
import logging |
|
from typing import Any, Dict |
|
from streaming import StreamingDataset |
|
from torch.utils.data import DataLoader |
|
from .interfaces import CallbackWithConfig |
|
from .warnings import experimental_class |
|
log = logging.getLogger(__name__) |
|
|
|
@experimental_class('CurriculumLearning callback') |
|
class CurriculumLearning(CallbackWithConfig): |
|
"""Starts an epoch with a different dataset when resuming from a checkpoint. |
|
|
|
Args: |
|
dataset_index (int): The index of the dataset currently being used. |
|
current_dataset_config (Dict): The configuration of the dataset currently |
|
being used. |
|
""" |
|
|
|
def __init__(self, dataset_index: int, train_config: Dict): |
|
self.dataset_index = dataset_index |
|
self.saved_dataset_index = 0 |
|
self.all_dataset_configs = [] |
|
self.current_dataset_state = {} |
|
self.current_dataset_config = train_config['dataloader'] |
|
|
|
def before_load(self, state: State, logger: Logger): |
|
del logger |
|
train_loader = state.train_dataloader |
|
if not isinstance(train_loader, DataLoader): |
|
raise ValueError(f'CurriculumLearning callback can only be used with a train ', f'dataloader of type DataLoader, but got {type(train_loader)}.') |
|
dataset = train_loader.dataset |
|
if not isinstance(dataset, StreamingDataset): |
|
raise ValueError(f'CurriculumLearning callback only supports StreamingDataset ', f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}') |
|
assert isinstance(dataset, StreamingDataset) |
|
self.current_dataset_state = dataset.state_dict(num_samples=0, from_beginning=False) |
|
|
|
def after_load(self, state: State, logger: Logger): |
|
del logger |
|
train_loader = state._train_dataloader |
|
assert isinstance(train_loader, DataLoader), 'CurriculumLearning callback requires a DataLoader.' |
|
dataset = train_loader.dataset |
|
assert isinstance(dataset, StreamingDataset), 'CurriculumLearning callback requires a StreamingDataset.' |
|
if self.saved_dataset_index < self.dataset_index: |
|
if self.current_dataset_state['epoch'] < 0: |
|
self.current_dataset_state['epoch'] = 0 |
|
dataset.load_state_dict(self.current_dataset_state) |
|
state.timestamp = state.timestamp.to_next_epoch() |
|
self.all_dataset_configs.append(self.current_dataset_config) |
|
elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0: |
|
self.all_dataset_configs.append(self.current_dataset_config) |
|
|
|
def state_dict(self): |
|
return {'dataset_index': self.dataset_index, 'all_dataset_configs': self.all_dataset_configs} |
|
|
|
def load_state_dict(self, state: Dict[str, Any]): |
|
self.saved_dataset_index = state.get('dataset_index', 0) |
|
self.all_dataset_configs = state.get('all_dataset_configs', []) |