videocrafter / lvdm /models /modules /attention_temporal.py
RamAnanth1's picture
Upload 14 files
b6b5d48
raw
history blame
No virus
15.1 kB
from typing import Optional, Any
import torch
import torch as th
from torch import nn, einsum
from einops import rearrange, repeat
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
from lvdm.models.modules.util import (
GEGLU,
exists,
default,
Normalize,
checkpoint,
zero_module,
)
# ---------------------------------------------------------------------------------------------------
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
# ---------------------------------------------------------------------------------------------------
class RelativePosition(nn.Module):
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(th.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
device = self.embeddings_table.device
range_vec_q = th.arange(length_q, device=device)
range_vec_k = th.arange(length_k, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = th.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = final_mat.long()
embeddings = self.embeddings_table[final_mat]
return embeddings
# ---------------------------------------------------------------------------------------------------
class TemporalCrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.,
use_relative_position=False, # whether use relative positional representation in temporal attention.
temporal_length=None, # relative positional representation
**kwargs,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.context_dim = context_dim
self.scale = dim_head ** -0.5
self.heads = heads
self.temporal_length = temporal_length
self.use_relative_position = use_relative_position
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
if use_relative_position:
assert(temporal_length is not None)
self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
nn.init.constant_(self.to_q.weight, 0)
nn.init.constant_(self.to_k.weight, 0)
nn.init.constant_(self.to_v.weight, 0)
nn.init.constant_(self.to_out[0].weight, 0)
nn.init.constant_(self.to_out[0].bias, 0)
def forward(self, x, context=None, mask=None):
nh = self.heads
out = x
# cal qkv
q = self.to_q(out)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=nh), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
# relative positional embedding
if self.use_relative_position:
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
k2 = self.relative_position_k(len_q, len_k)
sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale
sim += sim2
# mask attention
if mask is not None:
max_neg_value = -1e9
sim = sim + (1-mask.float()) * max_neg_value # 1=masking,0=no masking
# attend to values
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
# relative positional embedding
if self.use_relative_position:
v2 = self.relative_position_v(len_q, len_v)
out2 = einsum('b t s, t s d -> b t d', attn, v2)
out += out2
# merge head
out = rearrange(out, '(b h) n d -> b n (h d)', h=nh)
return self.to_out(out)
# ---------------------------------------------------------------------------------------------------
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.,
**kwargs,):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
b = x.shape[0]
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
# ---------------------------------------------------------------------------------------------------
class MemoryEfficientCrossAttention(nn.Module):
"""https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
"""
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0,
**kwargs,):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads."
)
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
# ---------------------------------------------------------------------------------------------------
class BasicTransformerBlockST(nn.Module):
"""
if no context is given to forward function, cross-attention defaults to self-attention
"""
def __init__(self,
# Spatial
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
# Temporal
temporal_length=None,
use_relative_position=True,
**kwargs,
):
super().__init__()
# spatial self attention (if context_dim is None) and spatial cross attention
if XFORMERS_IS_AVAILBLE:
self.attn1 = MemoryEfficientCrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
self.attn2 = MemoryEfficientCrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
else:
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs,)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
# Temporal attention
self.attn1_tmp = TemporalCrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
temporal_length=temporal_length,
use_relative_position=use_relative_position,
**kwargs,
)
self.attn2_tmp = TemporalCrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
# cross attn
context_dim=None,
# temporal attn
temporal_length=temporal_length,
use_relative_position=use_relative_position,
**kwargs,
)
self.norm4 = nn.LayerNorm(dim)
self.norm5 = nn.LayerNorm(dim)
def forward(self, x, context=None, **kwargs):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, mask=None,):
assert(x.dim() == 5), f"x shape = {x.shape}"
b, c, t, h, w = x.shape
# spatial self attention
x = rearrange(x, 'b c t h w -> (b t) (h w) c')
x = self.attn1(self.norm1(x)) + x
x = rearrange(x, '(b t) (h w) c -> b c t h w', b=b,h=h)
# temporal self attention
x = rearrange(x, 'b c t h w -> (b h w) t c')
x = self.attn1_tmp(self.norm4(x), mask=mask) + x
x = rearrange(x, '(b h w) t c -> b c t h w', b=b,h=h,w=w) # 3d -> 5d
# spatial cross attention
x = rearrange(x, 'b c t h w -> (b t) (h w) c')
if context is not None:
context_ = []
for i in range(context.shape[0]):
context_.append(context[i].unsqueeze(0).repeat(t, 1, 1))
context_ = torch.cat(context_,dim=0)
else:
context_ = None
x = self.attn2(self.norm2(x), context=context_) + x
x = rearrange(x, '(b t) (h w) c -> b c t h w', b=b,h=h)
# temporal cross attention
x = rearrange(x, 'b c t h w -> (b h w) t c')
x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
# feedforward
x = self.ff(self.norm3(x)) + x
x = rearrange(x, '(b h w) t c -> b c t h w', b=b,h=h,w=w) # 3d -> 5d
return x
# ---------------------------------------------------------------------------------------------------
class SpatialTemporalTransformer(nn.Module):
"""
Transformer block for video-like data (5D tensor).
First, project the input (aka embedding) with NO reshape.
Then apply standard transformer action.
The 5D -> 3D reshape operation will be done in the specific attention module.
"""
def __init__(
self,
in_channels, n_heads, d_head,
depth=1, dropout=0.,
context_dim=None,
# Temporal
temporal_length=None,
use_relative_position=True,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv3d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlockST(
inner_dim, n_heads, d_head, dropout=dropout,
# cross attn
context_dim=context_dim,
# temporal attn
temporal_length=temporal_length,
use_relative_position=use_relative_position,
**kwargs
) for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv3d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None, **kwargs):
assert(x.dim() == 5), f"x shape = {x.shape}"
x_in = x
x = self.norm(x)
x = self.proj_in(x)
for block in self.transformer_blocks:
x = block(x, context=context, **kwargs)
x = self.proj_out(x)
return x + x_in