|
import numpy as np |
|
from fairseq.data import FairseqDataset |
|
|
|
|
|
class DummyDataset(FairseqDataset): |
|
def __init__(self, batch, num_items, item_size): |
|
super().__init__() |
|
self.batch = batch |
|
self.num_items = num_items |
|
self.item_size = item_size |
|
|
|
def __getitem__(self, index): |
|
return index |
|
|
|
def __len__(self): |
|
return self.num_items |
|
|
|
def collater(self, samples): |
|
return self.batch |
|
|
|
@property |
|
def sizes(self): |
|
return np.array([self.item_size] * self.num_items) |
|
|
|
def num_tokens(self, index): |
|
return self.item_size |
|
|
|
def size(self, index): |
|
return self.item_size |
|
|
|
def ordered_indices(self): |
|
return np.arange(self.num_items) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return False |
|
|