OFA-OCR / fairseq /examples /laser /laser_src /multitask_data_utils.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
4.33 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import numpy as np
from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators
class MultiItr(object):
def __init__(self, itr):
self.itr = itr
self._counts = [0 for x in itr]
def __len__(self):
return sum(len(itr) for itr in self.itr)
def __iter__(self):
return self
def __next__(self):
ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)]
idx = ratios.index(min(ratios))
self._counts[idx] += 1
return next(self.itr[idx])
class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating):
"""A wrapper around multiple epoch batch iterators."""
def __init__(
self,
dataset,
batch_sampler,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
):
assert isinstance(dataset, OrderedDict)
assert len(dataset)
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
self.iterators = []
self.epoch = epoch
for key, dt in dataset.items():
epoch_iter = iterators.EpochBatchIterator(
dataset=dt,
collate_fn=dt.collater,
batch_sampler=batch_sampler[key],
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=0,
epoch=epoch,
)
self.iterators.append(epoch_iter)
def __len__(self):
return sum(len(itr) for itr in self.iterators)
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
# `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s.
return MultiItr(
[
itr.next_epoch_itr(
shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus
)
for itr in self.iterators
]
)
def end_of_epoch(self):
return all(itr.end_of_epoch() for itr in self.iterators)
@property
def next_epoch_idx(self):
"""Return the epoch index after *next_epoch_itr* is called."""
epochs = [itr.next_epoch_idx for itr in self.iterators]
self.epoch = epochs[0]
assert all(epoch == self.epoch for epoch in epochs)
return self.epoch
@property
def iterations_in_epoch(self):
return sum(itr.iterations_in_epoch for itr in self.iterators)
def state_dict(self):
return {
"iterators": [it.state_dict() for it in self.iterators],
"epoch": self.epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
for it, d in zip(self.iterators, state_dict["iterators"]):
it.load_state_dict(d)
class MultitaskDatasetWrapper(BaseWrapperDataset):
"""A wrapper for a multitask dataset."""
def __init__(self, dataset, target_language_id, sample=1.0, name=""):
super().__init__(dataset)
self.target_language_id = target_language_id
self.sample = sample
self.name = name
def collater(self, *args, **kwargs):
ans = self.dataset.collater(*args, **kwargs)
if "net_input" in ans:
ans["net_input"]["target_language_id"] = self.target_language_id
ans["net_input"]["dataset_name"] = self.name
return ans
def num_tokens(self, *args, **kwargs):
return self.dataset.num_tokens(*args, **kwargs)
def ordered_indices(self, *args, **kwargs):
indices = self.dataset.ordered_indices(*args, **kwargs)
# Hacky solution for sampling
size = int(self.sample * indices.shape[0])
return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size]))
def size(self, index: int):
return self.dataset.size(index)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)