Dionyssos's picture
add lfs
f7fd0c3
raw
history blame contribute delete
No virus
24.6 kB
from math import floor, log, pi
from typing import Any, List, Optional, Sequence, Tuple, Union
from .utils import *
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many
from torch import Tensor, einsum
"""
Utils
"""
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels*2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
class StyleTransformer1d(nn.Module):
def __init__(
self,
num_layers: int,
channels: int,
num_heads: int,
head_features: int,
multiplier: int,
use_context_time: bool = True,
use_rel_pos: bool = False,
context_features_multiplier: int = 1,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
embedding_max_length: int = 512,
):
super().__init__()
self.blocks = nn.ModuleList(
[
StyleTransformerBlock(
features=channels + context_embedding_features,
head_features=head_features,
num_heads=num_heads,
multiplier=multiplier,
style_dim=context_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
for i in range(num_layers)
]
)
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
nn.Conv1d(
in_channels=channels + context_embedding_features,
out_channels=channels,
kernel_size=1,
),
)
use_context_features = exists(context_features)
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
)
if use_context_time:
assert exists(context_mapping_features)
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
),
nn.GELU(),
)
if use_context_features:
assert exists(context_features) and exists(context_mapping_features)
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
),
nn.GELU(),
)
self.fixed_embedding = FixedEmbedding(
max_length=embedding_max_length, features=context_embedding_features
)
def get_mapping(
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
) -> Optional[Tensor]:
"""Combines context time features and features into mapping"""
items, mapping = [], None
# Compute time features
if self.use_context_time:
assert_message = "use_context_time=True but no time features provided"
assert exists(time), assert_message
items += [self.to_time(time)]
# Compute features
if self.use_context_features:
assert_message = "context_features exists but no features provided"
assert exists(features), assert_message
items += [self.to_features(features)]
# Compute joint mapping
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def run(self, x, time, embedding, features):
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x, features)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
def forward(self, x: Tensor,
time: Tensor,
embedding_mask_proba: float = 0.0,
embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None,
embedding_scale: float = 1.0) -> Tensor:
b, device = embedding.shape[0], embedding.device
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
# Randomly mask embedding
batch_mask = rand_bool(
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
# Compute both normal and fixed embedding outputs
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
# Scale conditional output using classifier-free guidance
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)
return x
class StyleTransformerBlock(nn.Module):
def __init__(
self,
features: int,
num_heads: int,
head_features: int,
style_dim: int,
multiplier: int,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None,
):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = StyleAttention(
features=features,
style_dim=style_dim,
num_heads=num_heads,
head_features=head_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
if self.use_cross_attention:
self.cross_attention = StyleAttention(
features=features,
style_dim=style_dim,
num_heads=num_heads,
head_features=head_features,
context_features=context_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
x = self.attention(x, s) + x
if self.use_cross_attention:
x = self.cross_attention(x, s, context=context) + x
x = self.feed_forward(x) + x
return x
class StyleAttention(nn.Module):
def __init__(
self,
features: int,
*,
style_dim: int,
head_features: int,
num_heads: int,
context_features: Optional[int] = None,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = AdaLayerNorm(style_dim, features)
self.norm_context = AdaLayerNorm(style_dim, context_features)
self.to_q = nn.Linear(
in_features=features, out_features=mid_features, bias=False
)
self.to_kv = nn.Linear(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features,
num_heads=num_heads,
head_features=head_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
assert_message = "You must provide a context when using context_features"
assert not self.context_features or exists(context), assert_message
# Use context if provided
context = default(context, x)
# Normalize then compute q from input and k,v from context
x, context = self.norm(x, s), self.norm_context(context, s)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
# Compute and return attention
return self.attention(q, k, v)
class Transformer1d(nn.Module):
def __init__(
self,
num_layers: int,
channels: int,
num_heads: int,
head_features: int,
multiplier: int,
use_context_time: bool = True,
use_rel_pos: bool = False,
context_features_multiplier: int = 1,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
embedding_max_length: int = 512,
):
super().__init__()
self.blocks = nn.ModuleList(
[
TransformerBlock(
features=channels + context_embedding_features,
head_features=head_features,
num_heads=num_heads,
multiplier=multiplier,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
for i in range(num_layers)
]
)
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
nn.Conv1d(
in_channels=channels + context_embedding_features,
out_channels=channels,
kernel_size=1,
),
)
use_context_features = exists(context_features)
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
)
if use_context_time:
assert exists(context_mapping_features)
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
),
nn.GELU(),
)
if use_context_features:
assert exists(context_features) and exists(context_mapping_features)
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
),
nn.GELU(),
)
self.fixed_embedding = FixedEmbedding(
max_length=embedding_max_length, features=context_embedding_features
)
def get_mapping(
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
) -> Optional[Tensor]:
"""Combines context time features and features into mapping"""
items, mapping = [], None
# Compute time features
if self.use_context_time:
assert_message = "use_context_time=True but no time features provided"
assert exists(time), assert_message
items += [self.to_time(time)]
# Compute features
if self.use_context_features:
assert_message = "context_features exists but no features provided"
assert exists(features), assert_message
items += [self.to_features(features)]
# Compute joint mapping
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def run(self, x, time, embedding, features):
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
def forward(self, x: Tensor,
time: Tensor,
embedding_mask_proba: float = 0.0,
embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None,
embedding_scale: float = 1.0) -> Tensor:
b, device = embedding.shape[0], embedding.device
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
# Randomly mask embedding
batch_mask = rand_bool(
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
# Compute both normal and fixed embedding outputs
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
# Scale conditional output using classifier-free guidance
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)
return x
"""
Attention Components
"""
class RelativePositionBias(nn.Module):
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.num_heads = num_heads
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
@staticmethod
def _relative_position_bucket(
relative_position: Tensor, num_buckets: int, max_distance: int
):
num_buckets //= 2
ret = (relative_position >= 0).to(torch.long) * num_buckets
n = torch.abs(relative_position)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = (
max_exact
+ (
torch.log(n.float() / max_exact)
/ log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, num_queries: int, num_keys: int) -> Tensor:
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
relative_position_bucket = self._relative_position_bucket(
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
)
bias = self.relative_attention_bias(relative_position_bucket)
bias = rearrange(bias, "m n h -> 1 h m n")
return bias
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
class AttentionBase(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
use_rel_pos: bool,
out_features: Optional[int] = None,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.scale = head_features ** -0.5
self.num_heads = num_heads
self.use_rel_pos = use_rel_pos
mid_features = head_features * num_heads
if use_rel_pos:
assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
self.rel_pos = RelativePositionBias(
num_buckets=rel_pos_num_buckets,
max_distance=rel_pos_max_distance,
num_heads=num_heads,
)
if out_features is None:
out_features = features
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Split heads
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
# Compute similarity matrix
sim = einsum("... n d, ... m d -> ... n m", q, k)
sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
sim = sim * self.scale
# Get attention matrix with softmax
attn = sim.softmax(dim=-1)
# Compute values
out = einsum("... n m, ... m d -> ... n d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Attention(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
context_features: Optional[int] = None,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = nn.LayerNorm(features)
self.norm_context = nn.LayerNorm(context_features)
self.to_q = nn.Linear(
in_features=features, out_features=mid_features, bias=False
)
self.to_kv = nn.Linear(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features,
out_features=out_features,
num_heads=num_heads,
head_features=head_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
assert_message = "You must provide a context when using context_features"
assert not self.context_features or exists(context), assert_message
# Use context if provided
context = default(context, x)
# Normalize then compute q from input and k,v from context
x, context = self.norm(x), self.norm_context(context)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
# Compute and return attention
return self.attention(q, k, v)
"""
Transformer Blocks
"""
class TransformerBlock(nn.Module):
def __init__(
self,
features: int,
num_heads: int,
head_features: int,
multiplier: int,
use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None,
):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
if self.use_cross_attention:
self.cross_attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features,
context_features=context_features,
use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance,
)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
x = self.attention(x) + x
if self.use_cross_attention:
x = self.cross_attention(x, context=context) + x
x = self.feed_forward(x) + x
return x
"""
Time Embeddings
"""
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device, half_dim = x.device, self.dim // 2
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
nn.Linear(in_features=dim + 1, out_features=out_features),
)
class FixedEmbedding(nn.Module):
def __init__(self, max_length: int, features: int):
super().__init__()
self.max_length = max_length
self.embedding = nn.Embedding(max_length, features)
def forward(self, x: Tensor) -> Tensor:
batch_size, length, device = *x.shape[0:2], x.device
assert_message = "Input sequence length must be <= max_length"
assert length <= self.max_length, assert_message
position = torch.arange(length, device=device)
fixed_embedding = self.embedding(position)
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
return fixed_embedding