Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import random | |
def worker_init_fn(wid): | |
seed_sequence = np.random.SeedSequence( | |
[torch.initial_seed(), wid] | |
) | |
to_seed = spawn_get(seed_sequence, 2, dtype=int) | |
torch.random.manual_seed(to_seed) | |
np_seed = spawn_get(seed_sequence, 2, dtype=np.ndarray) | |
np.random.seed(np_seed) | |
py_seed = spawn_get(seed_sequence, 2, dtype=int) | |
random.seed(py_seed) | |
def spawn_get(seedseq, n_entropy, dtype): | |
child = seedseq.spawn(1)[0] | |
state = child.generate_state(n_entropy, dtype=np.uint32) | |
if dtype == np.ndarray: | |
return state | |
elif dtype == int: | |
state_as_int = 0 | |
for shift, s in enumerate(state): | |
state_as_int = state_as_int + int((2 ** (32 * shift) * s)) | |
return state_as_int | |
else: | |
raise ValueError(f'not a valid dtype "{dtype}"') | |