Cédric Colas
initial commit
e775f6d
raw
history blame
1.58 kB
from torch.utils.data.sampler import RandomSampler, Sampler
import numpy as np
class FixedLenRandomSampler(RandomSampler):
"""
Code from mnpinto - Miguel
https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10
"""
def __init__(self, data_source, bs, epoch_size, *args, **kwargs):
super().__init__(data_source)
self.epoch_size = epoch_size
self.bs = bs
self.not_sampled = np.array([True]*len(data_source))
self.size_to_sample = self.epoch_size * self.bs
@property
def _reset_state(self):
self.not_sampled[:] = True
def __iter__(self):
ns = sum(self.not_sampled)
idx_last = []
if ns >= self.size_to_sample:
idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample, replace=False).tolist()
if ns == self.size_to_sample:
self._reset_state
else:
idx_last = np.where(self.not_sampled)[0].tolist()
self._reset_state
idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample-len(idx_last), replace=False).tolist()
self.not_sampled[idx] = False
idx = [*idx_last, *idx]
# print(ns, len(idx), len(idx_last)) # debug
out = []
i_idx = 0
for i in range(self.epoch_size):
batch = []
for j in range(self.bs):
batch.append(idx[i_idx])
i_idx += 1
out.append(batch)
return iter(out)
def __len__(self):
return self.epoch_size