Spaces:
Running
Running
# Copyright 2021 AlQuraishi Laboratory | |
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import importlib | |
import math | |
from typing import Optional, Callable, List, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint | |
from scipy.stats import truncnorm | |
from dockformer.utils.kernel.attention_core import attention_core | |
from dockformer.utils.precision_utils import is_fp16_enabled | |
from dockformer.utils.tensor_utils import ( | |
permute_final_dims, | |
flatten_final_dims, | |
) | |
# Suited for 40gb GPU | |
# DEFAULT_LMA_Q_CHUNK_SIZE = 1024 | |
# DEFAULT_LMA_KV_CHUNK_SIZE = 4096 | |
# Suited for 10gb GPU | |
DEFAULT_LMA_Q_CHUNK_SIZE = 64 | |
DEFAULT_LMA_KV_CHUNK_SIZE = 256 | |
def _prod(nums): | |
out = 1 | |
for n in nums: | |
out = out * n | |
return out | |
def _calculate_fan(linear_weight_shape, fan="fan_in"): | |
fan_out, fan_in = linear_weight_shape | |
if fan == "fan_in": | |
f = fan_in | |
elif fan == "fan_out": | |
f = fan_out | |
elif fan == "fan_avg": | |
f = (fan_in + fan_out) / 2 | |
else: | |
raise ValueError("Invalid fan option") | |
return f | |
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): | |
shape = weights.shape | |
f = _calculate_fan(shape, fan) | |
scale = scale / max(1, f) | |
a = -2 | |
b = 2 | |
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) | |
size = _prod(shape) | |
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) | |
samples = np.reshape(samples, shape) | |
with torch.no_grad(): | |
weights.copy_(torch.tensor(samples, device=weights.device)) | |
def lecun_normal_init_(weights): | |
trunc_normal_init_(weights, scale=1.0) | |
def he_normal_init_(weights): | |
trunc_normal_init_(weights, scale=2.0) | |
def glorot_uniform_init_(weights): | |
nn.init.xavier_uniform_(weights, gain=1) | |
def final_init_(weights): | |
with torch.no_grad(): | |
weights.fill_(0.0) | |
def gating_init_(weights): | |
with torch.no_grad(): | |
weights.fill_(0.0) | |
def normal_init_(weights): | |
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") | |
def ipa_point_weights_init_(weights): | |
with torch.no_grad(): | |
softplus_inverse_1 = 0.541324854612918 | |
weights.fill_(softplus_inverse_1) | |
class Linear(nn.Linear): | |
""" | |
A Linear layer with built-in nonstandard initializations. Called just | |
like torch.nn.Linear. | |
Implements the initializers in 1.11.4, plus some additional ones found | |
in the code. | |
""" | |
def __init__( | |
self, | |
in_dim: int, | |
out_dim: int, | |
bias: bool = True, | |
init: str = "default", | |
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, | |
precision=None | |
): | |
""" | |
Args: | |
in_dim: | |
The final dimension of inputs to the layer | |
out_dim: | |
The final dimension of layer outputs | |
bias: | |
Whether to learn an additive bias. True by default | |
init: | |
The initializer to use. Choose from: | |
"default": LeCun fan-in truncated normal initialization | |
"relu": He initialization w/ truncated normal distribution | |
"glorot": Fan-average Glorot uniform initialization | |
"gating": Weights=0, Bias=1 | |
"normal": Normal initialization with std=1/sqrt(fan_in) | |
"final": Weights=0, Bias=0 | |
Overridden by init_fn if the latter is not None. | |
init_fn: | |
A custom initializer taking weight and bias as inputs. | |
Overrides init if not None. | |
""" | |
super(Linear, self).__init__(in_dim, out_dim, bias=bias) | |
if bias: | |
with torch.no_grad(): | |
self.bias.fill_(0) | |
with torch.no_grad(): | |
if init_fn is not None: | |
init_fn(self.weight, self.bias) | |
else: | |
if init == "default": | |
lecun_normal_init_(self.weight) | |
elif init == "relu": | |
he_normal_init_(self.weight) | |
elif init == "glorot": | |
glorot_uniform_init_(self.weight) | |
elif init == "gating": | |
gating_init_(self.weight) | |
if bias: | |
self.bias.fill_(1.0) | |
elif init == "normal": | |
normal_init_(self.weight) | |
elif init == "final": | |
final_init_(self.weight) | |
else: | |
raise ValueError("Invalid init string.") | |
self.precision = precision | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
d = input.dtype | |
if self.precision is not None: | |
with torch.cuda.amp.autocast(enabled=False): | |
bias = self.bias.to(dtype=self.precision) if self.bias is not None else None | |
return nn.functional.linear(input.to(dtype=self.precision), | |
self.weight.to(dtype=self.precision), | |
bias).to(dtype=d) | |
if d is torch.bfloat16: | |
with torch.cuda.amp.autocast(enabled=False): | |
bias = self.bias.to(dtype=d) if self.bias is not None else None | |
return nn.functional.linear(input, self.weight.to(dtype=d), bias) | |
return nn.functional.linear(input, self.weight, self.bias) | |
class LayerNorm(nn.Module): | |
def __init__(self, c_in, eps=1e-5): | |
super(LayerNorm, self).__init__() | |
self.c_in = (c_in,) | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(c_in)) | |
self.bias = nn.Parameter(torch.zeros(c_in)) | |
def forward(self, x): | |
d = x.dtype | |
if d is torch.bfloat16: | |
with torch.cuda.amp.autocast(enabled=False): | |
out = nn.functional.layer_norm( | |
x, | |
self.c_in, | |
self.weight.to(dtype=d), | |
self.bias.to(dtype=d), | |
self.eps | |
) | |
else: | |
out = nn.functional.layer_norm( | |
x, | |
self.c_in, | |
self.weight, | |
self.bias, | |
self.eps, | |
) | |
return out | |
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: | |
""" | |
Softmax, but without automatic casting to fp32 when the input is of | |
type bfloat16 | |
""" | |
d = t.dtype | |
if d is torch.bfloat16: | |
with torch.cuda.amp.autocast(enabled=False): | |
s = torch.nn.functional.softmax(t, dim=dim) | |
else: | |
s = torch.nn.functional.softmax(t, dim=dim) | |
return s | |
#@torch.jit.script | |
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor: | |
# [*, H, C_hidden, K] | |
key = permute_final_dims(key, (1, 0)) | |
# [*, H, Q, K] | |
a = torch.matmul(query, key) | |
for b in biases: | |
a += b | |
a = softmax_no_cast(a, -1) | |
# [*, H, Q, C_hidden] | |
a = torch.matmul(a, value) | |
return a | |
class Attention(nn.Module): | |
""" | |
Standard multi-head attention using AlphaFold's default layer | |
initialization. Allows multiple bias vectors. | |
""" | |
def __init__( | |
self, | |
c_q: int, | |
c_k: int, | |
c_v: int, | |
c_hidden: int, | |
no_heads: int, | |
gating: bool = True, | |
): | |
""" | |
Args: | |
c_q: | |
Input dimension of query data | |
c_k: | |
Input dimension of key data | |
c_v: | |
Input dimension of value data | |
c_hidden: | |
Per-head hidden dimension | |
no_heads: | |
Number of attention heads | |
gating: | |
Whether the output should be gated using query data | |
""" | |
super(Attention, self).__init__() | |
self.c_q = c_q | |
self.c_k = c_k | |
self.c_v = c_v | |
self.c_hidden = c_hidden | |
self.no_heads = no_heads | |
self.gating = gating | |
# DISCREPANCY: c_hidden is not the per-head channel dimension, as | |
# stated in the supplement, but the overall channel dimension. | |
self.linear_q = Linear( | |
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot" | |
) | |
self.linear_k = Linear( | |
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot" | |
) | |
self.linear_v = Linear( | |
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot" | |
) | |
self.linear_o = Linear( | |
self.c_hidden * self.no_heads, self.c_q, init="final" | |
) | |
self.linear_g = None | |
if self.gating: | |
self.linear_g = Linear( | |
self.c_q, self.c_hidden * self.no_heads, init="gating" | |
) | |
self.sigmoid = nn.Sigmoid() | |
def _prep_qkv(self, | |
q_x: torch.Tensor, | |
kv_x: torch.Tensor, | |
apply_scale: bool = True | |
) -> Tuple[ | |
torch.Tensor, torch.Tensor, torch.Tensor | |
]: | |
# [*, Q/K/V, H * C_hidden] | |
q = self.linear_q(q_x) | |
k = self.linear_k(kv_x) | |
v = self.linear_v(kv_x) | |
# [*, Q/K, H, C_hidden] | |
q = q.view(q.shape[:-1] + (self.no_heads, -1)) | |
k = k.view(k.shape[:-1] + (self.no_heads, -1)) | |
v = v.view(v.shape[:-1] + (self.no_heads, -1)) | |
# [*, H, Q/K, C_hidden] | |
q = q.transpose(-2, -3) | |
k = k.transpose(-2, -3) | |
v = v.transpose(-2, -3) | |
if apply_scale: | |
q /= math.sqrt(self.c_hidden) | |
return q, k, v | |
def _wrap_up(self, | |
o: torch.Tensor, | |
q_x: torch.Tensor | |
) -> torch.Tensor: | |
if self.linear_g is not None: | |
g = self.sigmoid(self.linear_g(q_x)) | |
# [*, Q, H, C_hidden] | |
g = g.view(g.shape[:-1] + (self.no_heads, -1)) | |
o = o * g | |
# [*, Q, H * C_hidden] | |
o = flatten_final_dims(o, 2) | |
# [*, Q, C_q] | |
o = self.linear_o(o) | |
return o | |
def forward( | |
self, | |
q_x: torch.Tensor, | |
kv_x: torch.Tensor, | |
biases: Optional[List[torch.Tensor]] = None, | |
use_memory_efficient_kernel: bool = False, | |
use_lma: bool = False, | |
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE, | |
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE, | |
) -> torch.Tensor: | |
""" | |
Args: | |
q_x: | |
[*, Q, C_q] query data | |
kv_x: | |
[*, K, C_k] key data | |
biases: | |
List of biases that broadcast to [*, H, Q, K] | |
use_memory_efficient_kernel: | |
Whether to use a custom memory-efficient attention kernel. | |
This should be the default choice for most. If none of the | |
"use_<...>" flags are True, a stock PyTorch implementation | |
is used instead | |
use_lma: | |
Whether to use low-memory attention (Staats & Rabe 2021). If | |
none of the "use_<...>" flags are True, a stock PyTorch | |
implementation is used instead | |
lma_q_chunk_size: | |
Query chunk size (for LMA) | |
lma_kv_chunk_size: | |
Key/Value chunk size (for LMA) | |
Returns | |
[*, Q, C_q] attention update | |
""" | |
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): | |
raise ValueError( | |
"If use_lma is specified, lma_q_chunk_size and " | |
"lma_kv_chunk_size must be provided" | |
) | |
attn_options = [use_memory_efficient_kernel, use_lma] | |
if sum(attn_options) > 1: | |
raise ValueError( | |
"Choose at most one alternative attention algorithm" | |
) | |
if biases is None: | |
biases = [] | |
q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=True) | |
if is_fp16_enabled(): | |
use_memory_efficient_kernel = False | |
if use_memory_efficient_kernel: | |
if len(biases) > 2: | |
raise ValueError( | |
"If use_memory_efficient_kernel is True, you may only " | |
"provide up to two bias terms" | |
) | |
o = attention_core(q, k, v, *((biases + [None] * 2)[:2])) | |
o = o.transpose(-2, -3) | |
elif use_lma: | |
biases = [ | |
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) | |
for b in biases | |
] | |
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size) | |
o = o.transpose(-2, -3) | |
else: | |
o = _attention(q, k, v, biases) | |
o = o.transpose(-2, -3) | |
o = self._wrap_up(o, q_x) | |
return o | |
class GlobalAttention(nn.Module): | |
def __init__(self, c_in, c_hidden, no_heads, inf, eps): | |
super(GlobalAttention, self).__init__() | |
self.c_in = c_in | |
self.c_hidden = c_hidden | |
self.no_heads = no_heads | |
self.inf = inf | |
self.eps = eps | |
self.linear_q = Linear( | |
c_in, c_hidden * no_heads, bias=False, init="glorot" | |
) | |
self.linear_k = Linear( | |
c_in, c_hidden, bias=False, init="glorot", | |
) | |
self.linear_v = Linear( | |
c_in, c_hidden, bias=False, init="glorot", | |
) | |
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") | |
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, | |
m: torch.Tensor, | |
mask: torch.Tensor, | |
use_lma: bool = False, | |
) -> torch.Tensor: | |
# [*, N_res, C_in] | |
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / ( | |
torch.sum(mask, dim=-1)[..., None] + self.eps | |
) | |
# [*, N_res, H * C_hidden] | |
q = self.linear_q(q) | |
q *= (self.c_hidden ** (-0.5)) | |
# [*, N_res, H, C_hidden] | |
q = q.view(q.shape[:-1] + (self.no_heads, -1)) | |
# [*, N_res, C_hidden] | |
k = self.linear_k(m) | |
v = self.linear_v(m) | |
bias = (self.inf * (mask - 1))[..., :, None, :] | |
if not use_lma: | |
# [*, N_res, H, N_seq] | |
a = torch.matmul( | |
q, | |
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] | |
) | |
a += bias | |
a = softmax_no_cast(a) | |
# [*, N_res, H, C_hidden] | |
o = torch.matmul( | |
a, | |
v, | |
) | |
else: | |
o = _lma( | |
q, | |
k, | |
v, | |
[bias], | |
DEFAULT_LMA_Q_CHUNK_SIZE, | |
DEFAULT_LMA_KV_CHUNK_SIZE | |
) | |
# [*, N_res, C_hidden] | |
g = self.sigmoid(self.linear_g(m)) | |
# [*, N_res, H, C_hidden] | |
g = g.view(g.shape[:-1] + (self.no_heads, -1)) | |
# [*, N_res, H, C_hidden] | |
o = o.unsqueeze(-3) * g | |
# [*, N_res, H * C_hidden] | |
o = o.reshape(o.shape[:-2] + (-1,)) | |
# [*, N_res, C_in] | |
m = self.linear_o(o) | |
return m | |
def _lma( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
biases: List[torch.Tensor], | |
q_chunk_size: int, | |
kv_chunk_size: int, | |
): | |
no_q, no_kv = q.shape[-2], k.shape[-2] | |
# [*, H, Q, C_hidden] | |
o = q.new_zeros(q.shape) | |
for q_s in range(0, no_q, q_chunk_size): | |
q_chunk = q[..., q_s: q_s + q_chunk_size, :] | |
large_bias_chunks = [ | |
b[..., q_s: q_s + q_chunk_size, :] for b in biases | |
] | |
maxes = [] | |
weights = [] | |
values = [] | |
for kv_s in range(0, no_kv, kv_chunk_size): | |
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :] | |
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :] | |
small_bias_chunks = [ | |
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks | |
] | |
a = torch.einsum( | |
"...hqd,...hkd->...hqk", q_chunk, k_chunk, | |
) | |
for b in small_bias_chunks: | |
a += b | |
max_a = torch.max(a, dim=-1, keepdim=True)[0] | |
exp_a = torch.exp(a - max_a) | |
exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a) | |
maxes.append(max_a.detach().squeeze(-1)) | |
weights.append(torch.sum(exp_a, dim=-1)) | |
values.append(exp_v) | |
chunk_max = torch.stack(maxes, dim=-3) | |
chunk_weights = torch.stack(weights, dim=-3) | |
chunk_values = torch.stack(values, dim=-4) | |
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] | |
max_diffs = torch.exp(chunk_max - global_max) | |
chunk_values = chunk_values * max_diffs.unsqueeze(-1) | |
chunk_weights = chunk_weights * max_diffs | |
all_values = torch.sum(chunk_values, dim=-4) | |
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) | |
q_chunk_out = all_values / all_weights | |
o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out | |
return o | |