Spaces:
Running
Running
# Copyright 2021 AlQuraishi Laboratory | |
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
from torch.nn import Parameter | |
from dockformer.model.primitives import Linear, LayerNorm | |
from dockformer.utils.loss import ( | |
compute_plddt, | |
compute_tm, | |
compute_predicted_aligned_error, | |
) | |
from dockformer.utils.precision_utils import is_fp16_enabled | |
class AuxiliaryHeads(nn.Module): | |
def __init__(self, config): | |
super(AuxiliaryHeads, self).__init__() | |
self.plddt = PerResidueLDDTCaPredictor( | |
**config["lddt"], | |
) | |
self.distogram = DistogramHead( | |
**config["distogram"], | |
) | |
self.affinity_2d = Affinity2DPredictor( | |
**config["affinity_2d"], | |
) | |
self.affinity_1d = Affinity1DPredictor( | |
**config["affinity_1d"], | |
) | |
self.affinity_cls = AffinityClsTokenPredictor( | |
**config["affinity_cls"], | |
) | |
self.binding_site = BindingSitePredictor( | |
**config["binding_site"], | |
) | |
self.inter_contact = InterContactHead( | |
**config["inter_contact"], | |
) | |
self.config = config | |
def forward(self, outputs, inter_mask, affinity_mask): | |
aux_out = {} | |
lddt_logits = self.plddt(outputs["sm"]["single"]) | |
aux_out["lddt_logits"] = lddt_logits | |
# Required for relaxation later on | |
aux_out["plddt"] = compute_plddt(lddt_logits) | |
distogram_logits = self.distogram(outputs["pair"]) | |
aux_out["distogram_logits"] = distogram_logits | |
aux_out["inter_contact_logits"] = self.inter_contact(outputs["single"], outputs["pair"]) | |
aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask) | |
aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"]) | |
aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask) | |
aux_out["binding_site_logits"] = self.binding_site(outputs["single"]) | |
return aux_out | |
class Affinity2DPredictor(nn.Module): | |
def __init__(self, c_z, num_bins): | |
super(Affinity2DPredictor, self).__init__() | |
self.c_z = c_z | |
self.weight_linear = Linear(self.c_z + 1, 1) | |
self.embed_linear = Linear(self.c_z, self.c_z) | |
self.bins_linear = Linear(self.c_z, num_bins) | |
def forward(self, z, inter_contacts_logits, inter_pair_mask): | |
z_with_inter_contacts = torch.cat((z, inter_contacts_logits), dim=-1) # [*, N, N, c_z + 1] | |
weights = self.weight_linear(z_with_inter_contacts) # [*, N, N, 1] | |
x = self.embed_linear(z) # [*, N, N, c_z] | |
batch_size, N, M, _ = x.shape | |
flat_weights = weights.reshape(batch_size, N*M, -1) # [*, N*M, 1] | |
flat_x = x.reshape(batch_size, N*M, -1) # [*, N*M, c_z] | |
flat_inter_pair_mask = inter_pair_mask.reshape(batch_size, N*M, 1) | |
flat_weights = flat_weights.masked_fill(~(flat_inter_pair_mask.bool()), float('-inf')) # [*, N*N, 1] | |
flat_weights = torch.nn.functional.softmax(flat_weights, dim=1) # [*, N*N, 1] | |
flat_weights = torch.nan_to_num(flat_weights, nan=0.0) # [*, N*N, 1] | |
weighted_sum = torch.sum((flat_weights * flat_x).reshape(batch_size, N*M, -1), dim=1) # [*, c_z] | |
return self.bins_linear(weighted_sum) | |
class Affinity1DPredictor(nn.Module): | |
def __init__(self, c_s, num_bins, **kwargs): | |
super(Affinity1DPredictor, self).__init__() | |
self.c_s = c_s | |
self.linear1 = Linear(self.c_s, self.c_s, init="final") | |
self.linear2 = Linear(self.c_s, num_bins, init="final") | |
def forward(self, s): | |
# [*, N, C_out] | |
s = self.linear1(s) | |
# get an average over the sequence | |
s = torch.mean(s, dim=1) | |
logits = self.linear2(s) | |
return logits | |
class AffinityClsTokenPredictor(nn.Module): | |
def __init__(self, c_s, num_bins, **kwargs): | |
super(AffinityClsTokenPredictor, self).__init__() | |
self.c_s = c_s | |
self.linear = Linear(self.c_s, num_bins, init="final") | |
def forward(self, s, affinity_mask): | |
affinity_tokens = (s * affinity_mask.unsqueeze(-1)).sum(dim=1) | |
return self.linear(affinity_tokens) | |
class BindingSitePredictor(nn.Module): | |
def __init__(self, c_s, c_out, **kwargs): | |
super(BindingSitePredictor, self).__init__() | |
self.c_s = c_s | |
self.c_out = c_out | |
self.linear = Linear(self.c_s, self.c_out, init="final") | |
def forward(self, s): | |
# [*, N, C_out] | |
return self.linear(s) | |
class InterContactHead(nn.Module): | |
def __init__(self, c_s, c_z, c_out, **kwargs): | |
""" | |
Args: | |
c_z: | |
Input channel dimension | |
c_out: | |
Number of bins, but since boolean should be 1 | |
""" | |
super(InterContactHead, self).__init__() | |
self.c_s = c_s | |
self.c_z = c_z | |
self.c_out = c_out | |
self.linear = Linear(2 * self.c_s + self.c_z, self.c_out, init="final") | |
def forward(self, s, z): # [*, N, N, C_z] | |
# [*, N, N, no_bins] | |
batch_size, n, s_dim = s.shape | |
s_i = s.unsqueeze(2).expand(batch_size, n, n, s_dim) | |
s_j = s.unsqueeze(1).expand(batch_size, n, n, s_dim) | |
joined = torch.cat((s_i, s_j, z), dim=-1) | |
logits = self.linear(joined) | |
return logits | |
class PerResidueLDDTCaPredictor(nn.Module): | |
def __init__(self, no_bins, c_in, c_hidden): | |
super(PerResidueLDDTCaPredictor, self).__init__() | |
self.no_bins = no_bins | |
self.c_in = c_in | |
self.c_hidden = c_hidden | |
self.layer_norm = LayerNorm(self.c_in) | |
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu") | |
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu") | |
self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final") | |
self.relu = nn.ReLU() | |
def forward(self, s): | |
s = self.layer_norm(s) | |
s = self.linear_1(s) | |
s = self.relu(s) | |
s = self.linear_2(s) | |
s = self.relu(s) | |
s = self.linear_3(s) | |
return s | |
class DistogramHead(nn.Module): | |
""" | |
Computes a distogram probability distribution. | |
For use in computation of distogram loss, subsection 1.9.8 | |
""" | |
def __init__(self, c_z, no_bins, **kwargs): | |
""" | |
Args: | |
c_z: | |
Input channel dimension | |
no_bins: | |
Number of distogram bins | |
""" | |
super(DistogramHead, self).__init__() | |
self.c_z = c_z | |
self.no_bins = no_bins | |
self.linear = Linear(self.c_z, self.no_bins, init="final") | |
def _forward(self, z): # [*, N, N, C_z] | |
""" | |
Args: | |
z: | |
[*, N_res, N_res, C_z] pair embedding | |
Returns: | |
[*, N, N, no_bins] distogram probability distribution | |
""" | |
# [*, N, N, no_bins] | |
logits = self.linear(z) | |
logits = logits + logits.transpose(-2, -3) | |
return logits | |
def forward(self, z): | |
if(is_fp16_enabled()): | |
with torch.cuda.amp.autocast(enabled=False): | |
return self._forward(z.float()) | |
else: | |
return self._forward(z) | |