import torch class TorchSeedContext: def __init__(self, seed): self.seed = seed self.state = None def __enter__(self): self.state = torch.random.get_rng_state() torch.manual_seed(self.seed) def __exit__(self, type, value, traceback): torch.random.set_rng_state(self.state)