import numpy as np from fairseq.data import FairseqDataset class DummyDataset(FairseqDataset): def __init__(self, batch, num_items, item_size): super().__init__() self.batch = batch self.num_items = num_items self.item_size = item_size def __getitem__(self, index): return index def __len__(self): return self.num_items def collater(self, samples): return self.batch @property def sizes(self): return np.array([self.item_size] * self.num_items) def num_tokens(self, index): return self.item_size def size(self, index): return self.item_size def ordered_indices(self): return np.arange(self.num_items) @property def supports_prefetch(self): return False