Spaces:
Runtime error
Runtime error
from __future__ import division | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.data.sampler import Sampler | |
class RandomSampler(Sampler): | |
def __init__(self, data_source, checkpoint): | |
self.data_source = data_source | |
if checkpoint is not None and checkpoint['dataset_perm'] is not None: | |
self.dataset_perm = checkpoint['dataset_perm'] | |
self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] | |
else: | |
self.dataset_perm = torch.randperm(len(self.data_source)).tolist() | |
self.perm = torch.randperm(len(self.data_source)).tolist() | |
def __iter__(self): | |
return iter(self.perm) | |
def __len__(self): | |
return len(self.perm) | |
class SequentialSampler(Sampler): | |
def __init__(self, data_source, checkpoint): | |
self.data_source = data_source | |
if checkpoint is not None and checkpoint['dataset_perm'] is not None: | |
self.dataset_perm = checkpoint['dataset_perm'] | |
self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] | |
else: | |
self.dataset_perm = list(range(len(self.data_source))) | |
self.perm = self.dataset_perm | |
def __iter__(self): | |
return iter(self.perm) | |
def __len__(self): | |
return len(self.perm) | |
class CheckpointDataLoader(DataLoader): | |
""" | |
Extends torch.utils.data.DataLoader to handle resuming training from an arbitrary point within an epoch. | |
""" | |
def __init__( | |
self, | |
dataset, | |
checkpoint=None, | |
batch_size=1, | |
shuffle=False, | |
num_workers=0, | |
pin_memory=False, | |
drop_last=True, | |
timeout=0, | |
worker_init_fn=None | |
): | |
if shuffle: | |
sampler = RandomSampler(dataset, checkpoint) | |
else: | |
sampler = SequentialSampler(dataset, checkpoint) | |
if checkpoint is not None: | |
self.checkpoint_batch_idx = checkpoint['batch_idx'] | |
else: | |
self.checkpoint_batch_idx = 0 | |
super(CheckpointDataLoader, self).__init__( | |
dataset, | |
sampler=sampler, | |
shuffle=False, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=drop_last, | |
pin_memory=pin_memory, | |
timeout=timeout, | |
worker_init_fn=None | |
) | |