salad-demo / salad /spaghetti /models /models_utils.py
DveloperY0115's picture
init repo
801501a
from ..custom_types import *
from abc import ABC
import math
def torch_no_grad(func):
def wrapper(*args, **kwargs):
with torch.no_grad():
result = func(*args, **kwargs)
return result
return wrapper
class Model(nn.Module, ABC):
def __init__(self):
super(Model, self).__init__()
self.save_model: Union[None, Callable[[nn.Module]]] = None
def save(self, **kwargs):
self.save_model(self, **kwargs)
class Concatenate(nn.Module):
def __init__(self, dim):
super(Concatenate, self).__init__()
self.dim = dim
def forward(self, x):
return torch.cat(x, dim=self.dim)
class View(nn.Module):
def __init__(self, *shape):
super(View, self).__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0, self.dim1 = dim0, dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
class Dummy(nn.Module):
def __init__(self, *args):
super(Dummy, self).__init__()
def forward(self, *args):
return args[0]
class SineLayer(nn.Module):
"""
From the siren repository
https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb
"""
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega_0=30):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.output_channels = out_features
self.init_weights()
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / self.in_features,
1 / self.in_features)
else:
self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)
def forward(self, input):
return torch.sin(self.omega_0 * self.linear(input))
class MLP(nn.Module):
def forward(self, x, *_):
return self.net(x)
def __init__(self, ch: Union[List[int], Tuple[int, ...]], act: nn.Module = nn.ReLU,
weight_norm=False):
super(MLP, self).__init__()
layers = []
for i in range(len(ch) - 1):
layers.append(nn.Linear(ch[i], ch[i + 1]))
if weight_norm:
layers[-1] = nn.utils.weight_norm(layers[-1])
if i < len(ch) - 2:
layers.append(act(True))
self.net = nn.Sequential(*layers)
class GMAttend(nn.Module):
def __init__(self, hidden_dim: int):
super(GMAttend, self).__init__()
self.key_dim = hidden_dim // 8
self.query_w = nn.Linear(hidden_dim, self.key_dim)
self.key_w = nn.Linear(hidden_dim, self.key_dim)
self.value_w = nn.Linear(hidden_dim, hidden_dim)
self.softmax = nn.Softmax(dim=3)
self.gamma = nn.Parameter(torch.zeros(1))
self.scale = 1 / torch.sqrt(torch.tensor(self.key_dim, dtype=torch.float32))
def forward(self, x):
queries = self.query_w(x)
keys = self.key_w(x)
vals = self.value_w(x)
attention = self.softmax(torch.einsum('bgqf,bgkf->bgqk', queries, keys))
out = torch.einsum('bgvf,bgqv->bgqf', vals, attention)
out = self.gamma * out + x
return out
def recursive_to(item, device):
if type(item) is T:
return item.to(device)
elif type(item) is tuple or type(item) is list:
return [recursive_to(item[i], device) for i in range(len(item))]
return item