DveloperY0115's picture
init repo
801501a
"""
Implementation of time conditioned Transformer.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
# Not a parameter
self.register_buffer(
"pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
)
def _get_sinusoid_encoding_table(self, n_position, d_hid):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
"""
Input:
x: [B,N,D]
"""
return x + self.pos_table[:, : x.size(1)].clone().detach()
class ConcatSquashLinear(nn.Module):
def __init__(self, dim_in, dim_out, dim_ctx):
super(ConcatSquashLinear, self).__init__()
self._layer = nn.Linear(dim_in, dim_out)
self._hyper_bias = nn.Linear(dim_ctx, dim_out, bias=False)
self._hyper_gate = nn.Linear(dim_ctx, dim_out)
def forward(self, ctx, x):
assert ctx.dim() == x.dim()
gate = torch.sigmoid(self._hyper_gate(ctx))
bias = self._hyper_bias(ctx)
ret = self._layer(x) * gate + bias
return ret
class TimeMLP(nn.Module):
def __init__(
self,
dim_in,
dim_h,
dim_out,
dim_ctx=None,
act=F.relu,
dropout=0.0,
use_time=False,
):
super().__init__()
self.act = act
self.use_time = use_time
dim_h = int(dim_h)
if use_time:
self.fc1 = ConcatSquashLinear(dim_in, dim_h, dim_ctx)
self.fc2 = ConcatSquashLinear(dim_h, dim_out, dim_ctx)
else:
self.fc1 = nn.Linear(dim_in, dim_h)
self.fc2 = nn.Linear(dim_h, dim_out)
self.dropout = nn.Dropout(dropout)
def forward(self, x, ctx=None):
if self.use_time:
x = self.fc1(x=x, ctx=ctx)
else:
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
if self.use_time:
x = self.fc2(x=x, ctx=ctx)
else:
x = self.fc2(x)
x = self.dropout(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, dropout=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim_self // num_heads
self.scale = head_dim**-0.5
self.to_queries = nn.Linear(dim_self, dim_self)
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2)
self.project = nn.Linear(dim_self, dim_self)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x,
y=None,
mask=None,
alpha=None,
):
y = y if y is not None else x
b_a, n, c = x.shape
b, m, d = y.shape
# b n h dh
queries = self.to_queries(x).reshape(
b_a, n, self.num_heads, c // self.num_heads
)
# b m 2 h dh
keys_values = self.to_keys_values(y).reshape(
b, m, 2, self.num_heads, c // self.num_heads
)
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
if alpha is not None:
out, attention = self.forward_interpolation(
queries, keys, values, alpha, mask
)
else:
attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1)
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
attention = attention.softmax(dim=2)
attention = self.dropout(attention)
out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c)
out = self.project(out)
return out, attention
class TimeTransformerEncoderLayer(nn.Module):
def __init__(
self,
dim_self,
dim_ctx=None,
num_heads=1,
mlp_ratio=2.0,
act=F.leaky_relu,
dropout=0.0,
use_time=True,
):
super().__init__()
self.use_time = use_time
self.act = act
self.attn = MultiHeadAttention(dim_self, dim_self, num_heads, dropout)
self.attn_norm = nn.LayerNorm(dim_self)
mlp_ratio = int(mlp_ratio)
self.mlp = TimeMLP(
dim_self, dim_self * mlp_ratio, dim_self, dim_ctx, use_time=use_time
)
self.norm = nn.LayerNorm(dim_self)
self.dropout = nn.Dropout(dropout)
def forward(self, x, ctx=None):
res = x
x, attn = self.attn(x)
x = self.attn_norm(x + res)
res = x
x = self.mlp(x, ctx=ctx)
x = self.norm(x + res)
return x, attn
class TimeTransformerDecoderLayer(TimeTransformerEncoderLayer):
def __init__(
self,
dim_self,
dim_ref,
dim_ctx=None,
num_heads=1,
mlp_ratio=2,
act=F.leaky_relu,
dropout=0.0,
use_time=True,
):
super().__init__(
dim_self=dim_self,
dim_ctx=dim_ctx,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
act=act,
dropout=dropout,
use_time=use_time,
)
self.cross_attn = MultiHeadAttention(dim_self, dim_ref, num_heads, dropout)
self.cross_attn_norm = nn.LayerNorm(dim_self)
def forward(self, x, y, ctx=None):
res = x
x, attn = self.attn(x)
x = self.attn_norm(x + res)
res = x
x, attn = self.cross_attn(x, y)
x = self.cross_attn_norm(x + res)
res = x
x = self.mlp(x, ctx=ctx)
x = self.norm(x + res)
return x, attn
class TimeTransformerEncoder(nn.Module):
def __init__(
self,
dim_self,
dim_ctx=None,
num_heads=1,
mlp_ratio=2.0,
act=F.leaky_relu,
dropout=0.0,
use_time=True,
num_layers=3,
last_fc=False,
last_fc_dim_out=None,
):
super().__init__()
self.last_fc = last_fc
if last_fc:
self.fc = nn.Linear(dim_self, last_fc_dim_out)
self.layers = nn.ModuleList(
[
TimeTransformerEncoderLayer(
dim_self,
dim_ctx=dim_ctx,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
act=act,
dropout=dropout,
use_time=use_time,
)
for _ in range(num_layers)
]
)
def forward(self, x, ctx=None):
for i, layer in enumerate(self.layers):
x, attn = layer(x, ctx=ctx)
if self.last_fc:
x = self.fc(x)
return x
class TimeTransformerDecoder(nn.Module):
def __init__(
self,
dim_self,
dim_ref,
dim_ctx=None,
num_heads=1,
mlp_ratio=2.0,
act=F.leaky_relu,
dropout=0.0,
use_time=True,
num_layers=3,
last_fc=True,
last_fc_dim_out=None,
):
super().__init__()
self.last_fc = last_fc
if last_fc:
self.fc = nn.Linear(dim_self, last_fc_dim_out)
self.layers = nn.ModuleList(
[
TimeTransformerDecoderLayer(
dim_self,
dim_ref,
dim_ctx,
num_heads,
mlp_ratio,
act,
dropout,
use_time,
)
for _ in range(num_layers)
]
)
def forward(self, x, y, ctx=None):
for i, layer in enumerate(self.layers):
x, attn = layer(x, y=y, ctx=ctx)
if self.last_fc:
x = self.fc(x)
return x