Spaces:
Runtime error
Runtime error
from torch.utils.data.sampler import RandomSampler, Sampler | |
import numpy as np | |
class FixedLenRandomSampler(RandomSampler): | |
""" | |
Code from mnpinto - Miguel | |
https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10 | |
""" | |
def __init__(self, data_source, bs, epoch_size, *args, **kwargs): | |
super().__init__(data_source) | |
self.epoch_size = epoch_size | |
self.bs = bs | |
self.not_sampled = np.array([True]*len(data_source)) | |
self.size_to_sample = self.epoch_size * self.bs | |
def _reset_state(self): | |
self.not_sampled[:] = True | |
def __iter__(self): | |
ns = sum(self.not_sampled) | |
idx_last = [] | |
if ns >= self.size_to_sample: | |
idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample, replace=False).tolist() | |
if ns == self.size_to_sample: | |
self._reset_state | |
else: | |
idx_last = np.where(self.not_sampled)[0].tolist() | |
self._reset_state | |
idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample-len(idx_last), replace=False).tolist() | |
self.not_sampled[idx] = False | |
idx = [*idx_last, *idx] | |
# print(ns, len(idx), len(idx_last)) # debug | |
out = [] | |
i_idx = 0 | |
for i in range(self.epoch_size): | |
batch = [] | |
for j in range(self.bs): | |
batch.append(idx[i_idx]) | |
i_idx += 1 | |
out.append(batch) | |
return iter(out) | |
def __len__(self): | |
return self.epoch_size | |