Spaces:
Running
Running
# Copyright 2021 AlQuraishi Laboratory | |
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from dockformer.model.primitives import Linear, LayerNorm | |
class PairTransition(nn.Module): | |
""" | |
Implements Algorithm 15. | |
""" | |
def __init__(self, c_z, n): | |
""" | |
Args: | |
c_z: | |
Pair transition channel dimension | |
n: | |
Factor by which c_z is multiplied to obtain hidden channel | |
dimension | |
""" | |
super(PairTransition, self).__init__() | |
self.c_z = c_z | |
self.n = n | |
self.layer_norm = LayerNorm(self.c_z) | |
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") | |
self.relu = nn.ReLU() | |
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") | |
def _transition(self, z, mask): | |
# [*, N_res, N_res, C_z] | |
z = self.layer_norm(z) | |
# [*, N_res, N_res, C_hidden] | |
z = self.linear_1(z) | |
z = self.relu(z) | |
# [*, N_res, N_res, C_z] | |
z = self.linear_2(z) | |
z = z * mask | |
return z | |
def forward(self, | |
z: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
""" | |
Args: | |
z: | |
[*, N_res, N_res, C_z] pair embedding | |
Returns: | |
[*, N_res, N_res, C_z] pair embedding update | |
""" | |
# DISCREPANCY: DeepMind forgets to apply the mask in this module. | |
if mask is None: | |
mask = z.new_ones(z.shape[:-1]) | |
# [*, N_res, N_res, 1] | |
mask = mask.unsqueeze(-1) | |
z = self._transition(z=z, mask=mask) | |
return z | |