Spaces:
Running
Running
File size: 5,243 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from fairseq.data import LanguagePairDataset
from . import register_task
from .translation import TranslationTask, load_langpair_dataset
@register_task("translation_from_pretrained_bart")
class TranslationFromPretrainedBARTTask(TranslationTask):
"""
Translate from source language to target language with a model initialized with a multilingual pretrain.
Args:
src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--langs', type=str, metavar='LANG',
help='comma-separated list of monolingual language, '
'for example, "en,de,fr". These should match the '
'langs from pretraining (and be in the same order). '
'You should always add all pretraining language idx '
'during finetuning.')
parser.add_argument('--prepend-bos', action='store_true',
help='prepend bos token to each sentence, which matches '
'mBART pretraining')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args, src_dict, tgt_dict)
self.langs = args.langs.split(",")
for d in [src_dict, tgt_dict]:
for l in self.langs:
d.add_symbol("[{}]".format(l))
d.add_symbol("<mask>")
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
split,
src,
self.src_dict,
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=getattr(self.args, "max_source_positions", 1024),
max_target_positions=getattr(self.args, "max_target_positions", 1024),
load_alignments=self.args.load_alignments,
prepend_bos=getattr(self.args, "prepend_bos", False),
append_source_id=True,
)
def build_generator(self, models, args, **unused):
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
self.target_dictionary,
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
else:
from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang))
source_tokens = []
for s_t in src_tokens:
s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)])
source_tokens.append(s_t)
dataset = LanguagePairDataset(
source_tokens,
src_lengths,
self.source_dictionary,
tgt_dict=self.target_dictionary,
constraints=constraints,
)
return dataset
|