Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import datetime | |
import logging | |
import time | |
import torch | |
from fairseq.data import ( | |
FairseqDataset, | |
LanguagePairDataset, | |
ListDataset, | |
data_utils, | |
iterators, | |
) | |
from fairseq.data.multilingual.multilingual_data_manager import ( | |
MultilingualDatasetManager, | |
) | |
from fairseq.data.multilingual.sampling_method import SamplingMethod | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
from fairseq.utils import FileContentsAction | |
### | |
def get_time_gap(s, e): | |
return ( | |
datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) | |
).__str__() | |
### | |
logger = logging.getLogger(__name__) | |
class TranslationMultiSimpleEpochTask(LegacyFairseqTask): | |
""" | |
Translate from one (source) language to another (target) language. | |
Args: | |
langs (List[str]): a list of languages that are being supported | |
dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries | |
training (bool): whether the task should be configured for training or not | |
.. note:: | |
The translation task is compatible with :mod:`fairseq-train`, | |
:mod:`fairseq-generate` and :mod:`fairseq-interactive`. | |
The translation task provides the following additional command-line | |
arguments: | |
.. argparse:: | |
:ref: fairseq.tasks.translation_parser | |
:prog: | |
""" | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
# fmt: off | |
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', | |
help='inference source language') | |
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', | |
help='inference target language') | |
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', | |
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr', | |
action=FileContentsAction) | |
parser.add_argument('--keep-inference-langtok', action='store_true', | |
help='keep language tokens in inference output (e.g. for analysis or debugging)') | |
SamplingMethod.add_arguments(parser) | |
MultilingualDatasetManager.add_args(parser) | |
# fmt: on | |
def __init__(self, args, langs, dicts, training): | |
super().__init__(args) | |
self.langs = langs | |
self.dicts = dicts | |
self.training = training | |
if training: | |
self.lang_pairs = args.lang_pairs | |
else: | |
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] | |
# eval_lang_pairs for multilingual translation is usually all of the | |
# lang_pairs. However for other multitask settings or when we want to | |
# optimize for certain languages we want to use a different subset. Thus | |
# the eval_lang_pairs class variable is provided for classes that extend | |
# this class. | |
self.eval_lang_pairs = self.lang_pairs | |
# model_lang_pairs will be used to build encoder-decoder model pairs in | |
# models.build_model(). This allows multitask type of sub-class can | |
# build models other than the input lang_pairs | |
self.model_lang_pairs = self.lang_pairs | |
self.source_langs = [d.split("-")[0] for d in self.lang_pairs] | |
self.target_langs = [d.split("-")[1] for d in self.lang_pairs] | |
self.check_dicts(self.dicts, self.source_langs, self.target_langs) | |
self.sampling_method = SamplingMethod.build_sampler(args, self) | |
self.data_manager = MultilingualDatasetManager.setup_data_manager( | |
args, self.lang_pairs, langs, dicts, self.sampling_method | |
) | |
def check_dicts(self, dicts, source_langs, target_langs): | |
if self.args.source_dict is not None or self.args.target_dict is not None: | |
# no need to check whether the source side and target side are sharing dictionaries | |
return | |
src_dict = dicts[source_langs[0]] | |
tgt_dict = dicts[target_langs[0]] | |
for src_lang in source_langs: | |
assert ( | |
src_dict == dicts[src_lang] | |
), "Diffrent dictionary are specified for different source languages; " | |
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages" | |
for tgt_lang in target_langs: | |
assert ( | |
tgt_dict == dicts[tgt_lang] | |
), "Diffrent dictionary are specified for different target languages; " | |
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages" | |
def setup_task(cls, args, **kwargs): | |
langs, dicts, training = MultilingualDatasetManager.prepare( | |
cls.load_dictionary, args, **kwargs | |
) | |
return cls(args, langs, dicts, training) | |
def has_sharded_data(self, split): | |
return self.data_manager.has_sharded_data(split) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
if split in self.datasets: | |
dataset = self.datasets[split] | |
if self.has_sharded_data(split): | |
if self.args.virtual_epoch_size is not None: | |
if dataset.load_next_shard: | |
shard_epoch = dataset.shard_epoch | |
else: | |
# no need to load next shard so skip loading | |
# also this avoid always loading from beginning of the data | |
return | |
else: | |
shard_epoch = epoch | |
else: | |
# estimate the shard epoch from virtual data size and virtual epoch size | |
shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) | |
logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}") | |
logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
if split in self.datasets: | |
del self.datasets[split] | |
logger.info("old dataset deleted manually") | |
logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
self.datasets[split] = self.data_manager.load_dataset( | |
split, | |
self.training, | |
epoch=epoch, | |
combine=combine, | |
shard_epoch=shard_epoch, | |
**kwargs, | |
) | |
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
if constraints is not None: | |
raise NotImplementedError( | |
"Constrained decoding with the multilingual_translation task is not supported" | |
) | |
src_data = ListDataset(src_tokens, src_lengths) | |
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) | |
src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] | |
if self.args.lang_tok_replacing_bos_eos: | |
dataset = self.data_manager.alter_dataset_langtok( | |
dataset, | |
src_eos=self.source_dictionary.eos(), | |
src_lang=self.args.source_lang, | |
tgt_eos=self.target_dictionary.eos(), | |
tgt_lang=self.args.target_lang, | |
src_langtok_spec=src_langtok_spec, | |
tgt_langtok_spec=tgt_langtok_spec, | |
) | |
else: | |
dataset.src = self.data_manager.src_dataset_tranform_func( | |
self.args.source_lang, | |
self.args.target_lang, | |
dataset=dataset.src, | |
spec=src_langtok_spec, | |
) | |
return dataset | |
def build_generator( | |
self, | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=None, | |
): | |
if not getattr(args, "keep_inference_langtok", False): | |
_, tgt_langtok_spec = self.args.langtoks["main"] | |
if tgt_langtok_spec: | |
tgt_lang_tok = self.data_manager.get_decoder_langtok( | |
self.args.target_lang, tgt_langtok_spec | |
) | |
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} | |
extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok} | |
return super().build_generator( | |
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
) | |
def build_model(self, args, from_checkpoint=False): | |
return super().build_model(args, from_checkpoint) | |
def valid_step(self, sample, model, criterion): | |
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
return loss, sample_size, logging_output | |
def inference_step( | |
self, generator, models, sample, prefix_tokens=None, constraints=None | |
): | |
with torch.no_grad(): | |
_, tgt_langtok_spec = self.args.langtoks["main"] | |
if not self.args.lang_tok_replacing_bos_eos: | |
if prefix_tokens is None and tgt_langtok_spec: | |
tgt_lang_tok = self.data_manager.get_decoder_langtok( | |
self.args.target_lang, tgt_langtok_spec | |
) | |
src_tokens = sample["net_input"]["src_tokens"] | |
bsz = src_tokens.size(0) | |
prefix_tokens = ( | |
torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens) | |
) | |
return generator.generate( | |
models, | |
sample, | |
prefix_tokens=prefix_tokens, | |
constraints=constraints, | |
) | |
else: | |
return generator.generate( | |
models, | |
sample, | |
prefix_tokens=prefix_tokens, | |
bos_token=self.data_manager.get_decoder_langtok( | |
self.args.target_lang, tgt_langtok_spec | |
) | |
if tgt_langtok_spec | |
else self.target_dictionary.eos(), | |
) | |
def reduce_metrics(self, logging_outputs, criterion): | |
super().reduce_metrics(logging_outputs, criterion) | |
def max_positions(self): | |
"""Return the max sentence length allowed by the task.""" | |
return (self.args.max_source_positions, self.args.max_target_positions) | |
def source_dictionary(self): | |
return self.data_manager.get_source_dictionary(self.source_langs[0]) | |
def target_dictionary(self): | |
return self.data_manager.get_target_dictionary(self.target_langs[0]) | |
def create_batch_sampler_func( | |
self, | |
max_positions, | |
ignore_invalid_inputs, | |
max_tokens, | |
max_sentences, | |
required_batch_size_multiple=1, | |
seed=1, | |
): | |
def construct_batch_sampler(dataset, epoch): | |
splits = [ | |
s for s, _ in self.datasets.items() if self.datasets[s] == dataset | |
] | |
split = splits[0] if len(splits) > 0 else None | |
# NEW implementation | |
if epoch is not None: | |
# initialize the dataset with the correct starting epoch | |
dataset.set_epoch(epoch) | |
# get indices ordered by example size | |
start_time = time.time() | |
logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}") | |
with data_utils.numpy_seed(seed): | |
indices = dataset.ordered_indices() | |
logger.info( | |
f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}" | |
) | |
logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
# filter examples that are too large | |
if max_positions is not None: | |
my_time = time.time() | |
indices = self.filter_indices_by_size( | |
indices, dataset, max_positions, ignore_invalid_inputs | |
) | |
logger.info( | |
f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}" | |
) | |
logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
# create mini-batches with given size constraints | |
my_time = time.time() | |
batch_sampler = dataset.batch_by_size( | |
indices, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
) | |
logger.info( | |
f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}" | |
) | |
logger.info( | |
f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}" | |
) | |
logger.info(f"mem usage: {data_utils.get_mem_usage()}") | |
return batch_sampler | |
return construct_batch_sampler | |
# we need to override get_batch_iterator because we want to reset the epoch iterator each time | |
def get_batch_iterator( | |
self, | |
dataset, | |
max_tokens=None, | |
max_sentences=None, | |
max_positions=None, | |
ignore_invalid_inputs=False, | |
required_batch_size_multiple=1, | |
seed=1, | |
num_shards=1, | |
shard_id=0, | |
num_workers=0, | |
epoch=1, | |
data_buffer_size=0, | |
disable_iterator_cache=False, | |
skip_remainder_batch=False, | |
grouped_shuffling=False, | |
update_epoch_batch_itr=False, | |
): | |
""" | |
Get an iterator that yields batches of data from the given dataset. | |
Args: | |
dataset (~fairseq.data.FairseqDataset): dataset to batch | |
max_tokens (int, optional): max number of tokens in each batch | |
(default: None). | |
max_sentences (int, optional): max number of sentences in each | |
batch (default: None). | |
max_positions (optional): max sentence length supported by the | |
model (default: None). | |
ignore_invalid_inputs (bool, optional): don't raise Exception for | |
sentences that are too long (default: False). | |
required_batch_size_multiple (int, optional): require batch size to | |
be a multiple of N (default: 1). | |
seed (int, optional): seed for random number generator for | |
reproducibility (default: 1). | |
num_shards (int, optional): shard the data iterator into N | |
shards (default: 1). | |
shard_id (int, optional): which shard of the data iterator to | |
return (default: 0). | |
num_workers (int, optional): how many subprocesses to use for data | |
loading. 0 means the data will be loaded in the main process | |
(default: 0). | |
epoch (int, optional): the epoch to start the iterator from | |
(default: 0). | |
data_buffer_size (int, optional): number of batches to | |
preload (default: 0). | |
disable_iterator_cache (bool, optional): don't cache the | |
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) | |
(default: False). | |
grouped_shuffling (bool, optional): group batches with each groups | |
containing num_shards batches and shuffle groups. Reduces difference | |
between sequence lengths among workers for batches sorted by length. | |
update_epoch_batch_itr (bool optional): if true then donot use the cached | |
batch iterator for the epoch | |
Returns: | |
~fairseq.iterators.EpochBatchIterator: a batched iterator over the | |
given dataset split | |
""" | |
# initialize the dataset with the correct starting epoch | |
assert isinstance(dataset, FairseqDataset) | |
if dataset in self.dataset_to_epoch_iter: | |
return self.dataset_to_epoch_iter[dataset] | |
if self.args.sampling_method == "RoundRobin": | |
batch_iter = super().get_batch_iterator( | |
dataset, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
max_positions=max_positions, | |
ignore_invalid_inputs=ignore_invalid_inputs, | |
required_batch_size_multiple=required_batch_size_multiple, | |
seed=seed, | |
num_shards=num_shards, | |
shard_id=shard_id, | |
num_workers=num_workers, | |
epoch=epoch, | |
data_buffer_size=data_buffer_size, | |
disable_iterator_cache=disable_iterator_cache, | |
skip_remainder_batch=skip_remainder_batch, | |
update_epoch_batch_itr=update_epoch_batch_itr, | |
) | |
self.dataset_to_epoch_iter[dataset] = batch_iter | |
return batch_iter | |
construct_batch_sampler = self.create_batch_sampler_func( | |
max_positions, | |
ignore_invalid_inputs, | |
max_tokens, | |
max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
seed=seed, | |
) | |
epoch_iter = iterators.EpochBatchIterator( | |
dataset=dataset, | |
collate_fn=dataset.collater, | |
batch_sampler=construct_batch_sampler, | |
seed=seed, | |
num_shards=num_shards, | |
shard_id=shard_id, | |
num_workers=num_workers, | |
epoch=epoch, | |
) | |
return epoch_iter | |