Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import signal | |
import time | |
import unittest | |
import torch | |
from torch import nn | |
from fairseq.distributed import DistributedTimeoutWrapper | |
class ModuleWithDelay(nn.Module): | |
def __init__(self, delay): | |
super().__init__() | |
self.delay = delay | |
def forward(self, x): | |
time.sleep(self.delay) | |
return x | |
class TestDistributedTimeoutWrapper(unittest.TestCase): | |
def setUp(self): | |
logging.disable(logging.CRITICAL) | |
def tearDown(self): | |
logging.disable(logging.NOTSET) | |
def test_no_timeout(self): | |
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT) | |
module(torch.rand(5)) | |
module.stop_timeout() | |
def test_timeout_safe(self): | |
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT) | |
module(torch.rand(5)) | |
module.stop_timeout() | |
def test_timeout_killed(self): | |
with self.assertRaises(KeyboardInterrupt): | |
module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT) | |
module(torch.rand(5)) | |
module.stop_timeout() | |
if __name__ == "__main__": | |
unittest.main() | |