|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import random |
|
import sys |
|
from io import StringIO |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from fairseq import options, utils |
|
from fairseq.data import Dictionary |
|
from fairseq.data.language_pair_dataset import collate |
|
from fairseq.models import ( |
|
FairseqEncoder, |
|
FairseqEncoderDecoderModel, |
|
FairseqIncrementalDecoder, |
|
) |
|
from fairseq.models.fairseq_encoder import EncoderOut |
|
from fairseq.tasks import LegacyFairseqTask |
|
from fairseq_cli import generate, interactive, preprocess, train, validate |
|
import fairseq.distributed.utils as distributed_utils |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
|
|
|
|
def dummy_dictionary(vocab_size, prefix="token_"): |
|
d = Dictionary() |
|
for i in range(vocab_size): |
|
token = prefix + str(i) |
|
d.add_symbol(token) |
|
d.finalize(padding_factor=1) |
|
return d |
|
|
|
|
|
def dummy_dataloader( |
|
samples, padding_idx=1, eos_idx=2, batch_size=None, |
|
): |
|
if batch_size is None: |
|
batch_size = len(samples) |
|
|
|
|
|
for i, sample in enumerate(samples): |
|
if "id" not in sample: |
|
sample["id"] = i |
|
|
|
|
|
dataset = TestDataset(samples) |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)), |
|
) |
|
return iter(dataloader) |
|
|
|
|
|
def sequence_generator_setup(): |
|
|
|
d = dummy_dictionary(vocab_size=2) |
|
|
|
eos = d.eos() |
|
w1 = 4 |
|
w2 = 5 |
|
|
|
|
|
src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) |
|
src_lengths = torch.LongTensor([2, 2]) |
|
|
|
args = argparse.Namespace() |
|
unk = 0.0 |
|
args.beam_probs = [ |
|
|
|
torch.FloatTensor( |
|
[ |
|
|
|
|
|
[0.0, unk, 0.9, 0.1], |
|
[0.0, unk, 0.9, 0.1], |
|
|
|
[0.0, unk, 0.7, 0.3], |
|
[0.0, unk, 0.7, 0.3], |
|
] |
|
), |
|
|
|
torch.FloatTensor( |
|
[ |
|
|
|
|
|
[1.0, unk, 0.0, 0.0], |
|
[0.0, unk, 0.9, 0.1], |
|
|
|
[0.25, unk, 0.35, 0.4], |
|
[0.00, unk, 0.10, 0.9], |
|
] |
|
), |
|
|
|
torch.FloatTensor( |
|
[ |
|
|
|
|
|
[0.0, unk, 0.1, 0.9], |
|
[ |
|
0.6, |
|
unk, |
|
0.2, |
|
0.2, |
|
], |
|
|
|
[ |
|
0.60, |
|
unk, |
|
0.4, |
|
0.00, |
|
], |
|
[0.01, unk, 0.0, 0.99], |
|
] |
|
), |
|
|
|
torch.FloatTensor( |
|
[ |
|
|
|
|
|
[ |
|
1.0, |
|
unk, |
|
0.0, |
|
0.0, |
|
], |
|
[ |
|
1.0, |
|
unk, |
|
0.0, |
|
0.0, |
|
], |
|
|
|
[ |
|
0.1, |
|
unk, |
|
0.5, |
|
0.4, |
|
], |
|
[ |
|
1.0, |
|
unk, |
|
0.0, |
|
0.0, |
|
], |
|
] |
|
), |
|
] |
|
|
|
task = TestTranslationTask.setup_task(args, d, d) |
|
model = task.build_model(args) |
|
tgt_dict = task.target_dictionary |
|
|
|
return tgt_dict, w1, w2, src_tokens, src_lengths, model |
|
|
|
|
|
def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False): |
|
def _create_dummy_data(filename): |
|
data = torch.rand(num_examples * maxlen) |
|
data = 97 + torch.floor(26 * data).int() |
|
with open(os.path.join(data_dir, filename), "w") as h: |
|
offset = 0 |
|
for _ in range(num_examples): |
|
ex_len = random.randint(1, maxlen) |
|
ex_str = " ".join(map(chr, data[offset : offset + ex_len])) |
|
print(ex_str, file=h) |
|
offset += ex_len |
|
|
|
def _create_dummy_alignment_data(filename_src, filename_tgt, filename): |
|
with open(os.path.join(data_dir, filename_src), "r") as src_f, open( |
|
os.path.join(data_dir, filename_tgt), "r" |
|
) as tgt_f, open(os.path.join(data_dir, filename), "w") as h: |
|
for src, tgt in zip(src_f, tgt_f): |
|
src_len = len(src.split()) |
|
tgt_len = len(tgt.split()) |
|
avg_len = (src_len + tgt_len) // 2 |
|
num_alignments = random.randint(avg_len // 2, 2 * avg_len) |
|
src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() |
|
tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() |
|
ex_str = " ".join( |
|
[ |
|
"{}-{}".format(src, tgt) |
|
for src, tgt in zip(src_indices, tgt_indices) |
|
] |
|
) |
|
print(ex_str, file=h) |
|
|
|
_create_dummy_data("train.in") |
|
_create_dummy_data("train.out") |
|
_create_dummy_data("valid.in") |
|
_create_dummy_data("valid.out") |
|
_create_dummy_data("test.in") |
|
_create_dummy_data("test.out") |
|
|
|
if alignment: |
|
_create_dummy_alignment_data("train.in", "train.out", "train.align") |
|
_create_dummy_alignment_data("valid.in", "valid.out", "valid.align") |
|
_create_dummy_alignment_data("test.in", "test.out", "test.align") |
|
|
|
|
|
def preprocess_lm_data(data_dir): |
|
preprocess_parser = options.get_preprocessing_parser() |
|
preprocess_args = preprocess_parser.parse_args( |
|
[ |
|
"--only-source", |
|
"--trainpref", |
|
os.path.join(data_dir, "train.out"), |
|
"--validpref", |
|
os.path.join(data_dir, "valid.out"), |
|
"--testpref", |
|
os.path.join(data_dir, "test.out"), |
|
"--destdir", |
|
data_dir, |
|
] |
|
) |
|
preprocess.main(preprocess_args) |
|
|
|
|
|
def preprocess_translation_data(data_dir, extra_flags=None): |
|
preprocess_parser = options.get_preprocessing_parser() |
|
preprocess_args = preprocess_parser.parse_args( |
|
[ |
|
"--source-lang", |
|
"in", |
|
"--target-lang", |
|
"out", |
|
"--trainpref", |
|
os.path.join(data_dir, "train"), |
|
"--validpref", |
|
os.path.join(data_dir, "valid"), |
|
"--testpref", |
|
os.path.join(data_dir, "test"), |
|
"--thresholdtgt", |
|
"0", |
|
"--thresholdsrc", |
|
"0", |
|
"--destdir", |
|
data_dir, |
|
] |
|
+ (extra_flags or []), |
|
) |
|
preprocess.main(preprocess_args) |
|
|
|
|
|
def preprocess_summarization_data(data_dir, extra_flags=None): |
|
preprocess_parser = options.get_preprocessing_parser() |
|
preprocess_args = preprocess_parser.parse_args( |
|
[ |
|
"--source-lang", |
|
"in", |
|
"--target-lang", |
|
"out", |
|
"--trainpref", |
|
os.path.join(data_dir, "train"), |
|
"--validpref", |
|
os.path.join(data_dir, "valid"), |
|
"--testpref", |
|
os.path.join(data_dir, "test"), |
|
"--thresholdtgt", |
|
"0", |
|
"--thresholdsrc", |
|
"0", |
|
"--joined-dictionary", |
|
"--destdir", |
|
data_dir, |
|
] |
|
+ (extra_flags or []), |
|
) |
|
preprocess.main(preprocess_args) |
|
|
|
|
|
def create_laser_data_and_config_json(data_dir): |
|
src_langs = ["de", "fr", "ru", "tr", "zh"] |
|
tgt_langs = ["en", "es"] |
|
config_json = {} |
|
config_train_json = [] |
|
src_vocab = None |
|
tgt_vocab = None |
|
|
|
for src_lang in src_langs: |
|
for tgt_lang in tgt_langs: |
|
langpair_folder = f"{src_lang}-{tgt_lang}" |
|
|
|
langpair_path = os.path.join(data_dir, langpair_folder) |
|
os.mkdir(langpair_path) |
|
create_dummy_data(langpair_path) |
|
preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"]) |
|
|
|
src_vocab = os.path.join(langpair_path, "dict.in.txt") |
|
tgt_vocab = os.path.join(langpair_path, "dict.out.txt") |
|
config_train_json.append( |
|
{ |
|
"id": 0 if tgt_lang == "en" else 1, |
|
"src": os.path.join(langpair_path, "train.in-out.in"), |
|
"tgt": os.path.join(langpair_path, "train.in-out.out"), |
|
} |
|
) |
|
|
|
config_json["src_vocab"] = src_vocab |
|
config_json["tgt_vocab"] = tgt_vocab |
|
config_json["train"] = config_train_json |
|
|
|
with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file: |
|
json.dump(config_json, config_file) |
|
|
|
return config_file |
|
|
|
|
|
def train_translation_model( |
|
data_dir, |
|
arch, |
|
extra_flags=None, |
|
task="translation", |
|
run_validation=False, |
|
lang_flags=None, |
|
extra_valid_flags=None, |
|
world_size=1, |
|
): |
|
if lang_flags is None: |
|
lang_flags = [ |
|
"--source-lang", |
|
"in", |
|
"--target-lang", |
|
"out", |
|
] |
|
train_parser = options.get_training_parser() |
|
train_args = options.parse_args_and_arch( |
|
train_parser, |
|
[ |
|
"--task", |
|
task, |
|
data_dir, |
|
"--save-dir", |
|
data_dir, |
|
"--arch", |
|
arch, |
|
"--optimizer", |
|
"nag", |
|
"--lr", |
|
"0.05", |
|
"--max-tokens", |
|
"500", |
|
"--max-epoch", |
|
"1", |
|
"--no-progress-bar", |
|
"--distributed-world-size", |
|
str(world_size), |
|
"--num-workers", |
|
"0", |
|
] |
|
+ lang_flags |
|
+ (extra_flags or []), |
|
) |
|
|
|
cfg = convert_namespace_to_omegaconf(train_args) |
|
distributed_utils.call_main(cfg, train.main) |
|
|
|
if run_validation: |
|
|
|
validate_parser = options.get_validation_parser() |
|
validate_args = options.parse_args_and_arch( |
|
validate_parser, |
|
[ |
|
"--task", |
|
task, |
|
data_dir, |
|
"--path", |
|
os.path.join(data_dir, "checkpoint_last.pt"), |
|
"--valid-subset", |
|
"valid", |
|
"--max-tokens", |
|
"500", |
|
"--no-progress-bar", |
|
"--num-workers", |
|
"0", |
|
] |
|
+ lang_flags |
|
+ (extra_valid_flags or []), |
|
) |
|
validate.main(validate_args) |
|
|
|
|
|
def generate_main(data_dir, extra_flags=None, path=None): |
|
if extra_flags is None: |
|
extra_flags = [ |
|
"--print-alignment", |
|
] |
|
if path is None: |
|
path = os.path.join(data_dir, "checkpoint_last.pt") |
|
generate_parser = options.get_generation_parser() |
|
generate_args = options.parse_args_and_arch( |
|
generate_parser, |
|
[ |
|
data_dir, |
|
"--path", |
|
path, |
|
"--beam", |
|
"3", |
|
"--batch-size", |
|
"64", |
|
"--max-len-b", |
|
"5", |
|
"--gen-subset", |
|
"valid", |
|
"--no-progress-bar", |
|
"--num-workers", |
|
"0", |
|
] |
|
+ (extra_flags or []), |
|
) |
|
|
|
|
|
generate.main(generate_args) |
|
|
|
|
|
generate_args.buffer_size = 0 |
|
generate_args.input = "-" |
|
generate_args.batch_size = None |
|
orig_stdin = sys.stdin |
|
sys.stdin = StringIO("h e l l o\n") |
|
interactive.main(generate_args) |
|
sys.stdin = orig_stdin |
|
|
|
|
|
class TestDataset(torch.utils.data.Dataset): |
|
def __init__(self, data): |
|
super().__init__() |
|
self.data = data |
|
self.sizes = None |
|
|
|
def __getitem__(self, index): |
|
return self.data[index] |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
|
|
class TestTranslationTask(LegacyFairseqTask): |
|
def __init__(self, args, src_dict, tgt_dict, model): |
|
super().__init__(args) |
|
self.src_dict = src_dict |
|
self.tgt_dict = tgt_dict |
|
self.model = model |
|
|
|
@classmethod |
|
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None): |
|
return cls(args, src_dict, tgt_dict, model) |
|
|
|
def build_model(self, args): |
|
return TestModel.build_model(args, self) |
|
|
|
@property |
|
def source_dictionary(self): |
|
return self.src_dict |
|
|
|
@property |
|
def target_dictionary(self): |
|
return self.tgt_dict |
|
|
|
|
|
class TestModel(FairseqEncoderDecoderModel): |
|
def __init__(self, encoder, decoder): |
|
super().__init__(encoder, decoder) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
encoder = TestEncoder(args, task.source_dictionary) |
|
decoder = TestIncrementalDecoder(args, task.target_dictionary) |
|
return cls(encoder, decoder) |
|
|
|
|
|
class TestEncoder(FairseqEncoder): |
|
def __init__(self, args, dictionary): |
|
super().__init__(dictionary) |
|
self.args = args |
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs): |
|
return EncoderOut( |
|
encoder_out=src_tokens, |
|
encoder_padding_mask=None, |
|
encoder_embedding=None, |
|
encoder_states=None, |
|
src_tokens=None, |
|
src_lengths=None, |
|
) |
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
return EncoderOut( |
|
encoder_out=encoder_out.encoder_out.index_select(0, new_order), |
|
encoder_padding_mask=None, |
|
encoder_embedding=None, |
|
encoder_states=None, |
|
src_tokens=None, |
|
src_lengths=None, |
|
) |
|
|
|
|
|
class TestIncrementalDecoder(FairseqIncrementalDecoder): |
|
def __init__(self, args, dictionary): |
|
super().__init__(dictionary) |
|
assert hasattr(args, "beam_probs") or hasattr(args, "probs") |
|
args.max_decoder_positions = getattr(args, "max_decoder_positions", 100) |
|
self.args = args |
|
|
|
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): |
|
if incremental_state is not None: |
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
bbsz = prev_output_tokens.size(0) |
|
vocab = len(self.dictionary) |
|
src_len = encoder_out.encoder_out.size(1) |
|
tgt_len = prev_output_tokens.size(1) |
|
|
|
|
|
if incremental_state is not None: |
|
|
|
step = utils.get_incremental_state(self, incremental_state, "step") |
|
if step is None: |
|
step = 0 |
|
utils.set_incremental_state(self, incremental_state, "step", step + 1) |
|
steps = [step] |
|
else: |
|
steps = list(range(tgt_len)) |
|
|
|
|
|
if hasattr(self.args, "probs"): |
|
assert ( |
|
self.args.probs.dim() == 3 |
|
), "expected probs to have size bsz*steps*vocab" |
|
probs = self.args.probs.index_select(1, torch.LongTensor(steps)) |
|
else: |
|
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() |
|
for i, step in enumerate(steps): |
|
|
|
|
|
if step < len(self.args.beam_probs): |
|
probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step] |
|
else: |
|
probs[:, i, self.dictionary.eos()] = 1.0 |
|
|
|
|
|
attn = torch.rand(bbsz, tgt_len, src_len) |
|
|
|
dev = prev_output_tokens.device |
|
return probs.to(dev), {"attn": [attn.to(dev)]} |
|
|
|
def get_normalized_probs(self, net_output, log_probs, _): |
|
|
|
probs = net_output[0] |
|
if log_probs: |
|
return probs.log() |
|
else: |
|
return probs |
|
|
|
def max_positions(self): |
|
return self.args.max_decoder_positions |
|
|
|
|
|
class TestReshapingEncoder(FairseqEncoder): |
|
def __init__(self, args, dictionary): |
|
super().__init__(dictionary) |
|
self.args = args |
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs): |
|
b_sz, t_sz = src_tokens.shape |
|
padding_needed = t_sz % 2 |
|
x = src_tokens |
|
if padding_needed > 0: |
|
padding_needed = 2 - padding_needed |
|
x = F.pad(x, (0, padding_needed)) |
|
|
|
return EncoderOut( |
|
encoder_out=x.view(b_sz, -1, 2), |
|
encoder_padding_mask=None, |
|
encoder_embedding=None, |
|
encoder_states=None, |
|
src_tokens=None, |
|
src_lengths=None, |
|
) |
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
return EncoderOut( |
|
encoder_out=encoder_out.encoder_out.index_select(0, new_order), |
|
encoder_padding_mask=None, |
|
encoder_embedding=None, |
|
encoder_states=None, |
|
src_tokens=None, |
|
src_lengths=None, |
|
) |
|
|
|
|
|
class TestReshapingModel(FairseqEncoderDecoderModel): |
|
def __init__(self, encoder, decoder): |
|
super().__init__(encoder, decoder) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
encoder = TestReshapingEncoder(args, task.source_dictionary) |
|
decoder = TestIncrementalDecoder(args, task.target_dictionary) |
|
return cls(encoder, decoder) |
|
|
|
|
|
class TestAdditionalInputEncoder(FairseqEncoder): |
|
def __init__(self, args, dictionary): |
|
super().__init__(dictionary) |
|
self.args = args |
|
|
|
def forward(self, src_tokens, src_lengths=None, **kwargs): |
|
assert "fancy_other_input" in kwargs |
|
assert kwargs["fancy_other_input"] is not None |
|
return EncoderOut( |
|
encoder_out=src_tokens, |
|
encoder_padding_mask=None, |
|
encoder_embedding=None, |
|
encoder_states=None, |
|
src_tokens=None, |
|
src_lengths=None, |
|
) |
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
return EncoderOut( |
|
encoder_out=encoder_out.encoder_out.index_select(0, new_order), |
|
encoder_padding_mask=None, |
|
encoder_embedding=None, |
|
encoder_states=None, |
|
src_tokens=None, |
|
src_lengths=None, |
|
) |
|
|
|
|
|
class TestAdditionalInputModel(FairseqEncoderDecoderModel): |
|
def __init__(self, encoder, decoder): |
|
super().__init__(encoder, decoder) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
encoder = TestAdditionalInputEncoder(args, task.source_dictionary) |
|
decoder = TestIncrementalDecoder(args, task.target_dictionary) |
|
return cls(encoder, decoder) |
|
|
|
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): |
|
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) |
|
decoder_out = self.decoder( |
|
prev_output_tokens, encoder_out=encoder_out, **kwargs |
|
) |
|
return decoder_out |
|
|
|
|
|
def train_language_model( |
|
data_dir, |
|
arch, |
|
extra_flags=None, |
|
run_validation=False, |
|
extra_valid_flags=None, |
|
task="language_modeling", |
|
world_size=1, |
|
): |
|
train_parser = options.get_training_parser() |
|
train_args = options.parse_args_and_arch( |
|
train_parser, |
|
[ |
|
"--task", |
|
task, |
|
data_dir, |
|
"--arch", |
|
arch, |
|
"--optimizer", |
|
"adam", |
|
"--lr", |
|
"0.0001", |
|
"--max-tokens", |
|
"500", |
|
"--tokens-per-sample", |
|
"500", |
|
"--save-dir", |
|
data_dir, |
|
"--max-epoch", |
|
"1", |
|
"--no-progress-bar", |
|
"--distributed-world-size", |
|
str(world_size), |
|
"--ddp-backend", |
|
"no_c10d", |
|
"--num-workers", |
|
"0", |
|
] |
|
+ (extra_flags or []), |
|
) |
|
cfg = convert_namespace_to_omegaconf(train_args) |
|
distributed_utils.call_main(cfg, train.main) |
|
|
|
if run_validation: |
|
|
|
validate_parser = options.get_validation_parser() |
|
validate_args = options.parse_args_and_arch( |
|
validate_parser, |
|
[ |
|
"--task", |
|
task, |
|
data_dir, |
|
"--path", |
|
os.path.join(data_dir, "checkpoint_last.pt"), |
|
"--valid-subset", |
|
"valid", |
|
"--max-tokens", |
|
"500", |
|
"--no-progress-bar", |
|
"--num-workers", |
|
"0", |
|
] |
|
+ (extra_valid_flags or []), |
|
) |
|
validate.main(validate_args) |
|
|