Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/tests
/test_reproducibility.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import os | |
import tempfile | |
import unittest | |
import torch | |
from . import test_binaries | |
class TestReproducibility(unittest.TestCase): | |
def _test_reproducibility( | |
self, | |
name, | |
extra_flags=None, | |
delta=0.0001, | |
resume_checkpoint="checkpoint1.pt", | |
max_epoch=3, | |
): | |
def get_last_log_stats_containing_string(log_records, search_string): | |
for log_record in logs.records[::-1]: | |
if isinstance(log_record.msg, str) and search_string in log_record.msg: | |
return json.loads(log_record.msg) | |
if extra_flags is None: | |
extra_flags = [] | |
with tempfile.TemporaryDirectory(name) as data_dir: | |
with self.assertLogs() as logs: | |
test_binaries.create_dummy_data(data_dir) | |
test_binaries.preprocess_translation_data(data_dir) | |
# train epochs 1 and 2 together | |
with self.assertLogs() as logs: | |
test_binaries.train_translation_model( | |
data_dir, | |
"fconv_iwslt_de_en", | |
[ | |
"--dropout", | |
"0.0", | |
"--log-format", | |
"json", | |
"--log-interval", | |
"1", | |
"--max-epoch", | |
str(max_epoch), | |
] | |
+ extra_flags, | |
) | |
train_log = get_last_log_stats_containing_string(logs.records, "train_loss") | |
valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss") | |
# train epoch 2, resuming from previous checkpoint 1 | |
os.rename( | |
os.path.join(data_dir, resume_checkpoint), | |
os.path.join(data_dir, "checkpoint_last.pt"), | |
) | |
with self.assertLogs() as logs: | |
test_binaries.train_translation_model( | |
data_dir, | |
"fconv_iwslt_de_en", | |
[ | |
"--dropout", | |
"0.0", | |
"--log-format", | |
"json", | |
"--log-interval", | |
"1", | |
"--max-epoch", | |
str(max_epoch), | |
] | |
+ extra_flags, | |
) | |
train_res_log = get_last_log_stats_containing_string( | |
logs.records, "train_loss" | |
) | |
valid_res_log = get_last_log_stats_containing_string( | |
logs.records, "valid_loss" | |
) | |
for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]: | |
self.assertAlmostEqual( | |
float(train_log[k]), float(train_res_log[k]), delta=delta | |
) | |
for k in [ | |
"valid_loss", | |
"valid_ppl", | |
"valid_num_updates", | |
"valid_best_loss", | |
]: | |
self.assertAlmostEqual( | |
float(valid_log[k]), float(valid_res_log[k]), delta=delta | |
) | |
def test_reproducibility(self): | |
self._test_reproducibility("test_reproducibility") | |
def test_reproducibility_fp16(self): | |
self._test_reproducibility( | |
"test_reproducibility_fp16", | |
[ | |
"--fp16", | |
"--fp16-init-scale", | |
"4096", | |
], | |
delta=0.011, | |
) | |
def test_reproducibility_memory_efficient_fp16(self): | |
self._test_reproducibility( | |
"test_reproducibility_memory_efficient_fp16", | |
[ | |
"--memory-efficient-fp16", | |
"--fp16-init-scale", | |
"4096", | |
], | |
) | |
def test_reproducibility_amp(self): | |
self._test_reproducibility( | |
"test_reproducibility_amp", | |
[ | |
"--amp", | |
"--fp16-init-scale", | |
"4096", | |
], | |
delta=0.011, | |
) | |
def test_mid_epoch_reproducibility(self): | |
self._test_reproducibility( | |
"test_mid_epoch_reproducibility", | |
["--save-interval-updates", "3"], | |
resume_checkpoint="checkpoint_1_3.pt", | |
max_epoch=1, | |
) | |
if __name__ == "__main__": | |
unittest.main() | |