FAPM / esm /modules.py
wenkai's picture
Upload 31 files
7409b0a verified
raw
history blame
13.9 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .multihead_attention import MultiheadAttention # noqa
from .axial_attention import ColumnSelfAttention, RowSelfAttention
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different
(and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def symmetrize(x):
"Make layer symmetric in final two dimensions, used for contact prediction."
return x + x.transpose(-1, -2)
def apc(x):
"Perform average product correct, used for contact prediction."
a1 = x.sum(-1, keepdims=True)
a2 = x.sum(-2, keepdims=True)
a12 = x.sum((-1, -2), keepdims=True)
avg = a1 * a2
avg.div_(a12) # in-place to reduce memory
normalized = x - avg
return normalized
class ESM1LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12, affine=True):
"""Construct a layernorm layer in the TF style (eps inside the sqrt)."""
super().__init__()
self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
self.eps = eps
self.affine = bool(affine)
if self.affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
else:
self.weight, self.bias = None, None
def forward(self, x):
dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
means = x.mean(dims, keepdim=True)
x_zeromean = x - means
variances = x_zeromean.pow(2).mean(dims, keepdim=True)
x = x_zeromean / torch.sqrt(variances + self.eps)
if self.affine:
x = (self.weight * x) + self.bias
return x
try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm
class ESM1bLayerNorm(_FusedLayerNorm):
@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)
except ImportError:
from torch.nn import LayerNorm as ESM1bLayerNorm
class TransformerLayer(nn.Module):
"""Transformer layer block."""
def __init__(
self,
embed_dim,
ffn_embed_dim,
attention_heads,
add_bias_kv=True,
use_esm1b_layer_norm=False,
use_rotary_embeddings: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self.use_rotary_embeddings = use_rotary_embeddings
self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm
self.self_attn = MultiheadAttention(
self.embed_dim,
self.attention_heads,
add_bias_kv=add_bias_kv,
add_zero_attn=False,
use_rotary_embeddings=self.use_rotary_embeddings,
)
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
self.final_layer_norm = BertLayerNorm(self.embed_dim)
def forward(
self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
):
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=True,
need_head_weights=need_head_weights,
attn_mask=self_attn_mask,
)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = gelu(self.fc1(x))
x = self.fc2(x)
x = residual + x
return x, attn
class AxialTransformerLayer(nn.Module):
"""Implements an Axial MSA Transformer block."""
def __init__(
self,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
max_tokens_per_msa: int = 2**14,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout_prob = dropout
row_self_attention = RowSelfAttention(
embedding_dim,
num_attention_heads,
dropout=dropout,
max_tokens_per_msa=max_tokens_per_msa,
)
column_self_attention = ColumnSelfAttention(
embedding_dim,
num_attention_heads,
dropout=dropout,
max_tokens_per_msa=max_tokens_per_msa,
)
feed_forward_layer = FeedForwardNetwork(
embedding_dim,
ffn_embedding_dim,
activation_dropout=activation_dropout,
max_tokens_per_msa=max_tokens_per_msa,
)
self.row_self_attention = self.build_residual(row_self_attention)
self.column_self_attention = self.build_residual(column_self_attention)
self.feed_forward_layer = self.build_residual(feed_forward_layer)
def build_residual(self, layer: nn.Module):
return NormalizedResidualBlock(
layer,
self.embedding_dim,
self.dropout_prob,
)
def forward(
self,
x: torch.Tensor,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_head_weights: bool = False,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer implementation.
"""
x, row_attn = self.row_self_attention(
x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
)
x, column_attn = self.column_self_attention(
x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
)
x = self.feed_forward_layer(x)
if need_head_weights:
return x, column_attn, row_attn
else:
return x
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
if padding_idx is not None:
num_embeddings_ = num_embeddings + padding_idx + 1
else:
num_embeddings_ = num_embeddings
super().__init__(num_embeddings_, embedding_dim, padding_idx)
self.max_positions = num_embeddings
def forward(self, input: torch.Tensor):
"""Input is expected to be of size [bsz x seqlen]."""
if input.size(1) > self.max_positions:
raise ValueError(
f"Sequence length {input.size(1)} above maximum "
f" sequence length of {self.max_positions}"
)
mask = input.ne(self.padding_idx).int()
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
class SinusoidalPositionalEmbedding(nn.Module):
def __init__(self, embed_dim, padding_idx, learned=False):
super().__init__()
self.embed_dim = embed_dim
self.padding_idx = padding_idx
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.weights = None
def forward(self, x):
bsz, seq_len = x.shape
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
self.weights = self.get_embedding(max_pos)
self.weights = self.weights.type_as(self._float_tensor)
positions = self.make_positions(x)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def make_positions(self, x):
mask = x.ne(self.padding_idx)
range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
positions = range_buf.expand_as(x)
return positions * mask.long() + self.padding_idx * (1 - mask.long())
def get_embedding(self, num_embeddings):
half_dim = self.embed_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if self.embed_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if self.padding_idx is not None:
emb[self.padding_idx, :] = 0
return emb
class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, weight):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.layer_norm = ESM1bLayerNorm(embed_dim)
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
class ContactPredictionHead(nn.Module):
"""Performs symmetrization, apc, and computes a logistic regression on the output features"""
def __init__(
self,
in_features: int,
prepend_bos: bool,
append_eos: bool,
bias=True,
eos_idx: Optional[int] = None,
):
super().__init__()
self.in_features = in_features
self.prepend_bos = prepend_bos
self.append_eos = append_eos
if append_eos and eos_idx is None:
raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
self.eos_idx = eos_idx
self.regression = nn.Linear(in_features, 1, bias)
self.activation = nn.Sigmoid()
def forward(self, tokens, attentions):
# remove eos token attentions
if self.append_eos:
eos_mask = tokens.ne(self.eos_idx).to(attentions)
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
# remove cls token attentions
if self.prepend_bos:
attentions = attentions[..., 1:, 1:]
batch_size, layers, heads, seqlen, _ = attentions.size()
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
# features: B x C x T x T
attentions = attentions.to(
self.regression.weight.device
) # attentions always float32, may need to convert to float16
attentions = apc(symmetrize(attentions))
attentions = attentions.permute(0, 2, 3, 1)
return self.activation(self.regression(attentions).squeeze(3))
class NormalizedResidualBlock(nn.Module):
def __init__(
self,
layer: nn.Module,
embedding_dim: int,
dropout: float = 0.1,
):
super().__init__()
self.embedding_dim = embedding_dim
self.layer = layer
self.dropout_module = nn.Dropout(
dropout,
)
self.layer_norm = ESM1bLayerNorm(self.embedding_dim)
def forward(self, x, *args, **kwargs):
residual = x
x = self.layer_norm(x)
outputs = self.layer(x, *args, **kwargs)
if isinstance(outputs, tuple):
x, *out = outputs
else:
x = outputs
out = None
x = self.dropout_module(x)
x = residual + x
if out is not None:
return (x,) + tuple(out)
else:
return x
class FeedForwardNetwork(nn.Module):
def __init__(
self,
embedding_dim: int,
ffn_embedding_dim: int,
activation_dropout: float = 0.1,
max_tokens_per_msa: int = 2**14,
):
super().__init__()
self.embedding_dim = embedding_dim
self.ffn_embedding_dim = ffn_embedding_dim
self.max_tokens_per_msa = max_tokens_per_msa
self.activation_fn = nn.GELU()
self.activation_dropout_module = nn.Dropout(
activation_dropout,
)
self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
def forward(self, x):
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
return x