from torch.utils.data.sampler import RandomSampler, Sampler import numpy as np class FixedLenRandomSampler(RandomSampler): """ Code from mnpinto - Miguel https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10 """ def __init__(self, data_source, bs, epoch_size, *args, **kwargs): super().__init__(data_source) self.epoch_size = epoch_size self.bs = bs self.not_sampled = np.array([True]*len(data_source)) self.size_to_sample = self.epoch_size * self.bs @property def _reset_state(self): self.not_sampled[:] = True def __iter__(self): ns = sum(self.not_sampled) idx_last = [] if ns >= self.size_to_sample: idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample, replace=False).tolist() if ns == self.size_to_sample: self._reset_state else: idx_last = np.where(self.not_sampled)[0].tolist() self._reset_state idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample-len(idx_last), replace=False).tolist() self.not_sampled[idx] = False idx = [*idx_last, *idx] # print(ns, len(idx), len(idx_last)) # debug out = [] i_idx = 0 for i in range(self.epoch_size): batch = [] for j in range(self.bs): batch.append(idx[i_idx]) i_idx += 1 out.append(batch) return iter(out) def __len__(self): return self.epoch_size