Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
import logging | |
import os | |
from collections import OrderedDict | |
from argparse import ArgumentError | |
import torch | |
from fairseq import metrics, options, utils | |
from fairseq.data import ( | |
Dictionary, | |
LanguagePairDataset, | |
RoundRobinZipDatasets, | |
TransformEosLangPairDataset, | |
) | |
from fairseq.models import FairseqMultiModel | |
from fairseq.tasks.translation import load_langpair_dataset | |
from . import LegacyFairseqTask, register_task | |
logger = logging.getLogger(__name__) | |
def _lang_token(lang: str): | |
return "__{}__".format(lang) | |
def _lang_token_index(dic: Dictionary, lang: str): | |
"""Return language token index.""" | |
idx = dic.index(_lang_token(lang)) | |
assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang) | |
return idx | |
class MultilingualTranslationTask(LegacyFairseqTask): | |
"""A task for training multiple translation models simultaneously. | |
We iterate round-robin over batches from multiple language pairs, ordered | |
according to the `--lang-pairs` argument. | |
The training loop is roughly: | |
for i in range(len(epoch)): | |
for lang_pair in args.lang_pairs: | |
batch = next_batch_for_lang_pair(lang_pair) | |
loss = criterion(model_for_lang_pair(lang_pair), batch) | |
loss.backward() | |
optimizer.step() | |
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset | |
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that | |
implements the `FairseqMultiModel` interface. | |
During inference it is required to specify a single `--source-lang` and | |
`--target-lang`, which indicates the inference langauge direction. | |
`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to | |
the same value as training. | |
""" | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
# fmt: off | |
parser.add_argument('data', metavar='DIR', help='path to data directory') | |
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', | |
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') | |
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', | |
help='source language (only needed for inference)') | |
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', | |
help='target language (only needed for inference)') | |
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', | |
help='pad the source on the left (default: True)') | |
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', | |
help='pad the target on the left (default: False)') | |
try: | |
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', | |
help='max number of tokens in the source sequence') | |
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', | |
help='max number of tokens in the target sequence') | |
except ArgumentError: | |
# this might have already been defined. Once we transition this to hydra it should be fine to add it here. | |
pass | |
parser.add_argument('--upsample-primary', default=1, type=int, | |
help='amount to upsample primary dataset') | |
parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'], | |
metavar='SRCTGT', | |
help='replace beginning-of-sentence in source sentence with source or target ' | |
'language token. (src/tgt)') | |
parser.add_argument('--decoder-langtok', action='store_true', | |
help='replace beginning-of-sentence in target sentence with target language token') | |
# fmt: on | |
def __init__(self, args, dicts, training): | |
super().__init__(args) | |
self.dicts = dicts | |
self.training = training | |
if training: | |
self.lang_pairs = args.lang_pairs | |
else: | |
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] | |
# eval_lang_pairs for multilingual translation is usually all of the | |
# lang_pairs. However for other multitask settings or when we want to | |
# optimize for certain languages we want to use a different subset. Thus | |
# the eval_lang_pairs class variable is provided for classes that extend | |
# this class. | |
self.eval_lang_pairs = self.lang_pairs | |
# model_lang_pairs will be used to build encoder-decoder model pairs in | |
# models.build_model(). This allows multitask type of sub-class can | |
# build models other than the input lang_pairs | |
self.model_lang_pairs = self.lang_pairs | |
self.langs = list(dicts.keys()) | |
def setup_task(cls, args, **kwargs): | |
dicts, training = cls.prepare(args, **kwargs) | |
return cls(args, dicts, training) | |
def update_args(cls, args): | |
args.left_pad_source = utils.eval_bool(args.left_pad_source) | |
args.left_pad_target = utils.eval_bool(args.left_pad_target) | |
if args.lang_pairs is None: | |
raise ValueError( | |
"--lang-pairs is required. List all the language pairs in the training objective." | |
) | |
if isinstance(args.lang_pairs, str): | |
args.lang_pairs = args.lang_pairs.split(",") | |
def prepare(cls, args, **kargs): | |
cls.update_args(args) | |
sorted_langs = sorted( | |
list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) | |
) | |
if args.source_lang is not None or args.target_lang is not None: | |
training = False | |
else: | |
training = True | |
# load dictionaries | |
dicts = OrderedDict() | |
for lang in sorted_langs: | |
paths = utils.split_paths(args.data) | |
assert len(paths) > 0 | |
dicts[lang] = cls.load_dictionary( | |
os.path.join(paths[0], "dict.{}.txt".format(lang)) | |
) | |
if len(dicts) > 0: | |
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() | |
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() | |
assert dicts[lang].unk() == dicts[sorted_langs[0]].unk() | |
if args.encoder_langtok is not None or args.decoder_langtok: | |
for lang_to_add in sorted_langs: | |
dicts[lang].add_symbol(_lang_token(lang_to_add)) | |
logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) | |
return dicts, training | |
def get_encoder_langtok(self, src_lang, tgt_lang): | |
if self.args.encoder_langtok is None: | |
return self.dicts[src_lang].eos() | |
if self.args.encoder_langtok == "src": | |
return _lang_token_index(self.dicts[src_lang], src_lang) | |
else: | |
return _lang_token_index(self.dicts[src_lang], tgt_lang) | |
def get_decoder_langtok(self, tgt_lang): | |
if not self.args.decoder_langtok: | |
return self.dicts[tgt_lang].eos() | |
return _lang_token_index(self.dicts[tgt_lang], tgt_lang) | |
def alter_dataset_langtok( | |
self, | |
lang_pair_dataset, | |
src_eos=None, | |
src_lang=None, | |
tgt_eos=None, | |
tgt_lang=None, | |
): | |
if self.args.encoder_langtok is None and not self.args.decoder_langtok: | |
return lang_pair_dataset | |
new_src_eos = None | |
if ( | |
self.args.encoder_langtok is not None | |
and src_eos is not None | |
and src_lang is not None | |
and tgt_lang is not None | |
): | |
new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang) | |
else: | |
src_eos = None | |
new_tgt_bos = None | |
if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None: | |
new_tgt_bos = self.get_decoder_langtok(tgt_lang) | |
else: | |
tgt_eos = None | |
return TransformEosLangPairDataset( | |
lang_pair_dataset, | |
src_eos=src_eos, | |
new_src_eos=new_src_eos, | |
tgt_bos=tgt_eos, | |
new_tgt_bos=new_tgt_bos, | |
) | |
def load_dataset(self, split, epoch=1, **kwargs): | |
"""Load a dataset split.""" | |
paths = utils.split_paths(self.args.data) | |
assert len(paths) > 0 | |
data_path = paths[(epoch - 1) % len(paths)] | |
def language_pair_dataset(lang_pair): | |
src, tgt = lang_pair.split("-") | |
langpair_dataset = load_langpair_dataset( | |
data_path, | |
split, | |
src, | |
self.dicts[src], | |
tgt, | |
self.dicts[tgt], | |
combine=True, | |
dataset_impl=self.args.dataset_impl, | |
upsample_primary=self.args.upsample_primary, | |
left_pad_source=self.args.left_pad_source, | |
left_pad_target=self.args.left_pad_target, | |
max_source_positions=self.args.max_source_positions, | |
max_target_positions=self.args.max_target_positions, | |
) | |
return self.alter_dataset_langtok( | |
langpair_dataset, | |
src_eos=self.dicts[src].eos(), | |
src_lang=src, | |
tgt_eos=self.dicts[tgt].eos(), | |
tgt_lang=tgt, | |
) | |
self.datasets[split] = RoundRobinZipDatasets( | |
OrderedDict( | |
[ | |
(lang_pair, language_pair_dataset(lang_pair)) | |
for lang_pair in self.lang_pairs | |
] | |
), | |
eval_key=None | |
if self.training | |
else "%s-%s" % (self.args.source_lang, self.args.target_lang), | |
) | |
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
if constraints is not None: | |
raise NotImplementedError( | |
"Constrained decoding with the multilingual_translation task is not supported" | |
) | |
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) | |
return RoundRobinZipDatasets( | |
OrderedDict( | |
[ | |
( | |
lang_pair, | |
self.alter_dataset_langtok( | |
LanguagePairDataset( | |
src_tokens, src_lengths, self.source_dictionary | |
), | |
src_eos=self.source_dictionary.eos(), | |
src_lang=self.args.source_lang, | |
tgt_eos=self.target_dictionary.eos(), | |
tgt_lang=self.args.target_lang, | |
), | |
) | |
] | |
), | |
eval_key=lang_pair, | |
) | |
def build_model(self, args, from_checkpoint=False): | |
def check_args(): | |
messages = [] | |
if ( | |
len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) | |
!= 0 | |
): | |
messages.append( | |
"--lang-pairs should include all the language pairs {}.".format( | |
args.lang_pairs | |
) | |
) | |
if self.args.encoder_langtok != args.encoder_langtok: | |
messages.append( | |
"--encoder-langtok should be {}.".format(args.encoder_langtok) | |
) | |
if self.args.decoder_langtok != args.decoder_langtok: | |
messages.append( | |
"--decoder-langtok should {} be set.".format( | |
"" if args.decoder_langtok else "not" | |
) | |
) | |
if len(messages) > 0: | |
raise ValueError(" ".join(messages)) | |
# Update args -> the fact that the constructor here | |
# changes the args object doesn't mean you get the same one here | |
self.update_args(args) | |
# Check if task args are consistant with model args | |
check_args() | |
from fairseq import models | |
model = models.build_model(args, self, from_checkpoint) | |
if not isinstance(model, FairseqMultiModel): | |
raise ValueError( | |
"MultilingualTranslationTask requires a FairseqMultiModel architecture" | |
) | |
return model | |
def _per_lang_pair_train_loss( | |
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad | |
): | |
loss, sample_size, logging_output = criterion( | |
model.models[lang_pair], sample[lang_pair] | |
) | |
if ignore_grad: | |
loss *= 0 | |
optimizer.backward(loss) | |
return loss, sample_size, logging_output | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
model.train() | |
from collections import defaultdict | |
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) | |
curr_lang_pairs = [ | |
lang_pair | |
for lang_pair in self.model_lang_pairs | |
if sample[lang_pair] is not None and len(sample[lang_pair]) != 0 | |
] | |
for idx, lang_pair in enumerate(curr_lang_pairs): | |
def maybe_no_sync(): | |
if ( | |
self.args.distributed_world_size > 1 | |
and hasattr(model, "no_sync") | |
and idx < len(curr_lang_pairs) - 1 | |
): | |
return model.no_sync() | |
else: | |
return contextlib.ExitStack() # dummy contextmanager | |
with maybe_no_sync(): | |
loss, sample_size, logging_output = self._per_lang_pair_train_loss( | |
lang_pair, | |
model, | |
update_num, | |
criterion, | |
sample, | |
optimizer, | |
ignore_grad, | |
) | |
agg_loss += loss.detach().item() | |
# TODO make summing of the sample sizes configurable | |
agg_sample_size += sample_size | |
for k in logging_output: | |
agg_logging_output[k] += logging_output[k] | |
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] | |
return agg_loss, agg_sample_size, agg_logging_output | |
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample): | |
return criterion(model.models[lang_pair], sample[lang_pair]) | |
def valid_step(self, sample, model, criterion): | |
model.eval() | |
with torch.no_grad(): | |
from collections import defaultdict | |
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) | |
for lang_pair in self.eval_lang_pairs: | |
if ( | |
lang_pair not in sample | |
or sample[lang_pair] is None | |
or len(sample[lang_pair]) == 0 | |
): | |
continue | |
loss, sample_size, logging_output = self._per_lang_pair_valid_loss( | |
lang_pair, model, criterion, sample | |
) | |
agg_loss += loss.data.item() | |
# TODO make summing of the sample sizes configurable | |
agg_sample_size += sample_size | |
for k in logging_output: | |
agg_logging_output[k] += logging_output[k] | |
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] | |
return agg_loss, agg_sample_size, agg_logging_output | |
def inference_step( | |
self, generator, models, sample, prefix_tokens=None, constraints=None | |
): | |
with torch.no_grad(): | |
if self.args.decoder_langtok: | |
bos_token = _lang_token_index( | |
self.target_dictionary, self.args.target_lang | |
) | |
else: | |
bos_token = self.target_dictionary.eos() | |
return generator.generate( | |
models, | |
sample, | |
prefix_tokens=prefix_tokens, | |
constraints=constraints, | |
bos_token=bos_token, | |
) | |
def reduce_metrics(self, logging_outputs, criterion): | |
with metrics.aggregate(): | |
# pass 'sample_size', 'nsentences', 'ntokens' stats to fairseq_task | |
super().reduce_metrics(logging_outputs, criterion) | |
for k in ["sample_size", "nsentences", "ntokens"]: | |
metrics.log_scalar(k, sum(l[k] for l in logging_outputs)) | |
def source_dictionary(self): | |
if self.training: | |
return next(iter(self.dicts.values())) | |
else: | |
return self.dicts[self.args.source_lang] | |
def target_dictionary(self): | |
if self.training: | |
return next(iter(self.dicts.values())) | |
else: | |
return self.dicts[self.args.target_lang] | |
def max_positions(self): | |
"""Return the max sentence length allowed by the task.""" | |
if len(self.datasets.values()) == 0: | |
return { | |
"%s-%s" | |
% (self.args.source_lang, self.args.target_lang): ( | |
self.args.max_source_positions, | |
self.args.max_target_positions, | |
) | |
} | |
return OrderedDict( | |
[ | |
(key, (self.args.max_source_positions, self.args.max_target_positions)) | |
for split in self.datasets.keys() | |
for key in self.datasets[split].datasets.keys() | |
] | |
) | |