zhengr's picture
init
c02bdcd
raw
history blame contribute delete
329 Bytes
import torch
class TorchSeedContext:
def __init__(self, seed):
self.seed = seed
self.state = None
def __enter__(self):
self.state = torch.random.get_rng_state()
torch.manual_seed(self.seed)
def __exit__(self, type, value, traceback):
torch.random.set_rng_state(self.state)