Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import random | |
import numpy as np | |
from modules.utils import rng | |
import logging | |
logger = logging.getLogger(__name__) | |
def deterministic(seed=0): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch_rn = rng.convert_np_to_torch(seed) | |
torch.manual_seed(torch_rn) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(torch_rn) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def is_numeric(obj): | |
if isinstance(obj, str): | |
try: | |
float(obj) | |
return True | |
except ValueError: | |
return False | |
elif isinstance(obj, (np.integer, np.signedinteger, np.unsignedinteger)): | |
return True | |
elif isinstance(obj, np.floating): | |
return True | |
elif isinstance(obj, (int, float)): | |
return True | |
else: | |
return False | |
class SeedContext: | |
def __init__(self, seed): | |
assert is_numeric(seed), "Seed must be an number." | |
try: | |
self.seed = int(np.clip(int(seed), -1, 2**32 - 1, out=None, dtype=np.int64)) | |
except Exception as e: | |
raise ValueError(f"Seed must be an integer, but: {type(seed)}") | |
self.seed = seed | |
self.state = None | |
if isinstance(seed, str) and seed.isdigit(): | |
self.seed = int(seed) | |
if isinstance(self.seed, float): | |
self.seed = int(self.seed) | |
if self.seed == -1: | |
self.seed = random.randint(0, 2**32 - 1) | |
def __enter__(self): | |
self.state = (torch.get_rng_state(), random.getstate(), np.random.get_state()) | |
try: | |
deterministic(self.seed) | |
except Exception as e: | |
# raise ValueError( | |
# f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}" | |
# ) | |
logger.warning( | |
f"Deterministic field, with: <{type(self.seed)}> {self.seed}" | |
) | |
def __exit__(self, exc_type, exc_value, traceback): | |
torch.set_rng_state(self.state[0]) | |
random.setstate(self.state[1]) | |
np.random.set_state(self.state[2]) | |
if __name__ == "__main__": | |
print(is_numeric("1234")) # True | |
print(is_numeric("12.34")) # True | |
print(is_numeric("-1234")) # True | |
print(is_numeric("abc123")) # False | |
print(is_numeric(np.int32(10))) # True | |
print(is_numeric(np.float64(10.5))) # True | |
print(is_numeric(10)) # True | |
print(is_numeric(10.5)) # True | |
print(is_numeric(np.int8(10))) # True | |
print(is_numeric(np.uint64(10))) # True | |
print(is_numeric(np.float16(10.5))) # True | |
print(is_numeric([1, 2, 3])) # False | |