|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import tempfile |
|
import unittest |
|
|
|
import torch |
|
from fairseq.data.dictionary import Dictionary |
|
from fairseq.models.transformer import TransformerModel |
|
from fairseq.modules import multihead_attention, sinusoidal_positional_embedding |
|
from fairseq.tasks.fairseq_task import LegacyFairseqTask |
|
|
|
|
|
DEFAULT_TEST_VOCAB_SIZE = 100 |
|
|
|
|
|
class DummyTask(LegacyFairseqTask): |
|
def __init__(self, args): |
|
super().__init__(args) |
|
self.dictionary = get_dummy_dictionary() |
|
if getattr(self.args, "ctc", False): |
|
self.dictionary.add_symbol("<ctc_blank>") |
|
self.src_dict = self.dictionary |
|
self.tgt_dict = self.dictionary |
|
|
|
@property |
|
def source_dictionary(self): |
|
return self.src_dict |
|
|
|
@property |
|
def target_dictionary(self): |
|
return self.dictionary |
|
|
|
|
|
def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): |
|
dummy_dict = Dictionary() |
|
|
|
for id, _ in enumerate(range(vocab_size)): |
|
dummy_dict.add_symbol("{}".format(id), 1000) |
|
return dummy_dict |
|
|
|
|
|
def get_dummy_task_and_parser(): |
|
""" |
|
Return a dummy task and argument parser, which can be used to |
|
create a model/criterion. |
|
""" |
|
parser = argparse.ArgumentParser( |
|
description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS |
|
) |
|
DummyTask.add_args(parser) |
|
args = parser.parse_args([]) |
|
task = DummyTask.setup_task(args) |
|
return task, parser |
|
|
|
|
|
def _test_save_and_load(scripted_module): |
|
with tempfile.NamedTemporaryFile() as f: |
|
scripted_module.save(f.name) |
|
torch.jit.load(f.name) |
|
|
|
|
|
class TestExportModels(unittest.TestCase): |
|
def test_export_multihead_attention(self): |
|
module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) |
|
scripted = torch.jit.script(module) |
|
_test_save_and_load(scripted) |
|
|
|
def test_incremental_state_multihead_attention(self): |
|
module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) |
|
module1 = torch.jit.script(module1) |
|
module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) |
|
module2 = torch.jit.script(module2) |
|
|
|
state = {} |
|
state = module1.set_incremental_state(state, "key", {"a": torch.tensor([1])}) |
|
state = module2.set_incremental_state(state, "key", {"a": torch.tensor([2])}) |
|
v1 = module1.get_incremental_state(state, "key")["a"] |
|
v2 = module2.get_incremental_state(state, "key")["a"] |
|
|
|
self.assertEqual(v1, 1) |
|
self.assertEqual(v2, 2) |
|
|
|
def test_positional_embedding(self): |
|
module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding( |
|
embedding_dim=8, padding_idx=1 |
|
) |
|
scripted = torch.jit.script(module) |
|
_test_save_and_load(scripted) |
|
|
|
@unittest.skipIf( |
|
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" |
|
) |
|
def test_export_transformer(self): |
|
task, parser = get_dummy_task_and_parser() |
|
TransformerModel.add_args(parser) |
|
args = parser.parse_args([]) |
|
model = TransformerModel.build_model(args, task) |
|
scripted = torch.jit.script(model) |
|
_test_save_and_load(scripted) |
|
|
|
@unittest.skipIf( |
|
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" |
|
) |
|
def test_export_transformer_no_token_pos_emb(self): |
|
task, parser = get_dummy_task_and_parser() |
|
TransformerModel.add_args(parser) |
|
args = parser.parse_args([]) |
|
args.no_token_positional_embeddings = True |
|
model = TransformerModel.build_model(args, task) |
|
scripted = torch.jit.script(model) |
|
_test_save_and_load(scripted) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|