File size: 3,678 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("dummy_mt")
class DummyMTTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("--dict-size", default=49996, type=int)
parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument("--src-len", default=30, type=int)
parser.add_argument("--tgt-len", default=30, type=int)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(dictionary)))
args.max_source_positions = args.src_len + dictionary.pad() + 2
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
item_size = max(self.args.src_len, self.args.tgt_len)
if self.args.batch_size is not None:
bsz = self.args.batch_size
else:
bsz = max(1, self.args.max_tokens // item_size)
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.args.src_len, dtype=torch.long
),
"prev_output_tokens": tgt.clone(),
},
"target": tgt,
"nsentences": bsz,
"ntokens": bsz * self.args.tgt_len,
},
num_items=self.args.dataset_size,
item_size=item_size,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False
|