import torch | |
class TorchSeedContext: | |
def __init__(self, seed): | |
self.seed = seed | |
self.state = None | |
def __enter__(self): | |
self.state = torch.random.get_rng_state() | |
torch.manual_seed(self.seed) | |
def __exit__(self, type, value, traceback): | |
torch.random.set_rng_state(self.state) | |