File size: 2,781 Bytes
824afbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from dataclasses import dataclass
from typing import TypeVar, Generic, Type, Optional
from functools import wraps
import time
import random

import torch as T
import torch.nn as nn

# @TODO: remove si_module from codebase
# we use this in our research codebase to make modules from callable configs
si_module_TpV = TypeVar('si_module_TpV')
def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]:
    if not hasattr(cls, 'Config') or not isinstance(cls.Config, type):
        class Config:
            pass
        cls.Config = Config
    
    cls.Config = dataclass(cls.Config)
    
    class ConfigWrapper(cls.Config, Generic[si_module_TpV]):
        def __call__(self, *args, **kwargs) -> si_module_TpV:
            if len(kwargs) > 0:
                config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()}
                config_dict.update(kwargs)
                new_config = type(self)(**config_dict)
                return cls(new_config)
            else:
                return cls(self, *args)
    
    ConfigWrapper.__module__ = cls.__module__
    ConfigWrapper.__name__ = f"{cls.__name__}Config"
    ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config"
    
    cls.Config = ConfigWrapper
    
    original_init = cls.__init__
    def new_init(self, *args, **kwargs):
        self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None)
        original_init(self, *args, **kwargs)
        self.register_buffer('_device_tracker', T.Tensor(), persistent=False)
    
    cls.__init__ = new_init
    
    @property
    def device(self):
        return self._device_tracker.device
    
    @property
    def dtype(self):
        return self._device_tracker.dtype
    
    cls.device = device
    cls.dtype = dtype
    
    return cls


def get_activation(nonlinear_activation, nonlinear_activation_params={}):
    if hasattr(nn, nonlinear_activation):
        return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
    else:
        raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn")


def exists(v):
    return v is not None

def isnt(v):
    return not exists(v)

def truthyexists(v):
    return exists(v) and v is not False

def truthyattr(obj, attr):
    return hasattr(obj, attr) and truthyexists(getattr(obj, attr))

defaultT = TypeVar('defaultT')

def default(*args: Optional[defaultT]) -> Optional[defaultT]:
    for arg in args:
        if exists(arg):
            return arg
    return None

def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner