Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from fairseq.data import data_utils | |
from . import BaseWrapperDataset | |
class PadDataset(BaseWrapperDataset): | |
def __init__(self, dataset, pad_idx, left_pad, pad_length=None): | |
super().__init__(dataset) | |
self.pad_idx = pad_idx | |
self.left_pad = left_pad | |
self.pad_length = pad_length | |
def collater(self, samples): | |
return data_utils.collate_tokens( | |
samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length | |
) | |
class LeftPadDataset(PadDataset): | |
def __init__(self, dataset, pad_idx): | |
super().__init__(dataset, pad_idx, left_pad=True) | |
class RightPadDataset(PadDataset): | |
def __init__(self, dataset, pad_idx): | |
super().__init__(dataset, pad_idx, left_pad=False) | |