Spaces:
Running
Running
File size: 17,926 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import logging
import time
import torch
from fairseq.data import (
FairseqDataset,
LanguagePairDataset,
ListDataset,
data_utils,
iterators,
)
from fairseq.data.multilingual.multilingual_data_manager import (
MultilingualDatasetManager,
)
from fairseq.data.multilingual.sampling_method import SamplingMethod
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.utils import FileContentsAction
###
def get_time_gap(s, e):
return (
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
).__str__()
###
logger = logging.getLogger(__name__)
@register_task("translation_multi_simple_epoch")
class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
"""
Translate from one (source) language to another (target) language.
Args:
langs (List[str]): a list of languages that are being supported
dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries
training (bool): whether the task should be configured for training or not
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='inference source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='inference target language')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr',
action=FileContentsAction)
parser.add_argument('--keep-inference-langtok', action='store_true',
help='keep language tokens in inference output (e.g. for analysis or debugging)')
SamplingMethod.add_arguments(parser)
MultilingualDatasetManager.add_args(parser)
# fmt: on
def __init__(self, args, langs, dicts, training):
super().__init__(args)
self.langs = langs
self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = self.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = self.lang_pairs
self.source_langs = [d.split("-")[0] for d in self.lang_pairs]
self.target_langs = [d.split("-")[1] for d in self.lang_pairs]
self.check_dicts(self.dicts, self.source_langs, self.target_langs)
self.sampling_method = SamplingMethod.build_sampler(args, self)
self.data_manager = MultilingualDatasetManager.setup_data_manager(
args, self.lang_pairs, langs, dicts, self.sampling_method
)
def check_dicts(self, dicts, source_langs, target_langs):
if self.args.source_dict is not None or self.args.target_dict is not None:
# no need to check whether the source side and target side are sharing dictionaries
return
src_dict = dicts[source_langs[0]]
tgt_dict = dicts[target_langs[0]]
for src_lang in source_langs:
assert (
src_dict == dicts[src_lang]
), "Diffrent dictionary are specified for different source languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages"
for tgt_lang in target_langs:
assert (
tgt_dict == dicts[tgt_lang]
), "Diffrent dictionary are specified for different target languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages"
@classmethod
def setup_task(cls, args, **kwargs):
langs, dicts, training = MultilingualDatasetManager.prepare(
cls.load_dictionary, args, **kwargs
)
return cls(args, langs, dicts, training)
def has_sharded_data(self, split):
return self.data_manager.has_sharded_data(split)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if split in self.datasets:
dataset = self.datasets[split]
if self.has_sharded_data(split):
if self.args.virtual_epoch_size is not None:
if dataset.load_next_shard:
shard_epoch = dataset.shard_epoch
else:
# no need to load next shard so skip loading
# also this avoid always loading from beginning of the data
return
else:
shard_epoch = epoch
else:
# estimate the shard epoch from virtual data size and virtual epoch size
shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch)
logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
if split in self.datasets:
del self.datasets[split]
logger.info("old dataset deleted manually")
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
self.datasets[split] = self.data_manager.load_dataset(
split,
self.training,
epoch=epoch,
combine=combine,
shard_epoch=shard_epoch,
**kwargs,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the multilingual_translation task is not supported"
)
src_data = ListDataset(src_tokens, src_lengths)
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
if self.args.lang_tok_replacing_bos_eos:
dataset = self.data_manager.alter_dataset_langtok(
dataset,
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_eos=self.target_dictionary.eos(),
tgt_lang=self.args.target_lang,
src_langtok_spec=src_langtok_spec,
tgt_langtok_spec=tgt_langtok_spec,
)
else:
dataset.src = self.data_manager.src_dataset_tranform_func(
self.args.source_lang,
self.args.target_lang,
dataset=dataset.src,
spec=src_langtok_spec,
)
return dataset
def build_generator(
self,
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=None,
):
if not getattr(args, "keep_inference_langtok", False):
_, tgt_langtok_spec = self.args.langtoks["main"]
if tgt_langtok_spec:
tgt_lang_tok = self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok}
return super().build_generator(
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_model(self, args, from_checkpoint=False):
return super().build_model(args, from_checkpoint)
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
_, tgt_langtok_spec = self.args.langtoks["main"]
if not self.args.lang_tok_replacing_bos_eos:
if prefix_tokens is None and tgt_langtok_spec:
tgt_lang_tok = self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
src_tokens = sample["net_input"]["src_tokens"]
bsz = src_tokens.size(0)
prefix_tokens = (
torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens)
)
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
)
else:
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=self.data_manager.get_decoder_langtok(
self.args.target_lang, tgt_langtok_spec
)
if tgt_langtok_spec
else self.target_dictionary.eos(),
)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def source_dictionary(self):
return self.data_manager.get_source_dictionary(self.source_langs[0])
@property
def target_dictionary(self):
return self.data_manager.get_target_dictionary(self.target_langs[0])
def create_batch_sampler_func(
self,
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=1,
seed=1,
):
def construct_batch_sampler(dataset, epoch):
splits = [
s for s, _ in self.datasets.items() if self.datasets[s] == dataset
]
split = splits[0] if len(splits) > 0 else None
# NEW implementation
if epoch is not None:
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
start_time = time.time()
logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}")
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
logger.info(
f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# filter examples that are too large
if max_positions is not None:
my_time = time.time()
indices = self.filter_indices_by_size(
indices, dataset, max_positions, ignore_invalid_inputs
)
logger.info(
f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
# create mini-batches with given size constraints
my_time = time.time()
batch_sampler = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
logger.info(
f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}"
)
logger.info(
f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}"
)
logger.info(f"mem usage: {data_utils.get_mem_usage()}")
return batch_sampler
return construct_batch_sampler
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
grouped_shuffling (bool, optional): group batches with each groups
containing num_shards batches and shuffle groups. Reduces difference
between sequence lengths among workers for batches sorted by length.
update_epoch_batch_itr (bool optional): if true then donot use the cached
batch iterator for the epoch
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
# initialize the dataset with the correct starting epoch
assert isinstance(dataset, FairseqDataset)
if dataset in self.dataset_to_epoch_iter:
return self.dataset_to_epoch_iter[dataset]
if self.args.sampling_method == "RoundRobin":
batch_iter = super().get_batch_iterator(
dataset,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
data_buffer_size=data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
skip_remainder_batch=skip_remainder_batch,
update_epoch_batch_itr=update_epoch_batch_itr,
)
self.dataset_to_epoch_iter[dataset] = batch_iter
return batch_iter
construct_batch_sampler = self.create_batch_sampler_func(
max_positions,
ignore_invalid_inputs,
max_tokens,
max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
)
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=construct_batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
return epoch_iter
|