|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import json |
|
import os |
|
import tempfile |
|
import unittest |
|
from io import StringIO |
|
|
|
import torch |
|
|
|
from . import test_binaries |
|
|
|
|
|
class TestReproducibility(unittest.TestCase): |
|
def _test_reproducibility( |
|
self, |
|
name, |
|
extra_flags=None, |
|
delta=0.0001, |
|
resume_checkpoint="checkpoint1.pt", |
|
max_epoch=3, |
|
): |
|
def get_last_log_stats_containing_string(log_records, search_string): |
|
for log_record in logs.records[::-1]: |
|
if isinstance(log_record.msg, str) and search_string in log_record.msg: |
|
return json.loads(log_record.msg) |
|
|
|
if extra_flags is None: |
|
extra_flags = [] |
|
|
|
with tempfile.TemporaryDirectory(name) as data_dir: |
|
with self.assertLogs() as logs: |
|
test_binaries.create_dummy_data(data_dir) |
|
test_binaries.preprocess_translation_data(data_dir) |
|
|
|
|
|
with self.assertLogs() as logs: |
|
test_binaries.train_translation_model( |
|
data_dir, |
|
"fconv_iwslt_de_en", |
|
[ |
|
"--dropout", |
|
"0.0", |
|
"--log-format", |
|
"json", |
|
"--log-interval", |
|
"1", |
|
"--max-epoch", |
|
str(max_epoch), |
|
] |
|
+ extra_flags, |
|
) |
|
train_log = get_last_log_stats_containing_string(logs.records, "train_loss") |
|
valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss") |
|
|
|
|
|
os.rename( |
|
os.path.join(data_dir, resume_checkpoint), |
|
os.path.join(data_dir, "checkpoint_last.pt"), |
|
) |
|
with self.assertLogs() as logs: |
|
test_binaries.train_translation_model( |
|
data_dir, |
|
"fconv_iwslt_de_en", |
|
[ |
|
"--dropout", |
|
"0.0", |
|
"--log-format", |
|
"json", |
|
"--log-interval", |
|
"1", |
|
"--max-epoch", |
|
str(max_epoch), |
|
] |
|
+ extra_flags, |
|
) |
|
train_res_log = get_last_log_stats_containing_string( |
|
logs.records, "train_loss" |
|
) |
|
valid_res_log = get_last_log_stats_containing_string( |
|
logs.records, "valid_loss" |
|
) |
|
|
|
for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]: |
|
self.assertAlmostEqual( |
|
float(train_log[k]), float(train_res_log[k]), delta=delta |
|
) |
|
for k in [ |
|
"valid_loss", |
|
"valid_ppl", |
|
"valid_num_updates", |
|
"valid_best_loss", |
|
]: |
|
self.assertAlmostEqual( |
|
float(valid_log[k]), float(valid_res_log[k]), delta=delta |
|
) |
|
|
|
def test_reproducibility(self): |
|
self._test_reproducibility("test_reproducibility") |
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") |
|
def test_reproducibility_fp16(self): |
|
self._test_reproducibility( |
|
"test_reproducibility_fp16", |
|
[ |
|
"--fp16", |
|
"--fp16-init-scale", |
|
"4096", |
|
], |
|
delta=0.011, |
|
) |
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") |
|
def test_reproducibility_memory_efficient_fp16(self): |
|
self._test_reproducibility( |
|
"test_reproducibility_memory_efficient_fp16", |
|
[ |
|
"--memory-efficient-fp16", |
|
"--fp16-init-scale", |
|
"4096", |
|
], |
|
) |
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") |
|
def test_reproducibility_amp(self): |
|
self._test_reproducibility( |
|
"test_reproducibility_amp", |
|
[ |
|
"--amp", |
|
"--fp16-init-scale", |
|
"4096", |
|
], |
|
delta=0.011, |
|
) |
|
|
|
def test_mid_epoch_reproducibility(self): |
|
self._test_reproducibility( |
|
"test_mid_epoch_reproducibility", |
|
["--save-interval-updates", "3"], |
|
resume_checkpoint="checkpoint_1_3.pt", |
|
max_epoch=1, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|