File size: 1,351 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import signal
import time
import unittest
import torch
from torch import nn
from fairseq.distributed import DistributedTimeoutWrapper
class ModuleWithDelay(nn.Module):
def __init__(self, delay):
super().__init__()
self.delay = delay
def forward(self, x):
time.sleep(self.delay)
return x
class TestDistributedTimeoutWrapper(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_no_timeout(self):
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
def test_timeout_safe(self):
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
def test_timeout_killed(self):
with self.assertRaises(KeyboardInterrupt):
module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
if __name__ == "__main__":
unittest.main()
|