File size: 3,098 Bytes
d2b7e94
01e655b
d2b7e94
01e655b
d2b7e94
 
ec6a7d0
374f426
 
01e655b
 
f83b1b7
01e655b
 
ec6a7d0
 
01e655b
ec6a7d0
f83b1b7
 
 
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f83b1b7
01e655b
 
 
b44532e
01e655b
b44532e
01e655b
 
f83b1b7
01e655b
 
 
 
 
 
 
 
 
 
 
 
f83b1b7
 
 
 
 
 
 
01e655b
0129fb6
f83b1b7
0129fb6
374f426
 
 
 
 
0129fb6
01e655b
 
 
 
 
f83b1b7
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import logging
import random

import numpy as np
import torch

from modules.utils import rng

logger = logging.getLogger(__name__)


def deterministic(seed=0, cudnn_deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch_rn = rng.convert_np_to_torch(seed)
    torch.manual_seed(torch_rn)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(torch_rn)

        if cudnn_deterministic:
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False


def is_numeric(obj):
    if isinstance(obj, str):
        try:
            float(obj)
            return True
        except ValueError:
            return False
    elif isinstance(obj, (np.integer, np.signedinteger, np.unsignedinteger)):
        return True
    elif isinstance(obj, np.floating):
        return True
    elif isinstance(obj, (int, float)):
        return True
    else:
        return False


class SeedContext:
    def __init__(self, seed, cudnn_deterministic=False):
        assert is_numeric(seed), "Seed must be an number."

        try:
            self.seed = int(np.clip(int(seed), -1, 2**32 - 1, out=None, dtype=np.int64))
        except Exception as e:
            raise ValueError(f"Seed must be an integer, but: {type(seed)}")

        self.seed = seed
        self.cudnn_deterministic = cudnn_deterministic
        self.state = None

        if isinstance(seed, str) and seed.isdigit():
            self.seed = int(seed)

        if isinstance(self.seed, float):
            self.seed = int(self.seed)

        if self.seed == -1:
            self.seed = random.randint(0, 2**32 - 1)

    def __enter__(self):
        self.state = (
            torch.get_rng_state(),
            random.getstate(),
            np.random.get_state(),
            torch.backends.cudnn.deterministic,
            torch.backends.cudnn.benchmark,
        )

        try:
            deterministic(self.seed, cudnn_deterministic=self.cudnn_deterministic)
        except Exception as e:
            # raise ValueError(
            #     f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}"
            # )
            logger.warning(
                f"Deterministic field, with: <{type(self.seed)}> {self.seed}"
            )

    def __exit__(self, exc_type, exc_value, traceback):
        torch.set_rng_state(self.state[0])
        random.setstate(self.state[1])
        np.random.set_state(self.state[2])
        torch.backends.cudnn.deterministic = self.state[3]
        torch.backends.cudnn.benchmark = self.state[4]


if __name__ == "__main__":
    print(is_numeric("1234"))  # True
    print(is_numeric("12.34"))  # True
    print(is_numeric("-1234"))  # True
    print(is_numeric("abc123"))  # False
    print(is_numeric(np.int32(10)))  # True
    print(is_numeric(np.float64(10.5)))  # True
    print(is_numeric(10))  # True
    print(is_numeric(10.5))  # True
    print(is_numeric(np.int8(10)))  # True
    print(is_numeric(np.uint64(10)))  # True
    print(is_numeric(np.float16(10.5)))  # True
    print(is_numeric([1, 2, 3]))  # False