Spaces:
Sleeping
Sleeping
""" | |
Implementation of time conditioned Transformer. | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_hid, n_position=200): | |
super(PositionalEncoding, self).__init__() | |
# Not a parameter | |
self.register_buffer( | |
"pos_table", self._get_sinusoid_encoding_table(n_position, d_hid) | |
) | |
def _get_sinusoid_encoding_table(self, n_position, d_hid): | |
"""Sinusoid position encoding table""" | |
# TODO: make it with torch instead of numpy | |
def get_position_angle_vec(position): | |
return [ | |
position / np.power(10000, 2 * (hid_j // 2) / d_hid) | |
for hid_j in range(d_hid) | |
] | |
sinusoid_table = np.array( | |
[get_position_angle_vec(pos_i) for pos_i in range(n_position)] | |
) | |
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
def forward(self, x): | |
""" | |
Input: | |
x: [B,N,D] | |
""" | |
return x + self.pos_table[:, : x.size(1)].clone().detach() | |
class ConcatSquashLinear(nn.Module): | |
def __init__(self, dim_in, dim_out, dim_ctx): | |
super(ConcatSquashLinear, self).__init__() | |
self._layer = nn.Linear(dim_in, dim_out) | |
self._hyper_bias = nn.Linear(dim_ctx, dim_out, bias=False) | |
self._hyper_gate = nn.Linear(dim_ctx, dim_out) | |
def forward(self, ctx, x): | |
assert ctx.dim() == x.dim() | |
gate = torch.sigmoid(self._hyper_gate(ctx)) | |
bias = self._hyper_bias(ctx) | |
ret = self._layer(x) * gate + bias | |
return ret | |
class TimeMLP(nn.Module): | |
def __init__( | |
self, | |
dim_in, | |
dim_h, | |
dim_out, | |
dim_ctx=None, | |
act=F.relu, | |
dropout=0.0, | |
use_time=False, | |
): | |
super().__init__() | |
self.act = act | |
self.use_time = use_time | |
dim_h = int(dim_h) | |
if use_time: | |
self.fc1 = ConcatSquashLinear(dim_in, dim_h, dim_ctx) | |
self.fc2 = ConcatSquashLinear(dim_h, dim_out, dim_ctx) | |
else: | |
self.fc1 = nn.Linear(dim_in, dim_h) | |
self.fc2 = nn.Linear(dim_h, dim_out) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, ctx=None): | |
if self.use_time: | |
x = self.fc1(x=x, ctx=ctx) | |
else: | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.dropout(x) | |
if self.use_time: | |
x = self.fc2(x=x, ctx=ctx) | |
else: | |
x = self.fc2(x) | |
x = self.dropout(x) | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, dim_self, dim_ref, num_heads, dropout=0.0): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim_self // num_heads | |
self.scale = head_dim**-0.5 | |
self.to_queries = nn.Linear(dim_self, dim_self) | |
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2) | |
self.project = nn.Linear(dim_self, dim_self) | |
self.dropout = nn.Dropout(dropout) | |
def forward( | |
self, | |
x, | |
y=None, | |
mask=None, | |
alpha=None, | |
): | |
y = y if y is not None else x | |
b_a, n, c = x.shape | |
b, m, d = y.shape | |
# b n h dh | |
queries = self.to_queries(x).reshape( | |
b_a, n, self.num_heads, c // self.num_heads | |
) | |
# b m 2 h dh | |
keys_values = self.to_keys_values(y).reshape( | |
b, m, 2, self.num_heads, c // self.num_heads | |
) | |
keys, values = keys_values[:, :, 0], keys_values[:, :, 1] | |
if alpha is not None: | |
out, attention = self.forward_interpolation( | |
queries, keys, values, alpha, mask | |
) | |
else: | |
attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale | |
if mask is not None: | |
if mask.dim() == 2: | |
mask = mask.unsqueeze(1) | |
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) | |
attention = attention.softmax(dim=2) | |
attention = self.dropout(attention) | |
out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c) | |
out = self.project(out) | |
return out, attention | |
class TimeTransformerEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
dim_self, | |
dim_ctx=None, | |
num_heads=1, | |
mlp_ratio=2.0, | |
act=F.leaky_relu, | |
dropout=0.0, | |
use_time=True, | |
): | |
super().__init__() | |
self.use_time = use_time | |
self.act = act | |
self.attn = MultiHeadAttention(dim_self, dim_self, num_heads, dropout) | |
self.attn_norm = nn.LayerNorm(dim_self) | |
mlp_ratio = int(mlp_ratio) | |
self.mlp = TimeMLP( | |
dim_self, dim_self * mlp_ratio, dim_self, dim_ctx, use_time=use_time | |
) | |
self.norm = nn.LayerNorm(dim_self) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, ctx=None): | |
res = x | |
x, attn = self.attn(x) | |
x = self.attn_norm(x + res) | |
res = x | |
x = self.mlp(x, ctx=ctx) | |
x = self.norm(x + res) | |
return x, attn | |
class TimeTransformerDecoderLayer(TimeTransformerEncoderLayer): | |
def __init__( | |
self, | |
dim_self, | |
dim_ref, | |
dim_ctx=None, | |
num_heads=1, | |
mlp_ratio=2, | |
act=F.leaky_relu, | |
dropout=0.0, | |
use_time=True, | |
): | |
super().__init__( | |
dim_self=dim_self, | |
dim_ctx=dim_ctx, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
act=act, | |
dropout=dropout, | |
use_time=use_time, | |
) | |
self.cross_attn = MultiHeadAttention(dim_self, dim_ref, num_heads, dropout) | |
self.cross_attn_norm = nn.LayerNorm(dim_self) | |
def forward(self, x, y, ctx=None): | |
res = x | |
x, attn = self.attn(x) | |
x = self.attn_norm(x + res) | |
res = x | |
x, attn = self.cross_attn(x, y) | |
x = self.cross_attn_norm(x + res) | |
res = x | |
x = self.mlp(x, ctx=ctx) | |
x = self.norm(x + res) | |
return x, attn | |
class TimeTransformerEncoder(nn.Module): | |
def __init__( | |
self, | |
dim_self, | |
dim_ctx=None, | |
num_heads=1, | |
mlp_ratio=2.0, | |
act=F.leaky_relu, | |
dropout=0.0, | |
use_time=True, | |
num_layers=3, | |
last_fc=False, | |
last_fc_dim_out=None, | |
): | |
super().__init__() | |
self.last_fc = last_fc | |
if last_fc: | |
self.fc = nn.Linear(dim_self, last_fc_dim_out) | |
self.layers = nn.ModuleList( | |
[ | |
TimeTransformerEncoderLayer( | |
dim_self, | |
dim_ctx=dim_ctx, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
act=act, | |
dropout=dropout, | |
use_time=use_time, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
def forward(self, x, ctx=None): | |
for i, layer in enumerate(self.layers): | |
x, attn = layer(x, ctx=ctx) | |
if self.last_fc: | |
x = self.fc(x) | |
return x | |
class TimeTransformerDecoder(nn.Module): | |
def __init__( | |
self, | |
dim_self, | |
dim_ref, | |
dim_ctx=None, | |
num_heads=1, | |
mlp_ratio=2.0, | |
act=F.leaky_relu, | |
dropout=0.0, | |
use_time=True, | |
num_layers=3, | |
last_fc=True, | |
last_fc_dim_out=None, | |
): | |
super().__init__() | |
self.last_fc = last_fc | |
if last_fc: | |
self.fc = nn.Linear(dim_self, last_fc_dim_out) | |
self.layers = nn.ModuleList( | |
[ | |
TimeTransformerDecoderLayer( | |
dim_self, | |
dim_ref, | |
dim_ctx, | |
num_heads, | |
mlp_ratio, | |
act, | |
dropout, | |
use_time, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
def forward(self, x, y, ctx=None): | |
for i, layer in enumerate(self.layers): | |
x, attn = layer(x, y=y, ctx=ctx) | |
if self.last_fc: | |
x = self.fc(x) | |
return x | |