Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from . import BaseWrapperDataset | |
class ReplaceDataset(BaseWrapperDataset): | |
"""Replaces tokens found in the dataset by a specified replacement token | |
Args: | |
dataset (~torch.utils.data.Dataset): dataset to replace tokens in | |
replace_map(Dictionary[int,int]): map of token to replace -> replacement token | |
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be | |
as many as the number of objects returned by the underlying dataset __getitem__ method. | |
""" | |
def __init__(self, dataset, replace_map, offsets): | |
super().__init__(dataset) | |
assert len(replace_map) > 0 | |
self.replace_map = replace_map | |
self.offsets = offsets | |
def __getitem__(self, index): | |
item = self.dataset[index] | |
is_tuple = isinstance(item, tuple) | |
srcs = item if is_tuple else [item] | |
for offset, src in zip(self.offsets, srcs): | |
for k, v in self.replace_map.items(): | |
src_off = src[offset:] if offset >= 0 else src[:offset] | |
src_off.masked_fill_(src_off == k, v) | |
item = srcs if is_tuple else srcs[0] | |
return item | |