Spaces:
Running
Running
add code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +7 -5
- dockformer/__init__.py +6 -0
- dockformer/config.py +358 -0
- dockformer/data/data_modules.py +643 -0
- dockformer/data/data_pipeline.py +503 -0
- dockformer/data/data_transforms.py +731 -0
- dockformer/data/errors.py +22 -0
- dockformer/data/ligand_features.py +66 -0
- dockformer/data/parsers.py +53 -0
- dockformer/data/protein_features.py +71 -0
- dockformer/data/utils.py +54 -0
- dockformer/model/__init__.py +0 -0
- dockformer/model/dropout.py +69 -0
- dockformer/model/embedders.py +346 -0
- dockformer/model/evoformer.py +468 -0
- dockformer/model/heads.py +260 -0
- dockformer/model/model.py +318 -0
- dockformer/model/pair_transition.py +81 -0
- dockformer/model/primitives.py +598 -0
- dockformer/model/single_attention.py +184 -0
- dockformer/model/structure_module.py +837 -0
- dockformer/model/torchscript.py +171 -0
- dockformer/model/triangular_attention.py +104 -0
- dockformer/model/triangular_multiplicative_update.py +173 -0
- dockformer/resources/__init__.py +0 -0
- dockformer/resources/stereo_chemical_props.txt +345 -0
- dockformer/utils/__init__.py +0 -0
- dockformer/utils/callbacks.py +15 -0
- dockformer/utils/checkpointing.py +78 -0
- dockformer/utils/config_tools.py +32 -0
- dockformer/utils/consts.py +25 -0
- dockformer/utils/exponential_moving_average.py +71 -0
- dockformer/utils/feats.py +174 -0
- dockformer/utils/geometry/__init__.py +28 -0
- dockformer/utils/geometry/quat_rigid.py +38 -0
- dockformer/utils/geometry/rigid_matrix_vector.py +181 -0
- dockformer/utils/geometry/rotation_matrix.py +208 -0
- dockformer/utils/geometry/test_utils.py +97 -0
- dockformer/utils/geometry/utils.py +22 -0
- dockformer/utils/geometry/vector.py +261 -0
- dockformer/utils/kernel/__init__.py +0 -0
- dockformer/utils/kernel/attention_core.py +107 -0
- dockformer/utils/kernel/csrc/compat.h +11 -0
- dockformer/utils/kernel/csrc/softmax_cuda.cpp +44 -0
- dockformer/utils/kernel/csrc/softmax_cuda_kernel.cu +241 -0
- dockformer/utils/kernel/csrc/softmax_cuda_stub.cpp +36 -0
- dockformer/utils/logger.py +80 -0
- dockformer/utils/loss.py +1171 -0
- dockformer/utils/lr_schedulers.py +82 -0
- dockformer/utils/precision_utils.py +23 -0
Dockerfile
CHANGED
@@ -6,17 +6,19 @@ WORKDIR /usr/src/app
|
|
6 |
COPY --link --chown=1000 ./ /usr/src/app
|
7 |
COPY . .
|
8 |
|
9 |
-
|
10 |
# install dependcies
|
11 |
-
RUN conda install -y pandas numpy scikit-learn
|
12 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
13 |
|
14 |
-
#
|
|
|
15 |
|
|
|
16 |
USER user
|
17 |
|
18 |
#do not modify below
|
19 |
EXPOSE 7860
|
20 |
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
21 |
|
22 |
-
CMD ["python", "inference_app.py"]
|
|
|
|
6 |
COPY --link --chown=1000 ./ /usr/src/app
|
7 |
COPY . .
|
8 |
|
|
|
9 |
# install dependcies
|
10 |
+
# RUN conda install -y pandas numpy scikit-learn
|
11 |
+
# RUN pip install --no-cache-dir -r requirements.txt
|
12 |
|
13 |
+
# Create Conda environment from env.yaml
|
14 |
+
RUN conda env create -f env.yml
|
15 |
|
16 |
+
#if you need to download executable and run them switch to the default non-root user
|
17 |
USER user
|
18 |
|
19 |
#do not modify below
|
20 |
EXPOSE 7860
|
21 |
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
22 |
|
23 |
+
# CMD ["python", "inference_app.py"]
|
24 |
+
CMD ["conda", "run", "-n", "dockformer-venv", "python", "inference_app.py"]
|
dockformer/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import model
|
2 |
+
from . import utils
|
3 |
+
from . import data
|
4 |
+
from . import resources
|
5 |
+
|
6 |
+
__all__ = ["model", "utils", "data", "resources"]
|
dockformer/config.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import ml_collections as mlc
|
3 |
+
|
4 |
+
from dockformer.utils.config_tools import set_inf, enforce_config_constraints
|
5 |
+
|
6 |
+
|
7 |
+
def model_config(
|
8 |
+
name,
|
9 |
+
train=False,
|
10 |
+
low_prec=False,
|
11 |
+
long_sequence_inference=False
|
12 |
+
):
|
13 |
+
c = copy.deepcopy(config)
|
14 |
+
# TRAINING PRESETS
|
15 |
+
if name == "initial_training":
|
16 |
+
# AF2 Suppl. Table 4, "initial training" setting
|
17 |
+
|
18 |
+
pass
|
19 |
+
elif name == "finetune_affinity":
|
20 |
+
c.loss.affinity2d.weight = 0.5
|
21 |
+
c.loss.affinity1d.weight = 0.5
|
22 |
+
c.loss.binding_site.weight = 0.5
|
23 |
+
c.loss.positions_inter_distogram.weight = 0.5 # this is not essential given fape?
|
24 |
+
else:
|
25 |
+
raise ValueError("Invalid model name")
|
26 |
+
|
27 |
+
c.globals.use_lma = False
|
28 |
+
|
29 |
+
if long_sequence_inference:
|
30 |
+
assert(not train)
|
31 |
+
c.globals.use_lma = True
|
32 |
+
|
33 |
+
if train:
|
34 |
+
c.globals.blocks_per_ckpt = 1
|
35 |
+
c.globals.use_lma = False
|
36 |
+
|
37 |
+
if low_prec:
|
38 |
+
c.globals.eps = 1e-4
|
39 |
+
# If we want exact numerical parity with the original, inf can't be
|
40 |
+
# a global constant
|
41 |
+
set_inf(c, 1e4)
|
42 |
+
|
43 |
+
enforce_config_constraints(c)
|
44 |
+
|
45 |
+
return c
|
46 |
+
|
47 |
+
|
48 |
+
c_z = mlc.FieldReference(128, field_type=int)
|
49 |
+
c_m = mlc.FieldReference(256, field_type=int)
|
50 |
+
c_t = mlc.FieldReference(64, field_type=int)
|
51 |
+
c_e = mlc.FieldReference(64, field_type=int)
|
52 |
+
c_s = mlc.FieldReference(384, field_type=int)
|
53 |
+
|
54 |
+
|
55 |
+
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
|
56 |
+
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
|
57 |
+
aux_affinity_bins = mlc.FieldReference(32, field_type=int)
|
58 |
+
eps = mlc.FieldReference(1e-8, field_type=float)
|
59 |
+
|
60 |
+
NUM_RES = "num residues placeholder"
|
61 |
+
NUM_LIG_ATOMS = "num ligand atoms placeholder"
|
62 |
+
NUM_TOKEN = "num tokens placeholder"
|
63 |
+
|
64 |
+
|
65 |
+
config = mlc.ConfigDict(
|
66 |
+
{
|
67 |
+
"data": {
|
68 |
+
"common": {
|
69 |
+
"feat": {
|
70 |
+
"aatype": [NUM_TOKEN],
|
71 |
+
"all_atom_mask": [NUM_TOKEN, None],
|
72 |
+
"all_atom_positions": [NUM_TOKEN, None, None],
|
73 |
+
"atom14_alt_gt_exists": [NUM_TOKEN, None],
|
74 |
+
"atom14_alt_gt_positions": [NUM_TOKEN, None, None],
|
75 |
+
"atom14_atom_exists": [NUM_TOKEN, None],
|
76 |
+
"atom14_atom_is_ambiguous": [NUM_TOKEN, None],
|
77 |
+
"atom14_gt_exists": [NUM_TOKEN, None],
|
78 |
+
"atom14_gt_positions": [NUM_TOKEN, None, None],
|
79 |
+
"atom37_atom_exists": [NUM_TOKEN, None],
|
80 |
+
"backbone_rigid_mask": [NUM_TOKEN],
|
81 |
+
"backbone_rigid_tensor": [NUM_TOKEN, None, None],
|
82 |
+
"chi_angles_sin_cos": [NUM_TOKEN, None, None],
|
83 |
+
"chi_mask": [NUM_TOKEN, None],
|
84 |
+
"no_recycling_iters": [],
|
85 |
+
"pseudo_beta": [NUM_TOKEN, None],
|
86 |
+
"pseudo_beta_mask": [NUM_TOKEN],
|
87 |
+
"residue_index": [NUM_TOKEN],
|
88 |
+
"in_chain_residue_index": [NUM_TOKEN],
|
89 |
+
"chain_index": [NUM_TOKEN],
|
90 |
+
"residx_atom14_to_atom37": [NUM_TOKEN, None],
|
91 |
+
"residx_atom37_to_atom14": [NUM_TOKEN, None],
|
92 |
+
"resolution": [],
|
93 |
+
"rigidgroups_alt_gt_frames": [NUM_TOKEN, None, None, None],
|
94 |
+
"rigidgroups_group_exists": [NUM_TOKEN, None],
|
95 |
+
"rigidgroups_group_is_ambiguous": [NUM_TOKEN, None],
|
96 |
+
"rigidgroups_gt_exists": [NUM_TOKEN, None],
|
97 |
+
"rigidgroups_gt_frames": [NUM_TOKEN, None, None, None],
|
98 |
+
"seq_length": [],
|
99 |
+
"token_mask": [NUM_TOKEN],
|
100 |
+
"target_feat": [NUM_TOKEN, None],
|
101 |
+
"use_clamped_fape": [],
|
102 |
+
},
|
103 |
+
"max_recycling_iters": 1,
|
104 |
+
"unsupervised_features": [
|
105 |
+
"aatype",
|
106 |
+
"residue_index",
|
107 |
+
"in_chain_residue_index",
|
108 |
+
"chain_index",
|
109 |
+
"seq_length",
|
110 |
+
"no_recycling_iters",
|
111 |
+
"all_atom_mask",
|
112 |
+
"all_atom_positions",
|
113 |
+
],
|
114 |
+
},
|
115 |
+
"supervised": {
|
116 |
+
"clamp_prob": 0.9,
|
117 |
+
"supervised_features": [
|
118 |
+
"resolution",
|
119 |
+
"use_clamped_fape",
|
120 |
+
],
|
121 |
+
},
|
122 |
+
"predict": {
|
123 |
+
"fixed_size": True,
|
124 |
+
"crop": False,
|
125 |
+
"crop_size": None,
|
126 |
+
"supervised": False,
|
127 |
+
"uniform_recycling": False,
|
128 |
+
},
|
129 |
+
"eval": {
|
130 |
+
"fixed_size": True,
|
131 |
+
"crop": False,
|
132 |
+
"crop_size": None,
|
133 |
+
"supervised": True,
|
134 |
+
"uniform_recycling": False,
|
135 |
+
},
|
136 |
+
"train": {
|
137 |
+
"fixed_size": True,
|
138 |
+
"crop": True,
|
139 |
+
"crop_size": 355,
|
140 |
+
"supervised": True,
|
141 |
+
"clamp_prob": 0.9,
|
142 |
+
"uniform_recycling": True,
|
143 |
+
"protein_distogram_mask_prob": 0.1,
|
144 |
+
},
|
145 |
+
"data_module": {
|
146 |
+
"data_loaders": {
|
147 |
+
"batch_size": 1,
|
148 |
+
# "batch_size": 2,
|
149 |
+
"num_workers": 16,
|
150 |
+
"pin_memory": True,
|
151 |
+
"should_verify": False,
|
152 |
+
},
|
153 |
+
},
|
154 |
+
},
|
155 |
+
# Recurring FieldReferences that can be changed globally here
|
156 |
+
"globals": {
|
157 |
+
"blocks_per_ckpt": blocks_per_ckpt,
|
158 |
+
# Use Staats & Rabe's low-memory attention algorithm.
|
159 |
+
"use_lma": False,
|
160 |
+
"max_lr": 1e-3,
|
161 |
+
"c_z": c_z,
|
162 |
+
"c_m": c_m,
|
163 |
+
"c_t": c_t,
|
164 |
+
"c_e": c_e,
|
165 |
+
"c_s": c_s,
|
166 |
+
"eps": eps,
|
167 |
+
},
|
168 |
+
"model": {
|
169 |
+
"_mask_trans": False,
|
170 |
+
"structure_input_embedder": {
|
171 |
+
"protein_tf_dim": 20,
|
172 |
+
# len(POSSIBLE_ATOM_TYPES) + len(POSSIBLE_CHARGES) + len(POSSIBLE_CHIRALITIES)
|
173 |
+
"ligand_tf_dim": 34,
|
174 |
+
"additional_tf_dim": 3, # number of classes (prot, lig, aff)
|
175 |
+
"ligand_bond_dim": 6,
|
176 |
+
"c_z": c_z,
|
177 |
+
"c_m": c_m,
|
178 |
+
"relpos_k": 32,
|
179 |
+
"prot_min_bin": 3.25,
|
180 |
+
"prot_max_bin": 20.75,
|
181 |
+
"prot_no_bins": 15,
|
182 |
+
"lig_min_bin": 0.75,
|
183 |
+
"lig_max_bin": 9.75,
|
184 |
+
"lig_no_bins": 10,
|
185 |
+
"inf": 1e8,
|
186 |
+
},
|
187 |
+
"recycling_embedder": {
|
188 |
+
"c_z": c_z,
|
189 |
+
"c_m": c_m,
|
190 |
+
"min_bin": 3.25,
|
191 |
+
"max_bin": 20.75,
|
192 |
+
"no_bins": 15,
|
193 |
+
"inf": 1e8,
|
194 |
+
},
|
195 |
+
"evoformer_stack": {
|
196 |
+
"c_m": c_m,
|
197 |
+
"c_z": c_z,
|
198 |
+
"c_hidden_single_att": 32,
|
199 |
+
"c_hidden_mul": 128,
|
200 |
+
"c_hidden_pair_att": 32,
|
201 |
+
"c_s": c_s,
|
202 |
+
"no_heads_single": 8,
|
203 |
+
"no_heads_pair": 4,
|
204 |
+
# "no_blocks": 48,
|
205 |
+
"no_blocks": 2,
|
206 |
+
"transition_n": 4,
|
207 |
+
"single_dropout": 0.15,
|
208 |
+
"pair_dropout": 0.25,
|
209 |
+
"blocks_per_ckpt": blocks_per_ckpt,
|
210 |
+
"clear_cache_between_blocks": False,
|
211 |
+
"inf": 1e9,
|
212 |
+
"eps": eps, # 1e-10,
|
213 |
+
},
|
214 |
+
"structure_module": {
|
215 |
+
"c_s": c_s,
|
216 |
+
"c_z": c_z,
|
217 |
+
"c_ipa": 16,
|
218 |
+
"c_resnet": 128,
|
219 |
+
"no_heads_ipa": 12,
|
220 |
+
"no_qk_points": 4,
|
221 |
+
"no_v_points": 8,
|
222 |
+
"dropout_rate": 0.1,
|
223 |
+
"no_blocks": 8,
|
224 |
+
"no_transition_layers": 1,
|
225 |
+
"no_resnet_blocks": 2,
|
226 |
+
"no_angles": 7,
|
227 |
+
"trans_scale_factor": 10,
|
228 |
+
"epsilon": eps, # 1e-12,
|
229 |
+
"inf": 1e5,
|
230 |
+
},
|
231 |
+
"heads": {
|
232 |
+
"lddt": {
|
233 |
+
"no_bins": 50,
|
234 |
+
"c_in": c_s,
|
235 |
+
"c_hidden": 128,
|
236 |
+
},
|
237 |
+
"distogram": {
|
238 |
+
"c_z": c_z,
|
239 |
+
"no_bins": aux_distogram_bins,
|
240 |
+
},
|
241 |
+
"affinity_2d": {
|
242 |
+
"c_z": c_z,
|
243 |
+
"num_bins": aux_affinity_bins,
|
244 |
+
},
|
245 |
+
"affinity_1d": {
|
246 |
+
"c_s": c_s,
|
247 |
+
"num_bins": aux_affinity_bins,
|
248 |
+
},
|
249 |
+
"affinity_cls": {
|
250 |
+
"c_s": c_s,
|
251 |
+
"num_bins": aux_affinity_bins,
|
252 |
+
},
|
253 |
+
"binding_site": {
|
254 |
+
"c_s": c_s,
|
255 |
+
"c_out": 1,
|
256 |
+
},
|
257 |
+
"inter_contact": {
|
258 |
+
"c_s": c_s,
|
259 |
+
"c_z": c_z,
|
260 |
+
"c_out": 1,
|
261 |
+
},
|
262 |
+
},
|
263 |
+
# A negative value indicates that no early stopping will occur, i.e.
|
264 |
+
# the model will always run `max_recycling_iters` number of recycling
|
265 |
+
# iterations. A positive value will enable early stopping if the
|
266 |
+
# difference in pairwise distances is less than the tolerance between
|
267 |
+
# recycling steps.
|
268 |
+
"recycle_early_stop_tolerance": -1.
|
269 |
+
},
|
270 |
+
"relax": {
|
271 |
+
"max_iterations": 0, # no max
|
272 |
+
"tolerance": 2.39,
|
273 |
+
"stiffness": 10.0,
|
274 |
+
"max_outer_iterations": 20,
|
275 |
+
"exclude_residues": [],
|
276 |
+
},
|
277 |
+
"loss": {
|
278 |
+
"distogram": {
|
279 |
+
"min_bin": 2.3125,
|
280 |
+
"max_bin": 21.6875,
|
281 |
+
"no_bins": 64,
|
282 |
+
"eps": eps, # 1e-6,
|
283 |
+
"weight": 0.3,
|
284 |
+
},
|
285 |
+
"positions_inter_distogram": {
|
286 |
+
"max_dist": 20.0,
|
287 |
+
"weight": 0.0,
|
288 |
+
},
|
289 |
+
"positions_intra_distogram": {
|
290 |
+
"max_dist": 10.0,
|
291 |
+
"weight": 0.0,
|
292 |
+
},
|
293 |
+
"binding_site": {
|
294 |
+
"weight": 0.0,
|
295 |
+
"pos_class_weight": 20.0,
|
296 |
+
},
|
297 |
+
"inter_contact": {
|
298 |
+
"weight": 0.0,
|
299 |
+
"pos_class_weight": 200.0,
|
300 |
+
},
|
301 |
+
"affinity2d": {
|
302 |
+
"min_bin": 0,
|
303 |
+
"max_bin": 15,
|
304 |
+
"no_bins": aux_affinity_bins,
|
305 |
+
"weight": 0.0,
|
306 |
+
},
|
307 |
+
"affinity1d": {
|
308 |
+
"min_bin": 0,
|
309 |
+
"max_bin": 15,
|
310 |
+
"no_bins": aux_affinity_bins,
|
311 |
+
"weight": 0.0,
|
312 |
+
},
|
313 |
+
"affinity_cls": {
|
314 |
+
"min_bin": 0,
|
315 |
+
"max_bin": 15,
|
316 |
+
"no_bins": aux_affinity_bins,
|
317 |
+
"weight": 0.0,
|
318 |
+
},
|
319 |
+
"fape_backbone": {
|
320 |
+
"clamp_distance": 10.0,
|
321 |
+
"loss_unit_distance": 10.0,
|
322 |
+
"weight": 0.5,
|
323 |
+
},
|
324 |
+
"fape_sidechain": {
|
325 |
+
"clamp_distance": 10.0,
|
326 |
+
"length_scale": 10.0,
|
327 |
+
"weight": 0.5,
|
328 |
+
},
|
329 |
+
"fape_interface": {
|
330 |
+
"clamp_distance": 10.0,
|
331 |
+
"length_scale": 10.0,
|
332 |
+
"weight": 0.0,
|
333 |
+
},
|
334 |
+
"plddt_loss": {
|
335 |
+
"min_resolution": 0.1,
|
336 |
+
"max_resolution": 3.0,
|
337 |
+
"cutoff": 15.0,
|
338 |
+
"no_bins": 50,
|
339 |
+
"eps": eps, # 1e-10,
|
340 |
+
"weight": 0.01,
|
341 |
+
},
|
342 |
+
"supervised_chi": {
|
343 |
+
"chi_weight": 0.5,
|
344 |
+
"angle_norm_weight": 0.01,
|
345 |
+
"eps": eps, # 1e-6,
|
346 |
+
"weight": 1.0,
|
347 |
+
},
|
348 |
+
"chain_center_of_mass": {
|
349 |
+
"clamp_distance": -4.0,
|
350 |
+
"weight": 0.,
|
351 |
+
"eps": eps,
|
352 |
+
"enabled": False,
|
353 |
+
},
|
354 |
+
"eps": eps,
|
355 |
+
},
|
356 |
+
"ema": {"decay": 0.999},
|
357 |
+
}
|
358 |
+
)
|
dockformer/data/data_modules.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import itertools
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
from collections import Counter
|
6 |
+
from functools import partial
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import pickle
|
10 |
+
from typing import Optional, Sequence, Any
|
11 |
+
|
12 |
+
import ml_collections as mlc
|
13 |
+
import lightning as L
|
14 |
+
import torch
|
15 |
+
from torch.utils.data import RandomSampler
|
16 |
+
|
17 |
+
from dockformer.data.data_pipeline import parse_input_json
|
18 |
+
from dockformer.data import data_pipeline
|
19 |
+
from dockformer.utils.tensor_utils import dict_multimap
|
20 |
+
from dockformer.utils.tensor_utils import (
|
21 |
+
tensor_tree_map,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class OpenFoldSingleDataset(torch.utils.data.Dataset):
|
26 |
+
def __init__(self,
|
27 |
+
data_dir: str,
|
28 |
+
config: mlc.ConfigDict,
|
29 |
+
mode: str = "train",
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
data_dir:
|
34 |
+
A path to a directory containing mmCIF files (in train
|
35 |
+
mode) or FASTA files (in inference mode).
|
36 |
+
config:
|
37 |
+
A dataset config object. See openfold.config
|
38 |
+
mode:
|
39 |
+
"train", "val", or "predict"
|
40 |
+
"""
|
41 |
+
super(OpenFoldSingleDataset, self).__init__()
|
42 |
+
self.data_dir = data_dir
|
43 |
+
|
44 |
+
self.config = config
|
45 |
+
self.mode = mode
|
46 |
+
|
47 |
+
valid_modes = ["train", "eval", "predict"]
|
48 |
+
if mode not in valid_modes:
|
49 |
+
raise ValueError(f'mode must be one of {valid_modes}')
|
50 |
+
|
51 |
+
self._all_input_files = [i for i in os.listdir(data_dir) if i.endswith(".json")]
|
52 |
+
if self.config.data_module.data_loaders.should_verify:
|
53 |
+
self._all_input_files = [i for i in self._all_input_files if self._verify_json_input_file(i)]
|
54 |
+
|
55 |
+
self.data_pipeline = data_pipeline.DataPipeline(config, mode)
|
56 |
+
|
57 |
+
def _verify_json_input_file(self, file_name: str) -> bool:
|
58 |
+
with open(os.path.join(self.data_dir, file_name), "r") as f:
|
59 |
+
try:
|
60 |
+
loaded = json.load(f)
|
61 |
+
for i in ["input_structure"]:
|
62 |
+
if i not in loaded:
|
63 |
+
return False
|
64 |
+
if self.mode != "predict":
|
65 |
+
for i in ["gt_structure", "resolution"]:
|
66 |
+
if i not in loaded:
|
67 |
+
return False
|
68 |
+
except json.JSONDecodeError:
|
69 |
+
return False
|
70 |
+
return True
|
71 |
+
|
72 |
+
def get_metadata_for_idx(self, idx: int) -> dict:
|
73 |
+
input_path = os.path.join(self.data_dir, self._all_input_files[idx])
|
74 |
+
input_data = json.load(open(input_path, "r"))
|
75 |
+
metadata = {
|
76 |
+
"resolution": input_data.get("resolution", 99.0),
|
77 |
+
"input_path": input_path,
|
78 |
+
"input_name": os.path.basename(input_path).split(".json")[0],
|
79 |
+
}
|
80 |
+
return metadata
|
81 |
+
|
82 |
+
def __getitem__(self, idx):
|
83 |
+
return parse_input_json(
|
84 |
+
input_path=os.path.join(self.data_dir, self._all_input_files[idx]),
|
85 |
+
mode=self.mode,
|
86 |
+
config=self.config,
|
87 |
+
data_pipeline=self.data_pipeline,
|
88 |
+
data_dir=os.path.dirname(self.data_dir),
|
89 |
+
idx=idx,
|
90 |
+
)
|
91 |
+
|
92 |
+
def __len__(self):
|
93 |
+
return len(self._all_input_files)
|
94 |
+
|
95 |
+
|
96 |
+
def resolution_filter(resolution: int, max_resolution: float) -> bool:
|
97 |
+
"""Check that the resolution is <= max_resolution permitted"""
|
98 |
+
return resolution is not None and resolution <= max_resolution
|
99 |
+
|
100 |
+
|
101 |
+
def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool:
|
102 |
+
"""Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
|
103 |
+
total_len = sum([len(i) for i in seqs])
|
104 |
+
return total_len >= minimum_number_of_residues
|
105 |
+
|
106 |
+
|
107 |
+
class OpenFoldDataset(torch.utils.data.Dataset):
|
108 |
+
"""
|
109 |
+
Implements the stochastic filters applied during AlphaFold's training.
|
110 |
+
Because samples are selected from constituent datasets randomly, the
|
111 |
+
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
|
112 |
+
and filtered once at initialization.
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self,
|
116 |
+
datasets: Sequence[OpenFoldSingleDataset],
|
117 |
+
probabilities: Sequence[float],
|
118 |
+
epoch_len: int,
|
119 |
+
generator: torch.Generator = None,
|
120 |
+
_roll_at_init: bool = True,
|
121 |
+
):
|
122 |
+
self.datasets = datasets
|
123 |
+
self.probabilities = probabilities
|
124 |
+
self.epoch_len = epoch_len
|
125 |
+
self.generator = generator
|
126 |
+
|
127 |
+
self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
|
128 |
+
if _roll_at_init:
|
129 |
+
self.reroll()
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def deterministic_train_filter(
|
133 |
+
cache_entry: Any,
|
134 |
+
max_resolution: float = 9.,
|
135 |
+
max_single_aa_prop: float = 0.8,
|
136 |
+
*args, **kwargs
|
137 |
+
) -> bool:
|
138 |
+
# Hard filters
|
139 |
+
resolution = cache_entry["resolution"]
|
140 |
+
|
141 |
+
return all([
|
142 |
+
resolution_filter(resolution=resolution,
|
143 |
+
max_resolution=max_resolution)
|
144 |
+
])
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def get_stochastic_train_filter_prob(
|
148 |
+
cache_entry: Any,
|
149 |
+
*args, **kwargs
|
150 |
+
) -> float:
|
151 |
+
# Stochastic filters
|
152 |
+
probabilities = []
|
153 |
+
|
154 |
+
cluster_size = cache_entry.get("cluster_size", None)
|
155 |
+
if cluster_size is not None and cluster_size > 0:
|
156 |
+
probabilities.append(1 / cluster_size)
|
157 |
+
|
158 |
+
# Risk of underflow here?
|
159 |
+
out = 1
|
160 |
+
for p in probabilities:
|
161 |
+
out *= p
|
162 |
+
|
163 |
+
return out
|
164 |
+
|
165 |
+
def looped_shuffled_dataset_idx(self, dataset_len):
|
166 |
+
while True:
|
167 |
+
# Uniformly shuffle each dataset's indices
|
168 |
+
weights = [1. for _ in range(dataset_len)]
|
169 |
+
shuf = torch.multinomial(
|
170 |
+
torch.tensor(weights),
|
171 |
+
num_samples=dataset_len,
|
172 |
+
replacement=False,
|
173 |
+
generator=self.generator,
|
174 |
+
)
|
175 |
+
for idx in shuf:
|
176 |
+
yield idx
|
177 |
+
|
178 |
+
def looped_samples(self, dataset_idx):
|
179 |
+
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
|
180 |
+
dataset = self.datasets[dataset_idx]
|
181 |
+
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
|
182 |
+
while True:
|
183 |
+
weights = []
|
184 |
+
idx = []
|
185 |
+
for _ in range(max_cache_len):
|
186 |
+
candidate_idx = next(idx_iter)
|
187 |
+
# chain_id = dataset.idx_to_chain_id(candidate_idx)
|
188 |
+
# chain_data_cache_entry = chain_data_cache[chain_id]
|
189 |
+
# data_entry = dataset[candidate_idx.item()]
|
190 |
+
entry_metadata_for_filter = dataset.get_metadata_for_idx(candidate_idx.item())
|
191 |
+
if not self.deterministic_train_filter(entry_metadata_for_filter):
|
192 |
+
continue
|
193 |
+
|
194 |
+
p = self.get_stochastic_train_filter_prob(
|
195 |
+
entry_metadata_for_filter,
|
196 |
+
)
|
197 |
+
weights.append([1. - p, p])
|
198 |
+
idx.append(candidate_idx)
|
199 |
+
|
200 |
+
samples = torch.multinomial(
|
201 |
+
torch.tensor(weights),
|
202 |
+
num_samples=1,
|
203 |
+
generator=self.generator,
|
204 |
+
)
|
205 |
+
samples = samples.squeeze()
|
206 |
+
|
207 |
+
cache = [i for i, s in zip(idx, samples) if s]
|
208 |
+
|
209 |
+
for datapoint_idx in cache:
|
210 |
+
yield datapoint_idx
|
211 |
+
|
212 |
+
def __getitem__(self, idx):
|
213 |
+
dataset_idx, datapoint_idx = self.datapoints[idx]
|
214 |
+
return self.datasets[dataset_idx][datapoint_idx]
|
215 |
+
|
216 |
+
def __len__(self):
|
217 |
+
return self.epoch_len
|
218 |
+
|
219 |
+
def reroll(self):
|
220 |
+
# TODO bshor: I have removed support for filters (currently done in preprocess) and to weighting clusters
|
221 |
+
# now it is much faster, because it doesn't call looped_samples
|
222 |
+
dataset_choices = torch.multinomial(
|
223 |
+
torch.tensor(self.probabilities),
|
224 |
+
num_samples=self.epoch_len,
|
225 |
+
replacement=True,
|
226 |
+
generator=self.generator,
|
227 |
+
)
|
228 |
+
self.datapoints = []
|
229 |
+
counter_datasets = Counter(dataset_choices.tolist())
|
230 |
+
for dataset_idx, num_samples in counter_datasets.items():
|
231 |
+
dataset = self.datasets[dataset_idx]
|
232 |
+
sample_choices = torch.randint(0, len(dataset), (num_samples,), generator=self.generator)
|
233 |
+
for datapoint_idx in sample_choices:
|
234 |
+
self.datapoints.append((dataset_idx, datapoint_idx))
|
235 |
+
|
236 |
+
|
237 |
+
class OpenFoldBatchCollator:
|
238 |
+
def __call__(self, prots):
|
239 |
+
stack_fn = partial(torch.stack, dim=0)
|
240 |
+
return dict_multimap(stack_fn, prots)
|
241 |
+
|
242 |
+
|
243 |
+
class OpenFoldDataLoader(torch.utils.data.DataLoader):
|
244 |
+
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
|
245 |
+
super().__init__(*args, **kwargs)
|
246 |
+
self.config = config
|
247 |
+
self.stage = stage
|
248 |
+
self.generator = generator
|
249 |
+
self._prep_batch_properties_probs()
|
250 |
+
|
251 |
+
def _prep_batch_properties_probs(self):
|
252 |
+
keyed_probs = []
|
253 |
+
stage_cfg = self.config[self.stage]
|
254 |
+
|
255 |
+
max_iters = self.config.common.max_recycling_iters
|
256 |
+
|
257 |
+
if stage_cfg.uniform_recycling:
|
258 |
+
recycling_probs = [
|
259 |
+
1. / (max_iters + 1) for _ in range(max_iters + 1)
|
260 |
+
]
|
261 |
+
else:
|
262 |
+
recycling_probs = [
|
263 |
+
0. for _ in range(max_iters + 1)
|
264 |
+
]
|
265 |
+
recycling_probs[-1] = 1.
|
266 |
+
|
267 |
+
keyed_probs.append(
|
268 |
+
("no_recycling_iters", recycling_probs)
|
269 |
+
)
|
270 |
+
|
271 |
+
keys, probs = zip(*keyed_probs)
|
272 |
+
max_len = max([len(p) for p in probs])
|
273 |
+
padding = [[0.] * (max_len - len(p)) for p in probs]
|
274 |
+
|
275 |
+
self.prop_keys = keys
|
276 |
+
self.prop_probs_tensor = torch.tensor(
|
277 |
+
[p + pad for p, pad in zip(probs, padding)],
|
278 |
+
dtype=torch.float32,
|
279 |
+
)
|
280 |
+
|
281 |
+
def _add_batch_properties(self, batch):
|
282 |
+
# gt_features = batch.pop('gt_features', None)
|
283 |
+
samples = torch.multinomial(
|
284 |
+
self.prop_probs_tensor,
|
285 |
+
num_samples=1, # 1 per row
|
286 |
+
replacement=True,
|
287 |
+
generator=self.generator
|
288 |
+
)
|
289 |
+
|
290 |
+
aatype = batch["aatype"]
|
291 |
+
batch_dims = aatype.shape[:-2]
|
292 |
+
recycling_dim = aatype.shape[-1]
|
293 |
+
no_recycling = recycling_dim
|
294 |
+
for i, key in enumerate(self.prop_keys):
|
295 |
+
sample = int(samples[i][0])
|
296 |
+
sample_tensor = torch.tensor(
|
297 |
+
sample,
|
298 |
+
device=aatype.device,
|
299 |
+
requires_grad=False
|
300 |
+
)
|
301 |
+
orig_shape = sample_tensor.shape
|
302 |
+
sample_tensor = sample_tensor.view(
|
303 |
+
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
|
304 |
+
)
|
305 |
+
sample_tensor = sample_tensor.expand(
|
306 |
+
batch_dims + orig_shape + (recycling_dim,)
|
307 |
+
)
|
308 |
+
batch[key] = sample_tensor
|
309 |
+
|
310 |
+
if key == "no_recycling_iters":
|
311 |
+
no_recycling = sample
|
312 |
+
|
313 |
+
resample_recycling = lambda t: t[..., :no_recycling + 1]
|
314 |
+
batch = tensor_tree_map(resample_recycling, batch)
|
315 |
+
# batch['gt_features'] = gt_features
|
316 |
+
|
317 |
+
return batch
|
318 |
+
|
319 |
+
def __iter__(self):
|
320 |
+
it = super().__iter__()
|
321 |
+
|
322 |
+
def _batch_prop_gen(iterator):
|
323 |
+
for batch in iterator:
|
324 |
+
yield self._add_batch_properties(batch)
|
325 |
+
|
326 |
+
return _batch_prop_gen(it)
|
327 |
+
|
328 |
+
|
329 |
+
class OpenFoldDataModule(L.LightningDataModule):
|
330 |
+
def __init__(self,
|
331 |
+
config: mlc.ConfigDict,
|
332 |
+
train_data_dir: Optional[str] = None,
|
333 |
+
val_data_dir: Optional[str] = None,
|
334 |
+
predict_data_dir: Optional[str] = None,
|
335 |
+
batch_seed: Optional[int] = None,
|
336 |
+
train_epoch_len: int = 50000,
|
337 |
+
**kwargs
|
338 |
+
):
|
339 |
+
super(OpenFoldDataModule, self).__init__()
|
340 |
+
|
341 |
+
self.config = config
|
342 |
+
self.train_data_dir = train_data_dir
|
343 |
+
self.val_data_dir = val_data_dir
|
344 |
+
self.predict_data_dir = predict_data_dir
|
345 |
+
self.batch_seed = batch_seed
|
346 |
+
self.train_epoch_len = train_epoch_len
|
347 |
+
|
348 |
+
if self.train_data_dir is None and self.predict_data_dir is None:
|
349 |
+
raise ValueError(
|
350 |
+
'At least one of train_data_dir or predict_data_dir must be '
|
351 |
+
'specified'
|
352 |
+
)
|
353 |
+
|
354 |
+
self.training_mode = self.train_data_dir is not None
|
355 |
+
|
356 |
+
# if not self.training_mode and predict_alignment_dir is None:
|
357 |
+
# raise ValueError(
|
358 |
+
# 'In inference mode, predict_alignment_dir must be specified'
|
359 |
+
# )
|
360 |
+
# elif val_data_dir is not None and val_alignment_dir is None:
|
361 |
+
# raise ValueError(
|
362 |
+
# 'If val_data_dir is specified, val_alignment_dir must '
|
363 |
+
# 'be specified as well'
|
364 |
+
# )
|
365 |
+
|
366 |
+
def setup(self, stage):
|
367 |
+
# Most of the arguments are the same for the three datasets
|
368 |
+
dataset_gen = partial(OpenFoldSingleDataset,
|
369 |
+
config=self.config)
|
370 |
+
|
371 |
+
if self.training_mode:
|
372 |
+
train_dataset = dataset_gen(
|
373 |
+
data_dir=self.train_data_dir,
|
374 |
+
mode="train",
|
375 |
+
)
|
376 |
+
|
377 |
+
datasets = [train_dataset]
|
378 |
+
probabilities = [1.]
|
379 |
+
|
380 |
+
generator = None
|
381 |
+
if self.batch_seed is not None:
|
382 |
+
generator = torch.Generator()
|
383 |
+
generator = generator.manual_seed(self.batch_seed + 1)
|
384 |
+
|
385 |
+
self.train_dataset = OpenFoldDataset(
|
386 |
+
datasets=datasets,
|
387 |
+
probabilities=probabilities,
|
388 |
+
epoch_len=self.train_epoch_len,
|
389 |
+
generator=generator,
|
390 |
+
_roll_at_init=False,
|
391 |
+
)
|
392 |
+
|
393 |
+
if self.val_data_dir is not None:
|
394 |
+
self.eval_dataset = dataset_gen(
|
395 |
+
data_dir=self.val_data_dir,
|
396 |
+
mode="eval",
|
397 |
+
)
|
398 |
+
else:
|
399 |
+
self.eval_dataset = None
|
400 |
+
else:
|
401 |
+
self.predict_dataset = dataset_gen(
|
402 |
+
data_dir=self.predict_data_dir,
|
403 |
+
mode="predict",
|
404 |
+
)
|
405 |
+
|
406 |
+
def _gen_dataloader(self, stage):
|
407 |
+
generator = None
|
408 |
+
if self.batch_seed is not None:
|
409 |
+
generator = torch.Generator()
|
410 |
+
generator = generator.manual_seed(self.batch_seed)
|
411 |
+
|
412 |
+
if stage == "train":
|
413 |
+
dataset = self.train_dataset
|
414 |
+
# Filter the dataset, if necessary
|
415 |
+
dataset.reroll()
|
416 |
+
elif stage == "eval":
|
417 |
+
dataset = self.eval_dataset
|
418 |
+
elif stage == "predict":
|
419 |
+
dataset = self.predict_dataset
|
420 |
+
else:
|
421 |
+
raise ValueError("Invalid stage")
|
422 |
+
|
423 |
+
batch_collator = OpenFoldBatchCollator()
|
424 |
+
|
425 |
+
dl = OpenFoldDataLoader(
|
426 |
+
dataset,
|
427 |
+
config=self.config,
|
428 |
+
stage=stage,
|
429 |
+
generator=generator,
|
430 |
+
batch_size=self.config.data_module.data_loaders.batch_size,
|
431 |
+
# num_workers=self.config.data_module.data_loaders.num_workers,
|
432 |
+
num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator
|
433 |
+
collate_fn=batch_collator,
|
434 |
+
)
|
435 |
+
|
436 |
+
return dl
|
437 |
+
|
438 |
+
def train_dataloader(self):
|
439 |
+
return self._gen_dataloader("train")
|
440 |
+
|
441 |
+
def val_dataloader(self):
|
442 |
+
if self.eval_dataset is not None:
|
443 |
+
return self._gen_dataloader("eval")
|
444 |
+
return None
|
445 |
+
|
446 |
+
def predict_dataloader(self):
|
447 |
+
return self._gen_dataloader("predict")
|
448 |
+
|
449 |
+
|
450 |
+
class DummyDataset(torch.utils.data.Dataset):
|
451 |
+
def __init__(self, batch_path):
|
452 |
+
with open(batch_path, "rb") as f:
|
453 |
+
self.batch = pickle.load(f)
|
454 |
+
|
455 |
+
def __getitem__(self, idx):
|
456 |
+
return copy.deepcopy(self.batch)
|
457 |
+
|
458 |
+
def __len__(self):
|
459 |
+
return 1000
|
460 |
+
|
461 |
+
|
462 |
+
class DummyDataLoader(L.LightningDataModule):
|
463 |
+
def __init__(self, batch_path):
|
464 |
+
super().__init__()
|
465 |
+
self.dataset = DummyDataset(batch_path)
|
466 |
+
|
467 |
+
def train_dataloader(self):
|
468 |
+
return torch.utils.data.DataLoader(self.dataset)
|
469 |
+
|
470 |
+
|
471 |
+
class DockFormerSimpleDataset(torch.utils.data.Dataset):
|
472 |
+
def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train"):
|
473 |
+
clusters = json.load(open(clusters_json, "r"))
|
474 |
+
self.config = config
|
475 |
+
self.mode = mode
|
476 |
+
self._data_dir = os.path.dirname(clusters_json)
|
477 |
+
print("Data dir", self._data_dir)
|
478 |
+
self._clusters = clusters
|
479 |
+
self._all_input_files = sum(clusters.values(), [])
|
480 |
+
self.data_pipeline = data_pipeline.DataPipeline(config, mode)
|
481 |
+
|
482 |
+
def __getitem__(self, idx):
|
483 |
+
return parse_input_json(
|
484 |
+
input_path=os.path.join(self._data_dir, self._all_input_files[idx]),
|
485 |
+
mode=self.mode,
|
486 |
+
config=self.config,
|
487 |
+
data_pipeline=self.data_pipeline,
|
488 |
+
data_dir=self._data_dir,
|
489 |
+
idx=idx,
|
490 |
+
)
|
491 |
+
|
492 |
+
def __len__(self):
|
493 |
+
return len(self._all_input_files)
|
494 |
+
|
495 |
+
|
496 |
+
class DockFormerClusteredDataset(torch.utils.data.Dataset):
|
497 |
+
def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train", generator=None):
|
498 |
+
clusters = json.load(open(clusters_json, "r"))
|
499 |
+
self.config = config
|
500 |
+
self.mode = mode
|
501 |
+
self._data_dir = os.path.dirname(clusters_json)
|
502 |
+
self._clusters = list(clusters.values())
|
503 |
+
self.data_pipeline = data_pipeline.DataPipeline(config, mode)
|
504 |
+
self._generator = generator
|
505 |
+
|
506 |
+
def __getitem__(self, idx):
|
507 |
+
try:
|
508 |
+
cluster = self._clusters[idx]
|
509 |
+
# choose random from cluster
|
510 |
+
input_file = cluster[torch.randint(0, len(cluster), (1,), generator=self._generator).item()]
|
511 |
+
|
512 |
+
return parse_input_json(
|
513 |
+
input_path=os.path.join(self._data_dir, input_file),
|
514 |
+
mode=self.mode,
|
515 |
+
config=self.config,
|
516 |
+
data_pipeline=self.data_pipeline,
|
517 |
+
data_dir=self._data_dir,
|
518 |
+
idx=idx,
|
519 |
+
)
|
520 |
+
except Exception as e:
|
521 |
+
print("ERROR in loading", e)
|
522 |
+
traceback.print_exc()
|
523 |
+
return parse_input_json(
|
524 |
+
input_path=os.path.join(self._data_dir, self._clusters[0][0]),
|
525 |
+
mode=self.mode,
|
526 |
+
config=self.config,
|
527 |
+
data_pipeline=self.data_pipeline,
|
528 |
+
data_dir=self._data_dir,
|
529 |
+
idx=idx,
|
530 |
+
)
|
531 |
+
|
532 |
+
|
533 |
+
def __len__(self):
|
534 |
+
return len(self._clusters)
|
535 |
+
|
536 |
+
|
537 |
+
class DockFormerDataLoader(torch.utils.data.DataLoader):
|
538 |
+
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
|
539 |
+
super().__init__(*args, **kwargs)
|
540 |
+
self.config = config
|
541 |
+
self.stage = stage
|
542 |
+
# self.generator = generator
|
543 |
+
|
544 |
+
def _add_batch_properties(self, batch):
|
545 |
+
if self.config[self.stage].uniform_recycling:
|
546 |
+
aatype = batch["aatype"]
|
547 |
+
max_recycling_dim = aatype.shape[-1]
|
548 |
+
|
549 |
+
# num_recycles = torch.randint(0, max_recycling_dim, (1,), generator=self.generator)
|
550 |
+
num_recycles = torch.randint(0, max_recycling_dim, (1,)).item()
|
551 |
+
|
552 |
+
resample_recycling = lambda t: t[..., :num_recycles + 1]
|
553 |
+
batch = tensor_tree_map(resample_recycling, batch)
|
554 |
+
|
555 |
+
return batch
|
556 |
+
|
557 |
+
def __iter__(self):
|
558 |
+
it = super().__iter__()
|
559 |
+
|
560 |
+
def _batch_prop_gen(iterator):
|
561 |
+
for batch in iterator:
|
562 |
+
yield self._add_batch_properties(batch)
|
563 |
+
|
564 |
+
return _batch_prop_gen(it)
|
565 |
+
|
566 |
+
|
567 |
+
class DockFormerDataModule(L.LightningDataModule):
|
568 |
+
def __init__(self,
|
569 |
+
config: mlc.ConfigDict,
|
570 |
+
train_data_file: Optional[str] = None,
|
571 |
+
val_data_file: Optional[str] = None,
|
572 |
+
batch_seed: Optional[int] = None,
|
573 |
+
**kwargs
|
574 |
+
):
|
575 |
+
super(DockFormerDataModule, self).__init__()
|
576 |
+
|
577 |
+
self.config = config
|
578 |
+
self.train_data_file = train_data_file
|
579 |
+
self.val_data_file = val_data_file
|
580 |
+
self.batch_seed = batch_seed
|
581 |
+
|
582 |
+
assert self.train_data_file is not None, "train_data_file must be specified"
|
583 |
+
assert self.val_data_file is not None, "val_data_file must be specified"
|
584 |
+
|
585 |
+
self.train_dataset = None
|
586 |
+
self.val_dataset = None
|
587 |
+
|
588 |
+
def setup(self, stage):
|
589 |
+
generator = None
|
590 |
+
if self.batch_seed is not None:
|
591 |
+
generator = torch.Generator()
|
592 |
+
generator = generator.manual_seed(self.batch_seed + 1)
|
593 |
+
|
594 |
+
self.train_dataset = DockFormerClusteredDataset(
|
595 |
+
clusters_json=self.train_data_file,
|
596 |
+
config=self.config,
|
597 |
+
mode="train",
|
598 |
+
generator=generator,
|
599 |
+
)
|
600 |
+
|
601 |
+
self.val_dataset = DockFormerSimpleDataset(
|
602 |
+
clusters_json=self.val_data_file,
|
603 |
+
config=self.config,
|
604 |
+
mode="eval",
|
605 |
+
)
|
606 |
+
|
607 |
+
def _gen_dataloader(self, stage):
|
608 |
+
generator = None
|
609 |
+
if self.batch_seed is not None:
|
610 |
+
generator = torch.Generator()
|
611 |
+
generator = generator.manual_seed(self.batch_seed)
|
612 |
+
|
613 |
+
should_shuffle = stage == "train"
|
614 |
+
if stage == "train":
|
615 |
+
dataset = self.train_dataset
|
616 |
+
elif stage == "eval":
|
617 |
+
dataset = self.val_dataset
|
618 |
+
else:
|
619 |
+
raise ValueError("Invalid stage")
|
620 |
+
|
621 |
+
batch_collator = OpenFoldBatchCollator()
|
622 |
+
|
623 |
+
dl = DockFormerDataLoader(
|
624 |
+
dataset,
|
625 |
+
config=self.config,
|
626 |
+
stage=stage,
|
627 |
+
# generator=generator,
|
628 |
+
batch_size=self.config.data_module.data_loaders.batch_size,
|
629 |
+
# num_workers=self.config.data_module.data_loaders.num_workers,
|
630 |
+
num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator
|
631 |
+
collate_fn=batch_collator,
|
632 |
+
shuffle=should_shuffle,
|
633 |
+
)
|
634 |
+
|
635 |
+
return dl
|
636 |
+
|
637 |
+
def train_dataloader(self):
|
638 |
+
return self._gen_dataloader("train")
|
639 |
+
|
640 |
+
def val_dataloader(self):
|
641 |
+
if self.val_dataset is not None:
|
642 |
+
return self._gen_dataloader("eval")
|
643 |
+
return None
|
dockformer/data/data_pipeline.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
import time
|
18 |
+
from typing import List
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import ml_collections as mlc
|
23 |
+
from rdkit import Chem
|
24 |
+
|
25 |
+
from dockformer.data import data_transforms
|
26 |
+
from dockformer.data.data_transforms import get_restype_atom37_mask, get_restypes
|
27 |
+
from dockformer.data.ligand_features import make_ligand_features
|
28 |
+
from dockformer.data.protein_features import make_protein_features
|
29 |
+
from dockformer.data.utils import FeatureTensorDict, FeatureDict
|
30 |
+
from dockformer.utils import protein
|
31 |
+
|
32 |
+
|
33 |
+
def _np_filter_and_to_tensor_dict(np_example: FeatureDict, features_to_keep: List[str]) -> FeatureTensorDict:
|
34 |
+
"""Creates dict of tensors from a dict of NumPy arrays.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
np_example: A dict of NumPy feature arrays.
|
38 |
+
features: A list of strings of feature names to be returned in the dataset.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
A dictionary of features mapping feature names to features. Only the given
|
42 |
+
features are returned, all other ones are filtered out.
|
43 |
+
"""
|
44 |
+
# torch generates warnings if feature is already a torch Tensor
|
45 |
+
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
|
46 |
+
tensor_dict = {
|
47 |
+
k: to_tensor(v) for k, v in np_example.items() if k in features_to_keep
|
48 |
+
}
|
49 |
+
|
50 |
+
return tensor_dict
|
51 |
+
|
52 |
+
|
53 |
+
def _add_protein_probablistic_features(features: FeatureDict, cfg: mlc.ConfigDict, mode: str) -> FeatureDict:
|
54 |
+
if mode == "train":
|
55 |
+
p = torch.rand(1).item()
|
56 |
+
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
|
57 |
+
features["use_clamped_fape"] = np.float32(use_clamped_fape_value)
|
58 |
+
else:
|
59 |
+
features["use_clamped_fape"] = np.float32(0.0)
|
60 |
+
return features
|
61 |
+
|
62 |
+
|
63 |
+
@data_transforms.curry1
|
64 |
+
def compose(x, fs):
|
65 |
+
for f in fs:
|
66 |
+
x = f(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
def _apply_protein_transforms(tensors: FeatureTensorDict) -> FeatureTensorDict:
|
71 |
+
transforms = [
|
72 |
+
data_transforms.cast_to_64bit_ints,
|
73 |
+
data_transforms.squeeze_features,
|
74 |
+
data_transforms.make_atom14_masks,
|
75 |
+
data_transforms.make_atom14_positions,
|
76 |
+
data_transforms.atom37_to_frames,
|
77 |
+
data_transforms.atom37_to_torsion_angles(""),
|
78 |
+
data_transforms.make_pseudo_beta(),
|
79 |
+
data_transforms.get_backbone_frames,
|
80 |
+
data_transforms.get_chi_angles,
|
81 |
+
]
|
82 |
+
|
83 |
+
tensors = compose(transforms)(tensors)
|
84 |
+
|
85 |
+
return tensors
|
86 |
+
|
87 |
+
|
88 |
+
def _apply_protein_probablistic_transforms(tensors: FeatureTensorDict, cfg: mlc.ConfigDict, mode: str) \
|
89 |
+
-> FeatureTensorDict:
|
90 |
+
transforms = [data_transforms.make_target_feat()]
|
91 |
+
|
92 |
+
crop_feats = dict(cfg.common.feat)
|
93 |
+
|
94 |
+
if cfg[mode].fixed_size:
|
95 |
+
transforms.append(data_transforms.select_feat(list(crop_feats)))
|
96 |
+
# TODO bshor: restore transforms for training on cropped proteins, need to handle pocket somehow
|
97 |
+
# if so, look for random_crop_to_size and make_fixed_size in data_transforms.py
|
98 |
+
|
99 |
+
compose(transforms)(tensors)
|
100 |
+
|
101 |
+
return tensors
|
102 |
+
|
103 |
+
|
104 |
+
class DataPipeline:
|
105 |
+
"""Assembles input features."""
|
106 |
+
def __init__(self, config: mlc.ConfigDict, mode: str):
|
107 |
+
self.config = config
|
108 |
+
self.mode = mode
|
109 |
+
|
110 |
+
self.feature_names = config.common.unsupervised_features
|
111 |
+
if config[mode].supervised:
|
112 |
+
self.feature_names += config.supervised.supervised_features
|
113 |
+
|
114 |
+
def process_pdb(self, pdb_path: str) -> FeatureTensorDict:
|
115 |
+
"""
|
116 |
+
Assembles features for a protein in a PDB file.
|
117 |
+
"""
|
118 |
+
with open(pdb_path, 'r') as f:
|
119 |
+
pdb_str = f.read()
|
120 |
+
|
121 |
+
protein_object = protein.from_pdb_string(pdb_str)
|
122 |
+
description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
|
123 |
+
pdb_feats = make_protein_features(protein_object, description)
|
124 |
+
pdb_feats = _add_protein_probablistic_features(pdb_feats, self.config, self.mode)
|
125 |
+
|
126 |
+
tensor_feats = _np_filter_and_to_tensor_dict(pdb_feats, self.feature_names)
|
127 |
+
|
128 |
+
tensor_feats = _apply_protein_transforms(tensor_feats)
|
129 |
+
tensor_feats = _apply_protein_probablistic_transforms(tensor_feats, self.config, self.mode)
|
130 |
+
|
131 |
+
return tensor_feats
|
132 |
+
|
133 |
+
def process_smiles(self, smiles: str) -> FeatureTensorDict:
|
134 |
+
ligand = Chem.MolFromSmiles(smiles)
|
135 |
+
return make_ligand_features(ligand)
|
136 |
+
|
137 |
+
def process_mol2(self, mol2_path: str) -> FeatureTensorDict:
|
138 |
+
"""
|
139 |
+
Assembles features for a ligand in a mol2 file.
|
140 |
+
"""
|
141 |
+
ligand = Chem.MolFromMol2File(mol2_path)
|
142 |
+
assert ligand is not None, f"Failed to parse ligand from {mol2_path}"
|
143 |
+
|
144 |
+
conf = ligand.GetConformer()
|
145 |
+
positions = torch.tensor(conf.GetPositions())
|
146 |
+
|
147 |
+
return {
|
148 |
+
**make_ligand_features(ligand),
|
149 |
+
"gt_ligand_positions": positions.float()
|
150 |
+
}
|
151 |
+
|
152 |
+
def process_sdf(self, sdf_path: str) -> FeatureTensorDict:
|
153 |
+
"""
|
154 |
+
Assembles features for a ligand in a mol2 file.
|
155 |
+
"""
|
156 |
+
ligand = Chem.MolFromMolFile(sdf_path)
|
157 |
+
assert ligand is not None, f"Failed to parse ligand from {sdf_path}"
|
158 |
+
|
159 |
+
conf = ligand.GetConformer(0)
|
160 |
+
positions = torch.tensor(conf.GetPositions())
|
161 |
+
|
162 |
+
return {
|
163 |
+
**make_ligand_features(ligand),
|
164 |
+
"ligand_positions": positions.float()
|
165 |
+
}
|
166 |
+
|
167 |
+
def process_sdf_list(self, sdf_path_list: List[str]) -> FeatureTensorDict:
|
168 |
+
all_sdf_feats = [self.process_sdf(sdf_path) for sdf_path in sdf_path_list]
|
169 |
+
|
170 |
+
all_sizes = [sdf_feats["ligand_target_feat"].shape[0] for sdf_feats in all_sdf_feats]
|
171 |
+
|
172 |
+
joined_ligand_feats = {}
|
173 |
+
for k in all_sdf_feats[0].keys():
|
174 |
+
if k == "ligand_positions":
|
175 |
+
joined_positions = all_sdf_feats[0][k]
|
176 |
+
prev_offset = joined_positions.max(dim=0).values + 100
|
177 |
+
|
178 |
+
for i, sdf_feats in enumerate(all_sdf_feats[1:]):
|
179 |
+
offset = prev_offset - sdf_feats[k].min(dim=0).values
|
180 |
+
joined_positions = torch.cat([joined_positions, sdf_feats[k] + offset], dim=0)
|
181 |
+
prev_offset = joined_positions.max(dim=0).values + 100
|
182 |
+
joined_ligand_feats[k] = joined_positions
|
183 |
+
elif k in ["ligand_target_feat", "ligand_atype", "ligand_charge", "ligand_chirality", "ligand_bonds"]:
|
184 |
+
joined_ligand_feats[k] = torch.cat([sdf_feats[k] for sdf_feats in all_sdf_feats], dim=0)
|
185 |
+
if k == "ligand_target_feat":
|
186 |
+
joined_ligand_feats["ligand_idx"] = torch.cat([torch.full((sdf_feats[k].shape[0],), i)
|
187 |
+
for i, sdf_feats in enumerate(all_sdf_feats)], dim=0)
|
188 |
+
elif k == "ligand_bonds":
|
189 |
+
joined_ligand_feats["ligand_bonds_idx"] = torch.cat([torch.full((sdf_feats[k].shape[0],), i)
|
190 |
+
for i, sdf_feats in enumerate(all_sdf_feats)],
|
191 |
+
dim=0)
|
192 |
+
elif k == "ligand_bonds_feat":
|
193 |
+
joined_feature = torch.zeros((sum(all_sizes), sum(all_sizes), all_sdf_feats[0][k].shape[2]))
|
194 |
+
for i, sdf_feats in enumerate(all_sdf_feats):
|
195 |
+
start_idx = sum(all_sizes[:i])
|
196 |
+
end_idx = sum(all_sizes[:i + 1])
|
197 |
+
joined_feature[start_idx:end_idx, start_idx:end_idx, :] = sdf_feats[k]
|
198 |
+
joined_ligand_feats[k] = joined_feature
|
199 |
+
else:
|
200 |
+
raise ValueError(f"Unknown key in sdf list features {k}")
|
201 |
+
return joined_ligand_feats
|
202 |
+
|
203 |
+
def get_matching_positions_list(self, ref_path_list: List[str], gt_path_list: List[str]):
|
204 |
+
joined_gt_positions = []
|
205 |
+
|
206 |
+
for ref_ligand_path, gt_ligand_path in zip(ref_path_list, gt_path_list):
|
207 |
+
ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
|
208 |
+
gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
|
209 |
+
|
210 |
+
gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
|
211 |
+
|
212 |
+
gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
|
213 |
+
|
214 |
+
joined_gt_positions.extend(gt_positions)
|
215 |
+
|
216 |
+
return torch.tensor(np.array(joined_gt_positions)).float()
|
217 |
+
|
218 |
+
def get_matching_positions(self, ref_ligand_path: str, gt_ligand_path: str):
|
219 |
+
ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
|
220 |
+
gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
|
221 |
+
|
222 |
+
gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
|
223 |
+
|
224 |
+
gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
|
225 |
+
|
226 |
+
# ref_positions = ref_ligand.GetConformer(0).GetPositions()
|
227 |
+
# for i in range(len(ref_positions)):
|
228 |
+
# for j in range(i + 1, len(ref_positions)):
|
229 |
+
# dist_ref = np.linalg.norm(ref_positions[i] - ref_positions[j])
|
230 |
+
# dist_gt = np.linalg.norm(gt_positions[i] - gt_positions[j])
|
231 |
+
# dist_gt = np.linalg.norm(gt_original_positions[i] - gt_original_positions[j])
|
232 |
+
# if abs(dist_ref - dist_gt) > 1.0:
|
233 |
+
# print(f"Distance mismatch {i} {j} {dist_ref} {dist_gt}")
|
234 |
+
|
235 |
+
return torch.tensor(np.array(gt_positions)) .float()
|
236 |
+
|
237 |
+
|
238 |
+
def _prepare_recycles(feat: torch.Tensor, num_recycles: int) -> torch.Tensor:
|
239 |
+
return feat.unsqueeze(-1).repeat(*([1] * len(feat.shape)), num_recycles)
|
240 |
+
|
241 |
+
|
242 |
+
def _fit_to_crop(target_tensor: torch.Tensor, crop_size: int, start_ind: int) -> torch.Tensor:
|
243 |
+
if len(target_tensor.shape) == 1:
|
244 |
+
ret = torch.zeros((crop_size, ), dtype=target_tensor.dtype)
|
245 |
+
ret[start_ind:start_ind + target_tensor.shape[0]] = target_tensor
|
246 |
+
return ret
|
247 |
+
elif len(target_tensor.shape) == 2:
|
248 |
+
ret = torch.zeros((crop_size, target_tensor.shape[-1]), dtype=target_tensor.dtype)
|
249 |
+
ret[start_ind:start_ind + target_tensor.shape[0], :] = target_tensor
|
250 |
+
return ret
|
251 |
+
else:
|
252 |
+
ret = torch.zeros((crop_size, *target_tensor.shape[1:]), dtype=target_tensor.dtype)
|
253 |
+
ret[start_ind:start_ind + target_tensor.shape[0], ...] = target_tensor
|
254 |
+
return ret
|
255 |
+
|
256 |
+
|
257 |
+
def parse_input_json(input_path: str, mode: str, config: mlc.ConfigDict, data_pipeline: DataPipeline,
|
258 |
+
data_dir: str, idx: int) -> FeatureTensorDict:
|
259 |
+
start_load_time = time.time()
|
260 |
+
input_data = json.load(open(input_path, "r"))
|
261 |
+
if mode == "train" or mode == "eval":
|
262 |
+
print("loading", input_data["pdb_id"], end=" ")
|
263 |
+
|
264 |
+
num_recycles = config.common.max_recycling_iters + 1
|
265 |
+
|
266 |
+
input_pdb_path = os.path.join(data_dir, input_data["input_structure"])
|
267 |
+
input_protein_feats = data_pipeline.process_pdb(pdb_path=input_pdb_path)
|
268 |
+
|
269 |
+
# load ref sdf
|
270 |
+
if "ref_sdf" in input_data:
|
271 |
+
ref_sdf_path = os.path.join(data_dir, input_data["ref_sdf"])
|
272 |
+
ref_ligand_feats = data_pipeline.process_sdf(sdf_path=ref_sdf_path)
|
273 |
+
ref_ligand_feats["ligand_idx"] = torch.zeros((ref_ligand_feats["ligand_target_feat"].shape[0],))
|
274 |
+
ref_ligand_feats["ligand_bonds_idx"] = torch.zeros((ref_ligand_feats["ligand_bonds"].shape[0],))
|
275 |
+
elif "ref_sdf_list" in input_data:
|
276 |
+
sdf_path_list = [os.path.join(data_dir, i) for i in input_data["ref_sdf_list"]]
|
277 |
+
ref_ligand_feats = data_pipeline.process_sdf_list(sdf_path_list=sdf_path_list)
|
278 |
+
else:
|
279 |
+
raise ValueError("ref_sdf or ref_sdf_list must be in input_data")
|
280 |
+
|
281 |
+
n_res = input_protein_feats["protein_target_feat"].shape[0]
|
282 |
+
n_lig = ref_ligand_feats["ligand_target_feat"].shape[0]
|
283 |
+
n_affinity = 1
|
284 |
+
|
285 |
+
# add 1 for affinity token
|
286 |
+
crop_size = n_res + n_lig + n_affinity
|
287 |
+
if (mode == "train" or mode == "eval") and config.train.fixed_size:
|
288 |
+
crop_size = config.train.crop_size
|
289 |
+
|
290 |
+
assert crop_size >= n_res + n_lig + n_affinity, f"crop_size: {crop_size}, n_res: {n_res}, n_lig: {n_lig}"
|
291 |
+
|
292 |
+
token_mask = torch.zeros((crop_size,), dtype=torch.float32)
|
293 |
+
token_mask[:n_res + n_lig + n_affinity] = 1
|
294 |
+
|
295 |
+
protein_mask = torch.zeros((crop_size,), dtype=torch.float32)
|
296 |
+
protein_mask[:n_res] = 1
|
297 |
+
|
298 |
+
ligand_mask = torch.zeros((crop_size,), dtype=torch.float32)
|
299 |
+
ligand_mask[n_res:n_res + n_lig] = 1
|
300 |
+
|
301 |
+
affinity_mask = torch.zeros((crop_size,), dtype=torch.float32)
|
302 |
+
affinity_mask[n_res + n_lig] = 1
|
303 |
+
|
304 |
+
structural_mask = torch.zeros((crop_size,), dtype=torch.float32)
|
305 |
+
structural_mask[:n_res + n_lig] = 1
|
306 |
+
|
307 |
+
inter_pair_mask = torch.zeros((crop_size, crop_size), dtype=torch.float32)
|
308 |
+
inter_pair_mask[:n_res, n_res:n_res + n_lig] = 1
|
309 |
+
inter_pair_mask[n_res:n_res + n_lig, :n_res] = 1
|
310 |
+
|
311 |
+
protein_tf_dim = input_protein_feats["protein_target_feat"].shape[-1]
|
312 |
+
ligand_tf_dim = ref_ligand_feats["ligand_target_feat"].shape[-1]
|
313 |
+
joined_tf_dim = protein_tf_dim + ligand_tf_dim
|
314 |
+
|
315 |
+
target_feat = torch.zeros((crop_size, joined_tf_dim + 3), dtype=torch.float32)
|
316 |
+
target_feat[:n_res, :protein_tf_dim] = input_protein_feats["protein_target_feat"]
|
317 |
+
target_feat[n_res:n_res + n_lig, protein_tf_dim:joined_tf_dim] = ref_ligand_feats["ligand_target_feat"]
|
318 |
+
|
319 |
+
target_feat[:n_res, joined_tf_dim] = 1 # Set "is_protein" flag for protein rows
|
320 |
+
target_feat[n_res:n_res + n_lig, joined_tf_dim + 1] = 1 # Set "is_ligand" flag for ligand rows
|
321 |
+
target_feat[n_res + n_lig, joined_tf_dim + 2] = 1 # Set "is_affinity" flag for affinity row
|
322 |
+
|
323 |
+
ligand_bonds_feat = torch.zeros((crop_size, crop_size, ref_ligand_feats["ligand_bonds_feat"].shape[-1]),
|
324 |
+
dtype=torch.float32)
|
325 |
+
ligand_bonds_feat[n_res:n_res + n_lig, n_res:n_res + n_lig] = ref_ligand_feats["ligand_bonds_feat"]
|
326 |
+
|
327 |
+
input_positions = torch.zeros((crop_size, 3), dtype=torch.float32)
|
328 |
+
input_positions[:n_res] = input_protein_feats["pseudo_beta"]
|
329 |
+
input_positions[n_res:n_res + n_lig] = ref_ligand_feats["ligand_positions"]
|
330 |
+
|
331 |
+
protein_distogram_mask = torch.zeros(crop_size)
|
332 |
+
if mode == "train":
|
333 |
+
ones_indices = torch.randperm(n_res)[:int(n_res * config.train.protein_distogram_mask_prob)]
|
334 |
+
# print(ones_indices)
|
335 |
+
protein_distogram_mask[ones_indices] = 1
|
336 |
+
input_positions = input_positions * (1 - protein_distogram_mask).unsqueeze(-1)
|
337 |
+
elif mode == "predict":
|
338 |
+
# ignore all positions where pseudo_beta is 0, 0, 0
|
339 |
+
protein_distogram_mask = (input_positions == 0).all(dim=-1).float()
|
340 |
+
# print("Ignoring residues", torch.nonzero(distogram_mask).flatten())
|
341 |
+
|
342 |
+
# Implement ligand as amino acid type 20
|
343 |
+
ligand_aatype = 20 * torch.ones((n_lig,), dtype=input_protein_feats["aatype"].dtype)
|
344 |
+
aatype = torch.cat([input_protein_feats["aatype"], ligand_aatype], dim=0)
|
345 |
+
|
346 |
+
restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(target_feat.device)
|
347 |
+
lig_residx_atom37_to_atom14 = restype_atom37_to_atom14[20].repeat(n_lig, 1)
|
348 |
+
residx_atom37_to_atom14 = torch.cat([input_protein_feats["residx_atom37_to_atom14"], lig_residx_atom37_to_atom14],
|
349 |
+
dim=0)
|
350 |
+
|
351 |
+
restype_atom37_mask = get_restype_atom37_mask(target_feat.device)
|
352 |
+
lig_atom37_atom_exists = restype_atom37_mask[20].repeat(n_lig, 1)
|
353 |
+
atom37_atom_exists = torch.cat([input_protein_feats["atom37_atom_exists"], lig_atom37_atom_exists], dim=0)
|
354 |
+
|
355 |
+
feats = {
|
356 |
+
"token_mask": token_mask,
|
357 |
+
"protein_mask": protein_mask,
|
358 |
+
"ligand_mask": ligand_mask,
|
359 |
+
"affinity_mask": affinity_mask,
|
360 |
+
"structural_mask": structural_mask,
|
361 |
+
"inter_pair_mask": inter_pair_mask,
|
362 |
+
|
363 |
+
"target_feat": target_feat,
|
364 |
+
"ligand_bonds_feat": ligand_bonds_feat,
|
365 |
+
"input_positions": input_positions,
|
366 |
+
"protein_distogram_mask": protein_distogram_mask,
|
367 |
+
"protein_residue_index": _fit_to_crop(input_protein_feats["residue_index"], crop_size, 0),
|
368 |
+
"aatype": _fit_to_crop(aatype, crop_size, 0),
|
369 |
+
"residx_atom37_to_atom14": _fit_to_crop(residx_atom37_to_atom14, crop_size, 0),
|
370 |
+
"atom37_atom_exists": _fit_to_crop(atom37_atom_exists, crop_size, 0),
|
371 |
+
}
|
372 |
+
|
373 |
+
if mode == "predict":
|
374 |
+
feats.update({
|
375 |
+
"in_chain_residue_index": input_protein_feats["in_chain_residue_index"],
|
376 |
+
"chain_index": input_protein_feats["chain_index"],
|
377 |
+
"ligand_atype": ref_ligand_feats["ligand_atype"],
|
378 |
+
"ligand_chirality": ref_ligand_feats["ligand_chirality"],
|
379 |
+
"ligand_charge": ref_ligand_feats["ligand_charge"],
|
380 |
+
"ligand_bonds": ref_ligand_feats["ligand_bonds"],
|
381 |
+
"ligand_idx": ref_ligand_feats["ligand_idx"],
|
382 |
+
"ligand_bonds_idx": ref_ligand_feats["ligand_bonds_idx"],
|
383 |
+
})
|
384 |
+
|
385 |
+
if mode == 'train' or mode == 'eval':
|
386 |
+
gt_pdb_path = os.path.join(data_dir, input_data["gt_structure"])
|
387 |
+
gt_protein_feats = data_pipeline.process_pdb(pdb_path=gt_pdb_path)
|
388 |
+
|
389 |
+
if "gt_sdf" in input_data:
|
390 |
+
gt_ligand_positions = data_pipeline.get_matching_positions(
|
391 |
+
os.path.join(data_dir, input_data["ref_sdf"]),
|
392 |
+
os.path.join(data_dir, input_data["gt_sdf"]),
|
393 |
+
)
|
394 |
+
elif "gt_sdf_list" in input_data:
|
395 |
+
gt_ligand_positions = data_pipeline.get_matching_positions_list(
|
396 |
+
[os.path.join(data_dir, i) for i in input_data["ref_sdf_list"]],
|
397 |
+
[os.path.join(data_dir, i) for i in input_data["gt_sdf_list"]],
|
398 |
+
)
|
399 |
+
else:
|
400 |
+
raise ValueError("gt_sdf or gt_sdf_list must be in input_data")
|
401 |
+
|
402 |
+
affinity_loss_factor = torch.tensor([1.0], dtype=torch.float32)
|
403 |
+
if input_data["affinity"] is None:
|
404 |
+
eps = 1e-6
|
405 |
+
affinity_loss_factor = torch.tensor([eps], dtype=torch.float32)
|
406 |
+
affinity = torch.tensor([0.0], dtype=torch.float32)
|
407 |
+
else:
|
408 |
+
affinity = torch.tensor([input_data["affinity"]], dtype=torch.float32)
|
409 |
+
|
410 |
+
resolution = torch.tensor(input_data["resolution"], dtype=torch.float32)
|
411 |
+
|
412 |
+
# prepare inter_contacts
|
413 |
+
expanded_prot_pos = gt_protein_feats["pseudo_beta"].unsqueeze(1) # Shape: (N_prot, 1, 3)
|
414 |
+
expanded_lig_pos = gt_ligand_positions.unsqueeze(0) # Shape: (1, N_lig, 3)
|
415 |
+
distances = torch.sqrt(torch.sum((expanded_prot_pos - expanded_lig_pos) ** 2, dim=-1))
|
416 |
+
inter_contact = (distances < 5.0).float()
|
417 |
+
binding_site_mask = inter_contact.any(dim=1).float()
|
418 |
+
|
419 |
+
inter_contact_reshaped_to_crop = torch.zeros((crop_size, crop_size), dtype=torch.float32)
|
420 |
+
inter_contact_reshaped_to_crop[:n_res, n_res:n_res + n_lig] = inter_contact
|
421 |
+
inter_contact_reshaped_to_crop[n_res:n_res + n_lig, :n_res] = inter_contact.T
|
422 |
+
|
423 |
+
# Use CA positions only
|
424 |
+
lig_single_res_atom37_mask = torch.zeros((37,), dtype=torch.float32)
|
425 |
+
lig_single_res_atom37_mask[1] = 1
|
426 |
+
lig_atom37_mask = lig_single_res_atom37_mask.unsqueeze(0).expand(n_lig, -1)
|
427 |
+
lig_single_res_atom14_mask = torch.zeros((14,), dtype=torch.float32)
|
428 |
+
lig_single_res_atom14_mask[1] = 1
|
429 |
+
lig_atom14_mask = lig_single_res_atom14_mask.unsqueeze(0).expand(n_lig, -1)
|
430 |
+
|
431 |
+
lig_atom37_positions = gt_ligand_positions.unsqueeze(1).expand(-1, 37, -1)
|
432 |
+
lig_atom37_positions = lig_atom37_positions * lig_single_res_atom37_mask.view(1, 37, 1).expand(n_lig, -1, 3)
|
433 |
+
|
434 |
+
lig_atom14_positions = gt_ligand_positions.unsqueeze(1).expand(-1, 14, -1)
|
435 |
+
lig_atom14_positions = lig_atom14_positions * lig_single_res_atom14_mask.view(1, 14, 1).expand(n_lig, -1, 3)
|
436 |
+
|
437 |
+
atom37_gt_positions = torch.cat([gt_protein_feats["all_atom_positions"], lig_atom37_positions], dim=0)
|
438 |
+
atom37_atom_exists_in_res = torch.cat([gt_protein_feats["atom37_atom_exists"], lig_atom37_mask], dim=0)
|
439 |
+
atom37_atom_exists_in_gt = torch.cat([gt_protein_feats["all_atom_mask"], lig_atom37_mask], dim=0)
|
440 |
+
|
441 |
+
atom14_gt_positions = torch.cat([gt_protein_feats["atom14_gt_positions"], lig_atom14_positions], dim=0)
|
442 |
+
atom14_atom_exists_in_res = torch.cat([gt_protein_feats["atom14_atom_exists"], lig_atom14_mask], dim=0)
|
443 |
+
atom14_atom_exists_in_gt = torch.cat([gt_protein_feats["atom14_gt_exists"], lig_atom14_mask], dim=0)
|
444 |
+
|
445 |
+
gt_pseudo_beta_with_lig = torch.cat([gt_protein_feats["pseudo_beta"], gt_ligand_positions], dim=0)
|
446 |
+
gt_pseudo_beta_with_lig_mask = torch.cat(
|
447 |
+
[gt_protein_feats["pseudo_beta_mask"],
|
448 |
+
torch.ones((n_lig,), dtype=gt_protein_feats["pseudo_beta_mask"].dtype)],
|
449 |
+
dim=0)
|
450 |
+
|
451 |
+
# IGNORES: residx_atom14_to_atom37, rigidgroups_group_exists,
|
452 |
+
# rigidgroups_group_is_ambiguous, pseudo_beta_mask, backbone_rigid_mask, protein_target_feat
|
453 |
+
gt_protein_feats = {
|
454 |
+
"atom37_gt_positions": atom37_gt_positions, # torch.Size([n_struct, 37, 3])
|
455 |
+
"atom37_atom_exists_in_res": atom37_atom_exists_in_res, # torch.Size([n_struct, 37])
|
456 |
+
"atom37_atom_exists_in_gt": atom37_atom_exists_in_gt, # torch.Size([n_struct, 37])
|
457 |
+
|
458 |
+
"atom14_gt_positions": atom14_gt_positions, # torch.Size([n_struct, 14, 3])
|
459 |
+
"atom14_atom_exists_in_res": atom14_atom_exists_in_res, # torch.Size([n_struct, 14])
|
460 |
+
"atom14_atom_exists_in_gt": atom14_atom_exists_in_gt, # torch.Size([n_struct, 14])
|
461 |
+
|
462 |
+
"gt_pseudo_beta_with_lig": gt_pseudo_beta_with_lig, # torch.Size([n_struct, 3])
|
463 |
+
"gt_pseudo_beta_with_lig_mask": gt_pseudo_beta_with_lig_mask, # torch.Size([n_struct])
|
464 |
+
|
465 |
+
# These we don't need to add the ligand to, because padding is sufficient (everything should be 0)
|
466 |
+
"atom14_alt_gt_positions": gt_protein_feats["atom14_alt_gt_positions"], # torch.Size([n_res, 14, 3])
|
467 |
+
"atom14_alt_gt_exists": gt_protein_feats["atom14_alt_gt_exists"], # torch.Size([n_res, 14])
|
468 |
+
"atom14_atom_is_ambiguous": gt_protein_feats["atom14_atom_is_ambiguous"], # torch.Size([n_res, 14])
|
469 |
+
"rigidgroups_gt_frames": gt_protein_feats["rigidgroups_gt_frames"], # torch.Size([n_res, 8, 4, 4])
|
470 |
+
"rigidgroups_gt_exists": gt_protein_feats["rigidgroups_gt_exists"], # torch.Size([n_res, 8])
|
471 |
+
"rigidgroups_alt_gt_frames": gt_protein_feats["rigidgroups_alt_gt_frames"], # torch.Size([n_res, 8, 4, 4])
|
472 |
+
"backbone_rigid_tensor": gt_protein_feats["backbone_rigid_tensor"], # torch.Size([n_res, 4, 4])
|
473 |
+
"backbone_rigid_mask": gt_protein_feats["backbone_rigid_mask"], # torch.Size([n_res])
|
474 |
+
"chi_angles_sin_cos": gt_protein_feats["chi_angles_sin_cos"],
|
475 |
+
"chi_mask": gt_protein_feats["chi_mask"],
|
476 |
+
}
|
477 |
+
|
478 |
+
for k, v in gt_protein_feats.items():
|
479 |
+
gt_protein_feats[k] = _fit_to_crop(v, crop_size, 0)
|
480 |
+
|
481 |
+
feats = {
|
482 |
+
**feats,
|
483 |
+
**gt_protein_feats,
|
484 |
+
"gt_ligand_positions": _fit_to_crop(gt_ligand_positions, crop_size, n_res),
|
485 |
+
"resolution": resolution,
|
486 |
+
"affinity": affinity,
|
487 |
+
"affinity_loss_factor": affinity_loss_factor,
|
488 |
+
"seq_length": torch.tensor(n_res + n_lig),
|
489 |
+
"binding_site_mask": _fit_to_crop(binding_site_mask, crop_size, 0),
|
490 |
+
"gt_inter_contacts": inter_contact_reshaped_to_crop,
|
491 |
+
}
|
492 |
+
|
493 |
+
for k, v in feats.items():
|
494 |
+
# print(k, v.shape)
|
495 |
+
feats[k] = _prepare_recycles(v, num_recycles)
|
496 |
+
|
497 |
+
feats["batch_idx"] = torch.tensor(
|
498 |
+
[idx for _ in range(crop_size)], dtype=torch.int64, device=feats["aatype"].device
|
499 |
+
)
|
500 |
+
|
501 |
+
print("load time", round(time.time() - start_load_time, 4))
|
502 |
+
|
503 |
+
return feats
|
dockformer/data/data_transforms.py
ADDED
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import itertools
|
17 |
+
from functools import reduce, wraps
|
18 |
+
from operator import add
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from dockformer.config import NUM_RES
|
24 |
+
from dockformer.utils import residue_constants as rc
|
25 |
+
from dockformer.utils.residue_constants import restypes
|
26 |
+
from dockformer.utils.rigid_utils import Rotation, Rigid
|
27 |
+
from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array
|
28 |
+
from dockformer.utils.geometry.rotation_matrix import Rot3Array
|
29 |
+
from dockformer.utils.geometry.vector import Vec3Array
|
30 |
+
from dockformer.utils.tensor_utils import (
|
31 |
+
tree_map,
|
32 |
+
tensor_tree_map,
|
33 |
+
batched_gather,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def cast_to_64bit_ints(protein):
|
38 |
+
# We keep all ints as int64
|
39 |
+
for k, v in protein.items():
|
40 |
+
if v.dtype == torch.int32:
|
41 |
+
protein[k] = v.type(torch.int64)
|
42 |
+
|
43 |
+
return protein
|
44 |
+
|
45 |
+
|
46 |
+
def make_one_hot(x, num_classes):
|
47 |
+
x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
|
48 |
+
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
|
49 |
+
return x_one_hot
|
50 |
+
|
51 |
+
|
52 |
+
def curry1(f):
|
53 |
+
"""Supply all arguments but the first."""
|
54 |
+
@wraps(f)
|
55 |
+
def fc(*args, **kwargs):
|
56 |
+
return lambda x: f(x, *args, **kwargs)
|
57 |
+
|
58 |
+
return fc
|
59 |
+
|
60 |
+
|
61 |
+
def squeeze_features(protein):
|
62 |
+
"""Remove singleton and repeated dimensions in protein features."""
|
63 |
+
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
|
64 |
+
for k in [
|
65 |
+
"domain_name",
|
66 |
+
"seq_length",
|
67 |
+
"sequence",
|
68 |
+
"resolution",
|
69 |
+
"residue_index",
|
70 |
+
]:
|
71 |
+
if k in protein:
|
72 |
+
final_dim = protein[k].shape[-1]
|
73 |
+
if isinstance(final_dim, int) and final_dim == 1:
|
74 |
+
if torch.is_tensor(protein[k]):
|
75 |
+
protein[k] = torch.squeeze(protein[k], dim=-1)
|
76 |
+
else:
|
77 |
+
protein[k] = np.squeeze(protein[k], axis=-1)
|
78 |
+
|
79 |
+
for k in ["seq_length"]:
|
80 |
+
if k in protein:
|
81 |
+
protein[k] = protein[k][0]
|
82 |
+
|
83 |
+
return protein
|
84 |
+
|
85 |
+
|
86 |
+
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
|
87 |
+
"""Create pseudo beta features."""
|
88 |
+
is_gly = torch.eq(aatype, rc.restype_order["G"])
|
89 |
+
ca_idx = rc.atom_order["CA"]
|
90 |
+
cb_idx = rc.atom_order["CB"]
|
91 |
+
pseudo_beta = torch.where(
|
92 |
+
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
|
93 |
+
all_atom_positions[..., ca_idx, :],
|
94 |
+
all_atom_positions[..., cb_idx, :],
|
95 |
+
)
|
96 |
+
|
97 |
+
if all_atom_mask is not None:
|
98 |
+
pseudo_beta_mask = torch.where(
|
99 |
+
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
|
100 |
+
)
|
101 |
+
return pseudo_beta, pseudo_beta_mask
|
102 |
+
else:
|
103 |
+
return pseudo_beta
|
104 |
+
|
105 |
+
|
106 |
+
@curry1
|
107 |
+
def make_pseudo_beta(protein):
|
108 |
+
"""Create pseudo-beta (alpha for glycine) position and mask."""
|
109 |
+
(protein["pseudo_beta"], protein["pseudo_beta_mask"]) = pseudo_beta_fn(
|
110 |
+
protein["aatype"],
|
111 |
+
protein["all_atom_positions"],
|
112 |
+
protein["all_atom_mask"],
|
113 |
+
)
|
114 |
+
return protein
|
115 |
+
|
116 |
+
|
117 |
+
@curry1
|
118 |
+
def make_target_feat(protein):
|
119 |
+
"""Create and concatenate protein features."""
|
120 |
+
# Whether there is a domain break. Always zero for chains, but keeping for
|
121 |
+
# compatibility with domain datasets.
|
122 |
+
aatype_1hot = make_one_hot(protein["aatype"], 20)
|
123 |
+
|
124 |
+
protein["protein_target_feat"] = aatype_1hot
|
125 |
+
|
126 |
+
return protein
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
@curry1
|
131 |
+
def select_feat(protein, feature_list):
|
132 |
+
return {k: v for k, v in protein.items() if k in feature_list}
|
133 |
+
|
134 |
+
|
135 |
+
def get_restypes(device):
|
136 |
+
restype_atom14_to_atom37 = []
|
137 |
+
restype_atom37_to_atom14 = []
|
138 |
+
restype_atom14_mask = []
|
139 |
+
|
140 |
+
for rt in rc.restypes:
|
141 |
+
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
|
142 |
+
restype_atom14_to_atom37.append(
|
143 |
+
[(rc.atom_order[name] if name else 0) for name in atom_names]
|
144 |
+
)
|
145 |
+
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
146 |
+
restype_atom37_to_atom14.append(
|
147 |
+
[
|
148 |
+
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
|
149 |
+
for name in rc.atom_types
|
150 |
+
]
|
151 |
+
)
|
152 |
+
|
153 |
+
restype_atom14_mask.append(
|
154 |
+
[(1.0 if name else 0.0) for name in atom_names]
|
155 |
+
)
|
156 |
+
|
157 |
+
# Add dummy mapping for restype 'UNK'
|
158 |
+
restype_atom14_to_atom37.append([0] * 14)
|
159 |
+
restype_atom37_to_atom14.append([0] * 37)
|
160 |
+
restype_atom14_mask.append([0.0] * 14)
|
161 |
+
|
162 |
+
restype_atom14_to_atom37 = torch.tensor(
|
163 |
+
restype_atom14_to_atom37,
|
164 |
+
dtype=torch.int32,
|
165 |
+
device=device,
|
166 |
+
)
|
167 |
+
restype_atom37_to_atom14 = torch.tensor(
|
168 |
+
restype_atom37_to_atom14,
|
169 |
+
dtype=torch.int32,
|
170 |
+
device=device,
|
171 |
+
)
|
172 |
+
restype_atom14_mask = torch.tensor(
|
173 |
+
restype_atom14_mask,
|
174 |
+
dtype=torch.float32,
|
175 |
+
device=device,
|
176 |
+
)
|
177 |
+
|
178 |
+
return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask
|
179 |
+
|
180 |
+
|
181 |
+
def get_restype_atom37_mask(device):
|
182 |
+
# create the corresponding mask
|
183 |
+
restype_atom37_mask = torch.zeros(
|
184 |
+
[len(restypes) + 1, 37], dtype=torch.float32, device=device
|
185 |
+
)
|
186 |
+
for restype, restype_letter in enumerate(rc.restypes):
|
187 |
+
restype_name = rc.restype_1to3[restype_letter]
|
188 |
+
atom_names = rc.residue_atoms[restype_name]
|
189 |
+
for atom_name in atom_names:
|
190 |
+
atom_type = rc.atom_order[atom_name]
|
191 |
+
restype_atom37_mask[restype, atom_type] = 1
|
192 |
+
return restype_atom37_mask
|
193 |
+
|
194 |
+
|
195 |
+
def make_atom14_masks(protein):
|
196 |
+
"""Construct denser atom positions (14 dimensions instead of 37)."""
|
197 |
+
restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(protein["aatype"].device)
|
198 |
+
|
199 |
+
protein_aatype = protein['aatype'].to(torch.long)
|
200 |
+
|
201 |
+
# create the mapping for (residx, atom14) --> atom37, i.e. an array
|
202 |
+
# with shape (num_res, 14) containing the atom37 indices for this protein
|
203 |
+
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
|
204 |
+
residx_atom14_mask = restype_atom14_mask[protein_aatype]
|
205 |
+
|
206 |
+
protein["atom14_atom_exists"] = residx_atom14_mask
|
207 |
+
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
|
208 |
+
|
209 |
+
# create the gather indices for mapping back
|
210 |
+
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
|
211 |
+
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
|
212 |
+
|
213 |
+
restype_atom37_mask = get_restype_atom37_mask(protein["aatype"].device)
|
214 |
+
|
215 |
+
residx_atom37_mask = restype_atom37_mask[protein_aatype]
|
216 |
+
protein["atom37_atom_exists"] = residx_atom37_mask
|
217 |
+
|
218 |
+
return protein
|
219 |
+
|
220 |
+
|
221 |
+
def make_atom14_positions(protein):
|
222 |
+
"""Constructs denser atom positions (14 dimensions instead of 37)."""
|
223 |
+
residx_atom14_mask = protein["atom14_atom_exists"]
|
224 |
+
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
|
225 |
+
|
226 |
+
# Create a mask for known ground truth positions.
|
227 |
+
residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
|
228 |
+
protein["all_atom_mask"],
|
229 |
+
residx_atom14_to_atom37,
|
230 |
+
dim=-1,
|
231 |
+
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
|
232 |
+
)
|
233 |
+
|
234 |
+
# Gather the ground truth positions.
|
235 |
+
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
|
236 |
+
batched_gather(
|
237 |
+
protein["all_atom_positions"],
|
238 |
+
residx_atom14_to_atom37,
|
239 |
+
dim=-2,
|
240 |
+
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
|
241 |
+
)
|
242 |
+
)
|
243 |
+
|
244 |
+
protein["atom14_atom_exists"] = residx_atom14_mask
|
245 |
+
protein["atom14_gt_exists"] = residx_atom14_gt_mask
|
246 |
+
protein["atom14_gt_positions"] = residx_atom14_gt_positions
|
247 |
+
|
248 |
+
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
|
249 |
+
# alternative ground truth coordinates where the naming is swapped
|
250 |
+
restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
|
251 |
+
restype_3 += ["UNK"]
|
252 |
+
|
253 |
+
# Matrices for renaming ambiguous atoms.
|
254 |
+
all_matrices = {
|
255 |
+
res: torch.eye(
|
256 |
+
14,
|
257 |
+
dtype=protein["all_atom_mask"].dtype,
|
258 |
+
device=protein["all_atom_mask"].device,
|
259 |
+
)
|
260 |
+
for res in restype_3
|
261 |
+
}
|
262 |
+
for resname, swap in rc.residue_atom_renaming_swaps.items():
|
263 |
+
correspondences = torch.arange(
|
264 |
+
14, device=protein["all_atom_mask"].device
|
265 |
+
)
|
266 |
+
for source_atom_swap, target_atom_swap in swap.items():
|
267 |
+
source_index = rc.restype_name_to_atom14_names[resname].index(
|
268 |
+
source_atom_swap
|
269 |
+
)
|
270 |
+
target_index = rc.restype_name_to_atom14_names[resname].index(
|
271 |
+
target_atom_swap
|
272 |
+
)
|
273 |
+
correspondences[source_index] = target_index
|
274 |
+
correspondences[target_index] = source_index
|
275 |
+
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
|
276 |
+
for index, correspondence in enumerate(correspondences):
|
277 |
+
renaming_matrix[index, correspondence] = 1.0
|
278 |
+
all_matrices[resname] = renaming_matrix
|
279 |
+
|
280 |
+
renaming_matrices = torch.stack(
|
281 |
+
[all_matrices[restype] for restype in restype_3]
|
282 |
+
)
|
283 |
+
|
284 |
+
# Pick the transformation matrices for the given residue sequence
|
285 |
+
# shape (num_res, 14, 14).
|
286 |
+
renaming_transform = renaming_matrices[protein["aatype"]]
|
287 |
+
|
288 |
+
# Apply it to the ground truth positions. shape (num_res, 14, 3).
|
289 |
+
alternative_gt_positions = torch.einsum(
|
290 |
+
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
|
291 |
+
)
|
292 |
+
protein["atom14_alt_gt_positions"] = alternative_gt_positions
|
293 |
+
|
294 |
+
# Create the mask for the alternative ground truth (differs from the
|
295 |
+
# ground truth mask, if only one of the atoms in an ambiguous pair has a
|
296 |
+
# ground truth position).
|
297 |
+
alternative_gt_mask = torch.einsum(
|
298 |
+
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
|
299 |
+
)
|
300 |
+
protein["atom14_alt_gt_exists"] = alternative_gt_mask
|
301 |
+
|
302 |
+
# Create an ambiguous atoms mask. shape: (21, 14).
|
303 |
+
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
|
304 |
+
for resname, swap in rc.residue_atom_renaming_swaps.items():
|
305 |
+
for atom_name1, atom_name2 in swap.items():
|
306 |
+
restype = rc.restype_order[rc.restype_3to1[resname]]
|
307 |
+
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
|
308 |
+
atom_name1
|
309 |
+
)
|
310 |
+
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
|
311 |
+
atom_name2
|
312 |
+
)
|
313 |
+
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
|
314 |
+
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
|
315 |
+
|
316 |
+
# From this create an ambiguous_mask for the given sequence.
|
317 |
+
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
|
318 |
+
protein["aatype"]
|
319 |
+
]
|
320 |
+
|
321 |
+
return protein
|
322 |
+
|
323 |
+
|
324 |
+
def atom37_to_frames(protein, eps=1e-8):
|
325 |
+
aatype = protein["aatype"]
|
326 |
+
all_atom_positions = protein["all_atom_positions"]
|
327 |
+
all_atom_mask = protein["all_atom_mask"]
|
328 |
+
|
329 |
+
batch_dims = len(aatype.shape[:-1])
|
330 |
+
|
331 |
+
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
|
332 |
+
restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
|
333 |
+
restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
|
334 |
+
|
335 |
+
for restype, restype_letter in enumerate(rc.restypes):
|
336 |
+
resname = rc.restype_1to3[restype_letter]
|
337 |
+
for chi_idx in range(4):
|
338 |
+
if rc.chi_angles_mask[restype][chi_idx]:
|
339 |
+
names = rc.chi_angles_atoms[resname][chi_idx]
|
340 |
+
restype_rigidgroup_base_atom_names[
|
341 |
+
restype, chi_idx + 4, :
|
342 |
+
] = names[1:]
|
343 |
+
|
344 |
+
restype_rigidgroup_mask = all_atom_mask.new_zeros(
|
345 |
+
(*aatype.shape[:-1], 21, 8),
|
346 |
+
)
|
347 |
+
restype_rigidgroup_mask[..., 0] = 1
|
348 |
+
restype_rigidgroup_mask[..., 3] = 1
|
349 |
+
restype_rigidgroup_mask[..., :len(restypes), 4:] = all_atom_mask.new_tensor(
|
350 |
+
rc.chi_angles_mask
|
351 |
+
)
|
352 |
+
|
353 |
+
lookuptable = rc.atom_order.copy()
|
354 |
+
lookuptable[""] = 0
|
355 |
+
lookup = np.vectorize(lambda x: lookuptable[x])
|
356 |
+
restype_rigidgroup_base_atom37_idx = lookup(
|
357 |
+
restype_rigidgroup_base_atom_names,
|
358 |
+
)
|
359 |
+
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
|
360 |
+
restype_rigidgroup_base_atom37_idx,
|
361 |
+
)
|
362 |
+
restype_rigidgroup_base_atom37_idx = (
|
363 |
+
restype_rigidgroup_base_atom37_idx.view(
|
364 |
+
*((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
|
365 |
+
)
|
366 |
+
)
|
367 |
+
|
368 |
+
residx_rigidgroup_base_atom37_idx = batched_gather(
|
369 |
+
restype_rigidgroup_base_atom37_idx,
|
370 |
+
aatype,
|
371 |
+
dim=-3,
|
372 |
+
no_batch_dims=batch_dims,
|
373 |
+
)
|
374 |
+
|
375 |
+
base_atom_pos = batched_gather(
|
376 |
+
all_atom_positions,
|
377 |
+
residx_rigidgroup_base_atom37_idx,
|
378 |
+
dim=-2,
|
379 |
+
no_batch_dims=len(all_atom_positions.shape[:-2]),
|
380 |
+
)
|
381 |
+
|
382 |
+
gt_frames = Rigid.from_3_points(
|
383 |
+
p_neg_x_axis=base_atom_pos[..., 0, :],
|
384 |
+
origin=base_atom_pos[..., 1, :],
|
385 |
+
p_xy_plane=base_atom_pos[..., 2, :],
|
386 |
+
eps=eps,
|
387 |
+
)
|
388 |
+
|
389 |
+
group_exists = batched_gather(
|
390 |
+
restype_rigidgroup_mask,
|
391 |
+
aatype,
|
392 |
+
dim=-2,
|
393 |
+
no_batch_dims=batch_dims,
|
394 |
+
)
|
395 |
+
|
396 |
+
gt_atoms_exist = batched_gather(
|
397 |
+
all_atom_mask,
|
398 |
+
residx_rigidgroup_base_atom37_idx,
|
399 |
+
dim=-1,
|
400 |
+
no_batch_dims=len(all_atom_mask.shape[:-1]),
|
401 |
+
)
|
402 |
+
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
|
403 |
+
|
404 |
+
rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
|
405 |
+
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
|
406 |
+
rots[..., 0, 0, 0] = -1
|
407 |
+
rots[..., 0, 2, 2] = -1
|
408 |
+
|
409 |
+
rots = Rotation(rot_mats=rots)
|
410 |
+
gt_frames = gt_frames.compose(Rigid(rots, None))
|
411 |
+
|
412 |
+
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
|
413 |
+
*((1,) * batch_dims), 21, 8
|
414 |
+
)
|
415 |
+
restype_rigidgroup_rots = torch.eye(
|
416 |
+
3, dtype=all_atom_mask.dtype, device=aatype.device
|
417 |
+
)
|
418 |
+
restype_rigidgroup_rots = torch.tile(
|
419 |
+
restype_rigidgroup_rots,
|
420 |
+
(*((1,) * batch_dims), 21, 8, 1, 1),
|
421 |
+
)
|
422 |
+
|
423 |
+
for resname, _ in rc.residue_atom_renaming_swaps.items():
|
424 |
+
restype = rc.restype_order[rc.restype_3to1[resname]]
|
425 |
+
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
|
426 |
+
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
|
427 |
+
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
|
428 |
+
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
|
429 |
+
|
430 |
+
residx_rigidgroup_is_ambiguous = batched_gather(
|
431 |
+
restype_rigidgroup_is_ambiguous,
|
432 |
+
aatype,
|
433 |
+
dim=-2,
|
434 |
+
no_batch_dims=batch_dims,
|
435 |
+
)
|
436 |
+
|
437 |
+
residx_rigidgroup_ambiguity_rot = batched_gather(
|
438 |
+
restype_rigidgroup_rots,
|
439 |
+
aatype,
|
440 |
+
dim=-4,
|
441 |
+
no_batch_dims=batch_dims,
|
442 |
+
)
|
443 |
+
|
444 |
+
residx_rigidgroup_ambiguity_rot = Rotation(
|
445 |
+
rot_mats=residx_rigidgroup_ambiguity_rot
|
446 |
+
)
|
447 |
+
alt_gt_frames = gt_frames.compose(
|
448 |
+
Rigid(residx_rigidgroup_ambiguity_rot, None)
|
449 |
+
)
|
450 |
+
|
451 |
+
gt_frames_tensor = gt_frames.to_tensor_4x4()
|
452 |
+
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
|
453 |
+
|
454 |
+
protein["rigidgroups_gt_frames"] = gt_frames_tensor
|
455 |
+
protein["rigidgroups_gt_exists"] = gt_exists
|
456 |
+
protein["rigidgroups_group_exists"] = group_exists
|
457 |
+
protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
|
458 |
+
protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
|
459 |
+
|
460 |
+
return protein
|
461 |
+
|
462 |
+
|
463 |
+
def get_chi_atom_indices():
|
464 |
+
"""Returns atom indices needed to compute chi angles for all residue types.
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
|
468 |
+
in the order specified in rc.restypes + unknown residue type
|
469 |
+
at the end. For chi angles which are not defined on the residue, the
|
470 |
+
positions indices are by default set to 0.
|
471 |
+
"""
|
472 |
+
chi_atom_indices = []
|
473 |
+
for residue_name in rc.restypes:
|
474 |
+
residue_name = rc.restype_1to3[residue_name]
|
475 |
+
residue_chi_angles = rc.chi_angles_atoms[residue_name]
|
476 |
+
atom_indices = []
|
477 |
+
for chi_angle in residue_chi_angles:
|
478 |
+
atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
|
479 |
+
for _ in range(4 - len(atom_indices)):
|
480 |
+
atom_indices.append(
|
481 |
+
[0, 0, 0, 0]
|
482 |
+
) # For chi angles not defined on the AA.
|
483 |
+
chi_atom_indices.append(atom_indices)
|
484 |
+
|
485 |
+
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
|
486 |
+
|
487 |
+
return chi_atom_indices
|
488 |
+
|
489 |
+
|
490 |
+
@curry1
|
491 |
+
def atom37_to_torsion_angles(
|
492 |
+
protein,
|
493 |
+
prefix="",
|
494 |
+
):
|
495 |
+
"""
|
496 |
+
Convert coordinates to torsion angles.
|
497 |
+
|
498 |
+
This function is extremely sensitive to floating point imprecisions
|
499 |
+
and should be run with double precision whenever possible.
|
500 |
+
|
501 |
+
Args:
|
502 |
+
Dict containing:
|
503 |
+
* (prefix)aatype:
|
504 |
+
[*, N_res] residue indices
|
505 |
+
* (prefix)all_atom_positions:
|
506 |
+
[*, N_res, 37, 3] atom positions (in atom37
|
507 |
+
format)
|
508 |
+
* (prefix)all_atom_mask:
|
509 |
+
[*, N_res, 37] atom position mask
|
510 |
+
Returns:
|
511 |
+
The same dictionary updated with the following features:
|
512 |
+
|
513 |
+
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
|
514 |
+
Torsion angles
|
515 |
+
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
|
516 |
+
Alternate torsion angles (accounting for 180-degree symmetry)
|
517 |
+
"(prefix)torsion_angles_mask" ([*, N_res, 7])
|
518 |
+
Torsion angles mask
|
519 |
+
"""
|
520 |
+
aatype = protein[prefix + "aatype"]
|
521 |
+
all_atom_positions = protein[prefix + "all_atom_positions"]
|
522 |
+
all_atom_mask = protein[prefix + "all_atom_mask"]
|
523 |
+
|
524 |
+
aatype = torch.clamp(aatype, max=20)
|
525 |
+
|
526 |
+
pad = all_atom_positions.new_zeros(
|
527 |
+
[*all_atom_positions.shape[:-3], 1, 37, 3]
|
528 |
+
)
|
529 |
+
prev_all_atom_positions = torch.cat(
|
530 |
+
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
|
531 |
+
)
|
532 |
+
|
533 |
+
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
|
534 |
+
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
|
535 |
+
|
536 |
+
pre_omega_atom_pos = torch.cat(
|
537 |
+
[prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
|
538 |
+
dim=-2,
|
539 |
+
)
|
540 |
+
phi_atom_pos = torch.cat(
|
541 |
+
[prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
|
542 |
+
dim=-2,
|
543 |
+
)
|
544 |
+
psi_atom_pos = torch.cat(
|
545 |
+
[all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
|
546 |
+
dim=-2,
|
547 |
+
)
|
548 |
+
|
549 |
+
pre_omega_mask = torch.prod(
|
550 |
+
prev_all_atom_mask[..., 1:3], dim=-1
|
551 |
+
) * torch.prod(all_atom_mask[..., :2], dim=-1)
|
552 |
+
phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
|
553 |
+
all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
|
554 |
+
)
|
555 |
+
psi_mask = (
|
556 |
+
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
|
557 |
+
* all_atom_mask[..., 4]
|
558 |
+
)
|
559 |
+
|
560 |
+
chi_atom_indices = torch.as_tensor(
|
561 |
+
get_chi_atom_indices(), device=aatype.device
|
562 |
+
)
|
563 |
+
|
564 |
+
atom_indices = chi_atom_indices[..., aatype, :, :]
|
565 |
+
chis_atom_pos = batched_gather(
|
566 |
+
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
|
567 |
+
)
|
568 |
+
|
569 |
+
chi_angles_mask = list(rc.chi_angles_mask)
|
570 |
+
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
|
571 |
+
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
|
572 |
+
|
573 |
+
chis_mask = chi_angles_mask[aatype, :]
|
574 |
+
|
575 |
+
chi_angle_atoms_mask = batched_gather(
|
576 |
+
all_atom_mask,
|
577 |
+
atom_indices,
|
578 |
+
dim=-1,
|
579 |
+
no_batch_dims=len(atom_indices.shape[:-2]),
|
580 |
+
)
|
581 |
+
chi_angle_atoms_mask = torch.prod(
|
582 |
+
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
|
583 |
+
)
|
584 |
+
chis_mask = chis_mask * chi_angle_atoms_mask
|
585 |
+
|
586 |
+
torsions_atom_pos = torch.cat(
|
587 |
+
[
|
588 |
+
pre_omega_atom_pos[..., None, :, :],
|
589 |
+
phi_atom_pos[..., None, :, :],
|
590 |
+
psi_atom_pos[..., None, :, :],
|
591 |
+
chis_atom_pos,
|
592 |
+
],
|
593 |
+
dim=-3,
|
594 |
+
)
|
595 |
+
|
596 |
+
torsion_angles_mask = torch.cat(
|
597 |
+
[
|
598 |
+
pre_omega_mask[..., None],
|
599 |
+
phi_mask[..., None],
|
600 |
+
psi_mask[..., None],
|
601 |
+
chis_mask,
|
602 |
+
],
|
603 |
+
dim=-1,
|
604 |
+
)
|
605 |
+
|
606 |
+
torsion_frames = Rigid.from_3_points(
|
607 |
+
torsions_atom_pos[..., 1, :],
|
608 |
+
torsions_atom_pos[..., 2, :],
|
609 |
+
torsions_atom_pos[..., 0, :],
|
610 |
+
eps=1e-8,
|
611 |
+
)
|
612 |
+
|
613 |
+
fourth_atom_rel_pos = torsion_frames.invert().apply(
|
614 |
+
torsions_atom_pos[..., 3, :]
|
615 |
+
)
|
616 |
+
|
617 |
+
torsion_angles_sin_cos = torch.stack(
|
618 |
+
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
|
619 |
+
)
|
620 |
+
|
621 |
+
denom = torch.sqrt(
|
622 |
+
torch.sum(
|
623 |
+
torch.square(torsion_angles_sin_cos),
|
624 |
+
dim=-1,
|
625 |
+
dtype=torsion_angles_sin_cos.dtype,
|
626 |
+
keepdims=True,
|
627 |
+
)
|
628 |
+
+ 1e-8
|
629 |
+
)
|
630 |
+
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
|
631 |
+
|
632 |
+
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
|
633 |
+
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
|
634 |
+
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
|
635 |
+
|
636 |
+
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
|
637 |
+
rc.chi_pi_periodic,
|
638 |
+
)[aatype, ...]
|
639 |
+
|
640 |
+
mirror_torsion_angles = torch.cat(
|
641 |
+
[
|
642 |
+
all_atom_mask.new_ones(*aatype.shape, 3),
|
643 |
+
1.0 - 2.0 * chi_is_ambiguous,
|
644 |
+
],
|
645 |
+
dim=-1,
|
646 |
+
)
|
647 |
+
|
648 |
+
alt_torsion_angles_sin_cos = (
|
649 |
+
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
|
650 |
+
)
|
651 |
+
|
652 |
+
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
|
653 |
+
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
|
654 |
+
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
|
655 |
+
|
656 |
+
return protein
|
657 |
+
|
658 |
+
|
659 |
+
def get_backbone_frames(protein):
|
660 |
+
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
|
661 |
+
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
|
662 |
+
..., 0, :, :
|
663 |
+
]
|
664 |
+
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
|
665 |
+
|
666 |
+
return protein
|
667 |
+
|
668 |
+
|
669 |
+
def get_chi_angles(protein):
|
670 |
+
dtype = protein["all_atom_mask"].dtype
|
671 |
+
protein["chi_angles_sin_cos"] = (
|
672 |
+
protein["torsion_angles_sin_cos"][..., 3:, :]
|
673 |
+
).to(dtype)
|
674 |
+
protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
|
675 |
+
|
676 |
+
return protein
|
677 |
+
|
678 |
+
|
679 |
+
@curry1
|
680 |
+
def random_crop_to_size(
|
681 |
+
protein,
|
682 |
+
crop_size,
|
683 |
+
shape_schema,
|
684 |
+
seed=None,
|
685 |
+
):
|
686 |
+
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
|
687 |
+
# We want each ensemble to be cropped the same way
|
688 |
+
|
689 |
+
g = None
|
690 |
+
if seed is not None:
|
691 |
+
g = torch.Generator(device=protein["seq_length"].device)
|
692 |
+
g.manual_seed(seed)
|
693 |
+
|
694 |
+
seq_length = protein["seq_length"]
|
695 |
+
|
696 |
+
num_res_crop_size = min(int(seq_length), crop_size)
|
697 |
+
|
698 |
+
def _randint(lower, upper):
|
699 |
+
return int(torch.randint(
|
700 |
+
lower,
|
701 |
+
upper + 1,
|
702 |
+
(1,),
|
703 |
+
device=protein["seq_length"].device,
|
704 |
+
generator=g,
|
705 |
+
)[0])
|
706 |
+
|
707 |
+
n = seq_length - num_res_crop_size
|
708 |
+
if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
|
709 |
+
right_anchor = n
|
710 |
+
else:
|
711 |
+
x = _randint(0, n)
|
712 |
+
right_anchor = n - x
|
713 |
+
|
714 |
+
num_res_crop_start = _randint(0, right_anchor)
|
715 |
+
|
716 |
+
for k, v in protein.items():
|
717 |
+
if k not in shape_schema or (NUM_RES not in shape_schema[k]):
|
718 |
+
continue
|
719 |
+
|
720 |
+
slices = []
|
721 |
+
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
|
722 |
+
is_num_res = dim_size == NUM_RES
|
723 |
+
crop_start = num_res_crop_start if is_num_res else 0
|
724 |
+
crop_size = num_res_crop_size if is_num_res else dim
|
725 |
+
slices.append(slice(crop_start, crop_start + crop_size))
|
726 |
+
protein[k] = v[slices]
|
727 |
+
|
728 |
+
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
|
729 |
+
|
730 |
+
return protein
|
731 |
+
|
dockformer/data/errors.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""General-purpose errors used throughout the data pipeline"""
|
17 |
+
class Error(Exception):
|
18 |
+
"""Base class for exceptions."""
|
19 |
+
|
20 |
+
|
21 |
+
class MultipleChainsError(Error):
|
22 |
+
"""An error indicating that multiple chains were found for a given ID."""
|
dockformer/data/ligand_features.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from rdkit import Chem
|
6 |
+
|
7 |
+
from dockformer.data.utils import FeatureTensorDict
|
8 |
+
from dockformer.utils.consts import POSSIBLE_BOND_TYPES, POSSIBLE_ATOM_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES
|
9 |
+
|
10 |
+
|
11 |
+
def get_atom_features(atom: Chem.Atom):
|
12 |
+
# TODO: this is temporary, we need to add more features, for example for Zn
|
13 |
+
if atom.GetSymbol() not in POSSIBLE_ATOM_TYPES:
|
14 |
+
print(f"********Unknown atom type {atom.GetSymbol()}")
|
15 |
+
atom_type = POSSIBLE_ATOM_TYPES.index("Ni")
|
16 |
+
else:
|
17 |
+
atom_type = POSSIBLE_ATOM_TYPES.index(atom.GetSymbol())
|
18 |
+
atom_charge = POSSIBLE_CHARGES.index(max(min(atom.GetFormalCharge(), 1), -1))
|
19 |
+
atom_chirality = POSSIBLE_CHIRALITIES.index(atom.GetChiralTag())
|
20 |
+
|
21 |
+
return {"atom_type": atom_type, "atom_charge": atom_charge, "atom_chirality": atom_chirality}
|
22 |
+
|
23 |
+
|
24 |
+
def get_bond_features(bond: Chem.Bond):
|
25 |
+
bond_type = POSSIBLE_BOND_TYPES.index(bond.GetBondType())
|
26 |
+
return {"bond_type": bond_type}
|
27 |
+
|
28 |
+
|
29 |
+
def make_ligand_features(ligand: Chem.Mol) -> FeatureTensorDict:
|
30 |
+
atoms_features = []
|
31 |
+
atom_idx_to_atom_pos_idx = {}
|
32 |
+
for atom in ligand.GetAtoms():
|
33 |
+
atom_idx_to_atom_pos_idx[atom.GetIdx()] = len(atoms_features)
|
34 |
+
atoms_features.append(get_atom_features(atom))
|
35 |
+
|
36 |
+
atom_types = torch.tensor(np.array([atom["atom_type"] for atom in atoms_features], dtype=np.int64))
|
37 |
+
atom_types_one_hot = nn.functional.one_hot(atom_types, num_classes=len(POSSIBLE_ATOM_TYPES), )
|
38 |
+
atom_charges = torch.tensor(np.array([atom["atom_charge"] for atom in atoms_features], dtype=np.int64))
|
39 |
+
atom_charges_one_hot = nn.functional.one_hot(atom_charges, num_classes=len(POSSIBLE_CHARGES))
|
40 |
+
atom_chiralities = torch.tensor(np.array([atom["atom_chirality"] for atom in atoms_features], dtype=np.int64))
|
41 |
+
atom_chiralities_one_hot = nn.functional.one_hot(atom_chiralities, num_classes=len(POSSIBLE_CHIRALITIES))
|
42 |
+
|
43 |
+
ligand_target_feat = torch.cat([atom_types_one_hot.float(), atom_charges_one_hot.float(),
|
44 |
+
atom_chiralities_one_hot.float()], dim=1)
|
45 |
+
|
46 |
+
# create one-hot matrix encoding for bonds
|
47 |
+
ligand_bonds_feat = torch.zeros((len(atoms_features), len(atoms_features), len(POSSIBLE_BOND_TYPES)))
|
48 |
+
ligand_bonds = []
|
49 |
+
for bond in ligand.GetBonds():
|
50 |
+
atom1_idx = atom_idx_to_atom_pos_idx[bond.GetBeginAtomIdx()]
|
51 |
+
atom2_idx = atom_idx_to_atom_pos_idx[bond.GetEndAtomIdx()]
|
52 |
+
bond_features = get_bond_features(bond)
|
53 |
+
ligand_bonds.append((atom1_idx, atom2_idx, bond_features["bond_type"]))
|
54 |
+
ligand_bonds_feat[atom1_idx, atom2_idx, bond_features["bond_type"]] = 1
|
55 |
+
|
56 |
+
return {
|
57 |
+
# These are used for reconstruction at the end of the pipeline
|
58 |
+
"ligand_atype": atom_types,
|
59 |
+
"ligand_charge": atom_charges,
|
60 |
+
"ligand_chirality": atom_chiralities,
|
61 |
+
"ligand_bonds": torch.tensor(ligand_bonds, dtype=torch.int64),
|
62 |
+
# these are the actual features
|
63 |
+
"ligand_target_feat": ligand_target_feat.float(),
|
64 |
+
"ligand_bonds_feat": ligand_bonds_feat.float(),
|
65 |
+
}
|
66 |
+
|
dockformer/data/parsers.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Functions for parsing various file formats."""
|
17 |
+
import collections
|
18 |
+
import dataclasses
|
19 |
+
import itertools
|
20 |
+
import re
|
21 |
+
import string
|
22 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
|
23 |
+
|
24 |
+
|
25 |
+
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
|
26 |
+
"""Parses FASTA string and returns list of strings with amino-acid sequences.
|
27 |
+
|
28 |
+
Arguments:
|
29 |
+
fasta_string: The string contents of a FASTA file.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
A tuple of two lists:
|
33 |
+
* A list of sequences.
|
34 |
+
* A list of sequence descriptions taken from the comment lines. In the
|
35 |
+
same order as the sequences.
|
36 |
+
"""
|
37 |
+
sequences = []
|
38 |
+
descriptions = []
|
39 |
+
index = -1
|
40 |
+
for line in fasta_string.splitlines():
|
41 |
+
line = line.strip()
|
42 |
+
if line.startswith(">"):
|
43 |
+
index += 1
|
44 |
+
descriptions.append(line[1:]) # Remove the '>' at the beginning.
|
45 |
+
sequences.append("")
|
46 |
+
continue
|
47 |
+
elif line.startswith("#"):
|
48 |
+
continue
|
49 |
+
elif not line:
|
50 |
+
continue # Skip blank lines.
|
51 |
+
sequences[index] += line
|
52 |
+
|
53 |
+
return sequences, descriptions
|
dockformer/data/protein_features.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from dockformer.data.utils import FeatureDict
|
4 |
+
from dockformer.utils import residue_constants, protein
|
5 |
+
|
6 |
+
|
7 |
+
def _make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict:
|
8 |
+
"""Construct a feature dict of sequence features."""
|
9 |
+
features = {}
|
10 |
+
features["aatype"] = residue_constants.sequence_to_onehot(
|
11 |
+
sequence=sequence,
|
12 |
+
mapping=residue_constants.restype_order_with_x,
|
13 |
+
map_unknown_to_x=True,
|
14 |
+
)
|
15 |
+
features["domain_name"] = np.array(
|
16 |
+
[description.encode("utf-8")], dtype=object
|
17 |
+
)
|
18 |
+
# features["residue_index"] = np.array(range(num_res), dtype=np.int32)
|
19 |
+
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
|
20 |
+
features["sequence"] = np.array(
|
21 |
+
[sequence.encode("utf-8")], dtype=object
|
22 |
+
)
|
23 |
+
return features
|
24 |
+
|
25 |
+
|
26 |
+
def _aatype_to_str_sequence(aatype):
|
27 |
+
return ''.join([
|
28 |
+
residue_constants.restypes_with_x[aatype[i]]
|
29 |
+
for i in range(len(aatype))
|
30 |
+
])
|
31 |
+
|
32 |
+
|
33 |
+
def _make_protein_structure_features(protein_object: protein.Protein) -> FeatureDict:
|
34 |
+
pdb_feats = {}
|
35 |
+
|
36 |
+
all_atom_positions = protein_object.atom_positions
|
37 |
+
all_atom_mask = protein_object.atom_mask
|
38 |
+
|
39 |
+
pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
|
40 |
+
pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
|
41 |
+
pdb_feats["in_chain_residue_index"] = protein_object.residue_index.astype(np.int32)
|
42 |
+
|
43 |
+
gapped_res_indexes = []
|
44 |
+
prev_chain_index = protein_object.chain_index[0]
|
45 |
+
chain_start_res_ind = 0
|
46 |
+
for relative_res_ind, chain_index in zip(protein_object.residue_index, protein_object.chain_index):
|
47 |
+
if chain_index != prev_chain_index:
|
48 |
+
chain_start_res_ind = gapped_res_indexes[-1] + 50
|
49 |
+
prev_chain_index = chain_index
|
50 |
+
gapped_res_indexes.append(relative_res_ind + chain_start_res_ind)
|
51 |
+
|
52 |
+
pdb_feats["residue_index"] = np.array(gapped_res_indexes).astype(np.int32)
|
53 |
+
pdb_feats["chain_index"] = np.array(protein_object.chain_index).astype(np.int32)
|
54 |
+
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
|
55 |
+
|
56 |
+
return pdb_feats
|
57 |
+
|
58 |
+
|
59 |
+
def make_protein_features(protein_object: protein.Protein, description: str) -> FeatureDict:
|
60 |
+
feats = {}
|
61 |
+
aatype = protein_object.aatype
|
62 |
+
sequence = _aatype_to_str_sequence(aatype)
|
63 |
+
feats.update(
|
64 |
+
_make_sequence_features(sequence=sequence, description=description, num_res=len(protein_object.aatype))
|
65 |
+
)
|
66 |
+
|
67 |
+
feats.update(
|
68 |
+
_make_protein_structure_features(protein_object=protein_object)
|
69 |
+
)
|
70 |
+
|
71 |
+
return feats
|
dockformer/data/utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Common utilities for data pipeline tools."""
|
17 |
+
import contextlib
|
18 |
+
import datetime
|
19 |
+
import logging
|
20 |
+
import shutil
|
21 |
+
import tempfile
|
22 |
+
import time
|
23 |
+
from typing import Optional, Mapping, Dict
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
|
28 |
+
FeatureDict = Dict[str, np.ndarray]
|
29 |
+
FeatureTensorDict = Dict[str, torch.Tensor]
|
30 |
+
|
31 |
+
|
32 |
+
@contextlib.contextmanager
|
33 |
+
def tmpdir_manager(base_dir: Optional[str] = None):
|
34 |
+
"""Context manager that deletes a temporary directory on exit."""
|
35 |
+
tmpdir = tempfile.mkdtemp(dir=base_dir)
|
36 |
+
try:
|
37 |
+
yield tmpdir
|
38 |
+
finally:
|
39 |
+
shutil.rmtree(tmpdir, ignore_errors=True)
|
40 |
+
|
41 |
+
|
42 |
+
@contextlib.contextmanager
|
43 |
+
def timing(msg: str):
|
44 |
+
logging.info("Started %s", msg)
|
45 |
+
tic = time.perf_counter()
|
46 |
+
yield
|
47 |
+
toc = time.perf_counter()
|
48 |
+
logging.info("Finished %s in %.3f seconds", msg, toc - tic)
|
49 |
+
|
50 |
+
|
51 |
+
def to_date(s: str):
|
52 |
+
return datetime.datetime(
|
53 |
+
year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
|
54 |
+
)
|
dockformer/model/__init__.py
ADDED
File without changes
|
dockformer/model/dropout.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from functools import partialmethod
|
19 |
+
from typing import Union, List
|
20 |
+
|
21 |
+
|
22 |
+
class Dropout(nn.Module):
|
23 |
+
"""
|
24 |
+
Implementation of dropout with the ability to share the dropout mask
|
25 |
+
along a particular dimension.
|
26 |
+
|
27 |
+
If not in training mode, this module computes the identity function.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
r:
|
34 |
+
Dropout rate
|
35 |
+
batch_dim:
|
36 |
+
Dimension(s) along which the dropout mask is shared
|
37 |
+
"""
|
38 |
+
super(Dropout, self).__init__()
|
39 |
+
|
40 |
+
self.r = r
|
41 |
+
if type(batch_dim) == int:
|
42 |
+
batch_dim = [batch_dim]
|
43 |
+
self.batch_dim = batch_dim
|
44 |
+
self.dropout = nn.Dropout(self.r)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
x:
|
50 |
+
Tensor to which dropout is applied. Can have any shape
|
51 |
+
compatible with self.batch_dim
|
52 |
+
"""
|
53 |
+
shape = list(x.shape)
|
54 |
+
if self.batch_dim is not None:
|
55 |
+
for bd in self.batch_dim:
|
56 |
+
shape[bd] = 1
|
57 |
+
mask = x.new_ones(shape)
|
58 |
+
mask = self.dropout(mask)
|
59 |
+
x *= mask
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class DropoutRowwise(Dropout):
|
64 |
+
"""
|
65 |
+
Convenience class for rowwise dropout as described in subsection
|
66 |
+
1.11.6.
|
67 |
+
"""
|
68 |
+
|
69 |
+
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
|
dockformer/model/embedders.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from typing import Tuple, Optional
|
21 |
+
|
22 |
+
from dockformer.model.primitives import Linear, LayerNorm
|
23 |
+
from dockformer.utils.tensor_utils import add
|
24 |
+
|
25 |
+
|
26 |
+
class StructureInputEmbedder(nn.Module):
|
27 |
+
"""
|
28 |
+
Embeds a subset of the input features.
|
29 |
+
|
30 |
+
Implements a merge of Algorithms 3 and Algorithm 32.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
protein_tf_dim: int,
|
36 |
+
ligand_tf_dim: int,
|
37 |
+
additional_tf_dim: int,
|
38 |
+
ligand_bond_dim: int,
|
39 |
+
c_z: int,
|
40 |
+
c_m: int,
|
41 |
+
relpos_k: int,
|
42 |
+
prot_min_bin: float,
|
43 |
+
prot_max_bin: float,
|
44 |
+
prot_no_bins: int,
|
45 |
+
lig_min_bin: float,
|
46 |
+
lig_max_bin: float,
|
47 |
+
lig_no_bins: int,
|
48 |
+
inf: float = 1e8,
|
49 |
+
**kwargs,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
tf_dim:
|
54 |
+
Final dimension of the target features
|
55 |
+
c_z:
|
56 |
+
Pair embedding dimension
|
57 |
+
c_m:
|
58 |
+
Single embedding dimension
|
59 |
+
relpos_k:
|
60 |
+
Window size used in relative positional encoding
|
61 |
+
"""
|
62 |
+
super(StructureInputEmbedder, self).__init__()
|
63 |
+
|
64 |
+
self.tf_dim = protein_tf_dim + ligand_tf_dim + additional_tf_dim
|
65 |
+
self.pair_tf_dim = ligand_bond_dim
|
66 |
+
|
67 |
+
self.c_z = c_z
|
68 |
+
self.c_m = c_m
|
69 |
+
|
70 |
+
self.linear_tf_z_i = Linear(self.tf_dim, c_z)
|
71 |
+
self.linear_tf_z_j = Linear(self.tf_dim, c_z)
|
72 |
+
self.linear_tf_m = Linear(self.tf_dim, c_m)
|
73 |
+
|
74 |
+
self.ligand_linear_bond_z = Linear(ligand_bond_dim, c_z)
|
75 |
+
|
76 |
+
# RPE stuff
|
77 |
+
self.relpos_k = relpos_k
|
78 |
+
self.no_bins = 2 * relpos_k + 1
|
79 |
+
self.linear_relpos = Linear(self.no_bins, c_z)
|
80 |
+
|
81 |
+
# Recycling stuff
|
82 |
+
self.prot_min_bin = prot_min_bin
|
83 |
+
self.prot_max_bin = prot_max_bin
|
84 |
+
self.prot_no_bins = prot_no_bins
|
85 |
+
self.lig_min_bin = lig_min_bin
|
86 |
+
self.lig_max_bin = lig_max_bin
|
87 |
+
self.lig_no_bins = lig_no_bins
|
88 |
+
self.inf = inf
|
89 |
+
|
90 |
+
self.prot_recycling_linear = Linear(self.prot_no_bins + 1, self.c_z)
|
91 |
+
self.lig_recycling_linear = Linear(self.lig_no_bins, self.c_z)
|
92 |
+
self.layer_norm_m = LayerNorm(self.c_m)
|
93 |
+
self.layer_norm_z = LayerNorm(self.c_z)
|
94 |
+
|
95 |
+
def relpos(self, ri: torch.Tensor):
|
96 |
+
"""
|
97 |
+
Computes relative positional encodings
|
98 |
+
|
99 |
+
Implements Algorithm 4.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
ri:
|
103 |
+
"residue_index" features of shape [*, N]
|
104 |
+
"""
|
105 |
+
d = ri[..., None] - ri[..., None, :]
|
106 |
+
boundaries = torch.arange(
|
107 |
+
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
|
108 |
+
)
|
109 |
+
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
|
110 |
+
d = d[..., None] - reshaped_bins
|
111 |
+
d = torch.abs(d)
|
112 |
+
d = torch.argmin(d, dim=-1)
|
113 |
+
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
|
114 |
+
d = d.to(ri.dtype)
|
115 |
+
return self.linear_relpos(d)
|
116 |
+
|
117 |
+
def _get_binned_distogram(self, x, min_bin, max_bin, no_bins, recycling_linear, prot_distogram_mask=None):
|
118 |
+
# This squared method might become problematic in FP16 mode.
|
119 |
+
bins = torch.linspace(
|
120 |
+
min_bin,
|
121 |
+
max_bin,
|
122 |
+
no_bins,
|
123 |
+
dtype=x.dtype,
|
124 |
+
device=x.device,
|
125 |
+
requires_grad=False,
|
126 |
+
)
|
127 |
+
squared_bins = bins ** 2
|
128 |
+
upper = torch.cat(
|
129 |
+
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
|
130 |
+
)
|
131 |
+
d = torch.sum((x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True)
|
132 |
+
|
133 |
+
# [*, N, N, no_bins]
|
134 |
+
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
|
135 |
+
# print("d shape", d.shape, d[0][0][:10])
|
136 |
+
|
137 |
+
if prot_distogram_mask is not None:
|
138 |
+
expanded_d = torch.cat([d, torch.zeros(*d.shape[:-1], 1, device=d.device)], dim=-1)
|
139 |
+
|
140 |
+
# Step 2: Create a mask where `input_positions_masked` is 0
|
141 |
+
# Use broadcasting and tensor operations directly without additional variables
|
142 |
+
input_positions_mask = (prot_distogram_mask == 1).float() # Shape [N, crop_size]
|
143 |
+
mask_i = input_positions_mask.unsqueeze(2) # Shape [N, crop_size, 1]
|
144 |
+
mask_j = input_positions_mask.unsqueeze(1) # Shape [N, 1, crop_size]
|
145 |
+
|
146 |
+
# Step 3: Combine masks for both [N, :, i, :] and [N, i, :, :]
|
147 |
+
combined_mask = mask_i + mask_j # Shape [N, crop_size, crop_size]
|
148 |
+
combined_mask = combined_mask.clamp(max=1) # Ensure binary mask
|
149 |
+
|
150 |
+
# Step 4: Apply the mask
|
151 |
+
# a. Set all but the last position in the `no_bins + 1` dimension to 0 where the mask is 1
|
152 |
+
expanded_d[..., :-1] *= (1 - combined_mask).unsqueeze(-1) # Shape [N, crop_size, crop_size, no_bins]
|
153 |
+
|
154 |
+
# print("expanded_d shape1", expanded_d.shape, expanded_d[0][0][:10])
|
155 |
+
|
156 |
+
# b. Set the last position in the `no_bins + 1` dimension to 1 where the mask is 1
|
157 |
+
expanded_d[..., -1] += combined_mask # Shape [N, crop_size, crop_size, 1]
|
158 |
+
d = expanded_d
|
159 |
+
# print("expanded_d shape2", d.shape, d[0][0][:10])
|
160 |
+
|
161 |
+
return recycling_linear(d)
|
162 |
+
|
163 |
+
def forward(
|
164 |
+
self,
|
165 |
+
token_mask: torch.Tensor,
|
166 |
+
protein_mask: torch.Tensor,
|
167 |
+
ligand_mask: torch.Tensor,
|
168 |
+
target_feat: torch.Tensor,
|
169 |
+
ligand_bonds_feat: torch.Tensor,
|
170 |
+
input_positions: torch.Tensor,
|
171 |
+
protein_residue_index: torch.Tensor,
|
172 |
+
protein_distogram_mask: torch.Tensor,
|
173 |
+
inplace_safe: bool = False,
|
174 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
175 |
+
"""
|
176 |
+
Args:
|
177 |
+
batch: Dict containing
|
178 |
+
"protein_target_feat":
|
179 |
+
Features of shape [*, N_res + N_lig_atoms, tf_dim]
|
180 |
+
"residue_index":
|
181 |
+
Features of shape [*, N_res]
|
182 |
+
input_protein_coords:
|
183 |
+
[*, N_res, 3] AF predicted C_beta coordinates supplied as input
|
184 |
+
ligand_bonds_feat:
|
185 |
+
[*, N_lig_atoms, N_lig_atoms, tf_dim] ligand bonds features
|
186 |
+
Returns:
|
187 |
+
single_emb:
|
188 |
+
[*, N_res + N_lig_atoms, C_m] single embedding
|
189 |
+
pair_emb:
|
190 |
+
[*, N_res + N_lig_atoms, N_res + N_lig_atoms, C_z] pair embedding
|
191 |
+
|
192 |
+
"""
|
193 |
+
device = token_mask.device
|
194 |
+
pair_protein_mask = protein_mask[..., None] * protein_mask[..., None, :]
|
195 |
+
pair_ligand_mask = ligand_mask[..., None] * ligand_mask[..., None, :]
|
196 |
+
|
197 |
+
# Single representation embedding - Algorithm 3
|
198 |
+
tf_m = self.linear_tf_m(target_feat)
|
199 |
+
tf_m = self.layer_norm_m(tf_m) # previously this happend in the do_recycle function
|
200 |
+
|
201 |
+
# Pair representation
|
202 |
+
# protein pair embedding - Algorithm 3
|
203 |
+
# [*, N_res, c_z]
|
204 |
+
tf_emb_i = self.linear_tf_z_i(target_feat)
|
205 |
+
tf_emb_j = self.linear_tf_z_j(target_feat)
|
206 |
+
|
207 |
+
pair_emb = torch.zeros(*pair_protein_mask.shape, self.c_z, device=device)
|
208 |
+
pair_emb = add(pair_emb, tf_emb_i[..., None, :], inplace=inplace_safe)
|
209 |
+
pair_emb = add(pair_emb, tf_emb_j[..., None, :, :], inplace=inplace_safe)
|
210 |
+
|
211 |
+
# Apply relpos
|
212 |
+
relpos = self.relpos(protein_residue_index.type(tf_emb_i.dtype))
|
213 |
+
pair_emb += pair_protein_mask[..., None] * relpos
|
214 |
+
|
215 |
+
del relpos
|
216 |
+
|
217 |
+
# apply ligand bonds
|
218 |
+
ligand_bonds = self.ligand_linear_bond_z(ligand_bonds_feat)
|
219 |
+
pair_emb += pair_ligand_mask[..., None] * ligand_bonds
|
220 |
+
|
221 |
+
del ligand_bonds
|
222 |
+
|
223 |
+
# before recycles, do z_norm, this previously was a part of the recycles
|
224 |
+
pair_emb = self.layer_norm_z(pair_emb)
|
225 |
+
|
226 |
+
# apply protein recycle
|
227 |
+
prot_distogram_embed = self._get_binned_distogram(input_positions, self.prot_min_bin, self.prot_max_bin,
|
228 |
+
self.prot_no_bins, self.prot_recycling_linear,
|
229 |
+
protein_distogram_mask)
|
230 |
+
|
231 |
+
|
232 |
+
pair_emb = add(pair_emb, prot_distogram_embed * pair_protein_mask.unsqueeze(-1), inplace_safe)
|
233 |
+
|
234 |
+
del prot_distogram_embed
|
235 |
+
|
236 |
+
# apply ligand recycle
|
237 |
+
lig_distogram_embed = self._get_binned_distogram(input_positions, self.lig_min_bin, self.lig_max_bin,
|
238 |
+
self.lig_no_bins, self.lig_recycling_linear)
|
239 |
+
pair_emb = add(pair_emb, lig_distogram_embed * pair_ligand_mask.unsqueeze(-1), inplace_safe)
|
240 |
+
|
241 |
+
del lig_distogram_embed
|
242 |
+
|
243 |
+
return tf_m, pair_emb
|
244 |
+
|
245 |
+
|
246 |
+
class RecyclingEmbedder(nn.Module):
|
247 |
+
"""
|
248 |
+
Embeds the output of an iteration of the model for recycling.
|
249 |
+
|
250 |
+
Implements Algorithm 32.
|
251 |
+
"""
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
c_m: int,
|
255 |
+
c_z: int,
|
256 |
+
min_bin: float,
|
257 |
+
max_bin: float,
|
258 |
+
no_bins: int,
|
259 |
+
inf: float = 1e8,
|
260 |
+
**kwargs,
|
261 |
+
):
|
262 |
+
"""
|
263 |
+
Args:
|
264 |
+
c_m:
|
265 |
+
Single channel dimension
|
266 |
+
c_z:
|
267 |
+
Pair embedding channel dimension
|
268 |
+
min_bin:
|
269 |
+
Smallest distogram bin (Angstroms)
|
270 |
+
max_bin:
|
271 |
+
Largest distogram bin (Angstroms)
|
272 |
+
no_bins:
|
273 |
+
Number of distogram bins
|
274 |
+
"""
|
275 |
+
super(RecyclingEmbedder, self).__init__()
|
276 |
+
|
277 |
+
self.c_m = c_m
|
278 |
+
self.c_z = c_z
|
279 |
+
self.min_bin = min_bin
|
280 |
+
self.max_bin = max_bin
|
281 |
+
self.no_bins = no_bins
|
282 |
+
self.inf = inf
|
283 |
+
|
284 |
+
self.linear = Linear(self.no_bins, self.c_z)
|
285 |
+
self.layer_norm_m = LayerNorm(self.c_m)
|
286 |
+
self.layer_norm_z = LayerNorm(self.c_z)
|
287 |
+
|
288 |
+
def forward(
|
289 |
+
self,
|
290 |
+
m: torch.Tensor,
|
291 |
+
z: torch.Tensor,
|
292 |
+
x: torch.Tensor,
|
293 |
+
inplace_safe: bool = False,
|
294 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
295 |
+
"""
|
296 |
+
Args:
|
297 |
+
m:
|
298 |
+
First row of the single embedding. [*, N_res, C_m]
|
299 |
+
z:
|
300 |
+
[*, N_res, N_res, C_z] pair embedding
|
301 |
+
x:
|
302 |
+
[*, N_res, 3] predicted C_beta coordinates
|
303 |
+
Returns:
|
304 |
+
m:
|
305 |
+
[*, N_res, C_m] single embedding update
|
306 |
+
z:
|
307 |
+
[*, N_res, N_res, C_z] pair embedding update
|
308 |
+
"""
|
309 |
+
# [*, N, C_m]
|
310 |
+
m_update = self.layer_norm_m(m)
|
311 |
+
if(inplace_safe):
|
312 |
+
m.copy_(m_update)
|
313 |
+
m_update = m
|
314 |
+
|
315 |
+
# [*, N, N, C_z]
|
316 |
+
z_update = self.layer_norm_z(z)
|
317 |
+
if(inplace_safe):
|
318 |
+
z.copy_(z_update)
|
319 |
+
z_update = z
|
320 |
+
|
321 |
+
# This squared method might become problematic in FP16 mode.
|
322 |
+
bins = torch.linspace(
|
323 |
+
self.min_bin,
|
324 |
+
self.max_bin,
|
325 |
+
self.no_bins,
|
326 |
+
dtype=x.dtype,
|
327 |
+
device=x.device,
|
328 |
+
requires_grad=False,
|
329 |
+
)
|
330 |
+
squared_bins = bins ** 2
|
331 |
+
upper = torch.cat(
|
332 |
+
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
|
333 |
+
)
|
334 |
+
d = torch.sum(
|
335 |
+
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
|
336 |
+
)
|
337 |
+
|
338 |
+
# [*, N, N, no_bins]
|
339 |
+
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
|
340 |
+
|
341 |
+
# [*, N, N, C_z]
|
342 |
+
d = self.linear(d)
|
343 |
+
z_update = add(z_update, d, inplace_safe)
|
344 |
+
|
345 |
+
return m_update, z_update
|
346 |
+
|
dockformer/model/evoformer.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import math
|
16 |
+
import sys
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from typing import Tuple, Sequence, Optional
|
20 |
+
from functools import partial
|
21 |
+
from abc import ABC, abstractmethod
|
22 |
+
|
23 |
+
from dockformer.model.primitives import Linear, LayerNorm
|
24 |
+
from dockformer.model.dropout import DropoutRowwise
|
25 |
+
from dockformer.model.single_attention import SingleRowAttentionWithPairBias
|
26 |
+
|
27 |
+
from dockformer.model.pair_transition import PairTransition
|
28 |
+
from dockformer.model.triangular_attention import (
|
29 |
+
TriangleAttention,
|
30 |
+
)
|
31 |
+
from dockformer.model.triangular_multiplicative_update import (
|
32 |
+
TriangleMultiplicationOutgoing,
|
33 |
+
TriangleMultiplicationIncoming,
|
34 |
+
)
|
35 |
+
from dockformer.utils.checkpointing import checkpoint_blocks
|
36 |
+
from dockformer.utils.tensor_utils import add
|
37 |
+
|
38 |
+
|
39 |
+
class SingleRepTransition(nn.Module):
|
40 |
+
"""
|
41 |
+
Feed-forward network applied to single representation activations after attention.
|
42 |
+
|
43 |
+
Implements Algorithm 9
|
44 |
+
"""
|
45 |
+
def __init__(self, c_m, n):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
c_m:
|
49 |
+
channel dimension
|
50 |
+
n:
|
51 |
+
Factor multiplied to c_m to obtain the hidden channel dimension
|
52 |
+
"""
|
53 |
+
super(SingleRepTransition, self).__init__()
|
54 |
+
|
55 |
+
self.c_m = c_m
|
56 |
+
self.n = n
|
57 |
+
|
58 |
+
self.layer_norm = LayerNorm(self.c_m)
|
59 |
+
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
|
60 |
+
self.relu = nn.ReLU()
|
61 |
+
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
|
62 |
+
|
63 |
+
def _transition(self, m, mask):
|
64 |
+
m = self.layer_norm(m)
|
65 |
+
m = self.linear_1(m)
|
66 |
+
m = self.relu(m)
|
67 |
+
m = self.linear_2(m) * mask
|
68 |
+
return m
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self,
|
72 |
+
m: torch.Tensor,
|
73 |
+
mask: Optional[torch.Tensor] = None,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
m:
|
78 |
+
[*, N_res, C_m] activation after attention
|
79 |
+
mask:
|
80 |
+
[*, N_res, C_m] mask
|
81 |
+
Returns:
|
82 |
+
m:
|
83 |
+
[*, N_res, C_m] activation update
|
84 |
+
"""
|
85 |
+
# DISCREPANCY: DeepMind forgets to apply the mask here.
|
86 |
+
if mask is None:
|
87 |
+
mask = m.new_ones(m.shape[:-1])
|
88 |
+
|
89 |
+
mask = mask.unsqueeze(-1)
|
90 |
+
|
91 |
+
m = self._transition(m, mask)
|
92 |
+
|
93 |
+
return m
|
94 |
+
|
95 |
+
|
96 |
+
class PairStack(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
c_z: int,
|
100 |
+
c_hidden_mul: int,
|
101 |
+
c_hidden_pair_att: int,
|
102 |
+
no_heads_pair: int,
|
103 |
+
transition_n: int,
|
104 |
+
pair_dropout: float,
|
105 |
+
inf: float,
|
106 |
+
eps: float
|
107 |
+
):
|
108 |
+
super(PairStack, self).__init__()
|
109 |
+
|
110 |
+
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
111 |
+
c_z,
|
112 |
+
c_hidden_mul,
|
113 |
+
)
|
114 |
+
self.tri_mul_in = TriangleMultiplicationIncoming(
|
115 |
+
c_z,
|
116 |
+
c_hidden_mul,
|
117 |
+
)
|
118 |
+
|
119 |
+
self.tri_att_start = TriangleAttention(
|
120 |
+
c_z,
|
121 |
+
c_hidden_pair_att,
|
122 |
+
no_heads_pair,
|
123 |
+
inf=inf,
|
124 |
+
)
|
125 |
+
self.tri_att_end = TriangleAttention(
|
126 |
+
c_z,
|
127 |
+
c_hidden_pair_att,
|
128 |
+
no_heads_pair,
|
129 |
+
inf=inf,
|
130 |
+
)
|
131 |
+
|
132 |
+
self.pair_transition = PairTransition(
|
133 |
+
c_z,
|
134 |
+
transition_n,
|
135 |
+
)
|
136 |
+
|
137 |
+
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
|
138 |
+
|
139 |
+
def forward(self,
|
140 |
+
z: torch.Tensor,
|
141 |
+
pair_mask: torch.Tensor,
|
142 |
+
use_lma: bool = False,
|
143 |
+
inplace_safe: bool = False,
|
144 |
+
_mask_trans: bool = True,
|
145 |
+
) -> torch.Tensor:
|
146 |
+
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
147 |
+
# should be disabled to better approximate the exact activations of
|
148 |
+
# the original.
|
149 |
+
pair_trans_mask = pair_mask if _mask_trans else None
|
150 |
+
|
151 |
+
tmu_update = self.tri_mul_out(
|
152 |
+
z,
|
153 |
+
mask=pair_mask,
|
154 |
+
inplace_safe=inplace_safe,
|
155 |
+
_add_with_inplace=True,
|
156 |
+
)
|
157 |
+
if (not inplace_safe):
|
158 |
+
z = z + self.ps_dropout_row_layer(tmu_update)
|
159 |
+
else:
|
160 |
+
z = tmu_update
|
161 |
+
|
162 |
+
del tmu_update
|
163 |
+
|
164 |
+
tmu_update = self.tri_mul_in(
|
165 |
+
z,
|
166 |
+
mask=pair_mask,
|
167 |
+
inplace_safe=inplace_safe,
|
168 |
+
_add_with_inplace=True,
|
169 |
+
)
|
170 |
+
if (not inplace_safe):
|
171 |
+
z = z + self.ps_dropout_row_layer(tmu_update)
|
172 |
+
else:
|
173 |
+
z = tmu_update
|
174 |
+
|
175 |
+
del tmu_update
|
176 |
+
|
177 |
+
z = add(z,
|
178 |
+
self.ps_dropout_row_layer(
|
179 |
+
self.tri_att_start(
|
180 |
+
z,
|
181 |
+
mask=pair_mask,
|
182 |
+
use_memory_efficient_kernel=False,
|
183 |
+
use_lma=use_lma,
|
184 |
+
)
|
185 |
+
),
|
186 |
+
inplace=inplace_safe,
|
187 |
+
)
|
188 |
+
|
189 |
+
z = z.transpose(-2, -3)
|
190 |
+
if (inplace_safe):
|
191 |
+
z = z.contiguous()
|
192 |
+
|
193 |
+
z = add(z,
|
194 |
+
self.ps_dropout_row_layer(
|
195 |
+
self.tri_att_end(
|
196 |
+
z,
|
197 |
+
mask=pair_mask.transpose(-1, -2),
|
198 |
+
use_memory_efficient_kernel=False,
|
199 |
+
use_lma=use_lma,
|
200 |
+
)
|
201 |
+
),
|
202 |
+
inplace=inplace_safe,
|
203 |
+
)
|
204 |
+
|
205 |
+
z = z.transpose(-2, -3)
|
206 |
+
if (inplace_safe):
|
207 |
+
z = z.contiguous()
|
208 |
+
|
209 |
+
z = add(z,
|
210 |
+
self.pair_transition(
|
211 |
+
z, mask=pair_trans_mask,
|
212 |
+
),
|
213 |
+
inplace=inplace_safe,
|
214 |
+
)
|
215 |
+
|
216 |
+
return z
|
217 |
+
|
218 |
+
|
219 |
+
class EvoformerBlock(nn.Module, ABC):
|
220 |
+
def __init__(self,
|
221 |
+
c_m: int,
|
222 |
+
c_z: int,
|
223 |
+
c_hidden_single_att: int,
|
224 |
+
c_hidden_mul: int,
|
225 |
+
c_hidden_pair_att: int,
|
226 |
+
no_heads_single: int,
|
227 |
+
no_heads_pair: int,
|
228 |
+
transition_n: int,
|
229 |
+
single_dropout: float,
|
230 |
+
pair_dropout: float,
|
231 |
+
inf: float,
|
232 |
+
eps: float,
|
233 |
+
):
|
234 |
+
super(EvoformerBlock, self).__init__()
|
235 |
+
|
236 |
+
self.single_att_row = SingleRowAttentionWithPairBias(
|
237 |
+
c_m=c_m,
|
238 |
+
c_z=c_z,
|
239 |
+
c_hidden=c_hidden_single_att,
|
240 |
+
no_heads=no_heads_single,
|
241 |
+
inf=inf,
|
242 |
+
)
|
243 |
+
|
244 |
+
self.single_dropout_layer = DropoutRowwise(single_dropout)
|
245 |
+
|
246 |
+
self.single_transition = SingleRepTransition(
|
247 |
+
c_m=c_m,
|
248 |
+
n=transition_n,
|
249 |
+
)
|
250 |
+
|
251 |
+
self.pair_stack = PairStack(
|
252 |
+
c_z=c_z,
|
253 |
+
c_hidden_mul=c_hidden_mul,
|
254 |
+
c_hidden_pair_att=c_hidden_pair_att,
|
255 |
+
no_heads_pair=no_heads_pair,
|
256 |
+
transition_n=transition_n,
|
257 |
+
pair_dropout=pair_dropout,
|
258 |
+
inf=inf,
|
259 |
+
eps=eps
|
260 |
+
)
|
261 |
+
|
262 |
+
def forward(self,
|
263 |
+
m: Optional[torch.Tensor],
|
264 |
+
z: Optional[torch.Tensor],
|
265 |
+
single_mask: torch.Tensor,
|
266 |
+
pair_mask: torch.Tensor,
|
267 |
+
use_lma: bool = False,
|
268 |
+
inplace_safe: bool = False,
|
269 |
+
_mask_trans: bool = True,
|
270 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
271 |
+
|
272 |
+
single_trans_mask = single_mask if _mask_trans else None
|
273 |
+
|
274 |
+
input_tensors = [m, z]
|
275 |
+
|
276 |
+
m, z = input_tensors
|
277 |
+
|
278 |
+
z = self.pair_stack(
|
279 |
+
z=z,
|
280 |
+
pair_mask=pair_mask,
|
281 |
+
use_lma=use_lma,
|
282 |
+
inplace_safe=inplace_safe,
|
283 |
+
_mask_trans=_mask_trans,
|
284 |
+
)
|
285 |
+
|
286 |
+
m = add(m,
|
287 |
+
self.single_dropout_layer(
|
288 |
+
self.single_att_row(
|
289 |
+
m,
|
290 |
+
z=z,
|
291 |
+
mask=single_mask,
|
292 |
+
use_memory_efficient_kernel=False,
|
293 |
+
use_lma=use_lma,
|
294 |
+
)
|
295 |
+
),
|
296 |
+
inplace=inplace_safe,
|
297 |
+
)
|
298 |
+
|
299 |
+
m = add(m, self.single_transition(m, mask=single_mask), inplace=inplace_safe)
|
300 |
+
|
301 |
+
return m, z
|
302 |
+
|
303 |
+
|
304 |
+
class EvoformerStack(nn.Module):
|
305 |
+
"""
|
306 |
+
Main Evoformer trunk.
|
307 |
+
|
308 |
+
Implements Algorithm 6.
|
309 |
+
"""
|
310 |
+
|
311 |
+
def __init__(
|
312 |
+
self,
|
313 |
+
c_m: int,
|
314 |
+
c_z: int,
|
315 |
+
c_hidden_single_att: int,
|
316 |
+
c_hidden_mul: int,
|
317 |
+
c_hidden_pair_att: int,
|
318 |
+
c_s: int,
|
319 |
+
no_heads_single: int,
|
320 |
+
no_heads_pair: int,
|
321 |
+
no_blocks: int,
|
322 |
+
transition_n: int,
|
323 |
+
single_dropout: float,
|
324 |
+
pair_dropout: float,
|
325 |
+
blocks_per_ckpt: int,
|
326 |
+
inf: float,
|
327 |
+
eps: float,
|
328 |
+
clear_cache_between_blocks: bool = False,
|
329 |
+
**kwargs,
|
330 |
+
):
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
c_m:
|
334 |
+
single channel dimension
|
335 |
+
c_z:
|
336 |
+
Pair channel dimension
|
337 |
+
c_hidden_single_att:
|
338 |
+
Hidden dimension in single representation attention
|
339 |
+
c_hidden_mul:
|
340 |
+
Hidden dimension in multiplicative updates
|
341 |
+
c_hidden_pair_att:
|
342 |
+
Hidden dimension in triangular attention
|
343 |
+
c_s:
|
344 |
+
Channel dimension of the output "single" embedding
|
345 |
+
no_heads_single:
|
346 |
+
Number of heads used for single attention
|
347 |
+
no_heads_pair:
|
348 |
+
Number of heads used for pair attention
|
349 |
+
no_blocks:
|
350 |
+
Number of Evoformer blocks in the stack
|
351 |
+
transition_n:
|
352 |
+
Factor by which to multiply c_m to obtain the SingleTransition
|
353 |
+
hidden dimension
|
354 |
+
single_dropout:
|
355 |
+
Dropout rate for single activations
|
356 |
+
pair_dropout:
|
357 |
+
Dropout used for pair activations
|
358 |
+
blocks_per_ckpt:
|
359 |
+
Number of Evoformer blocks in each activation checkpoint
|
360 |
+
clear_cache_between_blocks:
|
361 |
+
Whether to clear CUDA's GPU memory cache between blocks of the
|
362 |
+
stack. Slows down each block but can reduce fragmentation
|
363 |
+
"""
|
364 |
+
super(EvoformerStack, self).__init__()
|
365 |
+
|
366 |
+
self.blocks_per_ckpt = blocks_per_ckpt
|
367 |
+
self.clear_cache_between_blocks = clear_cache_between_blocks
|
368 |
+
|
369 |
+
self.blocks = nn.ModuleList()
|
370 |
+
|
371 |
+
for _ in range(no_blocks):
|
372 |
+
block = EvoformerBlock(
|
373 |
+
c_m=c_m,
|
374 |
+
c_z=c_z,
|
375 |
+
c_hidden_single_att=c_hidden_single_att,
|
376 |
+
c_hidden_mul=c_hidden_mul,
|
377 |
+
c_hidden_pair_att=c_hidden_pair_att,
|
378 |
+
no_heads_single=no_heads_single,
|
379 |
+
no_heads_pair=no_heads_pair,
|
380 |
+
transition_n=transition_n,
|
381 |
+
single_dropout=single_dropout,
|
382 |
+
pair_dropout=pair_dropout,
|
383 |
+
inf=inf,
|
384 |
+
eps=eps,
|
385 |
+
)
|
386 |
+
self.blocks.append(block)
|
387 |
+
|
388 |
+
self.linear = Linear(c_m, c_s)
|
389 |
+
|
390 |
+
def _prep_blocks(self,
|
391 |
+
use_lma: bool,
|
392 |
+
single_mask: Optional[torch.Tensor],
|
393 |
+
pair_mask: Optional[torch.Tensor],
|
394 |
+
inplace_safe: bool,
|
395 |
+
_mask_trans: bool,
|
396 |
+
):
|
397 |
+
blocks = [
|
398 |
+
partial(
|
399 |
+
b,
|
400 |
+
single_mask=single_mask,
|
401 |
+
pair_mask=pair_mask,
|
402 |
+
use_lma=use_lma,
|
403 |
+
inplace_safe=inplace_safe,
|
404 |
+
_mask_trans=_mask_trans,
|
405 |
+
)
|
406 |
+
for b in self.blocks
|
407 |
+
]
|
408 |
+
|
409 |
+
if self.clear_cache_between_blocks:
|
410 |
+
def block_with_cache_clear(block, *args, **kwargs):
|
411 |
+
torch.cuda.empty_cache()
|
412 |
+
return block(*args, **kwargs)
|
413 |
+
|
414 |
+
blocks = [partial(block_with_cache_clear, b) for b in blocks]
|
415 |
+
|
416 |
+
return blocks
|
417 |
+
|
418 |
+
def forward(self,
|
419 |
+
m: torch.Tensor,
|
420 |
+
z: torch.Tensor,
|
421 |
+
single_mask: torch.Tensor,
|
422 |
+
pair_mask: torch.Tensor,
|
423 |
+
use_lma: bool = False,
|
424 |
+
inplace_safe: bool = False,
|
425 |
+
_mask_trans: bool = True,
|
426 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
427 |
+
"""
|
428 |
+
Args:
|
429 |
+
m:
|
430 |
+
[*, N_res, C_m] single embedding
|
431 |
+
z:
|
432 |
+
[*, N_res, N_res, C_z] pair embedding
|
433 |
+
single_mask:
|
434 |
+
[*, N_res] single mask
|
435 |
+
pair_mask:
|
436 |
+
[*, N_res, N_res] pair mask
|
437 |
+
use_lma:
|
438 |
+
Whether to use low-memory attention during inference.
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
m:
|
442 |
+
[*, N_res, C_m] single embedding
|
443 |
+
z:
|
444 |
+
[*, N_res, N_res, C_z] pair embedding
|
445 |
+
s:
|
446 |
+
[*, N_res, C_s] single embedding after linear layer
|
447 |
+
"""
|
448 |
+
blocks = self._prep_blocks(
|
449 |
+
use_lma=use_lma,
|
450 |
+
single_mask=single_mask,
|
451 |
+
pair_mask=pair_mask,
|
452 |
+
inplace_safe=inplace_safe,
|
453 |
+
_mask_trans=_mask_trans,
|
454 |
+
)
|
455 |
+
|
456 |
+
blocks_per_ckpt = self.blocks_per_ckpt
|
457 |
+
if(not torch.is_grad_enabled()):
|
458 |
+
blocks_per_ckpt = None
|
459 |
+
|
460 |
+
m, z = checkpoint_blocks(
|
461 |
+
blocks,
|
462 |
+
args=(m, z),
|
463 |
+
blocks_per_ckpt=blocks_per_ckpt,
|
464 |
+
)
|
465 |
+
|
466 |
+
s = self.linear(m)
|
467 |
+
|
468 |
+
return m, z, s
|
dockformer/model/heads.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch.nn import Parameter
|
19 |
+
|
20 |
+
from dockformer.model.primitives import Linear, LayerNorm
|
21 |
+
from dockformer.utils.loss import (
|
22 |
+
compute_plddt,
|
23 |
+
compute_tm,
|
24 |
+
compute_predicted_aligned_error,
|
25 |
+
)
|
26 |
+
from dockformer.utils.precision_utils import is_fp16_enabled
|
27 |
+
|
28 |
+
|
29 |
+
class AuxiliaryHeads(nn.Module):
|
30 |
+
def __init__(self, config):
|
31 |
+
super(AuxiliaryHeads, self).__init__()
|
32 |
+
|
33 |
+
self.plddt = PerResidueLDDTCaPredictor(
|
34 |
+
**config["lddt"],
|
35 |
+
)
|
36 |
+
|
37 |
+
self.distogram = DistogramHead(
|
38 |
+
**config["distogram"],
|
39 |
+
)
|
40 |
+
|
41 |
+
self.affinity_2d = Affinity2DPredictor(
|
42 |
+
**config["affinity_2d"],
|
43 |
+
)
|
44 |
+
|
45 |
+
self.affinity_1d = Affinity1DPredictor(
|
46 |
+
**config["affinity_1d"],
|
47 |
+
)
|
48 |
+
|
49 |
+
self.affinity_cls = AffinityClsTokenPredictor(
|
50 |
+
**config["affinity_cls"],
|
51 |
+
)
|
52 |
+
|
53 |
+
self.binding_site = BindingSitePredictor(
|
54 |
+
**config["binding_site"],
|
55 |
+
)
|
56 |
+
|
57 |
+
self.inter_contact = InterContactHead(
|
58 |
+
**config["inter_contact"],
|
59 |
+
)
|
60 |
+
|
61 |
+
self.config = config
|
62 |
+
|
63 |
+
def forward(self, outputs, inter_mask, affinity_mask):
|
64 |
+
aux_out = {}
|
65 |
+
lddt_logits = self.plddt(outputs["sm"]["single"])
|
66 |
+
aux_out["lddt_logits"] = lddt_logits
|
67 |
+
|
68 |
+
# Required for relaxation later on
|
69 |
+
aux_out["plddt"] = compute_plddt(lddt_logits)
|
70 |
+
|
71 |
+
distogram_logits = self.distogram(outputs["pair"])
|
72 |
+
aux_out["distogram_logits"] = distogram_logits
|
73 |
+
|
74 |
+
aux_out["inter_contact_logits"] = self.inter_contact(outputs["single"], outputs["pair"])
|
75 |
+
|
76 |
+
aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask)
|
77 |
+
|
78 |
+
aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"])
|
79 |
+
|
80 |
+
aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask)
|
81 |
+
|
82 |
+
aux_out["binding_site_logits"] = self.binding_site(outputs["single"])
|
83 |
+
|
84 |
+
return aux_out
|
85 |
+
|
86 |
+
|
87 |
+
class Affinity2DPredictor(nn.Module):
|
88 |
+
def __init__(self, c_z, num_bins):
|
89 |
+
super(Affinity2DPredictor, self).__init__()
|
90 |
+
|
91 |
+
self.c_z = c_z
|
92 |
+
|
93 |
+
self.weight_linear = Linear(self.c_z + 1, 1)
|
94 |
+
self.embed_linear = Linear(self.c_z, self.c_z)
|
95 |
+
self.bins_linear = Linear(self.c_z, num_bins)
|
96 |
+
|
97 |
+
def forward(self, z, inter_contacts_logits, inter_pair_mask):
|
98 |
+
z_with_inter_contacts = torch.cat((z, inter_contacts_logits), dim=-1) # [*, N, N, c_z + 1]
|
99 |
+
weights = self.weight_linear(z_with_inter_contacts) # [*, N, N, 1]
|
100 |
+
|
101 |
+
x = self.embed_linear(z) # [*, N, N, c_z]
|
102 |
+
batch_size, N, M, _ = x.shape
|
103 |
+
|
104 |
+
flat_weights = weights.reshape(batch_size, N*M, -1) # [*, N*M, 1]
|
105 |
+
flat_x = x.reshape(batch_size, N*M, -1) # [*, N*M, c_z]
|
106 |
+
flat_inter_pair_mask = inter_pair_mask.reshape(batch_size, N*M, 1)
|
107 |
+
|
108 |
+
flat_weights = flat_weights.masked_fill(~(flat_inter_pair_mask.bool()), float('-inf')) # [*, N*N, 1]
|
109 |
+
flat_weights = torch.nn.functional.softmax(flat_weights, dim=1) # [*, N*N, 1]
|
110 |
+
flat_weights = torch.nan_to_num(flat_weights, nan=0.0) # [*, N*N, 1]
|
111 |
+
weighted_sum = torch.sum((flat_weights * flat_x).reshape(batch_size, N*M, -1), dim=1) # [*, c_z]
|
112 |
+
|
113 |
+
return self.bins_linear(weighted_sum)
|
114 |
+
|
115 |
+
|
116 |
+
class Affinity1DPredictor(nn.Module):
|
117 |
+
def __init__(self, c_s, num_bins, **kwargs):
|
118 |
+
super(Affinity1DPredictor, self).__init__()
|
119 |
+
|
120 |
+
self.c_s = c_s
|
121 |
+
|
122 |
+
self.linear1 = Linear(self.c_s, self.c_s, init="final")
|
123 |
+
|
124 |
+
self.linear2 = Linear(self.c_s, num_bins, init="final")
|
125 |
+
|
126 |
+
def forward(self, s):
|
127 |
+
# [*, N, C_out]
|
128 |
+
s = self.linear1(s)
|
129 |
+
|
130 |
+
# get an average over the sequence
|
131 |
+
s = torch.mean(s, dim=1)
|
132 |
+
|
133 |
+
logits = self.linear2(s)
|
134 |
+
return logits
|
135 |
+
|
136 |
+
|
137 |
+
class AffinityClsTokenPredictor(nn.Module):
|
138 |
+
def __init__(self, c_s, num_bins, **kwargs):
|
139 |
+
super(AffinityClsTokenPredictor, self).__init__()
|
140 |
+
|
141 |
+
self.c_s = c_s
|
142 |
+
self.linear = Linear(self.c_s, num_bins, init="final")
|
143 |
+
|
144 |
+
def forward(self, s, affinity_mask):
|
145 |
+
affinity_tokens = (s * affinity_mask.unsqueeze(-1)).sum(dim=1)
|
146 |
+
return self.linear(affinity_tokens)
|
147 |
+
|
148 |
+
|
149 |
+
class BindingSitePredictor(nn.Module):
|
150 |
+
def __init__(self, c_s, c_out, **kwargs):
|
151 |
+
super(BindingSitePredictor, self).__init__()
|
152 |
+
|
153 |
+
self.c_s = c_s
|
154 |
+
self.c_out = c_out
|
155 |
+
|
156 |
+
self.linear = Linear(self.c_s, self.c_out, init="final")
|
157 |
+
|
158 |
+
def forward(self, s):
|
159 |
+
# [*, N, C_out]
|
160 |
+
return self.linear(s)
|
161 |
+
|
162 |
+
|
163 |
+
class InterContactHead(nn.Module):
|
164 |
+
def __init__(self, c_s, c_z, c_out, **kwargs):
|
165 |
+
"""
|
166 |
+
Args:
|
167 |
+
c_z:
|
168 |
+
Input channel dimension
|
169 |
+
c_out:
|
170 |
+
Number of bins, but since boolean should be 1
|
171 |
+
"""
|
172 |
+
super(InterContactHead, self).__init__()
|
173 |
+
|
174 |
+
self.c_s = c_s
|
175 |
+
self.c_z = c_z
|
176 |
+
self.c_out = c_out
|
177 |
+
|
178 |
+
self.linear = Linear(2 * self.c_s + self.c_z, self.c_out, init="final")
|
179 |
+
|
180 |
+
def forward(self, s, z): # [*, N, N, C_z]
|
181 |
+
# [*, N, N, no_bins]
|
182 |
+
batch_size, n, s_dim = s.shape
|
183 |
+
|
184 |
+
s_i = s.unsqueeze(2).expand(batch_size, n, n, s_dim)
|
185 |
+
s_j = s.unsqueeze(1).expand(batch_size, n, n, s_dim)
|
186 |
+
joined = torch.cat((s_i, s_j, z), dim=-1)
|
187 |
+
|
188 |
+
logits = self.linear(joined)
|
189 |
+
|
190 |
+
return logits
|
191 |
+
|
192 |
+
|
193 |
+
class PerResidueLDDTCaPredictor(nn.Module):
|
194 |
+
def __init__(self, no_bins, c_in, c_hidden):
|
195 |
+
super(PerResidueLDDTCaPredictor, self).__init__()
|
196 |
+
|
197 |
+
self.no_bins = no_bins
|
198 |
+
self.c_in = c_in
|
199 |
+
self.c_hidden = c_hidden
|
200 |
+
|
201 |
+
self.layer_norm = LayerNorm(self.c_in)
|
202 |
+
|
203 |
+
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
|
204 |
+
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
|
205 |
+
self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
|
206 |
+
|
207 |
+
self.relu = nn.ReLU()
|
208 |
+
|
209 |
+
def forward(self, s):
|
210 |
+
s = self.layer_norm(s)
|
211 |
+
s = self.linear_1(s)
|
212 |
+
s = self.relu(s)
|
213 |
+
s = self.linear_2(s)
|
214 |
+
s = self.relu(s)
|
215 |
+
s = self.linear_3(s)
|
216 |
+
|
217 |
+
return s
|
218 |
+
|
219 |
+
|
220 |
+
class DistogramHead(nn.Module):
|
221 |
+
"""
|
222 |
+
Computes a distogram probability distribution.
|
223 |
+
|
224 |
+
For use in computation of distogram loss, subsection 1.9.8
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(self, c_z, no_bins, **kwargs):
|
228 |
+
"""
|
229 |
+
Args:
|
230 |
+
c_z:
|
231 |
+
Input channel dimension
|
232 |
+
no_bins:
|
233 |
+
Number of distogram bins
|
234 |
+
"""
|
235 |
+
super(DistogramHead, self).__init__()
|
236 |
+
|
237 |
+
self.c_z = c_z
|
238 |
+
self.no_bins = no_bins
|
239 |
+
|
240 |
+
self.linear = Linear(self.c_z, self.no_bins, init="final")
|
241 |
+
|
242 |
+
def _forward(self, z): # [*, N, N, C_z]
|
243 |
+
"""
|
244 |
+
Args:
|
245 |
+
z:
|
246 |
+
[*, N_res, N_res, C_z] pair embedding
|
247 |
+
Returns:
|
248 |
+
[*, N, N, no_bins] distogram probability distribution
|
249 |
+
"""
|
250 |
+
# [*, N, N, no_bins]
|
251 |
+
logits = self.linear(z)
|
252 |
+
logits = logits + logits.transpose(-2, -3)
|
253 |
+
return logits
|
254 |
+
|
255 |
+
def forward(self, z):
|
256 |
+
if(is_fp16_enabled()):
|
257 |
+
with torch.cuda.amp.autocast(enabled=False):
|
258 |
+
return self._forward(z.float())
|
259 |
+
else:
|
260 |
+
return self._forward(z)
|
dockformer/model/model.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from functools import partial
|
16 |
+
import weakref
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from dockformer.utils.tensor_utils import masked_mean
|
22 |
+
from dockformer.model.embedders import (
|
23 |
+
StructureInputEmbedder,
|
24 |
+
RecyclingEmbedder,
|
25 |
+
)
|
26 |
+
from dockformer.model.evoformer import EvoformerStack
|
27 |
+
from dockformer.model.heads import AuxiliaryHeads
|
28 |
+
from dockformer.model.structure_module import StructureModule
|
29 |
+
import dockformer.utils.residue_constants as residue_constants
|
30 |
+
from dockformer.utils.feats import (
|
31 |
+
pseudo_beta_fn,
|
32 |
+
atom14_to_atom37,
|
33 |
+
)
|
34 |
+
from dockformer.utils.tensor_utils import (
|
35 |
+
add,
|
36 |
+
tensor_tree_map,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class AlphaFold(nn.Module):
|
41 |
+
"""
|
42 |
+
Alphafold 2.
|
43 |
+
|
44 |
+
Implements Algorithm 2 (but with training).
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, config):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
config:
|
51 |
+
A dict-like config object (like the one in config.py)
|
52 |
+
"""
|
53 |
+
super(AlphaFold, self).__init__()
|
54 |
+
|
55 |
+
self.globals = config.globals
|
56 |
+
self.config = config.model
|
57 |
+
|
58 |
+
# Main trunk + structure module
|
59 |
+
self.input_embedder = StructureInputEmbedder(
|
60 |
+
**self.config["structure_input_embedder"],
|
61 |
+
)
|
62 |
+
|
63 |
+
self.recycling_embedder = RecyclingEmbedder(
|
64 |
+
**self.config["recycling_embedder"],
|
65 |
+
)
|
66 |
+
|
67 |
+
self.evoformer = EvoformerStack(
|
68 |
+
**self.config["evoformer_stack"],
|
69 |
+
)
|
70 |
+
|
71 |
+
self.structure_module = StructureModule(
|
72 |
+
**self.config["structure_module"],
|
73 |
+
)
|
74 |
+
self.aux_heads = AuxiliaryHeads(
|
75 |
+
self.config["heads"],
|
76 |
+
)
|
77 |
+
|
78 |
+
def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool:
|
79 |
+
"""
|
80 |
+
Early stopping criteria based on criteria used in
|
81 |
+
AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
|
82 |
+
Args:
|
83 |
+
prev_pos: Previous atom positions in atom37/14 representation
|
84 |
+
next_pos: Current atom positions in atom37/14 representation
|
85 |
+
mask: 1-D sequence mask
|
86 |
+
eps: Epsilon used in square root calculation
|
87 |
+
Returns:
|
88 |
+
Whether to stop recycling early based on the desired tolerance.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def distances(points):
|
92 |
+
"""Compute all pairwise distances for a set of points."""
|
93 |
+
d = points[..., None, :] - points[..., None, :, :]
|
94 |
+
return torch.sqrt(torch.sum(d ** 2, dim=-1))
|
95 |
+
|
96 |
+
if self.config.recycle_early_stop_tolerance < 0:
|
97 |
+
return False
|
98 |
+
|
99 |
+
ca_idx = residue_constants.atom_order['CA']
|
100 |
+
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
|
101 |
+
mask = mask[..., None] * mask[..., None, :]
|
102 |
+
sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape))))
|
103 |
+
diff = torch.sqrt(sq_diff + eps).item()
|
104 |
+
return diff <= self.config.recycle_early_stop_tolerance
|
105 |
+
|
106 |
+
def iteration(self, feats, prevs, _recycle=True):
|
107 |
+
# Primary output dictionary
|
108 |
+
outputs = {}
|
109 |
+
|
110 |
+
# This needs to be done manually for DeepSpeed's sake
|
111 |
+
dtype = next(self.parameters()).dtype
|
112 |
+
for k in feats:
|
113 |
+
if feats[k].dtype == torch.float32:
|
114 |
+
feats[k] = feats[k].to(dtype=dtype)
|
115 |
+
|
116 |
+
# Grab some data about the input
|
117 |
+
batch_dims, n_total = feats["token_mask"].shape
|
118 |
+
device = feats["token_mask"].device
|
119 |
+
|
120 |
+
print("doing sample of size", feats["token_mask"].shape,
|
121 |
+
feats["protein_mask"].sum(dim=1), feats["ligand_mask"].sum(dim=1))
|
122 |
+
|
123 |
+
# Controls whether the model uses in-place operations throughout
|
124 |
+
# The dual condition accounts for activation checkpoints
|
125 |
+
# inplace_safe = not (self.training or torch.is_grad_enabled())
|
126 |
+
inplace_safe = False # so we don't need attn_core_inplace_cuda
|
127 |
+
|
128 |
+
# Prep some features
|
129 |
+
token_mask = feats["token_mask"]
|
130 |
+
pair_mask = token_mask[..., None] * token_mask[..., None, :]
|
131 |
+
|
132 |
+
# Initialize the single and pair representations
|
133 |
+
# m: [*, 1, n_total, C_m]
|
134 |
+
# z: [*, n_total, n_total, C_z]
|
135 |
+
m, z = self.input_embedder(
|
136 |
+
feats["token_mask"],
|
137 |
+
feats["protein_mask"],
|
138 |
+
feats["ligand_mask"],
|
139 |
+
feats["target_feat"],
|
140 |
+
feats["ligand_bonds_feat"],
|
141 |
+
feats["input_positions"],
|
142 |
+
feats["protein_residue_index"],
|
143 |
+
feats["protein_distogram_mask"],
|
144 |
+
inplace_safe=inplace_safe,
|
145 |
+
)
|
146 |
+
|
147 |
+
# Unpack the recycling embeddings. Removing them from the list allows
|
148 |
+
# them to be freed further down in this function, saving memory
|
149 |
+
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
|
150 |
+
|
151 |
+
# Initialize the recycling embeddings, if needs be
|
152 |
+
if None in [m_1_prev, z_prev, x_prev]:
|
153 |
+
# [*, N, C_m]
|
154 |
+
m_1_prev = m.new_zeros(
|
155 |
+
(batch_dims, n_total, self.config.structure_input_embedder.c_m),
|
156 |
+
requires_grad=False,
|
157 |
+
)
|
158 |
+
|
159 |
+
# [*, N, N, C_z]
|
160 |
+
z_prev = z.new_zeros(
|
161 |
+
(batch_dims, n_total, n_total, self.config.structure_input_embedder.c_z),
|
162 |
+
requires_grad=False,
|
163 |
+
)
|
164 |
+
|
165 |
+
# [*, N, 3]
|
166 |
+
x_prev = z.new_zeros(
|
167 |
+
(batch_dims, n_total, residue_constants.atom_type_num, 3),
|
168 |
+
requires_grad=False,
|
169 |
+
)
|
170 |
+
|
171 |
+
# shape == [1, n_total, 37, 3]
|
172 |
+
pseudo_beta_or_lig_x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None).to(dtype=z.dtype)
|
173 |
+
|
174 |
+
# m_1_prev_emb: [*, N, C_m]
|
175 |
+
# z_prev_emb: [*, N, N, C_z]
|
176 |
+
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
|
177 |
+
m_1_prev,
|
178 |
+
z_prev,
|
179 |
+
pseudo_beta_or_lig_x_prev,
|
180 |
+
inplace_safe=inplace_safe,
|
181 |
+
)
|
182 |
+
|
183 |
+
del pseudo_beta_or_lig_x_prev
|
184 |
+
|
185 |
+
# [*, S_c, N, C_m]
|
186 |
+
m += m_1_prev_emb
|
187 |
+
|
188 |
+
# [*, N, N, C_z]
|
189 |
+
z = add(z, z_prev_emb, inplace=inplace_safe)
|
190 |
+
|
191 |
+
# Deletions like these become significant for inference with large N,
|
192 |
+
# where they free unused tensors and remove references to others such
|
193 |
+
# that they can be offloaded later
|
194 |
+
del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb
|
195 |
+
|
196 |
+
# Run single + pair embeddings through the trunk of the network
|
197 |
+
# m: [*, N, C_m]
|
198 |
+
# z: [*, N, N, C_z]
|
199 |
+
# s: [*, N, C_s]
|
200 |
+
m, z, s = self.evoformer(
|
201 |
+
m,
|
202 |
+
z,
|
203 |
+
single_mask=token_mask.to(dtype=m.dtype),
|
204 |
+
pair_mask=pair_mask.to(dtype=z.dtype),
|
205 |
+
use_lma=self.globals.use_lma,
|
206 |
+
inplace_safe=inplace_safe,
|
207 |
+
_mask_trans=self.config._mask_trans,
|
208 |
+
)
|
209 |
+
|
210 |
+
outputs["pair"] = z
|
211 |
+
outputs["single"] = s
|
212 |
+
|
213 |
+
del z
|
214 |
+
|
215 |
+
# Predict 3D structure
|
216 |
+
outputs["sm"] = self.structure_module(
|
217 |
+
outputs,
|
218 |
+
feats["aatype"],
|
219 |
+
mask=token_mask.to(dtype=s.dtype),
|
220 |
+
inplace_safe=inplace_safe,
|
221 |
+
)
|
222 |
+
outputs["final_atom_positions"] = atom14_to_atom37(
|
223 |
+
outputs["sm"]["positions"][-1], feats
|
224 |
+
)
|
225 |
+
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
|
226 |
+
|
227 |
+
# Save embeddings for use during the next recycling iteration
|
228 |
+
|
229 |
+
# [*, N, C_m]
|
230 |
+
m_1_prev = m[..., 0, :, :]
|
231 |
+
|
232 |
+
# [*, N, N, C_z]
|
233 |
+
z_prev = outputs["pair"]
|
234 |
+
|
235 |
+
# TODO bshor: early stop depends on is_multimer, but I don't think it must
|
236 |
+
early_stop = False
|
237 |
+
# if self.globals.is_multimer:
|
238 |
+
# early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask)
|
239 |
+
|
240 |
+
del x_prev
|
241 |
+
|
242 |
+
# [*, N, 3]
|
243 |
+
x_prev = outputs["final_atom_positions"]
|
244 |
+
|
245 |
+
return outputs, m_1_prev, z_prev, x_prev, early_stop
|
246 |
+
|
247 |
+
def forward(self, batch):
|
248 |
+
"""
|
249 |
+
Args:
|
250 |
+
batch:
|
251 |
+
Dictionary of arguments outlined in Algorithm 2. Keys must
|
252 |
+
include the official names of the features in the
|
253 |
+
supplement subsection 1.2.9.
|
254 |
+
|
255 |
+
The final dimension of each input must have length equal to
|
256 |
+
the number of recycling iterations.
|
257 |
+
|
258 |
+
Features (without the recycling dimension):
|
259 |
+
|
260 |
+
"aatype" ([*, N_res]):
|
261 |
+
Contrary to the supplement, this tensor of residue
|
262 |
+
indices is not one-hot.
|
263 |
+
"protein_target_feat" ([*, N_res, C_tf])
|
264 |
+
One-hot encoding of the target sequence. C_tf is
|
265 |
+
config.model.input_embedder.tf_dim.
|
266 |
+
"residue_index" ([*, N_res])
|
267 |
+
Tensor whose final dimension consists of
|
268 |
+
consecutive indices from 0 to N_res.
|
269 |
+
"token_mask" ([*, N_token])
|
270 |
+
1-D token mask
|
271 |
+
"pair_mask" ([*, N_token, N_token])
|
272 |
+
2-D pair mask
|
273 |
+
"""
|
274 |
+
# Initialize recycling embeddings
|
275 |
+
m_1_prev, z_prev, x_prev = None, None, None
|
276 |
+
prevs = [m_1_prev, z_prev, x_prev]
|
277 |
+
|
278 |
+
is_grad_enabled = torch.is_grad_enabled()
|
279 |
+
|
280 |
+
# Main recycling loop
|
281 |
+
num_iters = batch["aatype"].shape[-1]
|
282 |
+
early_stop = False
|
283 |
+
num_recycles = 0
|
284 |
+
for cycle_no in range(num_iters):
|
285 |
+
# Select the features for the current recycling cycle
|
286 |
+
fetch_cur_batch = lambda t: t[..., cycle_no]
|
287 |
+
feats = tensor_tree_map(fetch_cur_batch, batch)
|
288 |
+
|
289 |
+
# Enable grad iff we're training and it's the final recycling layer
|
290 |
+
is_final_iter = cycle_no == (num_iters - 1) or early_stop
|
291 |
+
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
|
292 |
+
if is_final_iter:
|
293 |
+
# Sidestep AMP bug (PyTorch issue #65766)
|
294 |
+
if torch.is_autocast_enabled():
|
295 |
+
torch.clear_autocast_cache()
|
296 |
+
|
297 |
+
# Run the next iteration of the model
|
298 |
+
outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
|
299 |
+
feats,
|
300 |
+
prevs,
|
301 |
+
_recycle=(num_iters > 1)
|
302 |
+
)
|
303 |
+
|
304 |
+
num_recycles += 1
|
305 |
+
|
306 |
+
if not is_final_iter:
|
307 |
+
del outputs
|
308 |
+
prevs = [m_1_prev, z_prev, x_prev]
|
309 |
+
del m_1_prev, z_prev, x_prev
|
310 |
+
else:
|
311 |
+
break
|
312 |
+
|
313 |
+
outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
|
314 |
+
|
315 |
+
# Run auxiliary heads, remove the recycling dimension batch properties
|
316 |
+
outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0]))
|
317 |
+
|
318 |
+
return outputs
|
dockformer/model/pair_transition.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from dockformer.model.primitives import Linear, LayerNorm
|
21 |
+
|
22 |
+
|
23 |
+
class PairTransition(nn.Module):
|
24 |
+
"""
|
25 |
+
Implements Algorithm 15.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, c_z, n):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
c_z:
|
32 |
+
Pair transition channel dimension
|
33 |
+
n:
|
34 |
+
Factor by which c_z is multiplied to obtain hidden channel
|
35 |
+
dimension
|
36 |
+
"""
|
37 |
+
super(PairTransition, self).__init__()
|
38 |
+
|
39 |
+
self.c_z = c_z
|
40 |
+
self.n = n
|
41 |
+
|
42 |
+
self.layer_norm = LayerNorm(self.c_z)
|
43 |
+
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
|
44 |
+
self.relu = nn.ReLU()
|
45 |
+
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
|
46 |
+
|
47 |
+
def _transition(self, z, mask):
|
48 |
+
# [*, N_res, N_res, C_z]
|
49 |
+
z = self.layer_norm(z)
|
50 |
+
|
51 |
+
# [*, N_res, N_res, C_hidden]
|
52 |
+
z = self.linear_1(z)
|
53 |
+
z = self.relu(z)
|
54 |
+
|
55 |
+
# [*, N_res, N_res, C_z]
|
56 |
+
z = self.linear_2(z)
|
57 |
+
z = z * mask
|
58 |
+
|
59 |
+
return z
|
60 |
+
|
61 |
+
def forward(self,
|
62 |
+
z: torch.Tensor,
|
63 |
+
mask: Optional[torch.Tensor] = None,
|
64 |
+
) -> torch.Tensor:
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
z:
|
68 |
+
[*, N_res, N_res, C_z] pair embedding
|
69 |
+
Returns:
|
70 |
+
[*, N_res, N_res, C_z] pair embedding update
|
71 |
+
"""
|
72 |
+
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
|
73 |
+
if mask is None:
|
74 |
+
mask = z.new_ones(z.shape[:-1])
|
75 |
+
|
76 |
+
# [*, N_res, N_res, 1]
|
77 |
+
mask = mask.unsqueeze(-1)
|
78 |
+
|
79 |
+
z = self._transition(z=z, mask=mask)
|
80 |
+
|
81 |
+
return z
|
dockformer/model/primitives.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import importlib
|
16 |
+
import math
|
17 |
+
from typing import Optional, Callable, List, Tuple
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.utils.checkpoint
|
23 |
+
from scipy.stats import truncnorm
|
24 |
+
|
25 |
+
from dockformer.utils.kernel.attention_core import attention_core
|
26 |
+
from dockformer.utils.precision_utils import is_fp16_enabled
|
27 |
+
from dockformer.utils.tensor_utils import (
|
28 |
+
permute_final_dims,
|
29 |
+
flatten_final_dims,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
# Suited for 40gb GPU
|
34 |
+
# DEFAULT_LMA_Q_CHUNK_SIZE = 1024
|
35 |
+
# DEFAULT_LMA_KV_CHUNK_SIZE = 4096
|
36 |
+
# Suited for 10gb GPU
|
37 |
+
DEFAULT_LMA_Q_CHUNK_SIZE = 64
|
38 |
+
DEFAULT_LMA_KV_CHUNK_SIZE = 256
|
39 |
+
|
40 |
+
|
41 |
+
def _prod(nums):
|
42 |
+
out = 1
|
43 |
+
for n in nums:
|
44 |
+
out = out * n
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
49 |
+
fan_out, fan_in = linear_weight_shape
|
50 |
+
|
51 |
+
if fan == "fan_in":
|
52 |
+
f = fan_in
|
53 |
+
elif fan == "fan_out":
|
54 |
+
f = fan_out
|
55 |
+
elif fan == "fan_avg":
|
56 |
+
f = (fan_in + fan_out) / 2
|
57 |
+
else:
|
58 |
+
raise ValueError("Invalid fan option")
|
59 |
+
|
60 |
+
return f
|
61 |
+
|
62 |
+
|
63 |
+
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
|
64 |
+
shape = weights.shape
|
65 |
+
f = _calculate_fan(shape, fan)
|
66 |
+
scale = scale / max(1, f)
|
67 |
+
a = -2
|
68 |
+
b = 2
|
69 |
+
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
|
70 |
+
size = _prod(shape)
|
71 |
+
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
|
72 |
+
samples = np.reshape(samples, shape)
|
73 |
+
with torch.no_grad():
|
74 |
+
weights.copy_(torch.tensor(samples, device=weights.device))
|
75 |
+
|
76 |
+
|
77 |
+
def lecun_normal_init_(weights):
|
78 |
+
trunc_normal_init_(weights, scale=1.0)
|
79 |
+
|
80 |
+
|
81 |
+
def he_normal_init_(weights):
|
82 |
+
trunc_normal_init_(weights, scale=2.0)
|
83 |
+
|
84 |
+
|
85 |
+
def glorot_uniform_init_(weights):
|
86 |
+
nn.init.xavier_uniform_(weights, gain=1)
|
87 |
+
|
88 |
+
|
89 |
+
def final_init_(weights):
|
90 |
+
with torch.no_grad():
|
91 |
+
weights.fill_(0.0)
|
92 |
+
|
93 |
+
|
94 |
+
def gating_init_(weights):
|
95 |
+
with torch.no_grad():
|
96 |
+
weights.fill_(0.0)
|
97 |
+
|
98 |
+
|
99 |
+
def normal_init_(weights):
|
100 |
+
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
101 |
+
|
102 |
+
|
103 |
+
def ipa_point_weights_init_(weights):
|
104 |
+
with torch.no_grad():
|
105 |
+
softplus_inverse_1 = 0.541324854612918
|
106 |
+
weights.fill_(softplus_inverse_1)
|
107 |
+
|
108 |
+
|
109 |
+
class Linear(nn.Linear):
|
110 |
+
"""
|
111 |
+
A Linear layer with built-in nonstandard initializations. Called just
|
112 |
+
like torch.nn.Linear.
|
113 |
+
|
114 |
+
Implements the initializers in 1.11.4, plus some additional ones found
|
115 |
+
in the code.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
in_dim: int,
|
121 |
+
out_dim: int,
|
122 |
+
bias: bool = True,
|
123 |
+
init: str = "default",
|
124 |
+
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
125 |
+
precision=None
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
Args:
|
129 |
+
in_dim:
|
130 |
+
The final dimension of inputs to the layer
|
131 |
+
out_dim:
|
132 |
+
The final dimension of layer outputs
|
133 |
+
bias:
|
134 |
+
Whether to learn an additive bias. True by default
|
135 |
+
init:
|
136 |
+
The initializer to use. Choose from:
|
137 |
+
|
138 |
+
"default": LeCun fan-in truncated normal initialization
|
139 |
+
"relu": He initialization w/ truncated normal distribution
|
140 |
+
"glorot": Fan-average Glorot uniform initialization
|
141 |
+
"gating": Weights=0, Bias=1
|
142 |
+
"normal": Normal initialization with std=1/sqrt(fan_in)
|
143 |
+
"final": Weights=0, Bias=0
|
144 |
+
|
145 |
+
Overridden by init_fn if the latter is not None.
|
146 |
+
init_fn:
|
147 |
+
A custom initializer taking weight and bias as inputs.
|
148 |
+
Overrides init if not None.
|
149 |
+
"""
|
150 |
+
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
|
151 |
+
|
152 |
+
if bias:
|
153 |
+
with torch.no_grad():
|
154 |
+
self.bias.fill_(0)
|
155 |
+
|
156 |
+
with torch.no_grad():
|
157 |
+
if init_fn is not None:
|
158 |
+
init_fn(self.weight, self.bias)
|
159 |
+
else:
|
160 |
+
if init == "default":
|
161 |
+
lecun_normal_init_(self.weight)
|
162 |
+
elif init == "relu":
|
163 |
+
he_normal_init_(self.weight)
|
164 |
+
elif init == "glorot":
|
165 |
+
glorot_uniform_init_(self.weight)
|
166 |
+
elif init == "gating":
|
167 |
+
gating_init_(self.weight)
|
168 |
+
if bias:
|
169 |
+
self.bias.fill_(1.0)
|
170 |
+
elif init == "normal":
|
171 |
+
normal_init_(self.weight)
|
172 |
+
elif init == "final":
|
173 |
+
final_init_(self.weight)
|
174 |
+
else:
|
175 |
+
raise ValueError("Invalid init string.")
|
176 |
+
|
177 |
+
self.precision = precision
|
178 |
+
|
179 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
180 |
+
d = input.dtype
|
181 |
+
if self.precision is not None:
|
182 |
+
with torch.cuda.amp.autocast(enabled=False):
|
183 |
+
bias = self.bias.to(dtype=self.precision) if self.bias is not None else None
|
184 |
+
return nn.functional.linear(input.to(dtype=self.precision),
|
185 |
+
self.weight.to(dtype=self.precision),
|
186 |
+
bias).to(dtype=d)
|
187 |
+
|
188 |
+
if d is torch.bfloat16:
|
189 |
+
with torch.cuda.amp.autocast(enabled=False):
|
190 |
+
bias = self.bias.to(dtype=d) if self.bias is not None else None
|
191 |
+
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
|
192 |
+
|
193 |
+
return nn.functional.linear(input, self.weight, self.bias)
|
194 |
+
|
195 |
+
|
196 |
+
class LayerNorm(nn.Module):
|
197 |
+
def __init__(self, c_in, eps=1e-5):
|
198 |
+
super(LayerNorm, self).__init__()
|
199 |
+
|
200 |
+
self.c_in = (c_in,)
|
201 |
+
self.eps = eps
|
202 |
+
|
203 |
+
self.weight = nn.Parameter(torch.ones(c_in))
|
204 |
+
self.bias = nn.Parameter(torch.zeros(c_in))
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
d = x.dtype
|
208 |
+
if d is torch.bfloat16:
|
209 |
+
with torch.cuda.amp.autocast(enabled=False):
|
210 |
+
out = nn.functional.layer_norm(
|
211 |
+
x,
|
212 |
+
self.c_in,
|
213 |
+
self.weight.to(dtype=d),
|
214 |
+
self.bias.to(dtype=d),
|
215 |
+
self.eps
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
out = nn.functional.layer_norm(
|
219 |
+
x,
|
220 |
+
self.c_in,
|
221 |
+
self.weight,
|
222 |
+
self.bias,
|
223 |
+
self.eps,
|
224 |
+
)
|
225 |
+
|
226 |
+
return out
|
227 |
+
|
228 |
+
|
229 |
+
@torch.jit.ignore
|
230 |
+
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
231 |
+
"""
|
232 |
+
Softmax, but without automatic casting to fp32 when the input is of
|
233 |
+
type bfloat16
|
234 |
+
"""
|
235 |
+
d = t.dtype
|
236 |
+
if d is torch.bfloat16:
|
237 |
+
with torch.cuda.amp.autocast(enabled=False):
|
238 |
+
s = torch.nn.functional.softmax(t, dim=dim)
|
239 |
+
else:
|
240 |
+
s = torch.nn.functional.softmax(t, dim=dim)
|
241 |
+
|
242 |
+
return s
|
243 |
+
|
244 |
+
|
245 |
+
#@torch.jit.script
|
246 |
+
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
|
247 |
+
# [*, H, C_hidden, K]
|
248 |
+
key = permute_final_dims(key, (1, 0))
|
249 |
+
|
250 |
+
# [*, H, Q, K]
|
251 |
+
a = torch.matmul(query, key)
|
252 |
+
|
253 |
+
for b in biases:
|
254 |
+
a += b
|
255 |
+
|
256 |
+
a = softmax_no_cast(a, -1)
|
257 |
+
|
258 |
+
# [*, H, Q, C_hidden]
|
259 |
+
a = torch.matmul(a, value)
|
260 |
+
|
261 |
+
return a
|
262 |
+
|
263 |
+
|
264 |
+
class Attention(nn.Module):
|
265 |
+
"""
|
266 |
+
Standard multi-head attention using AlphaFold's default layer
|
267 |
+
initialization. Allows multiple bias vectors.
|
268 |
+
"""
|
269 |
+
def __init__(
|
270 |
+
self,
|
271 |
+
c_q: int,
|
272 |
+
c_k: int,
|
273 |
+
c_v: int,
|
274 |
+
c_hidden: int,
|
275 |
+
no_heads: int,
|
276 |
+
gating: bool = True,
|
277 |
+
):
|
278 |
+
"""
|
279 |
+
Args:
|
280 |
+
c_q:
|
281 |
+
Input dimension of query data
|
282 |
+
c_k:
|
283 |
+
Input dimension of key data
|
284 |
+
c_v:
|
285 |
+
Input dimension of value data
|
286 |
+
c_hidden:
|
287 |
+
Per-head hidden dimension
|
288 |
+
no_heads:
|
289 |
+
Number of attention heads
|
290 |
+
gating:
|
291 |
+
Whether the output should be gated using query data
|
292 |
+
"""
|
293 |
+
super(Attention, self).__init__()
|
294 |
+
|
295 |
+
self.c_q = c_q
|
296 |
+
self.c_k = c_k
|
297 |
+
self.c_v = c_v
|
298 |
+
self.c_hidden = c_hidden
|
299 |
+
self.no_heads = no_heads
|
300 |
+
self.gating = gating
|
301 |
+
|
302 |
+
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
303 |
+
# stated in the supplement, but the overall channel dimension.
|
304 |
+
|
305 |
+
self.linear_q = Linear(
|
306 |
+
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
307 |
+
)
|
308 |
+
self.linear_k = Linear(
|
309 |
+
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
310 |
+
)
|
311 |
+
self.linear_v = Linear(
|
312 |
+
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
|
313 |
+
)
|
314 |
+
self.linear_o = Linear(
|
315 |
+
self.c_hidden * self.no_heads, self.c_q, init="final"
|
316 |
+
)
|
317 |
+
|
318 |
+
self.linear_g = None
|
319 |
+
if self.gating:
|
320 |
+
self.linear_g = Linear(
|
321 |
+
self.c_q, self.c_hidden * self.no_heads, init="gating"
|
322 |
+
)
|
323 |
+
|
324 |
+
self.sigmoid = nn.Sigmoid()
|
325 |
+
|
326 |
+
def _prep_qkv(self,
|
327 |
+
q_x: torch.Tensor,
|
328 |
+
kv_x: torch.Tensor,
|
329 |
+
apply_scale: bool = True
|
330 |
+
) -> Tuple[
|
331 |
+
torch.Tensor, torch.Tensor, torch.Tensor
|
332 |
+
]:
|
333 |
+
# [*, Q/K/V, H * C_hidden]
|
334 |
+
q = self.linear_q(q_x)
|
335 |
+
k = self.linear_k(kv_x)
|
336 |
+
v = self.linear_v(kv_x)
|
337 |
+
|
338 |
+
# [*, Q/K, H, C_hidden]
|
339 |
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
340 |
+
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
341 |
+
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
342 |
+
|
343 |
+
# [*, H, Q/K, C_hidden]
|
344 |
+
q = q.transpose(-2, -3)
|
345 |
+
k = k.transpose(-2, -3)
|
346 |
+
v = v.transpose(-2, -3)
|
347 |
+
|
348 |
+
if apply_scale:
|
349 |
+
q /= math.sqrt(self.c_hidden)
|
350 |
+
|
351 |
+
return q, k, v
|
352 |
+
|
353 |
+
def _wrap_up(self,
|
354 |
+
o: torch.Tensor,
|
355 |
+
q_x: torch.Tensor
|
356 |
+
) -> torch.Tensor:
|
357 |
+
if self.linear_g is not None:
|
358 |
+
g = self.sigmoid(self.linear_g(q_x))
|
359 |
+
|
360 |
+
# [*, Q, H, C_hidden]
|
361 |
+
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
362 |
+
o = o * g
|
363 |
+
|
364 |
+
# [*, Q, H * C_hidden]
|
365 |
+
o = flatten_final_dims(o, 2)
|
366 |
+
|
367 |
+
# [*, Q, C_q]
|
368 |
+
o = self.linear_o(o)
|
369 |
+
|
370 |
+
return o
|
371 |
+
|
372 |
+
def forward(
|
373 |
+
self,
|
374 |
+
q_x: torch.Tensor,
|
375 |
+
kv_x: torch.Tensor,
|
376 |
+
biases: Optional[List[torch.Tensor]] = None,
|
377 |
+
use_memory_efficient_kernel: bool = False,
|
378 |
+
use_lma: bool = False,
|
379 |
+
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
|
380 |
+
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
|
381 |
+
) -> torch.Tensor:
|
382 |
+
"""
|
383 |
+
Args:
|
384 |
+
q_x:
|
385 |
+
[*, Q, C_q] query data
|
386 |
+
kv_x:
|
387 |
+
[*, K, C_k] key data
|
388 |
+
biases:
|
389 |
+
List of biases that broadcast to [*, H, Q, K]
|
390 |
+
use_memory_efficient_kernel:
|
391 |
+
Whether to use a custom memory-efficient attention kernel.
|
392 |
+
This should be the default choice for most. If none of the
|
393 |
+
"use_<...>" flags are True, a stock PyTorch implementation
|
394 |
+
is used instead
|
395 |
+
use_lma:
|
396 |
+
Whether to use low-memory attention (Staats & Rabe 2021). If
|
397 |
+
none of the "use_<...>" flags are True, a stock PyTorch
|
398 |
+
implementation is used instead
|
399 |
+
lma_q_chunk_size:
|
400 |
+
Query chunk size (for LMA)
|
401 |
+
lma_kv_chunk_size:
|
402 |
+
Key/Value chunk size (for LMA)
|
403 |
+
Returns
|
404 |
+
[*, Q, C_q] attention update
|
405 |
+
"""
|
406 |
+
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
|
407 |
+
raise ValueError(
|
408 |
+
"If use_lma is specified, lma_q_chunk_size and "
|
409 |
+
"lma_kv_chunk_size must be provided"
|
410 |
+
)
|
411 |
+
|
412 |
+
attn_options = [use_memory_efficient_kernel, use_lma]
|
413 |
+
if sum(attn_options) > 1:
|
414 |
+
raise ValueError(
|
415 |
+
"Choose at most one alternative attention algorithm"
|
416 |
+
)
|
417 |
+
|
418 |
+
if biases is None:
|
419 |
+
biases = []
|
420 |
+
|
421 |
+
q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=True)
|
422 |
+
|
423 |
+
if is_fp16_enabled():
|
424 |
+
use_memory_efficient_kernel = False
|
425 |
+
|
426 |
+
if use_memory_efficient_kernel:
|
427 |
+
if len(biases) > 2:
|
428 |
+
raise ValueError(
|
429 |
+
"If use_memory_efficient_kernel is True, you may only "
|
430 |
+
"provide up to two bias terms"
|
431 |
+
)
|
432 |
+
o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
|
433 |
+
o = o.transpose(-2, -3)
|
434 |
+
elif use_lma:
|
435 |
+
biases = [
|
436 |
+
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
|
437 |
+
for b in biases
|
438 |
+
]
|
439 |
+
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
|
440 |
+
o = o.transpose(-2, -3)
|
441 |
+
else:
|
442 |
+
o = _attention(q, k, v, biases)
|
443 |
+
o = o.transpose(-2, -3)
|
444 |
+
|
445 |
+
o = self._wrap_up(o, q_x)
|
446 |
+
|
447 |
+
return o
|
448 |
+
|
449 |
+
|
450 |
+
class GlobalAttention(nn.Module):
|
451 |
+
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
|
452 |
+
super(GlobalAttention, self).__init__()
|
453 |
+
|
454 |
+
self.c_in = c_in
|
455 |
+
self.c_hidden = c_hidden
|
456 |
+
self.no_heads = no_heads
|
457 |
+
self.inf = inf
|
458 |
+
self.eps = eps
|
459 |
+
|
460 |
+
self.linear_q = Linear(
|
461 |
+
c_in, c_hidden * no_heads, bias=False, init="glorot"
|
462 |
+
)
|
463 |
+
|
464 |
+
self.linear_k = Linear(
|
465 |
+
c_in, c_hidden, bias=False, init="glorot",
|
466 |
+
)
|
467 |
+
self.linear_v = Linear(
|
468 |
+
c_in, c_hidden, bias=False, init="glorot",
|
469 |
+
)
|
470 |
+
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
|
471 |
+
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
|
472 |
+
|
473 |
+
self.sigmoid = nn.Sigmoid()
|
474 |
+
|
475 |
+
def forward(self,
|
476 |
+
m: torch.Tensor,
|
477 |
+
mask: torch.Tensor,
|
478 |
+
use_lma: bool = False,
|
479 |
+
) -> torch.Tensor:
|
480 |
+
# [*, N_res, C_in]
|
481 |
+
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
|
482 |
+
torch.sum(mask, dim=-1)[..., None] + self.eps
|
483 |
+
)
|
484 |
+
|
485 |
+
# [*, N_res, H * C_hidden]
|
486 |
+
q = self.linear_q(q)
|
487 |
+
q *= (self.c_hidden ** (-0.5))
|
488 |
+
|
489 |
+
# [*, N_res, H, C_hidden]
|
490 |
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
491 |
+
|
492 |
+
# [*, N_res, C_hidden]
|
493 |
+
k = self.linear_k(m)
|
494 |
+
v = self.linear_v(m)
|
495 |
+
|
496 |
+
bias = (self.inf * (mask - 1))[..., :, None, :]
|
497 |
+
if not use_lma:
|
498 |
+
# [*, N_res, H, N_seq]
|
499 |
+
a = torch.matmul(
|
500 |
+
q,
|
501 |
+
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
|
502 |
+
)
|
503 |
+
a += bias
|
504 |
+
a = softmax_no_cast(a)
|
505 |
+
|
506 |
+
# [*, N_res, H, C_hidden]
|
507 |
+
o = torch.matmul(
|
508 |
+
a,
|
509 |
+
v,
|
510 |
+
)
|
511 |
+
else:
|
512 |
+
o = _lma(
|
513 |
+
q,
|
514 |
+
k,
|
515 |
+
v,
|
516 |
+
[bias],
|
517 |
+
DEFAULT_LMA_Q_CHUNK_SIZE,
|
518 |
+
DEFAULT_LMA_KV_CHUNK_SIZE
|
519 |
+
)
|
520 |
+
|
521 |
+
# [*, N_res, C_hidden]
|
522 |
+
g = self.sigmoid(self.linear_g(m))
|
523 |
+
|
524 |
+
# [*, N_res, H, C_hidden]
|
525 |
+
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
526 |
+
|
527 |
+
# [*, N_res, H, C_hidden]
|
528 |
+
o = o.unsqueeze(-3) * g
|
529 |
+
|
530 |
+
# [*, N_res, H * C_hidden]
|
531 |
+
o = o.reshape(o.shape[:-2] + (-1,))
|
532 |
+
|
533 |
+
# [*, N_res, C_in]
|
534 |
+
m = self.linear_o(o)
|
535 |
+
|
536 |
+
return m
|
537 |
+
|
538 |
+
|
539 |
+
def _lma(
|
540 |
+
q: torch.Tensor,
|
541 |
+
k: torch.Tensor,
|
542 |
+
v: torch.Tensor,
|
543 |
+
biases: List[torch.Tensor],
|
544 |
+
q_chunk_size: int,
|
545 |
+
kv_chunk_size: int,
|
546 |
+
):
|
547 |
+
no_q, no_kv = q.shape[-2], k.shape[-2]
|
548 |
+
|
549 |
+
# [*, H, Q, C_hidden]
|
550 |
+
o = q.new_zeros(q.shape)
|
551 |
+
for q_s in range(0, no_q, q_chunk_size):
|
552 |
+
q_chunk = q[..., q_s: q_s + q_chunk_size, :]
|
553 |
+
large_bias_chunks = [
|
554 |
+
b[..., q_s: q_s + q_chunk_size, :] for b in biases
|
555 |
+
]
|
556 |
+
|
557 |
+
maxes = []
|
558 |
+
weights = []
|
559 |
+
values = []
|
560 |
+
for kv_s in range(0, no_kv, kv_chunk_size):
|
561 |
+
k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
|
562 |
+
v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
|
563 |
+
small_bias_chunks = [
|
564 |
+
b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
|
565 |
+
]
|
566 |
+
|
567 |
+
a = torch.einsum(
|
568 |
+
"...hqd,...hkd->...hqk", q_chunk, k_chunk,
|
569 |
+
)
|
570 |
+
|
571 |
+
for b in small_bias_chunks:
|
572 |
+
a += b
|
573 |
+
|
574 |
+
max_a = torch.max(a, dim=-1, keepdim=True)[0]
|
575 |
+
exp_a = torch.exp(a - max_a)
|
576 |
+
exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
|
577 |
+
|
578 |
+
maxes.append(max_a.detach().squeeze(-1))
|
579 |
+
weights.append(torch.sum(exp_a, dim=-1))
|
580 |
+
values.append(exp_v)
|
581 |
+
|
582 |
+
chunk_max = torch.stack(maxes, dim=-3)
|
583 |
+
chunk_weights = torch.stack(weights, dim=-3)
|
584 |
+
chunk_values = torch.stack(values, dim=-4)
|
585 |
+
|
586 |
+
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
|
587 |
+
max_diffs = torch.exp(chunk_max - global_max)
|
588 |
+
chunk_values = chunk_values * max_diffs.unsqueeze(-1)
|
589 |
+
chunk_weights = chunk_weights * max_diffs
|
590 |
+
|
591 |
+
all_values = torch.sum(chunk_values, dim=-4)
|
592 |
+
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
|
593 |
+
|
594 |
+
q_chunk_out = all_values / all_weights
|
595 |
+
|
596 |
+
o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
|
597 |
+
|
598 |
+
return o
|
dockformer/model/single_attention.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from functools import partial
|
16 |
+
import math
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from typing import Optional, List, Tuple
|
21 |
+
|
22 |
+
from dockformer.model.primitives import (
|
23 |
+
Linear,
|
24 |
+
LayerNorm,
|
25 |
+
Attention,
|
26 |
+
)
|
27 |
+
from dockformer.utils.tensor_utils import permute_final_dims
|
28 |
+
|
29 |
+
|
30 |
+
class SingleAttention(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
c_in,
|
34 |
+
c_hidden,
|
35 |
+
no_heads,
|
36 |
+
pair_bias=False,
|
37 |
+
c_z=None,
|
38 |
+
inf=1e9,
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
c_in:
|
43 |
+
Input channel dimension
|
44 |
+
c_hidden:
|
45 |
+
Per-head hidden channel dimension
|
46 |
+
no_heads:
|
47 |
+
Number of attention heads
|
48 |
+
pair_bias:
|
49 |
+
Whether to use pair embedding bias
|
50 |
+
c_z:
|
51 |
+
Pair embedding channel dimension. Ignored unless pair_bias
|
52 |
+
is true
|
53 |
+
inf:
|
54 |
+
A large number to be used in computing the attention mask
|
55 |
+
"""
|
56 |
+
super(SingleAttention, self).__init__()
|
57 |
+
|
58 |
+
self.c_in = c_in
|
59 |
+
self.c_hidden = c_hidden
|
60 |
+
self.no_heads = no_heads
|
61 |
+
self.pair_bias = pair_bias
|
62 |
+
self.c_z = c_z
|
63 |
+
self.inf = inf
|
64 |
+
|
65 |
+
self.layer_norm_m = LayerNorm(self.c_in)
|
66 |
+
|
67 |
+
self.layer_norm_z = None
|
68 |
+
self.linear_z = None
|
69 |
+
if self.pair_bias:
|
70 |
+
self.layer_norm_z = LayerNorm(self.c_z)
|
71 |
+
self.linear_z = Linear(
|
72 |
+
self.c_z, self.no_heads, bias=False, init="normal"
|
73 |
+
)
|
74 |
+
|
75 |
+
self.mha = Attention(
|
76 |
+
self.c_in,
|
77 |
+
self.c_in,
|
78 |
+
self.c_in,
|
79 |
+
self.c_hidden,
|
80 |
+
self.no_heads,
|
81 |
+
)
|
82 |
+
|
83 |
+
def _prep_inputs(self,
|
84 |
+
m: torch.Tensor,
|
85 |
+
z: Optional[torch.Tensor],
|
86 |
+
mask: Optional[torch.Tensor],
|
87 |
+
inplace_safe: bool = False,
|
88 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
89 |
+
if mask is None:
|
90 |
+
# [*, N_res]
|
91 |
+
mask = m.new_ones(m.shape[:-1])
|
92 |
+
|
93 |
+
# [*, 1, 1, N_res]
|
94 |
+
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
95 |
+
|
96 |
+
if (self.pair_bias and
|
97 |
+
z is not None and # For the
|
98 |
+
self.layer_norm_z is not None and # benefit of
|
99 |
+
self.linear_z is not None # TorchScript
|
100 |
+
):
|
101 |
+
chunks = []
|
102 |
+
|
103 |
+
for i in range(0, z.shape[-3], 256):
|
104 |
+
z_chunk = z[..., i: i + 256, :, :]
|
105 |
+
|
106 |
+
# [*, N_res, N_res, C_z]
|
107 |
+
z_chunk = self.layer_norm_z(z_chunk)
|
108 |
+
|
109 |
+
# [*, N_res, N_res, no_heads]
|
110 |
+
z_chunk = self.linear_z(z_chunk)
|
111 |
+
|
112 |
+
chunks.append(z_chunk)
|
113 |
+
|
114 |
+
z = torch.cat(chunks, dim=-3)
|
115 |
+
|
116 |
+
# [*, no_heads, N_res, N_res]
|
117 |
+
z = permute_final_dims(z, (2, 0, 1))
|
118 |
+
|
119 |
+
return m, mask_bias, z
|
120 |
+
|
121 |
+
def forward(self,
|
122 |
+
m: torch.Tensor,
|
123 |
+
z: Optional[torch.Tensor] = None,
|
124 |
+
mask: Optional[torch.Tensor] = None,
|
125 |
+
use_memory_efficient_kernel: bool = False,
|
126 |
+
use_lma: bool = False,
|
127 |
+
inplace_safe: bool = False,
|
128 |
+
) -> torch.Tensor:
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
m:
|
132 |
+
[*, N_res, C_m] single embedding
|
133 |
+
z:
|
134 |
+
[*, N_res, N_res, C_z] pair embedding. Required only if pair_bias is True
|
135 |
+
mask:
|
136 |
+
[*, N_res] single mask
|
137 |
+
"""
|
138 |
+
m, mask_bias, z = self._prep_inputs(
|
139 |
+
m, z, mask, inplace_safe=inplace_safe
|
140 |
+
)
|
141 |
+
|
142 |
+
biases = [mask_bias]
|
143 |
+
if(z is not None):
|
144 |
+
biases.append(z)
|
145 |
+
|
146 |
+
m = self.layer_norm_m(m)
|
147 |
+
m = self.mha(
|
148 |
+
q_x=m,
|
149 |
+
kv_x=m,
|
150 |
+
biases=biases,
|
151 |
+
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
152 |
+
use_lma=use_lma,
|
153 |
+
)
|
154 |
+
|
155 |
+
return m
|
156 |
+
|
157 |
+
|
158 |
+
class SingleRowAttentionWithPairBias(SingleAttention):
|
159 |
+
"""
|
160 |
+
Implements Algorithm 7.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
c_m:
|
167 |
+
Input channel dimension
|
168 |
+
c_z:
|
169 |
+
Pair embedding channel dimension
|
170 |
+
c_hidden:
|
171 |
+
Per-head hidden channel dimension
|
172 |
+
no_heads:
|
173 |
+
Number of attention heads
|
174 |
+
inf:
|
175 |
+
Large number used to construct attention masks
|
176 |
+
"""
|
177 |
+
super(SingleRowAttentionWithPairBias, self).__init__(
|
178 |
+
c_m,
|
179 |
+
c_hidden,
|
180 |
+
no_heads,
|
181 |
+
pair_bias=True,
|
182 |
+
c_z=c_z,
|
183 |
+
inf=inf,
|
184 |
+
)
|
dockformer/model/structure_module.py
ADDED
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from functools import reduce
|
16 |
+
import importlib
|
17 |
+
import math
|
18 |
+
import sys
|
19 |
+
from operator import mul
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from typing import Optional, Tuple, Sequence, Union
|
24 |
+
|
25 |
+
from dockformer.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
|
26 |
+
from dockformer.utils.residue_constants import (
|
27 |
+
restype_rigid_group_default_frame,
|
28 |
+
restype_atom14_to_rigid_group,
|
29 |
+
restype_atom14_mask,
|
30 |
+
restype_atom14_rigid_group_positions,
|
31 |
+
)
|
32 |
+
from dockformer.utils.geometry.quat_rigid import QuatRigid
|
33 |
+
from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array
|
34 |
+
from dockformer.utils.geometry.vector import Vec3Array, square_euclidean_distance
|
35 |
+
from dockformer.utils.feats import (
|
36 |
+
frames_and_literature_positions_to_atom14_pos,
|
37 |
+
torsion_angles_to_frames,
|
38 |
+
)
|
39 |
+
from dockformer.utils.precision_utils import is_fp16_enabled
|
40 |
+
from dockformer.utils.rigid_utils import Rotation, Rigid
|
41 |
+
from dockformer.utils.tensor_utils import (
|
42 |
+
dict_multimap,
|
43 |
+
permute_final_dims,
|
44 |
+
flatten_final_dims,
|
45 |
+
)
|
46 |
+
|
47 |
+
import importlib.util
|
48 |
+
attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None
|
49 |
+
attn_core_inplace_cuda = None
|
50 |
+
if attn_core_is_installed:
|
51 |
+
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
|
52 |
+
|
53 |
+
|
54 |
+
class AngleResnetBlock(nn.Module):
|
55 |
+
def __init__(self, c_hidden):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
c_hidden:
|
59 |
+
Hidden channel dimension
|
60 |
+
"""
|
61 |
+
super(AngleResnetBlock, self).__init__()
|
62 |
+
|
63 |
+
self.c_hidden = c_hidden
|
64 |
+
|
65 |
+
self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
|
66 |
+
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
|
67 |
+
|
68 |
+
self.relu = nn.ReLU()
|
69 |
+
|
70 |
+
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
71 |
+
|
72 |
+
s_initial = a
|
73 |
+
|
74 |
+
a = self.relu(a)
|
75 |
+
a = self.linear_1(a)
|
76 |
+
a = self.relu(a)
|
77 |
+
a = self.linear_2(a)
|
78 |
+
|
79 |
+
return a + s_initial
|
80 |
+
|
81 |
+
|
82 |
+
class AngleResnet(nn.Module):
|
83 |
+
"""
|
84 |
+
Implements Algorithm 20, lines 11-14
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
|
88 |
+
"""
|
89 |
+
Args:
|
90 |
+
c_in:
|
91 |
+
Input channel dimension
|
92 |
+
c_hidden:
|
93 |
+
Hidden channel dimension
|
94 |
+
no_blocks:
|
95 |
+
Number of resnet blocks
|
96 |
+
no_angles:
|
97 |
+
Number of torsion angles to generate
|
98 |
+
epsilon:
|
99 |
+
Small constant for normalization
|
100 |
+
"""
|
101 |
+
super(AngleResnet, self).__init__()
|
102 |
+
|
103 |
+
self.c_in = c_in
|
104 |
+
self.c_hidden = c_hidden
|
105 |
+
self.no_blocks = no_blocks
|
106 |
+
self.no_angles = no_angles
|
107 |
+
self.eps = epsilon
|
108 |
+
|
109 |
+
self.linear_in = Linear(self.c_in, self.c_hidden)
|
110 |
+
self.linear_initial = Linear(self.c_in, self.c_hidden)
|
111 |
+
|
112 |
+
self.layers = nn.ModuleList()
|
113 |
+
for _ in range(self.no_blocks):
|
114 |
+
layer = AngleResnetBlock(c_hidden=self.c_hidden)
|
115 |
+
self.layers.append(layer)
|
116 |
+
|
117 |
+
self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
|
118 |
+
|
119 |
+
self.relu = nn.ReLU()
|
120 |
+
|
121 |
+
def forward(
|
122 |
+
self, s: torch.Tensor, s_initial: torch.Tensor
|
123 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
124 |
+
"""
|
125 |
+
Args:
|
126 |
+
s:
|
127 |
+
[*, C_hidden] single embedding
|
128 |
+
s_initial:
|
129 |
+
[*, C_hidden] single embedding as of the start of the
|
130 |
+
StructureModule
|
131 |
+
Returns:
|
132 |
+
[*, no_angles, 2] predicted angles
|
133 |
+
"""
|
134 |
+
# NOTE: The ReLU's applied to the inputs are absent from the supplement
|
135 |
+
# pseudocode but present in the source. For maximal compatibility with
|
136 |
+
# the pretrained weights, I'm going with the source.
|
137 |
+
|
138 |
+
# [*, C_hidden]
|
139 |
+
s_initial = self.relu(s_initial)
|
140 |
+
s_initial = self.linear_initial(s_initial)
|
141 |
+
s = self.relu(s)
|
142 |
+
s = self.linear_in(s)
|
143 |
+
s = s + s_initial
|
144 |
+
|
145 |
+
for l in self.layers:
|
146 |
+
s = l(s)
|
147 |
+
|
148 |
+
s = self.relu(s)
|
149 |
+
|
150 |
+
# [*, no_angles * 2]
|
151 |
+
s = self.linear_out(s)
|
152 |
+
|
153 |
+
# [*, no_angles, 2]
|
154 |
+
s = s.view(s.shape[:-1] + (-1, 2))
|
155 |
+
|
156 |
+
unnormalized_s = s
|
157 |
+
norm_denom = torch.sqrt(
|
158 |
+
torch.clamp(
|
159 |
+
torch.sum(s ** 2, dim=-1, keepdim=True),
|
160 |
+
min=self.eps,
|
161 |
+
)
|
162 |
+
)
|
163 |
+
s = s / norm_denom
|
164 |
+
|
165 |
+
return unnormalized_s, s
|
166 |
+
|
167 |
+
|
168 |
+
class PointProjection(nn.Module):
|
169 |
+
def __init__(self,
|
170 |
+
c_hidden: int,
|
171 |
+
num_points: int,
|
172 |
+
no_heads: int,
|
173 |
+
return_local_points: bool = False,
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
self.return_local_points = return_local_points
|
177 |
+
self.no_heads = no_heads
|
178 |
+
self.num_points = num_points
|
179 |
+
|
180 |
+
# Multimer requires this to be run with fp32 precision during training
|
181 |
+
precision = None
|
182 |
+
self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=precision)
|
183 |
+
|
184 |
+
def forward(self,
|
185 |
+
activations: torch.Tensor,
|
186 |
+
rigids: Union[Rigid, Rigid3Array],
|
187 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
188 |
+
# TODO: Needs to run in high precision during training
|
189 |
+
points_local = self.linear(activations)
|
190 |
+
out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
|
191 |
+
|
192 |
+
points_local = torch.split(
|
193 |
+
points_local, points_local.shape[-1] // 3, dim=-1
|
194 |
+
)
|
195 |
+
|
196 |
+
points_local = torch.stack(points_local, dim=-1).view(out_shape)
|
197 |
+
|
198 |
+
points_global = rigids[..., None, None].apply(points_local)
|
199 |
+
|
200 |
+
if(self.return_local_points):
|
201 |
+
return points_global, points_local
|
202 |
+
|
203 |
+
return points_global
|
204 |
+
|
205 |
+
|
206 |
+
class InvariantPointAttention(nn.Module):
|
207 |
+
"""
|
208 |
+
Implements Algorithm 22.
|
209 |
+
"""
|
210 |
+
def __init__(
|
211 |
+
self,
|
212 |
+
c_s: int,
|
213 |
+
c_z: int,
|
214 |
+
c_hidden: int,
|
215 |
+
no_heads: int,
|
216 |
+
no_qk_points: int,
|
217 |
+
no_v_points: int,
|
218 |
+
inf: float = 1e5,
|
219 |
+
eps: float = 1e-8,
|
220 |
+
):
|
221 |
+
"""
|
222 |
+
Args:
|
223 |
+
c_s:
|
224 |
+
Single representation channel dimension
|
225 |
+
c_z:
|
226 |
+
Pair representation channel dimension
|
227 |
+
c_hidden:
|
228 |
+
Hidden channel dimension
|
229 |
+
no_heads:
|
230 |
+
Number of attention heads
|
231 |
+
no_qk_points:
|
232 |
+
Number of query/key points to generate
|
233 |
+
no_v_points:
|
234 |
+
Number of value points to generate
|
235 |
+
"""
|
236 |
+
super(InvariantPointAttention, self).__init__()
|
237 |
+
|
238 |
+
self.c_s = c_s
|
239 |
+
self.c_z = c_z
|
240 |
+
self.c_hidden = c_hidden
|
241 |
+
self.no_heads = no_heads
|
242 |
+
self.no_qk_points = no_qk_points
|
243 |
+
self.no_v_points = no_v_points
|
244 |
+
self.inf = inf
|
245 |
+
self.eps = eps
|
246 |
+
|
247 |
+
# These linear layers differ from their specifications in the
|
248 |
+
# supplement. There, they lack bias and use Glorot initialization.
|
249 |
+
# Here as in the official source, they have bias and use the default
|
250 |
+
# Lecun initialization.
|
251 |
+
hc = self.c_hidden * self.no_heads
|
252 |
+
self.linear_q = Linear(self.c_s, hc, bias=True)
|
253 |
+
|
254 |
+
self.linear_q_points = PointProjection(
|
255 |
+
self.c_s,
|
256 |
+
self.no_qk_points,
|
257 |
+
self.no_heads,
|
258 |
+
)
|
259 |
+
|
260 |
+
|
261 |
+
self.linear_kv = Linear(self.c_s, 2 * hc)
|
262 |
+
self.linear_kv_points = PointProjection(
|
263 |
+
self.c_s,
|
264 |
+
self.no_qk_points + self.no_v_points,
|
265 |
+
self.no_heads,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.linear_b = Linear(self.c_z, self.no_heads)
|
269 |
+
|
270 |
+
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
|
271 |
+
ipa_point_weights_init_(self.head_weights)
|
272 |
+
|
273 |
+
concat_out_dim = self.no_heads * (
|
274 |
+
self.c_z + self.c_hidden + self.no_v_points * 4
|
275 |
+
)
|
276 |
+
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
|
277 |
+
|
278 |
+
self.softmax = nn.Softmax(dim=-1)
|
279 |
+
self.softplus = nn.Softplus()
|
280 |
+
|
281 |
+
def forward(
|
282 |
+
self,
|
283 |
+
s: torch.Tensor,
|
284 |
+
z: torch.Tensor,
|
285 |
+
r: Union[Rigid, Rigid3Array],
|
286 |
+
mask: torch.Tensor,
|
287 |
+
inplace_safe: bool = False,
|
288 |
+
) -> torch.Tensor:
|
289 |
+
"""
|
290 |
+
Args:
|
291 |
+
s:
|
292 |
+
[*, N_res, C_s] single representation
|
293 |
+
z:
|
294 |
+
[*, N_res, N_res, C_z] pair representation
|
295 |
+
r:
|
296 |
+
[*, N_res] transformation object
|
297 |
+
mask:
|
298 |
+
[*, N_res] mask
|
299 |
+
Returns:
|
300 |
+
[*, N_res, C_s] single representation update
|
301 |
+
"""
|
302 |
+
z = [z]
|
303 |
+
|
304 |
+
#######################################
|
305 |
+
# Generate scalar and point activations
|
306 |
+
#######################################
|
307 |
+
# [*, N_res, H * C_hidden]
|
308 |
+
q = self.linear_q(s)
|
309 |
+
|
310 |
+
# [*, N_res, H, C_hidden]
|
311 |
+
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
312 |
+
|
313 |
+
# [*, N_res, H, P_qk]
|
314 |
+
q_pts = self.linear_q_points(s, r)
|
315 |
+
|
316 |
+
# The following two blocks are equivalent
|
317 |
+
# They're separated only to preserve compatibility with old AF weights
|
318 |
+
|
319 |
+
# [*, N_res, H * 2 * C_hidden]
|
320 |
+
kv = self.linear_kv(s)
|
321 |
+
|
322 |
+
# [*, N_res, H, 2 * C_hidden]
|
323 |
+
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
|
324 |
+
|
325 |
+
# [*, N_res, H, C_hidden]
|
326 |
+
k, v = torch.split(kv, self.c_hidden, dim=-1)
|
327 |
+
|
328 |
+
kv_pts = self.linear_kv_points(s, r)
|
329 |
+
|
330 |
+
# [*, N_res, H, P_q/P_v, 3]
|
331 |
+
k_pts, v_pts = torch.split(
|
332 |
+
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
|
333 |
+
)
|
334 |
+
|
335 |
+
##########################
|
336 |
+
# Compute attention scores
|
337 |
+
##########################
|
338 |
+
# [*, N_res, N_res, H]
|
339 |
+
b = self.linear_b(z[0])
|
340 |
+
|
341 |
+
# [*, H, N_res, N_res]
|
342 |
+
if (is_fp16_enabled()):
|
343 |
+
with torch.cuda.amp.autocast(enabled=False):
|
344 |
+
a = torch.matmul(
|
345 |
+
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
|
346 |
+
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
a = torch.matmul(
|
350 |
+
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
|
351 |
+
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
|
352 |
+
)
|
353 |
+
|
354 |
+
a *= math.sqrt(1.0 / (3 * self.c_hidden))
|
355 |
+
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
|
356 |
+
|
357 |
+
# [*, N_res, N_res, H, P_q, 3]
|
358 |
+
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
|
359 |
+
|
360 |
+
if (inplace_safe):
|
361 |
+
pt_att *= pt_att
|
362 |
+
else:
|
363 |
+
pt_att = pt_att ** 2
|
364 |
+
|
365 |
+
pt_att = sum(torch.unbind(pt_att, dim=-1))
|
366 |
+
|
367 |
+
head_weights = self.softplus(self.head_weights).view(
|
368 |
+
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
|
369 |
+
)
|
370 |
+
head_weights = head_weights * math.sqrt(
|
371 |
+
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
|
372 |
+
)
|
373 |
+
|
374 |
+
if (inplace_safe):
|
375 |
+
pt_att *= head_weights
|
376 |
+
else:
|
377 |
+
pt_att = pt_att * head_weights
|
378 |
+
|
379 |
+
# [*, N_res, N_res, H]
|
380 |
+
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
|
381 |
+
|
382 |
+
# [*, N_res, N_res]
|
383 |
+
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
384 |
+
square_mask = self.inf * (square_mask - 1)
|
385 |
+
|
386 |
+
# [*, H, N_res, N_res]
|
387 |
+
pt_att = permute_final_dims(pt_att, (2, 0, 1))
|
388 |
+
|
389 |
+
if (inplace_safe):
|
390 |
+
a += pt_att
|
391 |
+
del pt_att
|
392 |
+
a += square_mask.unsqueeze(-3)
|
393 |
+
# in-place softmax
|
394 |
+
attn_core_inplace_cuda.forward_(
|
395 |
+
a,
|
396 |
+
reduce(mul, a.shape[:-1]),
|
397 |
+
a.shape[-1],
|
398 |
+
)
|
399 |
+
else:
|
400 |
+
a = a + pt_att
|
401 |
+
a = a + square_mask.unsqueeze(-3)
|
402 |
+
a = self.softmax(a)
|
403 |
+
|
404 |
+
################
|
405 |
+
# Compute output
|
406 |
+
################
|
407 |
+
# [*, N_res, H, C_hidden]
|
408 |
+
o = torch.matmul(
|
409 |
+
a, v.transpose(-2, -3).to(dtype=a.dtype)
|
410 |
+
).transpose(-2, -3)
|
411 |
+
|
412 |
+
# [*, N_res, H * C_hidden]
|
413 |
+
o = flatten_final_dims(o, 2)
|
414 |
+
|
415 |
+
# [*, H, 3, N_res, P_v]
|
416 |
+
if (inplace_safe):
|
417 |
+
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
|
418 |
+
o_pt = [
|
419 |
+
torch.matmul(a, v.to(a.dtype))
|
420 |
+
for v in torch.unbind(v_pts, dim=-3)
|
421 |
+
]
|
422 |
+
o_pt = torch.stack(o_pt, dim=-3)
|
423 |
+
else:
|
424 |
+
o_pt = torch.sum(
|
425 |
+
(
|
426 |
+
a[..., None, :, :, None]
|
427 |
+
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
|
428 |
+
),
|
429 |
+
dim=-2,
|
430 |
+
)
|
431 |
+
|
432 |
+
# [*, N_res, H, P_v, 3]
|
433 |
+
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
|
434 |
+
o_pt = r[..., None, None].invert_apply(o_pt)
|
435 |
+
|
436 |
+
# [*, N_res, H * P_v]
|
437 |
+
o_pt_norm = flatten_final_dims(
|
438 |
+
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
|
439 |
+
)
|
440 |
+
|
441 |
+
# [*, N_res, H * P_v, 3]
|
442 |
+
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
|
443 |
+
o_pt = torch.unbind(o_pt, dim=-1)
|
444 |
+
|
445 |
+
# [*, N_res, H, C_z]
|
446 |
+
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
|
447 |
+
|
448 |
+
# [*, N_res, H * C_z]
|
449 |
+
o_pair = flatten_final_dims(o_pair, 2)
|
450 |
+
|
451 |
+
# [*, N_res, C_s]
|
452 |
+
s = self.linear_out(
|
453 |
+
torch.cat(
|
454 |
+
(o, *o_pt, o_pt_norm, o_pair), dim=-1
|
455 |
+
).to(dtype=z[0].dtype)
|
456 |
+
)
|
457 |
+
|
458 |
+
return s
|
459 |
+
|
460 |
+
|
461 |
+
class BackboneUpdate(nn.Module):
|
462 |
+
"""
|
463 |
+
Implements part of Algorithm 23.
|
464 |
+
"""
|
465 |
+
|
466 |
+
def __init__(self, c_s):
|
467 |
+
"""
|
468 |
+
Args:
|
469 |
+
c_s:
|
470 |
+
Single representation channel dimension
|
471 |
+
"""
|
472 |
+
super(BackboneUpdate, self).__init__()
|
473 |
+
|
474 |
+
self.c_s = c_s
|
475 |
+
|
476 |
+
self.linear = Linear(self.c_s, 6, init="final")
|
477 |
+
|
478 |
+
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
479 |
+
"""
|
480 |
+
Args:
|
481 |
+
[*, N_res, C_s] single representation
|
482 |
+
Returns:
|
483 |
+
[*, N_res, 6] update vector
|
484 |
+
"""
|
485 |
+
# [*, 6]
|
486 |
+
update = self.linear(s)
|
487 |
+
|
488 |
+
return update
|
489 |
+
|
490 |
+
|
491 |
+
class StructureModuleTransitionLayer(nn.Module):
|
492 |
+
def __init__(self, c):
|
493 |
+
super(StructureModuleTransitionLayer, self).__init__()
|
494 |
+
|
495 |
+
self.c = c
|
496 |
+
|
497 |
+
self.linear_1 = Linear(self.c, self.c, init="relu")
|
498 |
+
self.linear_2 = Linear(self.c, self.c, init="relu")
|
499 |
+
self.linear_3 = Linear(self.c, self.c, init="final")
|
500 |
+
|
501 |
+
self.relu = nn.ReLU()
|
502 |
+
|
503 |
+
def forward(self, s):
|
504 |
+
s_initial = s
|
505 |
+
s = self.linear_1(s)
|
506 |
+
s = self.relu(s)
|
507 |
+
s = self.linear_2(s)
|
508 |
+
s = self.relu(s)
|
509 |
+
s = self.linear_3(s)
|
510 |
+
|
511 |
+
s = s + s_initial
|
512 |
+
|
513 |
+
return s
|
514 |
+
|
515 |
+
|
516 |
+
class StructureModuleTransition(nn.Module):
|
517 |
+
def __init__(self, c, num_layers, dropout_rate):
|
518 |
+
super(StructureModuleTransition, self).__init__()
|
519 |
+
|
520 |
+
self.c = c
|
521 |
+
self.num_layers = num_layers
|
522 |
+
self.dropout_rate = dropout_rate
|
523 |
+
|
524 |
+
self.layers = nn.ModuleList()
|
525 |
+
for _ in range(self.num_layers):
|
526 |
+
l = StructureModuleTransitionLayer(self.c)
|
527 |
+
self.layers.append(l)
|
528 |
+
|
529 |
+
self.dropout = nn.Dropout(self.dropout_rate)
|
530 |
+
self.layer_norm = LayerNorm(self.c)
|
531 |
+
|
532 |
+
def forward(self, s):
|
533 |
+
for l in self.layers:
|
534 |
+
s = l(s)
|
535 |
+
|
536 |
+
s = self.dropout(s)
|
537 |
+
s = self.layer_norm(s)
|
538 |
+
|
539 |
+
return s
|
540 |
+
|
541 |
+
|
542 |
+
class StructureModule(nn.Module):
|
543 |
+
def __init__(
|
544 |
+
self,
|
545 |
+
c_s,
|
546 |
+
c_z,
|
547 |
+
c_ipa,
|
548 |
+
c_resnet,
|
549 |
+
no_heads_ipa,
|
550 |
+
no_qk_points,
|
551 |
+
no_v_points,
|
552 |
+
dropout_rate,
|
553 |
+
no_blocks,
|
554 |
+
no_transition_layers,
|
555 |
+
no_resnet_blocks,
|
556 |
+
no_angles,
|
557 |
+
trans_scale_factor,
|
558 |
+
epsilon,
|
559 |
+
inf,
|
560 |
+
**kwargs,
|
561 |
+
):
|
562 |
+
"""
|
563 |
+
Args:
|
564 |
+
c_s:
|
565 |
+
Single representation channel dimension
|
566 |
+
c_z:
|
567 |
+
Pair representation channel dimension
|
568 |
+
c_ipa:
|
569 |
+
IPA hidden channel dimension
|
570 |
+
c_resnet:
|
571 |
+
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
|
572 |
+
no_heads_ipa:
|
573 |
+
Number of IPA heads
|
574 |
+
no_qk_points:
|
575 |
+
Number of query/key points to generate during IPA
|
576 |
+
no_v_points:
|
577 |
+
Number of value points to generate during IPA
|
578 |
+
dropout_rate:
|
579 |
+
Dropout rate used throughout the layer
|
580 |
+
no_blocks:
|
581 |
+
Number of structure module blocks
|
582 |
+
no_transition_layers:
|
583 |
+
Number of layers in the single representation transition
|
584 |
+
(Alg. 23 lines 8-9)
|
585 |
+
no_resnet_blocks:
|
586 |
+
Number of blocks in the angle resnet
|
587 |
+
no_angles:
|
588 |
+
Number of angles to generate in the angle resnet
|
589 |
+
trans_scale_factor:
|
590 |
+
Scale of single representation transition hidden dimension
|
591 |
+
epsilon:
|
592 |
+
Small number used in angle resnet normalization
|
593 |
+
inf:
|
594 |
+
Large number used for attention masking
|
595 |
+
"""
|
596 |
+
super(StructureModule, self).__init__()
|
597 |
+
|
598 |
+
self.c_s = c_s
|
599 |
+
self.c_z = c_z
|
600 |
+
self.c_ipa = c_ipa
|
601 |
+
self.c_resnet = c_resnet
|
602 |
+
self.no_heads_ipa = no_heads_ipa
|
603 |
+
self.no_qk_points = no_qk_points
|
604 |
+
self.no_v_points = no_v_points
|
605 |
+
self.dropout_rate = dropout_rate
|
606 |
+
self.no_blocks = no_blocks
|
607 |
+
self.no_transition_layers = no_transition_layers
|
608 |
+
self.no_resnet_blocks = no_resnet_blocks
|
609 |
+
self.no_angles = no_angles
|
610 |
+
self.trans_scale_factor = trans_scale_factor
|
611 |
+
self.epsilon = epsilon
|
612 |
+
self.inf = inf
|
613 |
+
|
614 |
+
# Buffers to be lazily initialized later
|
615 |
+
# self.default_frames
|
616 |
+
# self.group_idx
|
617 |
+
# self.atom_mask
|
618 |
+
# self.lit_positions
|
619 |
+
|
620 |
+
self.layer_norm_s = LayerNorm(self.c_s)
|
621 |
+
self.layer_norm_z = LayerNorm(self.c_z)
|
622 |
+
|
623 |
+
self.linear_in = Linear(self.c_s, self.c_s)
|
624 |
+
|
625 |
+
self.ipa = InvariantPointAttention(
|
626 |
+
self.c_s,
|
627 |
+
self.c_z,
|
628 |
+
self.c_ipa,
|
629 |
+
self.no_heads_ipa,
|
630 |
+
self.no_qk_points,
|
631 |
+
self.no_v_points,
|
632 |
+
inf=self.inf,
|
633 |
+
eps=self.epsilon,
|
634 |
+
)
|
635 |
+
|
636 |
+
self.ipa_dropout = nn.Dropout(self.dropout_rate)
|
637 |
+
self.layer_norm_ipa = LayerNorm(self.c_s)
|
638 |
+
|
639 |
+
self.transition = StructureModuleTransition(
|
640 |
+
self.c_s,
|
641 |
+
self.no_transition_layers,
|
642 |
+
self.dropout_rate,
|
643 |
+
)
|
644 |
+
|
645 |
+
self.bb_update = BackboneUpdate(self.c_s)
|
646 |
+
|
647 |
+
self.angle_resnet = AngleResnet(
|
648 |
+
self.c_s,
|
649 |
+
self.c_resnet,
|
650 |
+
self.no_resnet_blocks,
|
651 |
+
self.no_angles,
|
652 |
+
self.epsilon,
|
653 |
+
)
|
654 |
+
|
655 |
+
def forward(
|
656 |
+
self,
|
657 |
+
evoformer_output_dict,
|
658 |
+
aatype,
|
659 |
+
mask=None,
|
660 |
+
inplace_safe=False,
|
661 |
+
):
|
662 |
+
"""
|
663 |
+
Args:
|
664 |
+
evoformer_output_dict:
|
665 |
+
Dictionary containing:
|
666 |
+
"single":
|
667 |
+
[*, N_res, C_s] single representation
|
668 |
+
"pair":
|
669 |
+
[*, N_res, N_res, C_z] pair representation
|
670 |
+
aatype:
|
671 |
+
[*, N_res] amino acid indices
|
672 |
+
mask:
|
673 |
+
Optional [*, N_res] sequence mask
|
674 |
+
Returns:
|
675 |
+
A dictionary of outputs
|
676 |
+
"""
|
677 |
+
s = evoformer_output_dict["single"]
|
678 |
+
|
679 |
+
if mask is None:
|
680 |
+
# [*, N]
|
681 |
+
mask = s.new_ones(s.shape[:-1])
|
682 |
+
|
683 |
+
# [*, N, C_s]
|
684 |
+
s = self.layer_norm_s(s)
|
685 |
+
|
686 |
+
# [*, N, N, C_z]
|
687 |
+
z = self.layer_norm_z(evoformer_output_dict["pair"])
|
688 |
+
|
689 |
+
# [*, N, C_s]
|
690 |
+
s_initial = s
|
691 |
+
s = self.linear_in(s)
|
692 |
+
|
693 |
+
# [*, N]
|
694 |
+
rigids = Rigid.identity(
|
695 |
+
s.shape[:-1],
|
696 |
+
s.dtype,
|
697 |
+
s.device,
|
698 |
+
self.training,
|
699 |
+
fmt="quat",
|
700 |
+
)
|
701 |
+
outputs = []
|
702 |
+
for i in range(self.no_blocks):
|
703 |
+
# [*, N, C_s]
|
704 |
+
s = s + self.ipa(
|
705 |
+
s,
|
706 |
+
z,
|
707 |
+
rigids,
|
708 |
+
mask,
|
709 |
+
inplace_safe=inplace_safe,
|
710 |
+
)
|
711 |
+
s = self.ipa_dropout(s)
|
712 |
+
s = self.layer_norm_ipa(s)
|
713 |
+
s = self.transition(s)
|
714 |
+
|
715 |
+
# [*, N]
|
716 |
+
|
717 |
+
# [*, N_res, 6] vector of translations and rotations
|
718 |
+
bb_update_output = self.bb_update(s)
|
719 |
+
|
720 |
+
rigids = rigids.compose_q_update_vec(bb_update_output)
|
721 |
+
|
722 |
+
|
723 |
+
# To hew as closely as possible to AlphaFold, we convert our
|
724 |
+
# quaternion-based transformations to rotation-matrix ones
|
725 |
+
# here
|
726 |
+
backb_to_global = Rigid(
|
727 |
+
Rotation(
|
728 |
+
rot_mats=rigids.get_rots().get_rot_mats(),
|
729 |
+
quats=None
|
730 |
+
),
|
731 |
+
rigids.get_trans(),
|
732 |
+
)
|
733 |
+
|
734 |
+
backb_to_global = backb_to_global.scale_translation(
|
735 |
+
self.trans_scale_factor
|
736 |
+
)
|
737 |
+
|
738 |
+
# [*, N, 7, 2]
|
739 |
+
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
|
740 |
+
|
741 |
+
all_frames_to_global = self.torsion_angles_to_frames(
|
742 |
+
backb_to_global,
|
743 |
+
angles,
|
744 |
+
aatype,
|
745 |
+
)
|
746 |
+
|
747 |
+
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
|
748 |
+
all_frames_to_global,
|
749 |
+
aatype,
|
750 |
+
)
|
751 |
+
|
752 |
+
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
|
753 |
+
|
754 |
+
preds = {
|
755 |
+
"frames": scaled_rigids.to_tensor_7(),
|
756 |
+
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
|
757 |
+
"unnormalized_angles": unnormalized_angles,
|
758 |
+
"angles": angles,
|
759 |
+
"positions": pred_xyz,
|
760 |
+
"states": s,
|
761 |
+
}
|
762 |
+
|
763 |
+
outputs.append(preds)
|
764 |
+
|
765 |
+
rigids = rigids.stop_rot_gradient()
|
766 |
+
|
767 |
+
del z
|
768 |
+
|
769 |
+
outputs = dict_multimap(torch.stack, outputs)
|
770 |
+
outputs["single"] = s
|
771 |
+
|
772 |
+
return outputs
|
773 |
+
|
774 |
+
def _init_residue_constants(self, float_dtype, device):
|
775 |
+
if not hasattr(self, "default_frames"):
|
776 |
+
self.register_buffer(
|
777 |
+
"default_frames",
|
778 |
+
torch.tensor(
|
779 |
+
restype_rigid_group_default_frame,
|
780 |
+
dtype=float_dtype,
|
781 |
+
device=device,
|
782 |
+
requires_grad=False,
|
783 |
+
),
|
784 |
+
persistent=False,
|
785 |
+
)
|
786 |
+
if not hasattr(self, "group_idx"):
|
787 |
+
self.register_buffer(
|
788 |
+
"group_idx",
|
789 |
+
torch.tensor(
|
790 |
+
restype_atom14_to_rigid_group,
|
791 |
+
device=device,
|
792 |
+
requires_grad=False,
|
793 |
+
),
|
794 |
+
persistent=False,
|
795 |
+
)
|
796 |
+
if not hasattr(self, "atom_mask"):
|
797 |
+
self.register_buffer(
|
798 |
+
"atom_mask",
|
799 |
+
torch.tensor(
|
800 |
+
restype_atom14_mask,
|
801 |
+
dtype=float_dtype,
|
802 |
+
device=device,
|
803 |
+
requires_grad=False,
|
804 |
+
),
|
805 |
+
persistent=False,
|
806 |
+
)
|
807 |
+
if not hasattr(self, "lit_positions"):
|
808 |
+
self.register_buffer(
|
809 |
+
"lit_positions",
|
810 |
+
torch.tensor(
|
811 |
+
restype_atom14_rigid_group_positions,
|
812 |
+
dtype=float_dtype,
|
813 |
+
device=device,
|
814 |
+
requires_grad=False,
|
815 |
+
),
|
816 |
+
persistent=False,
|
817 |
+
)
|
818 |
+
|
819 |
+
def torsion_angles_to_frames(self, r, alpha, f):
|
820 |
+
# Lazily initialize the residue constants on the correct device
|
821 |
+
self._init_residue_constants(alpha.dtype, alpha.device)
|
822 |
+
# Separated purely to make testing less annoying
|
823 |
+
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
|
824 |
+
|
825 |
+
def frames_and_literature_positions_to_atom14_pos(
|
826 |
+
self, r, f # [*, N, 8] # [*, N]
|
827 |
+
):
|
828 |
+
# Lazily initialize the residue constants on the correct device
|
829 |
+
self._init_residue_constants(r.dtype, r.device)
|
830 |
+
return frames_and_literature_positions_to_atom14_pos(
|
831 |
+
r,
|
832 |
+
f,
|
833 |
+
self.default_frames,
|
834 |
+
self.group_idx,
|
835 |
+
self.atom_mask,
|
836 |
+
self.lit_positions,
|
837 |
+
)
|
dockformer/model/torchscript.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Optional, Sequence, Tuple
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from dockformer.model.evoformer import (
|
21 |
+
EvoformerBlock,
|
22 |
+
EvoformerStack,
|
23 |
+
)
|
24 |
+
from dockformer.model.single_attention import SingleRowAttentionWithPairBias
|
25 |
+
from dockformer.model.primitives import Attention, GlobalAttention
|
26 |
+
|
27 |
+
|
28 |
+
def script_preset_(model: torch.nn.Module):
|
29 |
+
"""
|
30 |
+
TorchScript a handful of low-level but frequently used submodule types
|
31 |
+
that are known to be scriptable.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
model:
|
35 |
+
A torch.nn.Module. It should contain at least some modules from
|
36 |
+
this repository, or this function won't do anything.
|
37 |
+
"""
|
38 |
+
script_submodules_(
|
39 |
+
model,
|
40 |
+
[
|
41 |
+
nn.Dropout,
|
42 |
+
Attention,
|
43 |
+
GlobalAttention,
|
44 |
+
EvoformerBlock,
|
45 |
+
],
|
46 |
+
attempt_trace=False,
|
47 |
+
batch_dims=None,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def _get_module_device(module: torch.nn.Module) -> torch.device:
|
52 |
+
"""
|
53 |
+
Fetches the device of a module, assuming that all of the module's
|
54 |
+
parameters reside on a single device
|
55 |
+
|
56 |
+
Args:
|
57 |
+
module: A torch.nn.Module
|
58 |
+
Returns:
|
59 |
+
The module's device
|
60 |
+
"""
|
61 |
+
return next(module.parameters()).device
|
62 |
+
|
63 |
+
|
64 |
+
def _trace_module(module, batch_dims=None):
|
65 |
+
if(batch_dims is None):
|
66 |
+
batch_dims = ()
|
67 |
+
|
68 |
+
# Stand-in values
|
69 |
+
n_seq = 10
|
70 |
+
n_res = 10
|
71 |
+
|
72 |
+
device = _get_module_device(module)
|
73 |
+
|
74 |
+
def msa(channel_dim):
|
75 |
+
return torch.rand(
|
76 |
+
(*batch_dims, n_seq, n_res, channel_dim),
|
77 |
+
device=device,
|
78 |
+
)
|
79 |
+
|
80 |
+
def pair(channel_dim):
|
81 |
+
return torch.rand(
|
82 |
+
(*batch_dims, n_res, n_res, channel_dim),
|
83 |
+
device=device,
|
84 |
+
)
|
85 |
+
|
86 |
+
if(isinstance(module, SingleRowAttentionWithPairBias)):
|
87 |
+
inputs = {
|
88 |
+
"forward": (
|
89 |
+
msa(module.c_in), # m
|
90 |
+
pair(module.c_z), # z
|
91 |
+
torch.randint(
|
92 |
+
0, 2,
|
93 |
+
(*batch_dims, n_seq, n_res)
|
94 |
+
), # mask
|
95 |
+
),
|
96 |
+
}
|
97 |
+
else:
|
98 |
+
raise TypeError(
|
99 |
+
f"tracing is not supported for modules of type {type(module)}"
|
100 |
+
)
|
101 |
+
|
102 |
+
return torch.jit.trace_module(module, inputs)
|
103 |
+
|
104 |
+
|
105 |
+
def _script_submodules_helper_(
|
106 |
+
model,
|
107 |
+
types,
|
108 |
+
attempt_trace,
|
109 |
+
to_trace,
|
110 |
+
):
|
111 |
+
for name, child in model.named_children():
|
112 |
+
if(types is None or any(isinstance(child, t) for t in types)):
|
113 |
+
try:
|
114 |
+
scripted = torch.jit.script(child)
|
115 |
+
setattr(model, name, scripted)
|
116 |
+
continue
|
117 |
+
except (RuntimeError, torch.jit.frontend.NotSupportedError) as e:
|
118 |
+
if(attempt_trace):
|
119 |
+
to_trace.add(type(child))
|
120 |
+
else:
|
121 |
+
raise e
|
122 |
+
|
123 |
+
_script_submodules_helper_(child, types, attempt_trace, to_trace)
|
124 |
+
|
125 |
+
|
126 |
+
def _trace_submodules_(
|
127 |
+
model,
|
128 |
+
types,
|
129 |
+
batch_dims=None,
|
130 |
+
):
|
131 |
+
for name, child in model.named_children():
|
132 |
+
if(any(isinstance(child, t) for t in types)):
|
133 |
+
traced = _trace_module(child, batch_dims=batch_dims)
|
134 |
+
setattr(model, name, traced)
|
135 |
+
else:
|
136 |
+
_trace_submodules_(child, types, batch_dims=batch_dims)
|
137 |
+
|
138 |
+
|
139 |
+
def script_submodules_(
|
140 |
+
model: nn.Module,
|
141 |
+
types: Optional[Sequence[type]] = None,
|
142 |
+
attempt_trace: Optional[bool] = True,
|
143 |
+
batch_dims: Optional[Tuple[int]] = None,
|
144 |
+
):
|
145 |
+
"""
|
146 |
+
Convert all submodules whose types match one of those in the input
|
147 |
+
list to recursively scripted equivalents in place. To script the entire
|
148 |
+
model, just call torch.jit.script on it directly.
|
149 |
+
|
150 |
+
When types is None, all submodules are scripted.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
model:
|
154 |
+
A torch.nn.Module
|
155 |
+
types:
|
156 |
+
A list of types of submodules to script
|
157 |
+
attempt_trace:
|
158 |
+
Whether to attempt to trace specified modules if scripting
|
159 |
+
fails. Recall that tracing eliminates all conditional
|
160 |
+
logic---with great tracing comes the mild responsibility of
|
161 |
+
having to remember to ensure that the modules in question
|
162 |
+
perform the same computations no matter what.
|
163 |
+
"""
|
164 |
+
to_trace = set()
|
165 |
+
|
166 |
+
# Aggressively script as much as possible first...
|
167 |
+
_script_submodules_helper_(model, types, attempt_trace, to_trace)
|
168 |
+
|
169 |
+
# ... and then trace stragglers.
|
170 |
+
if(attempt_trace and len(to_trace) > 0):
|
171 |
+
_trace_submodules_(model, to_trace, batch_dims=batch_dims)
|
dockformer/model/triangular_attention.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partialmethod, partial
|
17 |
+
import math
|
18 |
+
from typing import Optional, List
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
|
23 |
+
from dockformer.model.primitives import Linear, LayerNorm, Attention
|
24 |
+
from dockformer.utils.tensor_utils import permute_final_dims
|
25 |
+
|
26 |
+
|
27 |
+
class TriangleAttention(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self, c_in, c_hidden, no_heads, starting=True, inf=1e9
|
30 |
+
):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
c_in:
|
34 |
+
Input channel dimension
|
35 |
+
c_hidden:
|
36 |
+
Overall hidden channel dimension (not per-head)
|
37 |
+
no_heads:
|
38 |
+
Number of attention heads
|
39 |
+
"""
|
40 |
+
super(TriangleAttention, self).__init__()
|
41 |
+
|
42 |
+
self.c_in = c_in
|
43 |
+
self.c_hidden = c_hidden
|
44 |
+
self.no_heads = no_heads
|
45 |
+
self.starting = starting
|
46 |
+
self.inf = inf
|
47 |
+
|
48 |
+
self.layer_norm = LayerNorm(self.c_in)
|
49 |
+
|
50 |
+
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
|
51 |
+
|
52 |
+
self.mha = Attention(
|
53 |
+
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self,
|
57 |
+
x: torch.Tensor,
|
58 |
+
mask: Optional[torch.Tensor] = None,
|
59 |
+
use_memory_efficient_kernel: bool = False,
|
60 |
+
use_lma: bool = False,
|
61 |
+
) -> torch.Tensor:
|
62 |
+
"""
|
63 |
+
Args:
|
64 |
+
x:
|
65 |
+
[*, I, J, C_in] input tensor (e.g. the pair representation)
|
66 |
+
Returns:
|
67 |
+
[*, I, J, C_in] output tensor
|
68 |
+
"""
|
69 |
+
if mask is None:
|
70 |
+
# [*, I, J]
|
71 |
+
mask = x.new_ones(
|
72 |
+
x.shape[:-1],
|
73 |
+
)
|
74 |
+
|
75 |
+
if(not self.starting):
|
76 |
+
x = x.transpose(-2, -3)
|
77 |
+
mask = mask.transpose(-1, -2)
|
78 |
+
|
79 |
+
# [*, I, J, C_in]
|
80 |
+
x = self.layer_norm(x)
|
81 |
+
|
82 |
+
# [*, I, 1, 1, J]
|
83 |
+
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
84 |
+
|
85 |
+
# [*, H, I, J]
|
86 |
+
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
87 |
+
|
88 |
+
# [*, 1, H, I, J]
|
89 |
+
triangle_bias = triangle_bias.unsqueeze(-4)
|
90 |
+
|
91 |
+
biases = [mask_bias, triangle_bias]
|
92 |
+
|
93 |
+
x = self.mha(
|
94 |
+
q_x=x,
|
95 |
+
kv_x=x,
|
96 |
+
biases=biases,
|
97 |
+
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
98 |
+
use_lma=use_lma
|
99 |
+
)
|
100 |
+
|
101 |
+
if(not self.starting):
|
102 |
+
x = x.transpose(-2, -3)
|
103 |
+
|
104 |
+
return x
|
dockformer/model/triangular_multiplicative_update.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partialmethod
|
17 |
+
from typing import Optional
|
18 |
+
from abc import ABC, abstractmethod
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
|
23 |
+
from dockformer.model.primitives import Linear, LayerNorm
|
24 |
+
from dockformer.utils.precision_utils import is_fp16_enabled
|
25 |
+
from dockformer.utils.tensor_utils import permute_final_dims
|
26 |
+
|
27 |
+
|
28 |
+
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
|
29 |
+
"""
|
30 |
+
Implements Algorithms 11 and 12.
|
31 |
+
"""
|
32 |
+
@abstractmethod
|
33 |
+
def __init__(self, c_z, c_hidden, _outgoing):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
c_z:
|
37 |
+
Input channel dimension
|
38 |
+
c:
|
39 |
+
Hidden channel dimension
|
40 |
+
"""
|
41 |
+
super(BaseTriangleMultiplicativeUpdate, self).__init__()
|
42 |
+
self.c_z = c_z
|
43 |
+
self.c_hidden = c_hidden
|
44 |
+
self._outgoing = _outgoing
|
45 |
+
|
46 |
+
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
|
47 |
+
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
|
48 |
+
|
49 |
+
self.layer_norm_in = LayerNorm(self.c_z)
|
50 |
+
self.layer_norm_out = LayerNorm(self.c_hidden)
|
51 |
+
|
52 |
+
self.sigmoid = nn.Sigmoid()
|
53 |
+
|
54 |
+
def _combine_projections(self,
|
55 |
+
a: torch.Tensor,
|
56 |
+
b: torch.Tensor,
|
57 |
+
) -> torch.Tensor:
|
58 |
+
if(self._outgoing):
|
59 |
+
a = permute_final_dims(a, (2, 0, 1))
|
60 |
+
b = permute_final_dims(b, (2, 1, 0))
|
61 |
+
else:
|
62 |
+
a = permute_final_dims(a, (2, 1, 0))
|
63 |
+
b = permute_final_dims(b, (2, 0, 1))
|
64 |
+
|
65 |
+
p = torch.matmul(a, b)
|
66 |
+
|
67 |
+
return permute_final_dims(p, (1, 2, 0))
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def forward(self,
|
71 |
+
z: torch.Tensor,
|
72 |
+
mask: Optional[torch.Tensor] = None,
|
73 |
+
inplace_safe: bool = False,
|
74 |
+
_add_with_inplace: bool = False
|
75 |
+
) -> torch.Tensor:
|
76 |
+
"""
|
77 |
+
Args:
|
78 |
+
x:
|
79 |
+
[*, N_res, N_res, C_z] input tensor
|
80 |
+
mask:
|
81 |
+
[*, N_res, N_res] input mask
|
82 |
+
Returns:
|
83 |
+
[*, N_res, N_res, C_z] output tensor
|
84 |
+
"""
|
85 |
+
pass
|
86 |
+
|
87 |
+
|
88 |
+
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
89 |
+
"""
|
90 |
+
Implements Algorithms 11 and 12.
|
91 |
+
"""
|
92 |
+
def __init__(self, c_z, c_hidden, _outgoing=True):
|
93 |
+
"""
|
94 |
+
Args:
|
95 |
+
c_z:
|
96 |
+
Input channel dimension
|
97 |
+
c:
|
98 |
+
Hidden channel dimension
|
99 |
+
"""
|
100 |
+
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
|
101 |
+
c_hidden=c_hidden,
|
102 |
+
_outgoing=_outgoing)
|
103 |
+
|
104 |
+
self.linear_a_p = Linear(self.c_z, self.c_hidden)
|
105 |
+
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
|
106 |
+
self.linear_b_p = Linear(self.c_z, self.c_hidden)
|
107 |
+
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
|
108 |
+
|
109 |
+
def forward(self,
|
110 |
+
z: torch.Tensor,
|
111 |
+
mask: Optional[torch.Tensor] = None,
|
112 |
+
inplace_safe: bool = False,
|
113 |
+
_add_with_inplace: bool = False,
|
114 |
+
) -> torch.Tensor:
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x:
|
118 |
+
[*, N_res, N_res, C_z] input tensor
|
119 |
+
mask:
|
120 |
+
[*, N_res, N_res] input mask
|
121 |
+
Returns:
|
122 |
+
[*, N_res, N_res, C_z] output tensor
|
123 |
+
"""
|
124 |
+
|
125 |
+
if mask is None:
|
126 |
+
mask = z.new_ones(z.shape[:-1])
|
127 |
+
|
128 |
+
mask = mask.unsqueeze(-1)
|
129 |
+
|
130 |
+
z = self.layer_norm_in(z)
|
131 |
+
a = mask
|
132 |
+
a = a * self.sigmoid(self.linear_a_g(z))
|
133 |
+
a = a * self.linear_a_p(z)
|
134 |
+
b = mask
|
135 |
+
b = b * self.sigmoid(self.linear_b_g(z))
|
136 |
+
b = b * self.linear_b_p(z)
|
137 |
+
|
138 |
+
# Prevents overflow of torch.matmul in combine projections in
|
139 |
+
# reduced-precision modes
|
140 |
+
a_std = a.std()
|
141 |
+
b_std = b.std()
|
142 |
+
if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
|
143 |
+
a = a / a.std()
|
144 |
+
b = b / b.std()
|
145 |
+
|
146 |
+
if(is_fp16_enabled()):
|
147 |
+
with torch.cuda.amp.autocast(enabled=False):
|
148 |
+
x = self._combine_projections(a.float(), b.float())
|
149 |
+
else:
|
150 |
+
x = self._combine_projections(a, b)
|
151 |
+
|
152 |
+
del a, b
|
153 |
+
x = self.layer_norm_out(x)
|
154 |
+
x = self.linear_z(x)
|
155 |
+
g = self.sigmoid(self.linear_g(z))
|
156 |
+
x = x * g
|
157 |
+
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
|
162 |
+
"""
|
163 |
+
Implements Algorithm 11.
|
164 |
+
"""
|
165 |
+
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=True)
|
166 |
+
|
167 |
+
|
168 |
+
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
|
169 |
+
"""
|
170 |
+
Implements Algorithm 12.
|
171 |
+
"""
|
172 |
+
__init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
|
173 |
+
|
dockformer/resources/__init__.py
ADDED
File without changes
|
dockformer/resources/stereo_chemical_props.txt
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Bond Residue Mean StdDev
|
2 |
+
CA-CB ALA 1.520 0.021
|
3 |
+
N-CA ALA 1.459 0.020
|
4 |
+
CA-C ALA 1.525 0.026
|
5 |
+
C-O ALA 1.229 0.019
|
6 |
+
CA-CB ARG 1.535 0.022
|
7 |
+
CB-CG ARG 1.521 0.027
|
8 |
+
CG-CD ARG 1.515 0.025
|
9 |
+
CD-NE ARG 1.460 0.017
|
10 |
+
NE-CZ ARG 1.326 0.013
|
11 |
+
CZ-NH1 ARG 1.326 0.013
|
12 |
+
CZ-NH2 ARG 1.326 0.013
|
13 |
+
N-CA ARG 1.459 0.020
|
14 |
+
CA-C ARG 1.525 0.026
|
15 |
+
C-O ARG 1.229 0.019
|
16 |
+
CA-CB ASN 1.527 0.026
|
17 |
+
CB-CG ASN 1.506 0.023
|
18 |
+
CG-OD1 ASN 1.235 0.022
|
19 |
+
CG-ND2 ASN 1.324 0.025
|
20 |
+
N-CA ASN 1.459 0.020
|
21 |
+
CA-C ASN 1.525 0.026
|
22 |
+
C-O ASN 1.229 0.019
|
23 |
+
CA-CB ASP 1.535 0.022
|
24 |
+
CB-CG ASP 1.513 0.021
|
25 |
+
CG-OD1 ASP 1.249 0.023
|
26 |
+
CG-OD2 ASP 1.249 0.023
|
27 |
+
N-CA ASP 1.459 0.020
|
28 |
+
CA-C ASP 1.525 0.026
|
29 |
+
C-O ASP 1.229 0.019
|
30 |
+
CA-CB CYS 1.526 0.013
|
31 |
+
CB-SG CYS 1.812 0.016
|
32 |
+
N-CA CYS 1.459 0.020
|
33 |
+
CA-C CYS 1.525 0.026
|
34 |
+
C-O CYS 1.229 0.019
|
35 |
+
CA-CB GLU 1.535 0.022
|
36 |
+
CB-CG GLU 1.517 0.019
|
37 |
+
CG-CD GLU 1.515 0.015
|
38 |
+
CD-OE1 GLU 1.252 0.011
|
39 |
+
CD-OE2 GLU 1.252 0.011
|
40 |
+
N-CA GLU 1.459 0.020
|
41 |
+
CA-C GLU 1.525 0.026
|
42 |
+
C-O GLU 1.229 0.019
|
43 |
+
CA-CB GLN 1.535 0.022
|
44 |
+
CB-CG GLN 1.521 0.027
|
45 |
+
CG-CD GLN 1.506 0.023
|
46 |
+
CD-OE1 GLN 1.235 0.022
|
47 |
+
CD-NE2 GLN 1.324 0.025
|
48 |
+
N-CA GLN 1.459 0.020
|
49 |
+
CA-C GLN 1.525 0.026
|
50 |
+
C-O GLN 1.229 0.019
|
51 |
+
N-CA GLY 1.456 0.015
|
52 |
+
CA-C GLY 1.514 0.016
|
53 |
+
C-O GLY 1.232 0.016
|
54 |
+
CA-CB HIS 1.535 0.022
|
55 |
+
CB-CG HIS 1.492 0.016
|
56 |
+
CG-ND1 HIS 1.369 0.015
|
57 |
+
CG-CD2 HIS 1.353 0.017
|
58 |
+
ND1-CE1 HIS 1.343 0.025
|
59 |
+
CD2-NE2 HIS 1.415 0.021
|
60 |
+
CE1-NE2 HIS 1.322 0.023
|
61 |
+
N-CA HIS 1.459 0.020
|
62 |
+
CA-C HIS 1.525 0.026
|
63 |
+
C-O HIS 1.229 0.019
|
64 |
+
CA-CB ILE 1.544 0.023
|
65 |
+
CB-CG1 ILE 1.536 0.028
|
66 |
+
CB-CG2 ILE 1.524 0.031
|
67 |
+
CG1-CD1 ILE 1.500 0.069
|
68 |
+
N-CA ILE 1.459 0.020
|
69 |
+
CA-C ILE 1.525 0.026
|
70 |
+
C-O ILE 1.229 0.019
|
71 |
+
CA-CB LEU 1.533 0.023
|
72 |
+
CB-CG LEU 1.521 0.029
|
73 |
+
CG-CD1 LEU 1.514 0.037
|
74 |
+
CG-CD2 LEU 1.514 0.037
|
75 |
+
N-CA LEU 1.459 0.020
|
76 |
+
CA-C LEU 1.525 0.026
|
77 |
+
C-O LEU 1.229 0.019
|
78 |
+
CA-CB LYS 1.535 0.022
|
79 |
+
CB-CG LYS 1.521 0.027
|
80 |
+
CG-CD LYS 1.520 0.034
|
81 |
+
CD-CE LYS 1.508 0.025
|
82 |
+
CE-NZ LYS 1.486 0.025
|
83 |
+
N-CA LYS 1.459 0.020
|
84 |
+
CA-C LYS 1.525 0.026
|
85 |
+
C-O LYS 1.229 0.019
|
86 |
+
CA-CB MET 1.535 0.022
|
87 |
+
CB-CG MET 1.509 0.032
|
88 |
+
CG-SD MET 1.807 0.026
|
89 |
+
SD-CE MET 1.774 0.056
|
90 |
+
N-CA MET 1.459 0.020
|
91 |
+
CA-C MET 1.525 0.026
|
92 |
+
C-O MET 1.229 0.019
|
93 |
+
CA-CB PHE 1.535 0.022
|
94 |
+
CB-CG PHE 1.509 0.017
|
95 |
+
CG-CD1 PHE 1.383 0.015
|
96 |
+
CG-CD2 PHE 1.383 0.015
|
97 |
+
CD1-CE1 PHE 1.388 0.020
|
98 |
+
CD2-CE2 PHE 1.388 0.020
|
99 |
+
CE1-CZ PHE 1.369 0.019
|
100 |
+
CE2-CZ PHE 1.369 0.019
|
101 |
+
N-CA PHE 1.459 0.020
|
102 |
+
CA-C PHE 1.525 0.026
|
103 |
+
C-O PHE 1.229 0.019
|
104 |
+
CA-CB PRO 1.531 0.020
|
105 |
+
CB-CG PRO 1.495 0.050
|
106 |
+
CG-CD PRO 1.502 0.033
|
107 |
+
CD-N PRO 1.474 0.014
|
108 |
+
N-CA PRO 1.468 0.017
|
109 |
+
CA-C PRO 1.524 0.020
|
110 |
+
C-O PRO 1.228 0.020
|
111 |
+
CA-CB SER 1.525 0.015
|
112 |
+
CB-OG SER 1.418 0.013
|
113 |
+
N-CA SER 1.459 0.020
|
114 |
+
CA-C SER 1.525 0.026
|
115 |
+
C-O SER 1.229 0.019
|
116 |
+
CA-CB THR 1.529 0.026
|
117 |
+
CB-OG1 THR 1.428 0.020
|
118 |
+
CB-CG2 THR 1.519 0.033
|
119 |
+
N-CA THR 1.459 0.020
|
120 |
+
CA-C THR 1.525 0.026
|
121 |
+
C-O THR 1.229 0.019
|
122 |
+
CA-CB TRP 1.535 0.022
|
123 |
+
CB-CG TRP 1.498 0.018
|
124 |
+
CG-CD1 TRP 1.363 0.014
|
125 |
+
CG-CD2 TRP 1.432 0.017
|
126 |
+
CD1-NE1 TRP 1.375 0.017
|
127 |
+
NE1-CE2 TRP 1.371 0.013
|
128 |
+
CD2-CE2 TRP 1.409 0.012
|
129 |
+
CD2-CE3 TRP 1.399 0.015
|
130 |
+
CE2-CZ2 TRP 1.393 0.017
|
131 |
+
CE3-CZ3 TRP 1.380 0.017
|
132 |
+
CZ2-CH2 TRP 1.369 0.019
|
133 |
+
CZ3-CH2 TRP 1.396 0.016
|
134 |
+
N-CA TRP 1.459 0.020
|
135 |
+
CA-C TRP 1.525 0.026
|
136 |
+
C-O TRP 1.229 0.019
|
137 |
+
CA-CB TYR 1.535 0.022
|
138 |
+
CB-CG TYR 1.512 0.015
|
139 |
+
CG-CD1 TYR 1.387 0.013
|
140 |
+
CG-CD2 TYR 1.387 0.013
|
141 |
+
CD1-CE1 TYR 1.389 0.015
|
142 |
+
CD2-CE2 TYR 1.389 0.015
|
143 |
+
CE1-CZ TYR 1.381 0.013
|
144 |
+
CE2-CZ TYR 1.381 0.013
|
145 |
+
CZ-OH TYR 1.374 0.017
|
146 |
+
N-CA TYR 1.459 0.020
|
147 |
+
CA-C TYR 1.525 0.026
|
148 |
+
C-O TYR 1.229 0.019
|
149 |
+
CA-CB VAL 1.543 0.021
|
150 |
+
CB-CG1 VAL 1.524 0.021
|
151 |
+
CB-CG2 VAL 1.524 0.021
|
152 |
+
N-CA VAL 1.459 0.020
|
153 |
+
CA-C VAL 1.525 0.026
|
154 |
+
C-O VAL 1.229 0.019
|
155 |
+
-
|
156 |
+
|
157 |
+
Angle Residue Mean StdDev
|
158 |
+
N-CA-CB ALA 110.1 1.4
|
159 |
+
CB-CA-C ALA 110.1 1.5
|
160 |
+
N-CA-C ALA 111.0 2.7
|
161 |
+
CA-C-O ALA 120.1 2.1
|
162 |
+
N-CA-CB ARG 110.6 1.8
|
163 |
+
CB-CA-C ARG 110.4 2.0
|
164 |
+
CA-CB-CG ARG 113.4 2.2
|
165 |
+
CB-CG-CD ARG 111.6 2.6
|
166 |
+
CG-CD-NE ARG 111.8 2.1
|
167 |
+
CD-NE-CZ ARG 123.6 1.4
|
168 |
+
NE-CZ-NH1 ARG 120.3 0.5
|
169 |
+
NE-CZ-NH2 ARG 120.3 0.5
|
170 |
+
NH1-CZ-NH2 ARG 119.4 1.1
|
171 |
+
N-CA-C ARG 111.0 2.7
|
172 |
+
CA-C-O ARG 120.1 2.1
|
173 |
+
N-CA-CB ASN 110.6 1.8
|
174 |
+
CB-CA-C ASN 110.4 2.0
|
175 |
+
CA-CB-CG ASN 113.4 2.2
|
176 |
+
CB-CG-ND2 ASN 116.7 2.4
|
177 |
+
CB-CG-OD1 ASN 121.6 2.0
|
178 |
+
ND2-CG-OD1 ASN 121.9 2.3
|
179 |
+
N-CA-C ASN 111.0 2.7
|
180 |
+
CA-C-O ASN 120.1 2.1
|
181 |
+
N-CA-CB ASP 110.6 1.8
|
182 |
+
CB-CA-C ASP 110.4 2.0
|
183 |
+
CA-CB-CG ASP 113.4 2.2
|
184 |
+
CB-CG-OD1 ASP 118.3 0.9
|
185 |
+
CB-CG-OD2 ASP 118.3 0.9
|
186 |
+
OD1-CG-OD2 ASP 123.3 1.9
|
187 |
+
N-CA-C ASP 111.0 2.7
|
188 |
+
CA-C-O ASP 120.1 2.1
|
189 |
+
N-CA-CB CYS 110.8 1.5
|
190 |
+
CB-CA-C CYS 111.5 1.2
|
191 |
+
CA-CB-SG CYS 114.2 1.1
|
192 |
+
N-CA-C CYS 111.0 2.7
|
193 |
+
CA-C-O CYS 120.1 2.1
|
194 |
+
N-CA-CB GLU 110.6 1.8
|
195 |
+
CB-CA-C GLU 110.4 2.0
|
196 |
+
CA-CB-CG GLU 113.4 2.2
|
197 |
+
CB-CG-CD GLU 114.2 2.7
|
198 |
+
CG-CD-OE1 GLU 118.3 2.0
|
199 |
+
CG-CD-OE2 GLU 118.3 2.0
|
200 |
+
OE1-CD-OE2 GLU 123.3 1.2
|
201 |
+
N-CA-C GLU 111.0 2.7
|
202 |
+
CA-C-O GLU 120.1 2.1
|
203 |
+
N-CA-CB GLN 110.6 1.8
|
204 |
+
CB-CA-C GLN 110.4 2.0
|
205 |
+
CA-CB-CG GLN 113.4 2.2
|
206 |
+
CB-CG-CD GLN 111.6 2.6
|
207 |
+
CG-CD-OE1 GLN 121.6 2.0
|
208 |
+
CG-CD-NE2 GLN 116.7 2.4
|
209 |
+
OE1-CD-NE2 GLN 121.9 2.3
|
210 |
+
N-CA-C GLN 111.0 2.7
|
211 |
+
CA-C-O GLN 120.1 2.1
|
212 |
+
N-CA-C GLY 113.1 2.5
|
213 |
+
CA-C-O GLY 120.6 1.8
|
214 |
+
N-CA-CB HIS 110.6 1.8
|
215 |
+
CB-CA-C HIS 110.4 2.0
|
216 |
+
CA-CB-CG HIS 113.6 1.7
|
217 |
+
CB-CG-ND1 HIS 123.2 2.5
|
218 |
+
CB-CG-CD2 HIS 130.8 3.1
|
219 |
+
CG-ND1-CE1 HIS 108.2 1.4
|
220 |
+
ND1-CE1-NE2 HIS 109.9 2.2
|
221 |
+
CE1-NE2-CD2 HIS 106.6 2.5
|
222 |
+
NE2-CD2-CG HIS 109.2 1.9
|
223 |
+
CD2-CG-ND1 HIS 106.0 1.4
|
224 |
+
N-CA-C HIS 111.0 2.7
|
225 |
+
CA-C-O HIS 120.1 2.1
|
226 |
+
N-CA-CB ILE 110.8 2.3
|
227 |
+
CB-CA-C ILE 111.6 2.0
|
228 |
+
CA-CB-CG1 ILE 111.0 1.9
|
229 |
+
CB-CG1-CD1 ILE 113.9 2.8
|
230 |
+
CA-CB-CG2 ILE 110.9 2.0
|
231 |
+
CG1-CB-CG2 ILE 111.4 2.2
|
232 |
+
N-CA-C ILE 111.0 2.7
|
233 |
+
CA-C-O ILE 120.1 2.1
|
234 |
+
N-CA-CB LEU 110.4 2.0
|
235 |
+
CB-CA-C LEU 110.2 1.9
|
236 |
+
CA-CB-CG LEU 115.3 2.3
|
237 |
+
CB-CG-CD1 LEU 111.0 1.7
|
238 |
+
CB-CG-CD2 LEU 111.0 1.7
|
239 |
+
CD1-CG-CD2 LEU 110.5 3.0
|
240 |
+
N-CA-C LEU 111.0 2.7
|
241 |
+
CA-C-O LEU 120.1 2.1
|
242 |
+
N-CA-CB LYS 110.6 1.8
|
243 |
+
CB-CA-C LYS 110.4 2.0
|
244 |
+
CA-CB-CG LYS 113.4 2.2
|
245 |
+
CB-CG-CD LYS 111.6 2.6
|
246 |
+
CG-CD-CE LYS 111.9 3.0
|
247 |
+
CD-CE-NZ LYS 111.7 2.3
|
248 |
+
N-CA-C LYS 111.0 2.7
|
249 |
+
CA-C-O LYS 120.1 2.1
|
250 |
+
N-CA-CB MET 110.6 1.8
|
251 |
+
CB-CA-C MET 110.4 2.0
|
252 |
+
CA-CB-CG MET 113.3 1.7
|
253 |
+
CB-CG-SD MET 112.4 3.0
|
254 |
+
CG-SD-CE MET 100.2 1.6
|
255 |
+
N-CA-C MET 111.0 2.7
|
256 |
+
CA-C-O MET 120.1 2.1
|
257 |
+
N-CA-CB PHE 110.6 1.8
|
258 |
+
CB-CA-C PHE 110.4 2.0
|
259 |
+
CA-CB-CG PHE 113.9 2.4
|
260 |
+
CB-CG-CD1 PHE 120.8 0.7
|
261 |
+
CB-CG-CD2 PHE 120.8 0.7
|
262 |
+
CD1-CG-CD2 PHE 118.3 1.3
|
263 |
+
CG-CD1-CE1 PHE 120.8 1.1
|
264 |
+
CG-CD2-CE2 PHE 120.8 1.1
|
265 |
+
CD1-CE1-CZ PHE 120.1 1.2
|
266 |
+
CD2-CE2-CZ PHE 120.1 1.2
|
267 |
+
CE1-CZ-CE2 PHE 120.0 1.8
|
268 |
+
N-CA-C PHE 111.0 2.7
|
269 |
+
CA-C-O PHE 120.1 2.1
|
270 |
+
N-CA-CB PRO 103.3 1.2
|
271 |
+
CB-CA-C PRO 111.7 2.1
|
272 |
+
CA-CB-CG PRO 104.8 1.9
|
273 |
+
CB-CG-CD PRO 106.5 3.9
|
274 |
+
CG-CD-N PRO 103.2 1.5
|
275 |
+
CA-N-CD PRO 111.7 1.4
|
276 |
+
N-CA-C PRO 112.1 2.6
|
277 |
+
CA-C-O PRO 120.2 2.4
|
278 |
+
N-CA-CB SER 110.5 1.5
|
279 |
+
CB-CA-C SER 110.1 1.9
|
280 |
+
CA-CB-OG SER 111.2 2.7
|
281 |
+
N-CA-C SER 111.0 2.7
|
282 |
+
CA-C-O SER 120.1 2.1
|
283 |
+
N-CA-CB THR 110.3 1.9
|
284 |
+
CB-CA-C THR 111.6 2.7
|
285 |
+
CA-CB-OG1 THR 109.0 2.1
|
286 |
+
CA-CB-CG2 THR 112.4 1.4
|
287 |
+
OG1-CB-CG2 THR 110.0 2.3
|
288 |
+
N-CA-C THR 111.0 2.7
|
289 |
+
CA-C-O THR 120.1 2.1
|
290 |
+
N-CA-CB TRP 110.6 1.8
|
291 |
+
CB-CA-C TRP 110.4 2.0
|
292 |
+
CA-CB-CG TRP 113.7 1.9
|
293 |
+
CB-CG-CD1 TRP 127.0 1.3
|
294 |
+
CB-CG-CD2 TRP 126.6 1.3
|
295 |
+
CD1-CG-CD2 TRP 106.3 0.8
|
296 |
+
CG-CD1-NE1 TRP 110.1 1.0
|
297 |
+
CD1-NE1-CE2 TRP 109.0 0.9
|
298 |
+
NE1-CE2-CD2 TRP 107.3 1.0
|
299 |
+
CE2-CD2-CG TRP 107.3 0.8
|
300 |
+
CG-CD2-CE3 TRP 133.9 0.9
|
301 |
+
NE1-CE2-CZ2 TRP 130.4 1.1
|
302 |
+
CE3-CD2-CE2 TRP 118.7 1.2
|
303 |
+
CD2-CE2-CZ2 TRP 122.3 1.2
|
304 |
+
CE2-CZ2-CH2 TRP 117.4 1.0
|
305 |
+
CZ2-CH2-CZ3 TRP 121.6 1.2
|
306 |
+
CH2-CZ3-CE3 TRP 121.2 1.1
|
307 |
+
CZ3-CE3-CD2 TRP 118.8 1.3
|
308 |
+
N-CA-C TRP 111.0 2.7
|
309 |
+
CA-C-O TRP 120.1 2.1
|
310 |
+
N-CA-CB TYR 110.6 1.8
|
311 |
+
CB-CA-C TYR 110.4 2.0
|
312 |
+
CA-CB-CG TYR 113.4 1.9
|
313 |
+
CB-CG-CD1 TYR 121.0 0.6
|
314 |
+
CB-CG-CD2 TYR 121.0 0.6
|
315 |
+
CD1-CG-CD2 TYR 117.9 1.1
|
316 |
+
CG-CD1-CE1 TYR 121.3 0.8
|
317 |
+
CG-CD2-CE2 TYR 121.3 0.8
|
318 |
+
CD1-CE1-CZ TYR 119.8 0.9
|
319 |
+
CD2-CE2-CZ TYR 119.8 0.9
|
320 |
+
CE1-CZ-CE2 TYR 119.8 1.6
|
321 |
+
CE1-CZ-OH TYR 120.1 2.7
|
322 |
+
CE2-CZ-OH TYR 120.1 2.7
|
323 |
+
N-CA-C TYR 111.0 2.7
|
324 |
+
CA-C-O TYR 120.1 2.1
|
325 |
+
N-CA-CB VAL 111.5 2.2
|
326 |
+
CB-CA-C VAL 111.4 1.9
|
327 |
+
CA-CB-CG1 VAL 110.9 1.5
|
328 |
+
CA-CB-CG2 VAL 110.9 1.5
|
329 |
+
CG1-CB-CG2 VAL 110.9 1.6
|
330 |
+
N-CA-C VAL 111.0 2.7
|
331 |
+
CA-C-O VAL 120.1 2.1
|
332 |
+
-
|
333 |
+
|
334 |
+
Non-bonded distance Minimum Dist Tolerance
|
335 |
+
C-C 3.4 1.5
|
336 |
+
C-N 3.25 1.5
|
337 |
+
C-S 3.5 1.5
|
338 |
+
C-O 3.22 1.5
|
339 |
+
N-N 3.1 1.5
|
340 |
+
N-S 3.35 1.5
|
341 |
+
N-O 3.07 1.5
|
342 |
+
O-S 3.32 1.5
|
343 |
+
O-O 3.04 1.5
|
344 |
+
S-S 2.03 1.0
|
345 |
+
-
|
dockformer/utils/__init__.py
ADDED
File without changes
|
dockformer/utils/callbacks.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lightning.pytorch.callbacks import EarlyStopping
|
2 |
+
from lightning_utilities.core.rank_zero import rank_zero_info
|
3 |
+
|
4 |
+
|
5 |
+
class EarlyStoppingVerbose(EarlyStopping):
|
6 |
+
"""
|
7 |
+
The default EarlyStopping callback's verbose mode is too verbose.
|
8 |
+
This class outputs a message only when it's getting ready to stop.
|
9 |
+
"""
|
10 |
+
def _evalute_stopping_criteria(self, *args, **kwargs):
|
11 |
+
should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs)
|
12 |
+
if(should_stop):
|
13 |
+
rank_zero_info(f"{reason}\n")
|
14 |
+
|
15 |
+
return should_stop, reason
|
dockformer/utils/checkpointing.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import importlib
|
15 |
+
from typing import Any, Tuple, List, Callable, Optional
|
16 |
+
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
|
22 |
+
BLOCK_ARG = Any
|
23 |
+
BLOCK_ARGS = List[BLOCK_ARG]
|
24 |
+
|
25 |
+
|
26 |
+
@torch.jit.ignore
|
27 |
+
def checkpoint_blocks(
|
28 |
+
blocks: List[Callable],
|
29 |
+
args: BLOCK_ARGS,
|
30 |
+
blocks_per_ckpt: Optional[int],
|
31 |
+
) -> BLOCK_ARGS:
|
32 |
+
"""
|
33 |
+
Chunk a list of blocks and run each chunk with activation
|
34 |
+
checkpointing. We define a "block" as a callable whose only inputs are
|
35 |
+
the outputs of the previous block.
|
36 |
+
|
37 |
+
Implements Subsection 1.11.8
|
38 |
+
|
39 |
+
Args:
|
40 |
+
blocks:
|
41 |
+
List of blocks
|
42 |
+
args:
|
43 |
+
Tuple of arguments for the first block.
|
44 |
+
blocks_per_ckpt:
|
45 |
+
Size of each chunk. A higher value corresponds to fewer
|
46 |
+
checkpoints, and trades memory for speed. If None, no checkpointing
|
47 |
+
is performed.
|
48 |
+
Returns:
|
49 |
+
The output of the final block
|
50 |
+
"""
|
51 |
+
def wrap(a):
|
52 |
+
return (a,) if type(a) is not tuple else a
|
53 |
+
|
54 |
+
def exec(b, a):
|
55 |
+
for block in b:
|
56 |
+
a = wrap(block(*a))
|
57 |
+
return a
|
58 |
+
|
59 |
+
def chunker(s, e):
|
60 |
+
def exec_sliced(*a):
|
61 |
+
return exec(blocks[s:e], a)
|
62 |
+
|
63 |
+
return exec_sliced
|
64 |
+
|
65 |
+
# Avoids mishaps when the blocks take just one argument
|
66 |
+
args = wrap(args)
|
67 |
+
|
68 |
+
if blocks_per_ckpt is None or not torch.is_grad_enabled():
|
69 |
+
return exec(blocks, args)
|
70 |
+
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
|
71 |
+
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
|
72 |
+
|
73 |
+
for s in range(0, len(blocks), blocks_per_ckpt):
|
74 |
+
e = s + blocks_per_ckpt
|
75 |
+
args = torch.utils.checkpoint.checkpoint(chunker(s, e), *args, use_reentrant=True)
|
76 |
+
args = wrap(args)
|
77 |
+
|
78 |
+
return args
|
dockformer/utils/config_tools.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
|
3 |
+
import ml_collections as mlc
|
4 |
+
|
5 |
+
|
6 |
+
def set_inf(c, inf):
|
7 |
+
for k, v in c.items():
|
8 |
+
if isinstance(v, mlc.ConfigDict):
|
9 |
+
set_inf(v, inf)
|
10 |
+
elif k == "inf":
|
11 |
+
c[k] = inf
|
12 |
+
|
13 |
+
|
14 |
+
def enforce_config_constraints(config):
|
15 |
+
def string_to_setting(s):
|
16 |
+
path = s.split('.')
|
17 |
+
setting = config
|
18 |
+
for p in path:
|
19 |
+
setting = setting.get(p)
|
20 |
+
|
21 |
+
return setting
|
22 |
+
|
23 |
+
mutually_exclusive_bools = [
|
24 |
+
(
|
25 |
+
"globals.use_lma",
|
26 |
+
),
|
27 |
+
]
|
28 |
+
|
29 |
+
for options in mutually_exclusive_bools:
|
30 |
+
option_settings = [string_to_setting(o) for o in options]
|
31 |
+
if sum(option_settings) > 1:
|
32 |
+
raise ValueError(f"Only one of {', '.join(options)} may be set at a time")
|
dockformer/utils/consts.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rdkit.Chem.rdchem import ChiralType, BondType
|
2 |
+
|
3 |
+
# Survey of atom types in the PDBBind
|
4 |
+
# {'C': 403253, 'O': 101283, 'N': 81325, 'S': 6262, 'F': 5256, 'P': 3378, 'Cl': 2920, 'Br': 552, 'B': 237, 'I': 185,
|
5 |
+
# 'H': 181, 'Fe': 19, 'Se': 15, 'Ru': 10, 'Si': 5, 'Co': 4, 'Ir': 4, 'As': 2, 'Pt': 2, 'V': 1, 'Mg': 1, 'Be': 1,
|
6 |
+
# 'Rh': 1, 'Cu': 1, 'Re': 1}
|
7 |
+
# I have changed the uncommon types to common ions for the plinder dataset
|
8 |
+
# {'As': "Zn", 'Pt': "Mn", 'V': "Ca", 'Mg': "Mg", 'Be': "Na", 'Rh': "Al", 'Cu': "K", 'Re': "Ni"}
|
9 |
+
|
10 |
+
POSSIBLE_ATOM_TYPES = ['C', 'O', 'N', 'S', 'F', 'P', 'Cl', 'Br', 'B', 'I', 'H', 'Fe', 'Se', 'Ru', 'Si', 'Co', 'Ir',
|
11 |
+
'Zn', 'Mn', 'Ca', 'Mg', 'Na', 'Al', 'K', 'Ni']
|
12 |
+
|
13 |
+
# bonds Counter({BondType.SINGLE: 366857, BondType.AROMATIC: 214238, BondType.DOUBLE: 59725, BondType.TRIPLE: 866,
|
14 |
+
# BondType.UNSPECIFIED: 18, BondType.DATIVE: 8})
|
15 |
+
POSSIBLE_BOND_TYPES = [BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE, BondType.AROMATIC, BondType.UNSPECIFIED,
|
16 |
+
BondType.DATIVE]
|
17 |
+
|
18 |
+
# {0: 580061, 1: 13273, -1: 11473, 2: 44, 7: 17, -2: 8, 9: 7, 10: 7, 5: 3, 3: 3, 4: 1, 6: 1, 8: 1}
|
19 |
+
POSSIBLE_CHARGES = [-1, 0, 1]
|
20 |
+
|
21 |
+
# {ChiralType.CHI_UNSPECIFIED: 551374, ChiralType.CHI_TETRAHEDRAL_CCW: 27328, ChiralType.CHI_TETRAHEDRAL_CW: 26178,
|
22 |
+
# ChiralType.CHI_OCTAHEDRAL: 13, ChiralType.CHI_SQUAREPLANAR: 3, ChiralType.CHI_TRIGONALBIPYRAMIDAL: 3}
|
23 |
+
POSSIBLE_CHIRALITIES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CCW, ChiralType.CHI_TETRAHEDRAL_CW,
|
24 |
+
ChiralType.CHI_OCTAHEDRAL, ChiralType.CHI_SQUAREPLANAR, ChiralType.CHI_TRIGONALBIPYRAMIDAL]
|
25 |
+
|
dockformer/utils/exponential_moving_average.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from dockformer.utils.tensor_utils import tensor_tree_map
|
7 |
+
|
8 |
+
|
9 |
+
class ExponentialMovingAverage:
|
10 |
+
"""
|
11 |
+
Maintains moving averages of parameters with exponential decay
|
12 |
+
|
13 |
+
At each step, the stored copy `copy` of each parameter `param` is
|
14 |
+
updated as follows:
|
15 |
+
|
16 |
+
`copy = decay * copy + (1 - decay) * param`
|
17 |
+
|
18 |
+
where `decay` is an attribute of the ExponentialMovingAverage object.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, model: nn.Module, decay: float):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
model:
|
25 |
+
A torch.nn.Module whose parameters are to be tracked
|
26 |
+
decay:
|
27 |
+
A value (usually close to 1.) by which updates are
|
28 |
+
weighted as part of the above formula
|
29 |
+
"""
|
30 |
+
super(ExponentialMovingAverage, self).__init__()
|
31 |
+
|
32 |
+
clone_param = lambda t: t.clone().detach()
|
33 |
+
self.params = tensor_tree_map(clone_param, model.state_dict())
|
34 |
+
self.decay = decay
|
35 |
+
self.device = next(model.parameters()).device
|
36 |
+
|
37 |
+
def to(self, device):
|
38 |
+
self.params = tensor_tree_map(lambda t: t.to(device), self.params)
|
39 |
+
self.device = device
|
40 |
+
|
41 |
+
def _update_state_dict_(self, update, state_dict):
|
42 |
+
with torch.no_grad():
|
43 |
+
for k, v in update.items():
|
44 |
+
stored = state_dict[k]
|
45 |
+
if not isinstance(v, torch.Tensor):
|
46 |
+
self._update_state_dict_(v, stored)
|
47 |
+
else:
|
48 |
+
diff = stored - v
|
49 |
+
diff *= 1 - self.decay
|
50 |
+
stored -= diff
|
51 |
+
|
52 |
+
def update(self, model: torch.nn.Module) -> None:
|
53 |
+
"""
|
54 |
+
Updates the stored parameters using the state dict of the provided
|
55 |
+
module. The module should have the same structure as that used to
|
56 |
+
initialize the ExponentialMovingAverage object.
|
57 |
+
"""
|
58 |
+
self._update_state_dict_(model.state_dict(), self.params)
|
59 |
+
|
60 |
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
61 |
+
for k in state_dict["params"].keys():
|
62 |
+
self.params[k] = state_dict["params"][k].clone()
|
63 |
+
self.decay = state_dict["decay"]
|
64 |
+
|
65 |
+
def state_dict(self) -> OrderedDict:
|
66 |
+
return OrderedDict(
|
67 |
+
{
|
68 |
+
"params": self.params,
|
69 |
+
"decay": self.decay,
|
70 |
+
}
|
71 |
+
)
|
dockformer/utils/feats.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import math
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from typing import Dict, Union
|
22 |
+
|
23 |
+
from dockformer.utils import protein
|
24 |
+
import dockformer.utils.residue_constants as rc
|
25 |
+
from dockformer.utils.geometry import rigid_matrix_vector, rotation_matrix, vector
|
26 |
+
from dockformer.utils.rigid_utils import Rotation, Rigid
|
27 |
+
from dockformer.utils.tensor_utils import (
|
28 |
+
batched_gather,
|
29 |
+
one_hot,
|
30 |
+
tree_map,
|
31 |
+
tensor_tree_map,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
|
36 |
+
# rc.restype_order["X"] defines a ligand, and the atom position used is the CA
|
37 |
+
is_gly_or_lig = (aatype == rc.restype_order["G"]) | (aatype == rc.restype_order["Z"])
|
38 |
+
ca_idx = rc.atom_order["CA"]
|
39 |
+
cb_idx = rc.atom_order["CB"]
|
40 |
+
pseudo_beta = torch.where(
|
41 |
+
is_gly_or_lig[..., None].expand(*((-1,) * len(is_gly_or_lig.shape)), 3),
|
42 |
+
all_atom_positions[..., ca_idx, :],
|
43 |
+
all_atom_positions[..., cb_idx, :],
|
44 |
+
)
|
45 |
+
|
46 |
+
if all_atom_masks is not None:
|
47 |
+
pseudo_beta_mask = torch.where(
|
48 |
+
is_gly_or_lig,
|
49 |
+
all_atom_masks[..., ca_idx],
|
50 |
+
all_atom_masks[..., cb_idx],
|
51 |
+
)
|
52 |
+
return pseudo_beta, pseudo_beta_mask
|
53 |
+
else:
|
54 |
+
return pseudo_beta
|
55 |
+
|
56 |
+
|
57 |
+
def atom14_to_atom37(atom14, batch):
|
58 |
+
atom37_data = batched_gather(
|
59 |
+
atom14,
|
60 |
+
batch["residx_atom37_to_atom14"],
|
61 |
+
dim=-2,
|
62 |
+
no_batch_dims=len(atom14.shape[:-2]),
|
63 |
+
)
|
64 |
+
|
65 |
+
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
|
66 |
+
|
67 |
+
return atom37_data
|
68 |
+
|
69 |
+
|
70 |
+
def torsion_angles_to_frames(
|
71 |
+
r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
|
72 |
+
alpha: torch.Tensor,
|
73 |
+
aatype: torch.Tensor,
|
74 |
+
rrgdf: torch.Tensor,
|
75 |
+
):
|
76 |
+
|
77 |
+
rigid_type = type(r)
|
78 |
+
|
79 |
+
# [*, N, 8, 4, 4]
|
80 |
+
default_4x4 = rrgdf[aatype, ...]
|
81 |
+
|
82 |
+
# [*, N, 8] transformations, i.e.
|
83 |
+
# One [*, N, 8, 3, 3] rotation matrix and
|
84 |
+
# One [*, N, 8, 3] translation matrix
|
85 |
+
default_r = rigid_type.from_tensor_4x4(default_4x4)
|
86 |
+
|
87 |
+
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
|
88 |
+
bb_rot[..., 1] = 1
|
89 |
+
|
90 |
+
# [*, N, 8, 2]
|
91 |
+
alpha = torch.cat(
|
92 |
+
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
|
93 |
+
)
|
94 |
+
|
95 |
+
# [*, N, 8, 3, 3]
|
96 |
+
# Produces rotation matrices of the form:
|
97 |
+
# [
|
98 |
+
# [1, 0 , 0 ],
|
99 |
+
# [0, a_2,-a_1],
|
100 |
+
# [0, a_1, a_2]
|
101 |
+
# ]
|
102 |
+
# This follows the original code rather than the supplement, which uses
|
103 |
+
# different indices.
|
104 |
+
|
105 |
+
all_rots = alpha.new_zeros(default_r.shape + (4, 4))
|
106 |
+
all_rots[..., 0, 0] = 1
|
107 |
+
all_rots[..., 1, 1] = alpha[..., 1]
|
108 |
+
all_rots[..., 1, 2] = -alpha[..., 0]
|
109 |
+
all_rots[..., 2, 1:3] = alpha
|
110 |
+
|
111 |
+
all_rots = rigid_type.from_tensor_4x4(all_rots)
|
112 |
+
all_frames = default_r.compose(all_rots)
|
113 |
+
|
114 |
+
chi2_frame_to_frame = all_frames[..., 5]
|
115 |
+
chi3_frame_to_frame = all_frames[..., 6]
|
116 |
+
chi4_frame_to_frame = all_frames[..., 7]
|
117 |
+
|
118 |
+
chi1_frame_to_bb = all_frames[..., 4]
|
119 |
+
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
|
120 |
+
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
|
121 |
+
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
|
122 |
+
|
123 |
+
all_frames_to_bb = rigid_type.cat(
|
124 |
+
[
|
125 |
+
all_frames[..., :5],
|
126 |
+
chi2_frame_to_bb.unsqueeze(-1),
|
127 |
+
chi3_frame_to_bb.unsqueeze(-1),
|
128 |
+
chi4_frame_to_bb.unsqueeze(-1),
|
129 |
+
],
|
130 |
+
dim=-1,
|
131 |
+
)
|
132 |
+
|
133 |
+
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
|
134 |
+
|
135 |
+
return all_frames_to_global
|
136 |
+
|
137 |
+
|
138 |
+
def frames_and_literature_positions_to_atom14_pos(
|
139 |
+
r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
|
140 |
+
aatype: torch.Tensor,
|
141 |
+
default_frames,
|
142 |
+
group_idx,
|
143 |
+
atom_mask,
|
144 |
+
lit_positions,
|
145 |
+
):
|
146 |
+
# [*, N, 14, 4, 4]
|
147 |
+
default_4x4 = default_frames[aatype, ...]
|
148 |
+
|
149 |
+
# [*, N, 14]
|
150 |
+
group_mask = group_idx[aatype, ...]
|
151 |
+
|
152 |
+
# [*, N, 14, 8]
|
153 |
+
group_mask = nn.functional.one_hot(
|
154 |
+
group_mask,
|
155 |
+
num_classes=default_frames.shape[-3],
|
156 |
+
)
|
157 |
+
|
158 |
+
# [*, N, 14, 8]
|
159 |
+
t_atoms_to_global = r[..., None, :] * group_mask
|
160 |
+
|
161 |
+
# [*, N, 14]
|
162 |
+
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
|
163 |
+
lambda x: torch.sum(x, dim=-1)
|
164 |
+
)
|
165 |
+
|
166 |
+
# [*, N, 14]
|
167 |
+
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
|
168 |
+
|
169 |
+
# [*, N, 14, 3]
|
170 |
+
lit_positions = lit_positions[aatype, ...]
|
171 |
+
pred_positions = t_atoms_to_global.apply(lit_positions)
|
172 |
+
pred_positions = pred_positions * atom_mask
|
173 |
+
|
174 |
+
return pred_positions
|
dockformer/utils/geometry/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Geometry Module."""
|
15 |
+
|
16 |
+
from dockformer.utils.geometry import rigid_matrix_vector
|
17 |
+
from dockformer.utils.geometry import rotation_matrix
|
18 |
+
from dockformer.utils.geometry import vector
|
19 |
+
|
20 |
+
Rot3Array = rotation_matrix.Rot3Array
|
21 |
+
Rigid3Array = rigid_matrix_vector.Rigid3Array
|
22 |
+
|
23 |
+
Vec3Array = vector.Vec3Array
|
24 |
+
square_euclidean_distance = vector.square_euclidean_distance
|
25 |
+
euclidean_distance = vector.euclidean_distance
|
26 |
+
dihedral_angle = vector.dihedral_angle
|
27 |
+
dot = vector.dot
|
28 |
+
cross = vector.cross
|
dockformer/utils/geometry/quat_rigid.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from dockformer.model.primitives import Linear
|
5 |
+
from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array
|
6 |
+
from dockformer.utils.geometry.rotation_matrix import Rot3Array
|
7 |
+
from dockformer.utils.geometry.vector import Vec3Array
|
8 |
+
|
9 |
+
|
10 |
+
class QuatRigid(nn.Module):
|
11 |
+
def __init__(self, c_hidden, full_quat):
|
12 |
+
super().__init__()
|
13 |
+
self.full_quat = full_quat
|
14 |
+
if self.full_quat:
|
15 |
+
rigid_dim = 7
|
16 |
+
else:
|
17 |
+
rigid_dim = 6
|
18 |
+
|
19 |
+
self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32)
|
20 |
+
|
21 |
+
def forward(self, activations: torch.Tensor) -> Rigid3Array:
|
22 |
+
# NOTE: During training, this needs to be run in higher precision
|
23 |
+
rigid_flat = self.linear(activations)
|
24 |
+
|
25 |
+
rigid_flat = torch.unbind(rigid_flat, dim=-1)
|
26 |
+
if(self.full_quat):
|
27 |
+
qw, qx, qy, qz = rigid_flat[:4]
|
28 |
+
translation = rigid_flat[4:]
|
29 |
+
else:
|
30 |
+
qx, qy, qz = rigid_flat[:3]
|
31 |
+
qw = torch.ones_like(qx)
|
32 |
+
translation = rigid_flat[3:]
|
33 |
+
|
34 |
+
rotation = Rot3Array.from_quaternion(
|
35 |
+
qw, qx, qy, qz, normalize=True,
|
36 |
+
)
|
37 |
+
translation = Vec3Array(*translation)
|
38 |
+
return Rigid3Array(rotation, translation)
|
dockformer/utils/geometry/rigid_matrix_vector.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Rigid3Array Transformations represented by a Matrix and a Vector."""
|
15 |
+
|
16 |
+
from __future__ import annotations
|
17 |
+
import dataclasses
|
18 |
+
from typing import Union, List
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from dockformer.utils.geometry import rotation_matrix
|
23 |
+
from dockformer.utils.geometry import vector
|
24 |
+
|
25 |
+
|
26 |
+
Float = Union[float, torch.Tensor]
|
27 |
+
|
28 |
+
|
29 |
+
@dataclasses.dataclass(frozen=True)
|
30 |
+
class Rigid3Array:
|
31 |
+
"""Rigid Transformation, i.e. element of special euclidean group."""
|
32 |
+
|
33 |
+
rotation: rotation_matrix.Rot3Array
|
34 |
+
translation: vector.Vec3Array
|
35 |
+
|
36 |
+
def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
|
37 |
+
new_rotation = self.rotation @ other.rotation # __matmul__
|
38 |
+
new_translation = self.apply_to_point(other.translation)
|
39 |
+
return Rigid3Array(new_rotation, new_translation)
|
40 |
+
|
41 |
+
def __getitem__(self, index) -> Rigid3Array:
|
42 |
+
return Rigid3Array(
|
43 |
+
self.rotation[index],
|
44 |
+
self.translation[index],
|
45 |
+
)
|
46 |
+
|
47 |
+
def __mul__(self, other: torch.Tensor) -> Rigid3Array:
|
48 |
+
return Rigid3Array(
|
49 |
+
self.rotation * other,
|
50 |
+
self.translation * other,
|
51 |
+
)
|
52 |
+
|
53 |
+
def map_tensor_fn(self, fn) -> Rigid3Array:
|
54 |
+
return Rigid3Array(
|
55 |
+
self.rotation.map_tensor_fn(fn),
|
56 |
+
self.translation.map_tensor_fn(fn),
|
57 |
+
)
|
58 |
+
|
59 |
+
def inverse(self) -> Rigid3Array:
|
60 |
+
"""Return Rigid3Array corresponding to inverse transform."""
|
61 |
+
inv_rotation = self.rotation.inverse()
|
62 |
+
inv_translation = inv_rotation.apply_to_point(-self.translation)
|
63 |
+
return Rigid3Array(inv_rotation, inv_translation)
|
64 |
+
|
65 |
+
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
|
66 |
+
"""Apply Rigid3Array transform to point."""
|
67 |
+
return self.rotation.apply_to_point(point) + self.translation
|
68 |
+
|
69 |
+
def apply(self, point: torch.Tensor) -> torch.Tensor:
|
70 |
+
return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
|
71 |
+
|
72 |
+
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
|
73 |
+
"""Apply inverse Rigid3Array transform to point."""
|
74 |
+
new_point = point - self.translation
|
75 |
+
return self.rotation.apply_inverse_to_point(new_point)
|
76 |
+
|
77 |
+
def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
|
78 |
+
return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
|
79 |
+
|
80 |
+
def compose_rotation(self, other_rotation):
|
81 |
+
rot = self.rotation @ other_rotation
|
82 |
+
return Rigid3Array(rot, self.translation.clone())
|
83 |
+
|
84 |
+
def compose(self, other_rigid):
|
85 |
+
return self @ other_rigid
|
86 |
+
|
87 |
+
def unsqueeze(self, dim: int):
|
88 |
+
return Rigid3Array(
|
89 |
+
self.rotation.unsqueeze(dim),
|
90 |
+
self.translation.unsqueeze(dim),
|
91 |
+
)
|
92 |
+
|
93 |
+
@property
|
94 |
+
def shape(self) -> torch.Size:
|
95 |
+
return self.rotation.xx.shape
|
96 |
+
|
97 |
+
@property
|
98 |
+
def dtype(self) -> torch.dtype:
|
99 |
+
return self.rotation.xx.dtype
|
100 |
+
|
101 |
+
@property
|
102 |
+
def device(self) -> torch.device:
|
103 |
+
return self.rotation.xx.device
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def identity(cls, shape, device) -> Rigid3Array:
|
107 |
+
"""Return identity Rigid3Array of given shape."""
|
108 |
+
return cls(
|
109 |
+
rotation_matrix.Rot3Array.identity(shape, device),
|
110 |
+
vector.Vec3Array.zeros(shape, device)
|
111 |
+
)
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
|
115 |
+
return cls(
|
116 |
+
rotation_matrix.Rot3Array.cat(
|
117 |
+
[r.rotation for r in rigids], dim=dim
|
118 |
+
),
|
119 |
+
vector.Vec3Array.cat(
|
120 |
+
[r.translation for r in rigids], dim=dim
|
121 |
+
),
|
122 |
+
)
|
123 |
+
|
124 |
+
def scale_translation(self, factor: Float) -> Rigid3Array:
|
125 |
+
"""Scale translation in Rigid3Array by 'factor'."""
|
126 |
+
return Rigid3Array(self.rotation, self.translation * factor)
|
127 |
+
|
128 |
+
def to_tensor(self) -> torch.Tensor:
|
129 |
+
rot_array = self.rotation.to_tensor()
|
130 |
+
vec_array = self.translation.to_tensor()
|
131 |
+
array = torch.zeros(
|
132 |
+
rot_array.shape[:-2] + (4, 4),
|
133 |
+
device=rot_array.device,
|
134 |
+
dtype=rot_array.dtype
|
135 |
+
)
|
136 |
+
array[..., :3, :3] = rot_array
|
137 |
+
array[..., :3, 3] = vec_array
|
138 |
+
array[..., 3, 3] = 1.
|
139 |
+
return array
|
140 |
+
|
141 |
+
def to_tensor_4x4(self) -> torch.Tensor:
|
142 |
+
return self.to_tensor()
|
143 |
+
|
144 |
+
def reshape(self, new_shape) -> Rigid3Array:
|
145 |
+
rots = self.rotation.reshape(new_shape)
|
146 |
+
trans = self.translation.reshape(new_shape)
|
147 |
+
return Rigid3Array(rots, trans)
|
148 |
+
|
149 |
+
def stop_rot_gradient(self) -> Rigid3Array:
|
150 |
+
return Rigid3Array(
|
151 |
+
self.rotation.stop_gradient(),
|
152 |
+
self.translation,
|
153 |
+
)
|
154 |
+
|
155 |
+
@classmethod
|
156 |
+
def from_array(cls, array):
|
157 |
+
rot = rotation_matrix.Rot3Array.from_array(
|
158 |
+
array[..., :3, :3],
|
159 |
+
)
|
160 |
+
vec = vector.Vec3Array.from_array(array[..., :3, 3])
|
161 |
+
return cls(rot, vec)
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def from_tensor_4x4(cls, array):
|
165 |
+
return cls.from_array(array)
|
166 |
+
|
167 |
+
@classmethod
|
168 |
+
def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
|
169 |
+
"""Construct Rigid3Array from homogeneous 4x4 array."""
|
170 |
+
rotation = rotation_matrix.Rot3Array(
|
171 |
+
array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
|
172 |
+
array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
|
173 |
+
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
|
174 |
+
)
|
175 |
+
translation = vector.Vec3Array(
|
176 |
+
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
|
177 |
+
)
|
178 |
+
return cls(rotation, translation)
|
179 |
+
|
180 |
+
def cuda(self) -> Rigid3Array:
|
181 |
+
return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())
|
dockformer/utils/geometry/rotation_matrix.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Rot3Array Matrix Class."""
|
15 |
+
|
16 |
+
from __future__ import annotations
|
17 |
+
import dataclasses
|
18 |
+
from typing import List
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from dockformer.utils.geometry import utils
|
23 |
+
from dockformer.utils.geometry import vector
|
24 |
+
from dockformer.utils.tensor_utils import tensor_tree_map
|
25 |
+
|
26 |
+
|
27 |
+
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
|
28 |
+
|
29 |
+
@dataclasses.dataclass(frozen=True)
|
30 |
+
class Rot3Array:
|
31 |
+
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
|
32 |
+
xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
|
33 |
+
xy: torch.Tensor
|
34 |
+
xz: torch.Tensor
|
35 |
+
yx: torch.Tensor
|
36 |
+
yy: torch.Tensor
|
37 |
+
yz: torch.Tensor
|
38 |
+
zx: torch.Tensor
|
39 |
+
zy: torch.Tensor
|
40 |
+
zz: torch.Tensor
|
41 |
+
|
42 |
+
__array_ufunc__ = None
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
field_names = utils.get_field_names(Rot3Array)
|
46 |
+
return Rot3Array(
|
47 |
+
**{
|
48 |
+
name: getattr(self, name)[index]
|
49 |
+
for name in field_names
|
50 |
+
}
|
51 |
+
)
|
52 |
+
|
53 |
+
def __mul__(self, other: torch.Tensor):
|
54 |
+
field_names = utils.get_field_names(Rot3Array)
|
55 |
+
return Rot3Array(
|
56 |
+
**{
|
57 |
+
name: getattr(self, name) * other
|
58 |
+
for name in field_names
|
59 |
+
}
|
60 |
+
)
|
61 |
+
|
62 |
+
def __matmul__(self, other: Rot3Array) -> Rot3Array:
|
63 |
+
"""Composes two Rot3Arrays."""
|
64 |
+
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
|
65 |
+
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
|
66 |
+
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
|
67 |
+
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
|
68 |
+
|
69 |
+
def map_tensor_fn(self, fn) -> Rot3Array:
|
70 |
+
field_names = utils.get_field_names(Rot3Array)
|
71 |
+
return Rot3Array(
|
72 |
+
**{
|
73 |
+
name: fn(getattr(self, name))
|
74 |
+
for name in field_names
|
75 |
+
}
|
76 |
+
)
|
77 |
+
|
78 |
+
def inverse(self) -> Rot3Array:
|
79 |
+
"""Returns inverse of Rot3Array."""
|
80 |
+
return Rot3Array(
|
81 |
+
self.xx, self.yx, self.zx,
|
82 |
+
self.xy, self.yy, self.zy,
|
83 |
+
self.xz, self.yz, self.zz
|
84 |
+
)
|
85 |
+
|
86 |
+
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
|
87 |
+
"""Applies Rot3Array to point."""
|
88 |
+
return vector.Vec3Array(
|
89 |
+
self.xx * point.x + self.xy * point.y + self.xz * point.z,
|
90 |
+
self.yx * point.x + self.yy * point.y + self.yz * point.z,
|
91 |
+
self.zx * point.x + self.zy * point.y + self.zz * point.z
|
92 |
+
)
|
93 |
+
|
94 |
+
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
|
95 |
+
"""Applies inverse Rot3Array to point."""
|
96 |
+
return self.inverse().apply_to_point(point)
|
97 |
+
|
98 |
+
|
99 |
+
def unsqueeze(self, dim: int):
|
100 |
+
return Rot3Array(
|
101 |
+
*tensor_tree_map(
|
102 |
+
lambda t: t.unsqueeze(dim),
|
103 |
+
[getattr(self, c) for c in COMPONENTS]
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
def stop_gradient(self) -> Rot3Array:
|
108 |
+
return Rot3Array(
|
109 |
+
*[getattr(self, c).detach() for c in COMPONENTS]
|
110 |
+
)
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def identity(cls, shape, device) -> Rot3Array:
|
114 |
+
"""Returns identity of given shape."""
|
115 |
+
ones = torch.ones(shape, dtype=torch.float32, device=device)
|
116 |
+
zeros = torch.zeros(shape, dtype=torch.float32, device=device)
|
117 |
+
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
|
118 |
+
|
119 |
+
@classmethod
|
120 |
+
def from_two_vectors(
|
121 |
+
cls, e0: vector.Vec3Array,
|
122 |
+
e1: vector.Vec3Array
|
123 |
+
) -> Rot3Array:
|
124 |
+
"""Construct Rot3Array from two Vectors.
|
125 |
+
|
126 |
+
Rot3Array is constructed such that in the corresponding frame 'e0' lies on
|
127 |
+
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
e0: Vector
|
131 |
+
e1: Vector
|
132 |
+
Returns:
|
133 |
+
Rot3Array
|
134 |
+
"""
|
135 |
+
# Normalize the unit vector for the x-axis, e0.
|
136 |
+
e0 = e0.normalized()
|
137 |
+
# make e1 perpendicular to e0.
|
138 |
+
c = e1.dot(e0)
|
139 |
+
e1 = (e1 - c * e0).normalized()
|
140 |
+
# Compute e2 as cross product of e0 and e1.
|
141 |
+
e2 = e0.cross(e1)
|
142 |
+
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
|
143 |
+
|
144 |
+
@classmethod
|
145 |
+
def from_array(cls, array: torch.Tensor) -> Rot3Array:
|
146 |
+
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
|
147 |
+
rows = torch.unbind(array, dim=-2)
|
148 |
+
rc = [torch.unbind(e, dim=-1) for e in rows]
|
149 |
+
return cls(*[e for row in rc for e in row])
|
150 |
+
|
151 |
+
def to_tensor(self) -> torch.Tensor:
|
152 |
+
"""Convert Rot3Array to array of shape [..., 3, 3]."""
|
153 |
+
return torch.stack(
|
154 |
+
[
|
155 |
+
torch.stack([self.xx, self.xy, self.xz], dim=-1),
|
156 |
+
torch.stack([self.yx, self.yy, self.yz], dim=-1),
|
157 |
+
torch.stack([self.zx, self.zy, self.zz], dim=-1)
|
158 |
+
],
|
159 |
+
dim=-2
|
160 |
+
)
|
161 |
+
|
162 |
+
@classmethod
|
163 |
+
def from_quaternion(cls,
|
164 |
+
w: torch.Tensor,
|
165 |
+
x: torch.Tensor,
|
166 |
+
y: torch.Tensor,
|
167 |
+
z: torch.Tensor,
|
168 |
+
normalize: bool = True,
|
169 |
+
eps: float = 1e-6
|
170 |
+
) -> Rot3Array:
|
171 |
+
"""Construct Rot3Array from components of quaternion."""
|
172 |
+
if normalize:
|
173 |
+
inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
|
174 |
+
w = w * inv_norm
|
175 |
+
x = x * inv_norm
|
176 |
+
y = y * inv_norm
|
177 |
+
z = z * inv_norm
|
178 |
+
xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
|
179 |
+
xy = 2.0 * (x * y - w * z)
|
180 |
+
xz = 2.0 * (x * z + w * y)
|
181 |
+
yx = 2.0 * (x * y + w * z)
|
182 |
+
yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
|
183 |
+
yz = 2.0 * (y * z - w * x)
|
184 |
+
zx = 2.0 * (x * z - w * y)
|
185 |
+
zy = 2.0 * (y * z + w * x)
|
186 |
+
zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
|
187 |
+
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
|
188 |
+
|
189 |
+
def reshape(self, new_shape):
|
190 |
+
field_names = utils.get_field_names(Rot3Array)
|
191 |
+
reshape_fn = lambda t: t.reshape(new_shape)
|
192 |
+
return Rot3Array(
|
193 |
+
**{
|
194 |
+
name: reshape_fn(getattr(self, name))
|
195 |
+
for name in field_names
|
196 |
+
}
|
197 |
+
)
|
198 |
+
|
199 |
+
@classmethod
|
200 |
+
def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array:
|
201 |
+
field_names = utils.get_field_names(Rot3Array)
|
202 |
+
cat_fn = lambda l: torch.cat(l, dim=dim)
|
203 |
+
return cls(
|
204 |
+
**{
|
205 |
+
name: cat_fn([getattr(r, name) for r in rots])
|
206 |
+
for name in field_names
|
207 |
+
}
|
208 |
+
)
|
dockformer/utils/geometry/test_utils.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Shared utils for tests."""
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from dockformer.utils.geometry import rigid_matrix_vector
|
20 |
+
from dockformer.utils.geometry import rotation_matrix
|
21 |
+
from dockformer.utils.geometry import vector
|
22 |
+
|
23 |
+
|
24 |
+
def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
|
25 |
+
matrix2: rotation_matrix.Rot3Array):
|
26 |
+
for field in dataclasses.fields(rotation_matrix.Rot3Array):
|
27 |
+
field = field.name
|
28 |
+
assert torch.equal(
|
29 |
+
getattr(matrix1, field), getattr(matrix2, field))
|
30 |
+
|
31 |
+
|
32 |
+
def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
|
33 |
+
mat2: rotation_matrix.Rot3Array):
|
34 |
+
assert torch.allclose(mat1.to_tensor(), mat2.to_tensor(), atol=1e-6)
|
35 |
+
|
36 |
+
|
37 |
+
def assert_array_equal_to_rotation_matrix(array: torch.Tensor,
|
38 |
+
matrix: rotation_matrix.Rot3Array):
|
39 |
+
"""Check that array and Matrix match."""
|
40 |
+
assert torch.equal(matrix.xx, array[..., 0, 0])
|
41 |
+
assert torch.equal(matrix.xy, array[..., 0, 1])
|
42 |
+
assert torch.equal(matrix.xz, array[..., 0, 2])
|
43 |
+
assert torch.equal(matrix.yx, array[..., 1, 0])
|
44 |
+
assert torch.equal(matrix.yy, array[..., 1, 1])
|
45 |
+
assert torch.equal(matrix.yz, array[..., 1, 2])
|
46 |
+
assert torch.equal(matrix.zx, array[..., 2, 0])
|
47 |
+
assert torch.equal(matrix.zy, array[..., 2, 1])
|
48 |
+
assert torch.equal(matrix.zz, array[..., 2, 2])
|
49 |
+
|
50 |
+
|
51 |
+
def assert_array_close_to_rotation_matrix(array: torch.Tensor,
|
52 |
+
matrix: rotation_matrix.Rot3Array):
|
53 |
+
assert torch.allclose(matrix.to_tensor(), array, atol=1e-6)
|
54 |
+
|
55 |
+
|
56 |
+
def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
|
57 |
+
assert torch.equal(vec1.x, vec2.x)
|
58 |
+
assert torch.equal(vec1.y, vec2.y)
|
59 |
+
assert torch.equal(vec1.z, vec2.z)
|
60 |
+
|
61 |
+
|
62 |
+
def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
|
63 |
+
assert torch.allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
|
64 |
+
assert torch.allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
|
65 |
+
assert torch.allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
|
66 |
+
|
67 |
+
|
68 |
+
def assert_array_close_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
|
69 |
+
assert torch.allclose(vec.to_tensor(), array, atol=1e-6, rtol=0.)
|
70 |
+
|
71 |
+
|
72 |
+
def assert_array_equal_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
|
73 |
+
assert torch.equal(vec.to_tensor(), array)
|
74 |
+
|
75 |
+
|
76 |
+
def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
|
77 |
+
rigid2: rigid_matrix_vector.Rigid3Array):
|
78 |
+
assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
|
79 |
+
|
80 |
+
|
81 |
+
def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
|
82 |
+
rigid2: rigid_matrix_vector.Rigid3Array):
|
83 |
+
assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
|
84 |
+
|
85 |
+
|
86 |
+
def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
|
87 |
+
trans: vector.Vec3Array,
|
88 |
+
rigid: rigid_matrix_vector.Rigid3Array):
|
89 |
+
assert_rotation_matrix_equal(rot, rigid.rotation)
|
90 |
+
assert_vectors_equal(trans, rigid.translation)
|
91 |
+
|
92 |
+
|
93 |
+
def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
|
94 |
+
trans: vector.Vec3Array,
|
95 |
+
rigid: rigid_matrix_vector.Rigid3Array):
|
96 |
+
assert_rotation_matrix_close(rot, rigid.rotation)
|
97 |
+
assert_vectors_close(trans, rigid.translation)
|
dockformer/utils/geometry/utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utils for geometry library."""
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
|
18 |
+
|
19 |
+
def get_field_names(cls):
|
20 |
+
fields = dataclasses.fields(cls)
|
21 |
+
field_names = [f.name for f in fields]
|
22 |
+
return field_names
|
dockformer/utils/geometry/vector.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Vec3Array Class."""
|
15 |
+
|
16 |
+
from __future__ import annotations
|
17 |
+
import dataclasses
|
18 |
+
from typing import Union, List
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
Float = Union[float, torch.Tensor]
|
23 |
+
|
24 |
+
@dataclasses.dataclass(frozen=True)
|
25 |
+
class Vec3Array:
|
26 |
+
x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
|
27 |
+
y: torch.Tensor
|
28 |
+
z: torch.Tensor
|
29 |
+
|
30 |
+
def __post_init__(self):
|
31 |
+
if hasattr(self.x, 'dtype'):
|
32 |
+
assert self.x.dtype == self.y.dtype
|
33 |
+
assert self.x.dtype == self.z.dtype
|
34 |
+
assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
|
35 |
+
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
|
36 |
+
|
37 |
+
def __add__(self, other: Vec3Array) -> Vec3Array:
|
38 |
+
return Vec3Array(
|
39 |
+
self.x + other.x,
|
40 |
+
self.y + other.y,
|
41 |
+
self.z + other.z,
|
42 |
+
)
|
43 |
+
|
44 |
+
def __sub__(self, other: Vec3Array) -> Vec3Array:
|
45 |
+
return Vec3Array(
|
46 |
+
self.x - other.x,
|
47 |
+
self.y - other.y,
|
48 |
+
self.z - other.z,
|
49 |
+
)
|
50 |
+
|
51 |
+
def __mul__(self, other: Float) -> Vec3Array:
|
52 |
+
return Vec3Array(
|
53 |
+
self.x * other,
|
54 |
+
self.y * other,
|
55 |
+
self.z * other,
|
56 |
+
)
|
57 |
+
|
58 |
+
def __rmul__(self, other: Float) -> Vec3Array:
|
59 |
+
return self * other
|
60 |
+
|
61 |
+
def __truediv__(self, other: Float) -> Vec3Array:
|
62 |
+
return Vec3Array(
|
63 |
+
self.x / other,
|
64 |
+
self.y / other,
|
65 |
+
self.z / other,
|
66 |
+
)
|
67 |
+
|
68 |
+
def __neg__(self) -> Vec3Array:
|
69 |
+
return self * -1
|
70 |
+
|
71 |
+
def __pos__(self) -> Vec3Array:
|
72 |
+
return self * 1
|
73 |
+
|
74 |
+
def __getitem__(self, index) -> Vec3Array:
|
75 |
+
return Vec3Array(
|
76 |
+
self.x[index],
|
77 |
+
self.y[index],
|
78 |
+
self.z[index],
|
79 |
+
)
|
80 |
+
|
81 |
+
def __iter__(self):
|
82 |
+
return iter((self.x, self.y, self.z))
|
83 |
+
|
84 |
+
@property
|
85 |
+
def shape(self):
|
86 |
+
return self.x.shape
|
87 |
+
|
88 |
+
def map_tensor_fn(self, fn) -> Vec3Array:
|
89 |
+
return Vec3Array(
|
90 |
+
fn(self.x),
|
91 |
+
fn(self.y),
|
92 |
+
fn(self.z),
|
93 |
+
)
|
94 |
+
|
95 |
+
def cross(self, other: Vec3Array) -> Vec3Array:
|
96 |
+
"""Compute cross product between 'self' and 'other'."""
|
97 |
+
new_x = self.y * other.z - self.z * other.y
|
98 |
+
new_y = self.z * other.x - self.x * other.z
|
99 |
+
new_z = self.x * other.y - self.y * other.x
|
100 |
+
return Vec3Array(new_x, new_y, new_z)
|
101 |
+
|
102 |
+
def dot(self, other: Vec3Array) -> Float:
|
103 |
+
"""Compute dot product between 'self' and 'other'."""
|
104 |
+
return self.x * other.x + self.y * other.y + self.z * other.z
|
105 |
+
|
106 |
+
def norm(self, epsilon: float = 1e-6) -> Float:
|
107 |
+
"""Compute Norm of Vec3Array, clipped to epsilon."""
|
108 |
+
# To avoid NaN on the backward pass, we must use maximum before the sqrt
|
109 |
+
norm2 = self.dot(self)
|
110 |
+
if epsilon:
|
111 |
+
norm2 = torch.clamp(norm2, min=epsilon**2)
|
112 |
+
return torch.sqrt(norm2)
|
113 |
+
|
114 |
+
def norm2(self):
|
115 |
+
return self.dot(self)
|
116 |
+
|
117 |
+
def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
|
118 |
+
"""Return unit vector with optional clipping."""
|
119 |
+
return self / self.norm(epsilon)
|
120 |
+
|
121 |
+
def clone(self) -> Vec3Array:
|
122 |
+
return Vec3Array(
|
123 |
+
self.x.clone(),
|
124 |
+
self.y.clone(),
|
125 |
+
self.z.clone(),
|
126 |
+
)
|
127 |
+
|
128 |
+
def reshape(self, new_shape) -> Vec3Array:
|
129 |
+
x = self.x.reshape(new_shape)
|
130 |
+
y = self.y.reshape(new_shape)
|
131 |
+
z = self.z.reshape(new_shape)
|
132 |
+
|
133 |
+
return Vec3Array(x, y, z)
|
134 |
+
|
135 |
+
def sum(self, dim: int) -> Vec3Array:
|
136 |
+
return Vec3Array(
|
137 |
+
torch.sum(self.x, dim=dim),
|
138 |
+
torch.sum(self.y, dim=dim),
|
139 |
+
torch.sum(self.z, dim=dim),
|
140 |
+
)
|
141 |
+
|
142 |
+
def unsqueeze(self, dim: int):
|
143 |
+
return Vec3Array(
|
144 |
+
self.x.unsqueeze(dim),
|
145 |
+
self.y.unsqueeze(dim),
|
146 |
+
self.z.unsqueeze(dim),
|
147 |
+
)
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def zeros(cls, shape, device="cpu"):
|
151 |
+
"""Return Vec3Array corresponding to zeros of given shape."""
|
152 |
+
return cls(
|
153 |
+
torch.zeros(shape, dtype=torch.float32, device=device),
|
154 |
+
torch.zeros(shape, dtype=torch.float32, device=device),
|
155 |
+
torch.zeros(shape, dtype=torch.float32, device=device)
|
156 |
+
)
|
157 |
+
|
158 |
+
def to_tensor(self) -> torch.Tensor:
|
159 |
+
return torch.stack([self.x, self.y, self.z], dim=-1)
|
160 |
+
|
161 |
+
@classmethod
|
162 |
+
def from_array(cls, tensor):
|
163 |
+
return cls(*torch.unbind(tensor, dim=-1))
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
|
167 |
+
return cls(
|
168 |
+
torch.cat([v.x for v in vecs], dim=dim),
|
169 |
+
torch.cat([v.y for v in vecs], dim=dim),
|
170 |
+
torch.cat([v.z for v in vecs], dim=dim),
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
def square_euclidean_distance(
|
175 |
+
vec1: Vec3Array,
|
176 |
+
vec2: Vec3Array,
|
177 |
+
epsilon: float = 1e-6
|
178 |
+
) -> Float:
|
179 |
+
"""Computes square of euclidean distance between 'vec1' and 'vec2'.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
vec1: Vec3Array to compute distance to
|
183 |
+
vec2: Vec3Array to compute distance from, should be
|
184 |
+
broadcast compatible with 'vec1'
|
185 |
+
epsilon: distance is clipped from below to be at least epsilon
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Array of square euclidean distances;
|
189 |
+
shape will be result of broadcasting 'vec1' and 'vec2'
|
190 |
+
"""
|
191 |
+
difference = vec1 - vec2
|
192 |
+
distance = difference.dot(difference)
|
193 |
+
if epsilon:
|
194 |
+
distance = torch.clamp(distance, min=epsilon)
|
195 |
+
return distance
|
196 |
+
|
197 |
+
|
198 |
+
def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
|
199 |
+
return vector1.dot(vector2)
|
200 |
+
|
201 |
+
|
202 |
+
def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
|
203 |
+
return vector1.cross(vector2)
|
204 |
+
|
205 |
+
|
206 |
+
def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
|
207 |
+
return vector.norm(epsilon)
|
208 |
+
|
209 |
+
|
210 |
+
def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
|
211 |
+
return vector.normalized(epsilon)
|
212 |
+
|
213 |
+
|
214 |
+
def euclidean_distance(
|
215 |
+
vec1: Vec3Array,
|
216 |
+
vec2: Vec3Array,
|
217 |
+
epsilon: float = 1e-6
|
218 |
+
) -> Float:
|
219 |
+
"""Computes euclidean distance between 'vec1' and 'vec2'.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
vec1: Vec3Array to compute euclidean distance to
|
223 |
+
vec2: Vec3Array to compute euclidean distance from, should be
|
224 |
+
broadcast compatible with 'vec1'
|
225 |
+
epsilon: distance is clipped from below to be at least epsilon
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Array of euclidean distances;
|
229 |
+
shape will be result of broadcasting 'vec1' and 'vec2'
|
230 |
+
"""
|
231 |
+
distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
|
232 |
+
distance = torch.sqrt(distance_sq)
|
233 |
+
return distance
|
234 |
+
|
235 |
+
|
236 |
+
def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
|
237 |
+
d: Vec3Array) -> Float:
|
238 |
+
"""Computes torsion angle for a quadruple of points.
|
239 |
+
|
240 |
+
For points (a, b, c, d), this is the angle between the planes defined by
|
241 |
+
points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
|
242 |
+
|
243 |
+
Arguments:
|
244 |
+
a: A Vec3Array of coordinates.
|
245 |
+
b: A Vec3Array of coordinates.
|
246 |
+
c: A Vec3Array of coordinates.
|
247 |
+
d: A Vec3Array of coordinates.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
A tensor of angles in radians: [-pi, pi].
|
251 |
+
"""
|
252 |
+
v1 = a - b
|
253 |
+
v2 = b - c
|
254 |
+
v3 = d - c
|
255 |
+
|
256 |
+
c1 = v1.cross(v2)
|
257 |
+
c2 = v3.cross(v2)
|
258 |
+
c3 = c2.cross(c1)
|
259 |
+
|
260 |
+
v2_mag = v2.norm()
|
261 |
+
return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2))
|
dockformer/utils/kernel/__init__.py
ADDED
File without changes
|
dockformer/utils/kernel/attention_core.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import importlib
|
15 |
+
from functools import reduce
|
16 |
+
from operator import mul
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
# TODO bshor: solve attn_core_is_installed in mac
|
21 |
+
attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None
|
22 |
+
attn_core_inplace_cuda = None
|
23 |
+
if attn_core_is_installed:
|
24 |
+
attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
|
25 |
+
|
26 |
+
|
27 |
+
SUPPORTED_DTYPES = [torch.float32, torch.bfloat16]
|
28 |
+
|
29 |
+
|
30 |
+
class AttentionCoreFunction(torch.autograd.Function):
|
31 |
+
@staticmethod
|
32 |
+
def forward(ctx, q, k, v, bias_1=None, bias_2=None):
|
33 |
+
if(bias_1 is None and bias_2 is not None):
|
34 |
+
raise ValueError("bias_1 must be specified before bias_2")
|
35 |
+
if(q.dtype not in SUPPORTED_DTYPES):
|
36 |
+
raise ValueError("Unsupported datatype")
|
37 |
+
|
38 |
+
q = q.contiguous()
|
39 |
+
k = k.contiguous()
|
40 |
+
|
41 |
+
# [*, H, Q, K]
|
42 |
+
attention_logits = torch.matmul(
|
43 |
+
q, k.transpose(-1, -2),
|
44 |
+
)
|
45 |
+
|
46 |
+
if(bias_1 is not None):
|
47 |
+
attention_logits += bias_1
|
48 |
+
if(bias_2 is not None):
|
49 |
+
attention_logits += bias_2
|
50 |
+
|
51 |
+
attn_core_inplace_cuda.forward_(
|
52 |
+
attention_logits,
|
53 |
+
reduce(mul, attention_logits.shape[:-1]),
|
54 |
+
attention_logits.shape[-1],
|
55 |
+
)
|
56 |
+
|
57 |
+
o = torch.matmul(attention_logits, v)
|
58 |
+
|
59 |
+
ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None
|
60 |
+
ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None
|
61 |
+
ctx.save_for_backward(q, k, v, attention_logits)
|
62 |
+
|
63 |
+
return o
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def backward(ctx, grad_output):
|
67 |
+
q, k, v, attention_logits = ctx.saved_tensors
|
68 |
+
grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None
|
69 |
+
|
70 |
+
grad_v = torch.matmul(
|
71 |
+
attention_logits.transpose(-1, -2),
|
72 |
+
grad_output
|
73 |
+
)
|
74 |
+
|
75 |
+
attn_core_inplace_cuda.backward_(
|
76 |
+
attention_logits,
|
77 |
+
grad_output.contiguous(),
|
78 |
+
v.contiguous(), # v is implicitly transposed in the kernel
|
79 |
+
reduce(mul, attention_logits.shape[:-1]),
|
80 |
+
attention_logits.shape[-1],
|
81 |
+
grad_output.shape[-1],
|
82 |
+
)
|
83 |
+
|
84 |
+
if(ctx.bias_1_shape is not None):
|
85 |
+
grad_bias_1 = torch.sum(
|
86 |
+
attention_logits,
|
87 |
+
dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1),
|
88 |
+
keepdim=True,
|
89 |
+
)
|
90 |
+
|
91 |
+
if(ctx.bias_2_shape is not None):
|
92 |
+
grad_bias_2 = torch.sum(
|
93 |
+
attention_logits,
|
94 |
+
dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1),
|
95 |
+
keepdim=True,
|
96 |
+
)
|
97 |
+
|
98 |
+
grad_q = torch.matmul(
|
99 |
+
attention_logits, k
|
100 |
+
)
|
101 |
+
grad_k = torch.matmul(
|
102 |
+
q.transpose(-1, -2), attention_logits,
|
103 |
+
).transpose(-1, -2)
|
104 |
+
|
105 |
+
return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2
|
106 |
+
|
107 |
+
attention_core = AttentionCoreFunction.apply
|
dockformer/utils/kernel/csrc/compat.h
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
2 |
+
|
3 |
+
#ifndef TORCH_CHECK
|
4 |
+
#define TORCH_CHECK AT_CHECK
|
5 |
+
#endif
|
6 |
+
|
7 |
+
#ifdef VERSION_GE_1_3
|
8 |
+
#define DATA_PTR data_ptr
|
9 |
+
#else
|
10 |
+
#define DATA_PTR data
|
11 |
+
#endif
|
dockformer/utils/kernel/csrc/softmax_cuda.cpp
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 AlQuraishi Laboratory
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
|
16 |
+
|
17 |
+
#include <torch/extension.h>
|
18 |
+
|
19 |
+
void attn_softmax_inplace_forward_(
|
20 |
+
at::Tensor input,
|
21 |
+
long long rows, int cols
|
22 |
+
);
|
23 |
+
void attn_softmax_inplace_backward_(
|
24 |
+
at::Tensor output,
|
25 |
+
at::Tensor d_ov,
|
26 |
+
at::Tensor values,
|
27 |
+
long long rows,
|
28 |
+
int cols_output,
|
29 |
+
int cols_values
|
30 |
+
);
|
31 |
+
|
32 |
+
|
33 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
34 |
+
m.def(
|
35 |
+
"forward_",
|
36 |
+
&attn_softmax_inplace_forward_,
|
37 |
+
"Softmax forward (CUDA)"
|
38 |
+
);
|
39 |
+
m.def(
|
40 |
+
"backward_",
|
41 |
+
&attn_softmax_inplace_backward_,
|
42 |
+
"Softmax backward (CUDA)"
|
43 |
+
);
|
44 |
+
}
|
dockformer/utils/kernel/csrc/softmax_cuda_kernel.cu
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 AlQuraishi Laboratory
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
|
16 |
+
|
17 |
+
#include <math_constants.h>
|
18 |
+
#include <torch/extension.h>
|
19 |
+
#include <c10/cuda/CUDAGuard.h>
|
20 |
+
|
21 |
+
#include <iostream>
|
22 |
+
|
23 |
+
#include "ATen/ATen.h"
|
24 |
+
#include "ATen/cuda/CUDAContext.h"
|
25 |
+
#include "compat.h"
|
26 |
+
|
27 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
28 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
29 |
+
#define CHECK_INPUT(x) \
|
30 |
+
CHECK_CUDA(x); \
|
31 |
+
CHECK_CONTIGUOUS(x)
|
32 |
+
|
33 |
+
__inline__ __device__ float WarpAllReduceMax(float val) {
|
34 |
+
for (int mask = 1; mask < 32; mask *= 2) {
|
35 |
+
val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
|
36 |
+
}
|
37 |
+
return val;
|
38 |
+
}
|
39 |
+
|
40 |
+
__inline__ __device__ float WarpAllReduceSum(float val) {
|
41 |
+
for (int mask = 1; mask < 32; mask *= 2) {
|
42 |
+
val += __shfl_xor_sync(0xffffffff, val, mask);
|
43 |
+
}
|
44 |
+
return val;
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
template<typename T>
|
49 |
+
__global__ void attn_softmax_inplace_(
|
50 |
+
T *input,
|
51 |
+
long long rows, int cols
|
52 |
+
) {
|
53 |
+
int threadidx_x = threadIdx.x / 32;
|
54 |
+
int threadidx_y = threadIdx.x % 32;
|
55 |
+
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
|
56 |
+
int cols_per_thread = (cols + 31) / 32;
|
57 |
+
int cols_this_thread = cols_per_thread;
|
58 |
+
|
59 |
+
int last_y = (cols / cols_per_thread);
|
60 |
+
|
61 |
+
if (threadidx_y == last_y) {
|
62 |
+
cols_this_thread = cols - cols_per_thread * last_y;
|
63 |
+
}
|
64 |
+
else if (threadidx_y > last_y) {
|
65 |
+
cols_this_thread = 0;
|
66 |
+
}
|
67 |
+
|
68 |
+
float buf[32];
|
69 |
+
|
70 |
+
int lane_id = threadidx_y;
|
71 |
+
|
72 |
+
if (row_offset < rows) {
|
73 |
+
T *row_input = input + row_offset * cols;
|
74 |
+
T *row_output = row_input;
|
75 |
+
|
76 |
+
#pragma unroll
|
77 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
78 |
+
int idx = lane_id * cols_per_thread + i;
|
79 |
+
buf[i] = static_cast<float>(row_input[idx]);
|
80 |
+
}
|
81 |
+
|
82 |
+
float thread_max = -1 * CUDART_INF_F;
|
83 |
+
#pragma unroll
|
84 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
85 |
+
thread_max = max(thread_max, buf[i]);
|
86 |
+
}
|
87 |
+
|
88 |
+
float warp_max = WarpAllReduceMax(thread_max);
|
89 |
+
|
90 |
+
float thread_sum = 0.f;
|
91 |
+
#pragma unroll
|
92 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
93 |
+
buf[i] = __expf(buf[i] - warp_max);
|
94 |
+
thread_sum += buf[i];
|
95 |
+
}
|
96 |
+
|
97 |
+
float warp_sum = WarpAllReduceSum(thread_sum);
|
98 |
+
#pragma unroll
|
99 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
100 |
+
row_output[lane_id * cols_per_thread + i] =
|
101 |
+
static_cast<T>(__fdividef(buf[i], warp_sum));
|
102 |
+
}
|
103 |
+
}
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
void attn_softmax_inplace_forward_(
|
108 |
+
at::Tensor input,
|
109 |
+
long long rows, int cols
|
110 |
+
) {
|
111 |
+
CHECK_INPUT(input);
|
112 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
113 |
+
|
114 |
+
int grid = (rows + 3) / 4;
|
115 |
+
dim3 block(128);
|
116 |
+
|
117 |
+
if (input.dtype() == torch::kFloat32) {
|
118 |
+
attn_softmax_inplace_<float><<<grid, block>>>(
|
119 |
+
(float *)input.data_ptr(),
|
120 |
+
rows, cols
|
121 |
+
);
|
122 |
+
}
|
123 |
+
else {
|
124 |
+
attn_softmax_inplace_<at::BFloat16><<<grid, block>>>(
|
125 |
+
(at::BFloat16 *)input.data_ptr(),
|
126 |
+
rows, cols
|
127 |
+
);
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
|
132 |
+
template<typename T>
|
133 |
+
__global__ void attn_softmax_inplace_grad_(
|
134 |
+
T *output,
|
135 |
+
T *d_ov,
|
136 |
+
T *values,
|
137 |
+
long long rows,
|
138 |
+
int cols_output,
|
139 |
+
int cols_values
|
140 |
+
) {
|
141 |
+
int threadidx_x = threadIdx.x / 32;
|
142 |
+
int threadidx_y = threadIdx.x % 32;
|
143 |
+
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
|
144 |
+
int cols_per_thread = (cols_output + 31) / 32;
|
145 |
+
int cols_this_thread = cols_per_thread;
|
146 |
+
int rows_values = cols_output;
|
147 |
+
// values are set to the beginning of the current
|
148 |
+
// rows_values x cols_values leaf matrix
|
149 |
+
long long value_row_offset = row_offset - row_offset % rows_values;
|
150 |
+
int last_y = (cols_output / cols_per_thread);
|
151 |
+
|
152 |
+
if (threadidx_y == last_y) {
|
153 |
+
cols_this_thread = cols_output - cols_per_thread * last_y;
|
154 |
+
}
|
155 |
+
else if (threadidx_y > last_y) {
|
156 |
+
cols_this_thread = 0;
|
157 |
+
}
|
158 |
+
|
159 |
+
float y_buf[32];
|
160 |
+
float dy_buf[32];
|
161 |
+
|
162 |
+
int lane_id = threadidx_y;
|
163 |
+
|
164 |
+
if (row_offset < rows) {
|
165 |
+
T *row_output = output + row_offset * cols_output;
|
166 |
+
T *row_d_ov = d_ov + row_offset * cols_values;
|
167 |
+
T *row_values = values + value_row_offset * cols_values;
|
168 |
+
|
169 |
+
float thread_max = -1 * CUDART_INF_F;
|
170 |
+
|
171 |
+
// Compute a chunk of the output gradient on the fly
|
172 |
+
int value_row_idx = 0;
|
173 |
+
int value_idx = 0;
|
174 |
+
#pragma unroll
|
175 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
176 |
+
T sum = 0.;
|
177 |
+
#pragma unroll
|
178 |
+
for (int j = 0; j < cols_values; j++) {
|
179 |
+
value_row_idx = ((lane_id * cols_per_thread) + i);
|
180 |
+
value_idx = value_row_idx * cols_values + j;
|
181 |
+
sum += row_d_ov[j] * row_values[value_idx];
|
182 |
+
}
|
183 |
+
dy_buf[i] = static_cast<float>(sum);
|
184 |
+
}
|
185 |
+
|
186 |
+
#pragma unroll
|
187 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
188 |
+
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
|
189 |
+
}
|
190 |
+
|
191 |
+
float thread_sum = 0.;
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
195 |
+
thread_sum += y_buf[i] * dy_buf[i];
|
196 |
+
}
|
197 |
+
|
198 |
+
float warp_sum = WarpAllReduceSum(thread_sum);
|
199 |
+
|
200 |
+
#pragma unroll
|
201 |
+
for (int i = 0; i < cols_this_thread; i++) {
|
202 |
+
row_output[lane_id * cols_per_thread + i] = static_cast<T>(
|
203 |
+
(dy_buf[i] - warp_sum) * y_buf[i]
|
204 |
+
);
|
205 |
+
}
|
206 |
+
}
|
207 |
+
}
|
208 |
+
|
209 |
+
|
210 |
+
void attn_softmax_inplace_backward_(
|
211 |
+
at::Tensor output,
|
212 |
+
at::Tensor d_ov,
|
213 |
+
at::Tensor values,
|
214 |
+
long long rows,
|
215 |
+
int cols_output,
|
216 |
+
int cols_values
|
217 |
+
) {
|
218 |
+
CHECK_INPUT(output);
|
219 |
+
CHECK_INPUT(d_ov);
|
220 |
+
CHECK_INPUT(values);
|
221 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
222 |
+
|
223 |
+
int grid = (rows + 3) / 4;
|
224 |
+
dim3 block(128);
|
225 |
+
|
226 |
+
if (output.dtype() == torch::kFloat32) {
|
227 |
+
attn_softmax_inplace_grad_<float><<<grid, block>>>(
|
228 |
+
(float *)output.data_ptr(),
|
229 |
+
(float *)d_ov.data_ptr(),
|
230 |
+
(float *)values.data_ptr(),
|
231 |
+
rows, cols_output, cols_values
|
232 |
+
);
|
233 |
+
} else {
|
234 |
+
attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>(
|
235 |
+
(at::BFloat16 *)output.data_ptr(),
|
236 |
+
(at::BFloat16 *)d_ov.data_ptr(),
|
237 |
+
(at::BFloat16 *)values.data_ptr(),
|
238 |
+
rows, cols_output, cols_values
|
239 |
+
);
|
240 |
+
}
|
241 |
+
}
|
dockformer/utils/kernel/csrc/softmax_cuda_stub.cpp
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 AlQuraishi Laboratory
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
|
16 |
+
|
17 |
+
#include <torch/extension.h>
|
18 |
+
|
19 |
+
void attn_softmax_inplace_forward_(
|
20 |
+
at::Tensor input,
|
21 |
+
long long rows, int cols
|
22 |
+
)
|
23 |
+
{
|
24 |
+
throw std::runtime_error("attn_softmax_inplace_forward_ not implemented on CPU");
|
25 |
+
};
|
26 |
+
void attn_softmax_inplace_backward_(
|
27 |
+
at::Tensor output,
|
28 |
+
at::Tensor d_ov,
|
29 |
+
at::Tensor values,
|
30 |
+
long long rows,
|
31 |
+
int cols_output,
|
32 |
+
int cols_values
|
33 |
+
)
|
34 |
+
{
|
35 |
+
throw std::runtime_error("attn_softmax_inplace_backward_ not implemented on CPU");
|
36 |
+
};
|
dockformer/utils/logger.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import operator
|
16 |
+
import time
|
17 |
+
|
18 |
+
import dllogger as logger
|
19 |
+
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
|
20 |
+
import numpy as np
|
21 |
+
from lightning import Callback
|
22 |
+
import torch.cuda.profiler as profiler
|
23 |
+
|
24 |
+
|
25 |
+
def is_main_process():
|
26 |
+
return int(os.getenv("LOCAL_RANK", "0")) == 0
|
27 |
+
|
28 |
+
|
29 |
+
class PerformanceLoggingCallback(Callback):
|
30 |
+
def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False):
|
31 |
+
logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)])
|
32 |
+
self.warmup_steps = warmup_steps
|
33 |
+
self.global_batch_size = global_batch_size
|
34 |
+
self.step = 0
|
35 |
+
self.profile = profile
|
36 |
+
self.timestamps = []
|
37 |
+
|
38 |
+
def do_step(self):
|
39 |
+
self.step += 1
|
40 |
+
if self.profile and self.step == self.warmup_steps:
|
41 |
+
profiler.start()
|
42 |
+
if self.step > self.warmup_steps:
|
43 |
+
self.timestamps.append(time.time())
|
44 |
+
|
45 |
+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
46 |
+
self.do_step()
|
47 |
+
|
48 |
+
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx: int = 0):
|
49 |
+
self.do_step()
|
50 |
+
|
51 |
+
def process_performance_stats(self, deltas):
|
52 |
+
def _round3(val):
|
53 |
+
return round(val, 3)
|
54 |
+
|
55 |
+
throughput_imgps = _round3(self.global_batch_size / np.mean(deltas))
|
56 |
+
timestamps_ms = 1000 * deltas
|
57 |
+
stats = {
|
58 |
+
f"throughput": throughput_imgps,
|
59 |
+
f"latency_mean": _round3(timestamps_ms.mean()),
|
60 |
+
}
|
61 |
+
for level in [90, 95, 99]:
|
62 |
+
stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))})
|
63 |
+
|
64 |
+
return stats
|
65 |
+
|
66 |
+
def _log(self):
|
67 |
+
if is_main_process():
|
68 |
+
diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1]))
|
69 |
+
deltas = np.array(diffs)
|
70 |
+
stats = self.process_performance_stats(deltas)
|
71 |
+
logger.log(step=(), data=stats)
|
72 |
+
logger.flush()
|
73 |
+
|
74 |
+
def on_train_end(self, trainer, pl_module):
|
75 |
+
if self.profile:
|
76 |
+
profiler.stop()
|
77 |
+
self._log()
|
78 |
+
|
79 |
+
def on_epoch_end(self, trainer, pl_module):
|
80 |
+
self._log()
|
dockformer/utils/loss.py
ADDED
@@ -0,0 +1,1171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import time
|
16 |
+
|
17 |
+
import ml_collections
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from typing import Dict, Optional, Tuple
|
22 |
+
|
23 |
+
from dockformer.utils import residue_constants
|
24 |
+
from dockformer.utils.feats import pseudo_beta_fn
|
25 |
+
from dockformer.utils.rigid_utils import Rotation, Rigid
|
26 |
+
from dockformer.utils.geometry.vector import Vec3Array, euclidean_distance
|
27 |
+
from dockformer.utils.tensor_utils import (
|
28 |
+
tree_map,
|
29 |
+
masked_mean,
|
30 |
+
permute_final_dims,
|
31 |
+
)
|
32 |
+
import logging
|
33 |
+
from dockformer.utils.tensor_utils import tensor_tree_map
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
def softmax_cross_entropy(logits, labels):
|
39 |
+
loss = -1 * torch.sum(
|
40 |
+
labels * torch.nn.functional.log_softmax(logits, dim=-1),
|
41 |
+
dim=-1,
|
42 |
+
)
|
43 |
+
return loss
|
44 |
+
|
45 |
+
|
46 |
+
def sigmoid_cross_entropy(logits, labels):
|
47 |
+
logits_dtype = logits.dtype
|
48 |
+
try:
|
49 |
+
logits = logits.double()
|
50 |
+
labels = labels.double()
|
51 |
+
except:
|
52 |
+
logits = logits.to(dtype=torch.float32)
|
53 |
+
labels = labels.to(dtype=torch.float32)
|
54 |
+
|
55 |
+
log_p = torch.nn.functional.logsigmoid(logits)
|
56 |
+
# log_p = torch.log(torch.sigmoid(logits))
|
57 |
+
log_not_p = torch.nn.functional.logsigmoid(-1 * logits)
|
58 |
+
# log_not_p = torch.log(torch.sigmoid(-logits))
|
59 |
+
loss = (-1. * labels) * log_p - (1. - labels) * log_not_p
|
60 |
+
loss = loss.to(dtype=logits_dtype)
|
61 |
+
return loss
|
62 |
+
|
63 |
+
|
64 |
+
def torsion_angle_loss(
|
65 |
+
a, # [*, N, 7, 2]
|
66 |
+
a_gt, # [*, N, 7, 2]
|
67 |
+
a_alt_gt, # [*, N, 7, 2]
|
68 |
+
):
|
69 |
+
# [*, N, 7]
|
70 |
+
norm = torch.norm(a, dim=-1)
|
71 |
+
|
72 |
+
# [*, N, 7, 2]
|
73 |
+
a = a / norm.unsqueeze(-1)
|
74 |
+
|
75 |
+
# [*, N, 7]
|
76 |
+
diff_norm_gt = torch.norm(a - a_gt, dim=-1)
|
77 |
+
diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
|
78 |
+
min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
|
79 |
+
|
80 |
+
# [*]
|
81 |
+
l_torsion = torch.mean(min_diff, dim=(-1, -2))
|
82 |
+
l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
|
83 |
+
|
84 |
+
an_weight = 0.02
|
85 |
+
return l_torsion + an_weight * l_angle_norm
|
86 |
+
|
87 |
+
|
88 |
+
def compute_fape(
|
89 |
+
pred_frames: Rigid,
|
90 |
+
target_frames: Rigid,
|
91 |
+
frames_mask: torch.Tensor,
|
92 |
+
pred_positions: torch.Tensor,
|
93 |
+
target_positions: torch.Tensor,
|
94 |
+
positions_mask: torch.Tensor,
|
95 |
+
length_scale: float,
|
96 |
+
pair_mask: Optional[torch.Tensor] = None,
|
97 |
+
l1_clamp_distance: Optional[float] = None,
|
98 |
+
eps=1e-8,
|
99 |
+
) -> torch.Tensor:
|
100 |
+
"""
|
101 |
+
Computes FAPE loss.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
pred_frames:
|
105 |
+
[*, N_frames] Rigid object of predicted frames
|
106 |
+
target_frames:
|
107 |
+
[*, N_frames] Rigid object of ground truth frames
|
108 |
+
frames_mask:
|
109 |
+
[*, N_frames] binary mask for the frames
|
110 |
+
pred_positions:
|
111 |
+
[*, N_pts, 3] predicted atom positions
|
112 |
+
target_positions:
|
113 |
+
[*, N_pts, 3] ground truth positions
|
114 |
+
positions_mask:
|
115 |
+
[*, N_pts] positions mask
|
116 |
+
length_scale:
|
117 |
+
Length scale by which the loss is divided
|
118 |
+
pair_mask:
|
119 |
+
[*, N_frames, N_pts] mask to use for
|
120 |
+
separating intra- from inter-chain losses.
|
121 |
+
l1_clamp_distance:
|
122 |
+
Cutoff above which distance errors are disregarded
|
123 |
+
eps:
|
124 |
+
Small value used to regularize denominators
|
125 |
+
Returns:
|
126 |
+
[*] loss tensor
|
127 |
+
"""
|
128 |
+
# [*, N_frames, N_pts, 3]
|
129 |
+
local_pred_pos = pred_frames.invert()[..., None].apply(
|
130 |
+
pred_positions[..., None, :, :],
|
131 |
+
)
|
132 |
+
local_target_pos = target_frames.invert()[..., None].apply(
|
133 |
+
target_positions[..., None, :, :],
|
134 |
+
)
|
135 |
+
|
136 |
+
error_dist = torch.sqrt(
|
137 |
+
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
|
138 |
+
)
|
139 |
+
|
140 |
+
if l1_clamp_distance is not None:
|
141 |
+
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
|
142 |
+
|
143 |
+
normed_error = error_dist / length_scale
|
144 |
+
normed_error = normed_error * frames_mask[..., None]
|
145 |
+
normed_error = normed_error * positions_mask[..., None, :]
|
146 |
+
|
147 |
+
if pair_mask is not None:
|
148 |
+
normed_error = normed_error * pair_mask
|
149 |
+
normed_error = torch.sum(normed_error, dim=(-1, -2))
|
150 |
+
|
151 |
+
mask = frames_mask[..., None] * positions_mask[..., None, :] * pair_mask
|
152 |
+
norm_factor = torch.sum(mask, dim=(-2, -1))
|
153 |
+
|
154 |
+
normed_error = normed_error / (eps + norm_factor)
|
155 |
+
else:
|
156 |
+
# FP16-friendly averaging. Roughly equivalent to:
|
157 |
+
#
|
158 |
+
# norm_factor = (
|
159 |
+
# torch.sum(frames_mask, dim=-1) *
|
160 |
+
# torch.sum(positions_mask, dim=-1)
|
161 |
+
# )
|
162 |
+
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
|
163 |
+
#
|
164 |
+
# ("roughly" because eps is necessarily duplicated in the latter)
|
165 |
+
normed_error = torch.sum(normed_error, dim=-1)
|
166 |
+
normed_error = (
|
167 |
+
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
|
168 |
+
)
|
169 |
+
normed_error = torch.sum(normed_error, dim=-1)
|
170 |
+
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
|
171 |
+
|
172 |
+
return normed_error
|
173 |
+
|
174 |
+
|
175 |
+
def backbone_loss(
|
176 |
+
backbone_rigid_tensor: torch.Tensor,
|
177 |
+
backbone_rigid_mask: torch.Tensor,
|
178 |
+
traj: torch.Tensor,
|
179 |
+
pair_mask: Optional[torch.Tensor] = None,
|
180 |
+
use_clamped_fape: Optional[torch.Tensor] = None,
|
181 |
+
clamp_distance: float = 10.0,
|
182 |
+
loss_unit_distance: float = 10.0,
|
183 |
+
eps: float = 1e-4,
|
184 |
+
**kwargs,
|
185 |
+
) -> torch.Tensor:
|
186 |
+
### need to check if the traj belongs to 4*4 matrix or a tensor_7
|
187 |
+
if traj.shape[-1] == 7:
|
188 |
+
pred_aff = Rigid.from_tensor_7(traj)
|
189 |
+
elif traj.shape[-1] == 4:
|
190 |
+
pred_aff = Rigid.from_tensor_4x4(traj)
|
191 |
+
|
192 |
+
pred_aff = Rigid(
|
193 |
+
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
|
194 |
+
pred_aff.get_trans(),
|
195 |
+
)
|
196 |
+
|
197 |
+
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
|
198 |
+
# backbone tensor, normalizes it, and then turns it back to a rotation
|
199 |
+
# matrix. To avoid a potentially numerically unstable rotation matrix
|
200 |
+
# to quaternion conversion, we just use the original rotation matrix
|
201 |
+
# outright. This one hasn't been composed a bunch of times, though, so
|
202 |
+
# it might be fine.
|
203 |
+
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
|
204 |
+
|
205 |
+
fape_loss = compute_fape(
|
206 |
+
pred_aff,
|
207 |
+
gt_aff[None],
|
208 |
+
backbone_rigid_mask[None],
|
209 |
+
pred_aff.get_trans(),
|
210 |
+
gt_aff[None].get_trans(),
|
211 |
+
backbone_rigid_mask[None],
|
212 |
+
pair_mask=pair_mask,
|
213 |
+
l1_clamp_distance=clamp_distance,
|
214 |
+
length_scale=loss_unit_distance,
|
215 |
+
eps=eps,
|
216 |
+
)
|
217 |
+
if use_clamped_fape is not None:
|
218 |
+
unclamped_fape_loss = compute_fape(
|
219 |
+
pred_aff,
|
220 |
+
gt_aff[None],
|
221 |
+
backbone_rigid_mask[None],
|
222 |
+
pred_aff.get_trans(),
|
223 |
+
gt_aff[None].get_trans(),
|
224 |
+
backbone_rigid_mask[None],
|
225 |
+
pair_mask=pair_mask,
|
226 |
+
l1_clamp_distance=None,
|
227 |
+
length_scale=loss_unit_distance,
|
228 |
+
eps=eps,
|
229 |
+
)
|
230 |
+
|
231 |
+
fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
|
232 |
+
1 - use_clamped_fape
|
233 |
+
)
|
234 |
+
|
235 |
+
# Average over the batch dimension
|
236 |
+
fape_loss = torch.mean(fape_loss)
|
237 |
+
|
238 |
+
return fape_loss
|
239 |
+
|
240 |
+
|
241 |
+
def sidechain_loss(
|
242 |
+
pred_sidechain_frames: torch.Tensor,
|
243 |
+
pred_sidechain_atom_pos: torch.Tensor,
|
244 |
+
rigidgroups_gt_frames: torch.Tensor,
|
245 |
+
rigidgroups_alt_gt_frames: torch.Tensor,
|
246 |
+
rigidgroups_gt_exists: torch.Tensor,
|
247 |
+
renamed_atom14_gt_positions: torch.Tensor,
|
248 |
+
renamed_atom14_gt_exists: torch.Tensor,
|
249 |
+
alt_naming_is_better: torch.Tensor,
|
250 |
+
ligand_mask: torch.Tensor,
|
251 |
+
clamp_distance: float = 10.0,
|
252 |
+
length_scale: float = 10.0,
|
253 |
+
eps: float = 1e-4,
|
254 |
+
only_include_ligand_atoms: bool = False,
|
255 |
+
**kwargs,
|
256 |
+
) -> torch.Tensor:
|
257 |
+
renamed_gt_frames = (
|
258 |
+
1.0 - alt_naming_is_better[..., None, None, None]
|
259 |
+
) * rigidgroups_gt_frames + alt_naming_is_better[
|
260 |
+
..., None, None, None
|
261 |
+
] * rigidgroups_alt_gt_frames
|
262 |
+
|
263 |
+
# Steamroll the inputs
|
264 |
+
pred_sidechain_frames = pred_sidechain_frames[-1] # get only the last layer of the strcuture module
|
265 |
+
batch_dims = pred_sidechain_frames.shape[:-4]
|
266 |
+
pred_sidechain_frames = pred_sidechain_frames.view(*batch_dims, -1, 4, 4)
|
267 |
+
pred_sidechain_frames = Rigid.from_tensor_4x4(pred_sidechain_frames)
|
268 |
+
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
|
269 |
+
renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
|
270 |
+
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
|
271 |
+
pred_sidechain_atom_pos = pred_sidechain_atom_pos[-1]
|
272 |
+
pred_sidechain_atom_pos = pred_sidechain_atom_pos.view(*batch_dims, -1, 3)
|
273 |
+
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
|
274 |
+
*batch_dims, -1, 3
|
275 |
+
)
|
276 |
+
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
|
277 |
+
|
278 |
+
atom_mask_to_apply = renamed_atom14_gt_exists
|
279 |
+
if only_include_ligand_atoms:
|
280 |
+
ligand_atom14_mask = torch.repeat_interleave(ligand_mask, 14, dim=-1)
|
281 |
+
atom_mask_to_apply = atom_mask_to_apply * ligand_atom14_mask
|
282 |
+
|
283 |
+
fape = compute_fape(
|
284 |
+
pred_sidechain_frames,
|
285 |
+
renamed_gt_frames,
|
286 |
+
rigidgroups_gt_exists,
|
287 |
+
pred_sidechain_atom_pos,
|
288 |
+
renamed_atom14_gt_positions,
|
289 |
+
atom_mask_to_apply,
|
290 |
+
pair_mask=None,
|
291 |
+
l1_clamp_distance=clamp_distance,
|
292 |
+
length_scale=length_scale,
|
293 |
+
eps=eps,
|
294 |
+
)
|
295 |
+
|
296 |
+
return fape
|
297 |
+
|
298 |
+
|
299 |
+
def fape_bb(
|
300 |
+
out: Dict[str, torch.Tensor],
|
301 |
+
batch: Dict[str, torch.Tensor],
|
302 |
+
config: ml_collections.ConfigDict,
|
303 |
+
) -> torch.Tensor:
|
304 |
+
traj = out["sm"]["frames"]
|
305 |
+
bb_loss = backbone_loss(
|
306 |
+
traj=traj,
|
307 |
+
**{**batch, **config},
|
308 |
+
)
|
309 |
+
loss = torch.mean(bb_loss)
|
310 |
+
return loss
|
311 |
+
|
312 |
+
|
313 |
+
def fape_sidechain(
|
314 |
+
out: Dict[str, torch.Tensor],
|
315 |
+
batch: Dict[str, torch.Tensor],
|
316 |
+
config: ml_collections.ConfigDict,
|
317 |
+
) -> torch.Tensor:
|
318 |
+
sc_loss = sidechain_loss(
|
319 |
+
out["sm"]["sidechain_frames"],
|
320 |
+
out["sm"]["positions"],
|
321 |
+
**{**batch, **config},
|
322 |
+
)
|
323 |
+
loss = torch.mean(sc_loss)
|
324 |
+
return loss
|
325 |
+
|
326 |
+
|
327 |
+
def fape_interface(
|
328 |
+
out: Dict[str, torch.Tensor],
|
329 |
+
batch: Dict[str, torch.Tensor],
|
330 |
+
config: ml_collections.ConfigDict,
|
331 |
+
) -> torch.Tensor:
|
332 |
+
sc_loss = sidechain_loss(
|
333 |
+
out["sm"]["sidechain_frames"],
|
334 |
+
out["sm"]["positions"],
|
335 |
+
only_include_ligand_atoms=True,
|
336 |
+
**{**batch, **config},
|
337 |
+
)
|
338 |
+
loss = torch.mean(sc_loss)
|
339 |
+
return loss
|
340 |
+
|
341 |
+
|
342 |
+
def supervised_chi_loss(
|
343 |
+
angles_sin_cos: torch.Tensor,
|
344 |
+
unnormalized_angles_sin_cos: torch.Tensor,
|
345 |
+
aatype: torch.Tensor,
|
346 |
+
protein_mask: torch.Tensor,
|
347 |
+
chi_mask: torch.Tensor,
|
348 |
+
chi_angles_sin_cos: torch.Tensor,
|
349 |
+
chi_weight: float,
|
350 |
+
angle_norm_weight: float,
|
351 |
+
eps=1e-6,
|
352 |
+
**kwargs,
|
353 |
+
) -> torch.Tensor:
|
354 |
+
"""
|
355 |
+
Implements Algorithm 27 (torsionAngleLoss)
|
356 |
+
|
357 |
+
Args:
|
358 |
+
angles_sin_cos:
|
359 |
+
[*, N, 7, 2] predicted angles
|
360 |
+
unnormalized_angles_sin_cos:
|
361 |
+
The same angles, but unnormalized
|
362 |
+
aatype:
|
363 |
+
[*, N] residue indices
|
364 |
+
protein_mask:
|
365 |
+
[*, N] protein mask
|
366 |
+
chi_mask:
|
367 |
+
[*, N, 7] angle mask
|
368 |
+
chi_angles_sin_cos:
|
369 |
+
[*, N, 7, 2] ground truth angles
|
370 |
+
chi_weight:
|
371 |
+
Weight for the angle component of the loss
|
372 |
+
angle_norm_weight:
|
373 |
+
Weight for the normalization component of the loss
|
374 |
+
Returns:
|
375 |
+
[*] loss tensor
|
376 |
+
"""
|
377 |
+
pred_angles = angles_sin_cos[..., 3:, :]
|
378 |
+
residue_type_one_hot = torch.nn.functional.one_hot(
|
379 |
+
aatype,
|
380 |
+
residue_constants.restype_num + 1,
|
381 |
+
)
|
382 |
+
chi_pi_periodic = torch.einsum(
|
383 |
+
"...ij,jk->ik",
|
384 |
+
residue_type_one_hot.type(angles_sin_cos.dtype),
|
385 |
+
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
|
386 |
+
)
|
387 |
+
|
388 |
+
true_chi = chi_angles_sin_cos[None]
|
389 |
+
|
390 |
+
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
|
391 |
+
true_chi_shifted = shifted_mask * true_chi
|
392 |
+
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
|
393 |
+
sq_chi_error_shifted = torch.sum(
|
394 |
+
(true_chi_shifted - pred_angles) ** 2, dim=-1
|
395 |
+
)
|
396 |
+
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
|
397 |
+
|
398 |
+
# The ol' switcheroo
|
399 |
+
sq_chi_error = sq_chi_error.permute(
|
400 |
+
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
|
401 |
+
)
|
402 |
+
|
403 |
+
sq_chi_loss = masked_mean(
|
404 |
+
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
|
405 |
+
)
|
406 |
+
|
407 |
+
loss = chi_weight * sq_chi_loss
|
408 |
+
|
409 |
+
angle_norm = torch.sqrt(
|
410 |
+
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
|
411 |
+
)
|
412 |
+
norm_error = torch.abs(angle_norm - 1.0)
|
413 |
+
norm_error = norm_error.permute(
|
414 |
+
*range(len(norm_error.shape))[1:-2], 0, -2, -1
|
415 |
+
)
|
416 |
+
angle_norm_loss = masked_mean(
|
417 |
+
protein_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
|
418 |
+
)
|
419 |
+
|
420 |
+
loss = loss + angle_norm_weight * angle_norm_loss
|
421 |
+
|
422 |
+
# Average over the batch dimension
|
423 |
+
loss = torch.mean(loss)
|
424 |
+
|
425 |
+
return loss
|
426 |
+
|
427 |
+
|
428 |
+
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
|
429 |
+
num_bins = logits.shape[-1]
|
430 |
+
bin_width = 1.0 / num_bins
|
431 |
+
bounds = torch.arange(
|
432 |
+
start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
|
433 |
+
)
|
434 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
435 |
+
pred_lddt_ca = torch.sum(
|
436 |
+
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
|
437 |
+
dim=-1,
|
438 |
+
)
|
439 |
+
return pred_lddt_ca * 100
|
440 |
+
|
441 |
+
|
442 |
+
def lddt(
|
443 |
+
all_atom_pred_pos: torch.Tensor,
|
444 |
+
all_atom_positions: torch.Tensor,
|
445 |
+
all_atom_mask: torch.Tensor,
|
446 |
+
cutoff: float = 15.0,
|
447 |
+
eps: float = 1e-10,
|
448 |
+
per_residue: bool = True,
|
449 |
+
) -> torch.Tensor:
|
450 |
+
n = all_atom_mask.shape[-2]
|
451 |
+
dmat_true = torch.sqrt(
|
452 |
+
eps
|
453 |
+
+ torch.sum(
|
454 |
+
(
|
455 |
+
all_atom_positions[..., None, :]
|
456 |
+
- all_atom_positions[..., None, :, :]
|
457 |
+
)
|
458 |
+
** 2,
|
459 |
+
dim=-1,
|
460 |
+
)
|
461 |
+
)
|
462 |
+
|
463 |
+
dmat_pred = torch.sqrt(
|
464 |
+
eps
|
465 |
+
+ torch.sum(
|
466 |
+
(
|
467 |
+
all_atom_pred_pos[..., None, :]
|
468 |
+
- all_atom_pred_pos[..., None, :, :]
|
469 |
+
)
|
470 |
+
** 2,
|
471 |
+
dim=-1,
|
472 |
+
)
|
473 |
+
)
|
474 |
+
dists_to_score = (
|
475 |
+
(dmat_true < cutoff)
|
476 |
+
* all_atom_mask
|
477 |
+
* permute_final_dims(all_atom_mask, (1, 0))
|
478 |
+
* (1.0 - torch.eye(n, device=all_atom_mask.device))
|
479 |
+
)
|
480 |
+
|
481 |
+
dist_l1 = torch.abs(dmat_true - dmat_pred)
|
482 |
+
|
483 |
+
score = (
|
484 |
+
(dist_l1 < 0.5).type(dist_l1.dtype)
|
485 |
+
+ (dist_l1 < 1.0).type(dist_l1.dtype)
|
486 |
+
+ (dist_l1 < 2.0).type(dist_l1.dtype)
|
487 |
+
+ (dist_l1 < 4.0).type(dist_l1.dtype)
|
488 |
+
)
|
489 |
+
score = score * 0.25
|
490 |
+
|
491 |
+
dims = (-1,) if per_residue else (-2, -1)
|
492 |
+
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
|
493 |
+
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
|
494 |
+
|
495 |
+
return score
|
496 |
+
|
497 |
+
|
498 |
+
def lddt_ca(
|
499 |
+
all_atom_pred_pos: torch.Tensor,
|
500 |
+
all_atom_positions: torch.Tensor,
|
501 |
+
all_atom_mask: torch.Tensor,
|
502 |
+
cutoff: float = 15.0,
|
503 |
+
eps: float = 1e-10,
|
504 |
+
per_residue: bool = True,
|
505 |
+
) -> torch.Tensor:
|
506 |
+
ca_pos = residue_constants.atom_order["CA"]
|
507 |
+
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
508 |
+
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
509 |
+
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
|
510 |
+
|
511 |
+
return lddt(
|
512 |
+
all_atom_pred_pos,
|
513 |
+
all_atom_positions,
|
514 |
+
all_atom_mask,
|
515 |
+
cutoff=cutoff,
|
516 |
+
eps=eps,
|
517 |
+
per_residue=per_residue,
|
518 |
+
)
|
519 |
+
|
520 |
+
|
521 |
+
def lddt_loss(
|
522 |
+
logits: torch.Tensor,
|
523 |
+
all_atom_pred_pos: torch.Tensor,
|
524 |
+
atom37_gt_positions: torch.Tensor,
|
525 |
+
atom37_atom_exists_in_gt: torch.Tensor,
|
526 |
+
resolution: torch.Tensor,
|
527 |
+
cutoff: float = 15.0,
|
528 |
+
no_bins: int = 50,
|
529 |
+
min_resolution: float = 0.1,
|
530 |
+
max_resolution: float = 3.0,
|
531 |
+
eps: float = 1e-10,
|
532 |
+
**kwargs,
|
533 |
+
) -> torch.Tensor:
|
534 |
+
# remove ligand
|
535 |
+
logits = logits[:, :atom37_atom_exists_in_gt.shape[1], :]
|
536 |
+
|
537 |
+
ca_pos = residue_constants.atom_order["CA"]
|
538 |
+
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
539 |
+
atom37_gt_positions = atom37_gt_positions[..., ca_pos, :]
|
540 |
+
atom37_atom_exists_in_gt = atom37_atom_exists_in_gt[..., ca_pos: (ca_pos + 1)] # keep dim
|
541 |
+
|
542 |
+
score = lddt(
|
543 |
+
all_atom_pred_pos,
|
544 |
+
atom37_gt_positions,
|
545 |
+
atom37_atom_exists_in_gt,
|
546 |
+
cutoff=cutoff,
|
547 |
+
eps=eps
|
548 |
+
)
|
549 |
+
|
550 |
+
# TODO: Remove after initial pipeline testing
|
551 |
+
score = torch.nan_to_num(score, nan=torch.nanmean(score))
|
552 |
+
score[score < 0] = 0
|
553 |
+
|
554 |
+
score = score.detach()
|
555 |
+
bin_index = torch.floor(score * no_bins).long()
|
556 |
+
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
|
557 |
+
lddt_ca_one_hot = torch.nn.functional.one_hot(
|
558 |
+
bin_index, num_classes=no_bins
|
559 |
+
)
|
560 |
+
|
561 |
+
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
|
562 |
+
atom37_atom_exists_in_gt = atom37_atom_exists_in_gt.squeeze(-1)
|
563 |
+
loss = torch.sum(errors * atom37_atom_exists_in_gt, dim=-1) / (
|
564 |
+
eps + torch.sum(atom37_atom_exists_in_gt, dim=-1)
|
565 |
+
)
|
566 |
+
|
567 |
+
loss = loss * (
|
568 |
+
(resolution >= min_resolution) & (resolution <= max_resolution)
|
569 |
+
)
|
570 |
+
|
571 |
+
# Average over the batch dimension
|
572 |
+
loss = torch.mean(loss)
|
573 |
+
|
574 |
+
return loss
|
575 |
+
|
576 |
+
|
577 |
+
def distogram_loss(
|
578 |
+
logits,
|
579 |
+
gt_pseudo_beta_with_lig,
|
580 |
+
gt_pseudo_beta_with_lig_mask,
|
581 |
+
min_bin=2.3125,
|
582 |
+
max_bin=21.6875,
|
583 |
+
no_bins=64,
|
584 |
+
eps=1e-6,
|
585 |
+
**kwargs,
|
586 |
+
):
|
587 |
+
boundaries = torch.linspace(
|
588 |
+
min_bin,
|
589 |
+
max_bin,
|
590 |
+
no_bins - 1,
|
591 |
+
device=logits.device,
|
592 |
+
)
|
593 |
+
boundaries = boundaries ** 2
|
594 |
+
|
595 |
+
dists = torch.sum(
|
596 |
+
(gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2,
|
597 |
+
dim=-1,
|
598 |
+
keepdims=True,
|
599 |
+
)
|
600 |
+
|
601 |
+
true_bins = torch.sum(dists > boundaries, dim=-1)
|
602 |
+
errors = softmax_cross_entropy(
|
603 |
+
logits,
|
604 |
+
torch.nn.functional.one_hot(true_bins, no_bins),
|
605 |
+
)
|
606 |
+
|
607 |
+
square_mask = gt_pseudo_beta_with_lig_mask[..., None] * gt_pseudo_beta_with_lig_mask[..., None, :]
|
608 |
+
|
609 |
+
# FP16-friendly sum. Equivalent to:
|
610 |
+
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
|
611 |
+
# (eps + torch.sum(square_mask, dim=(-1, -2))))
|
612 |
+
denom = eps + torch.sum(square_mask, dim=(-1, -2))
|
613 |
+
mean = errors * square_mask
|
614 |
+
mean = torch.sum(mean, dim=-1)
|
615 |
+
mean = mean / denom[..., None]
|
616 |
+
mean = torch.sum(mean, dim=-1)
|
617 |
+
|
618 |
+
# Average over the batch dimensions
|
619 |
+
mean = torch.mean(mean)
|
620 |
+
|
621 |
+
return mean
|
622 |
+
|
623 |
+
|
624 |
+
def inter_contact_loss(
|
625 |
+
logits: torch.Tensor,
|
626 |
+
gt_inter_contacts: torch.Tensor,
|
627 |
+
inter_pair_mask: torch.Tensor,
|
628 |
+
pos_class_weight: float = 200.0,
|
629 |
+
contact_distance: float = 5.0,
|
630 |
+
**kwargs,
|
631 |
+
):
|
632 |
+
logits = logits.squeeze(-1)
|
633 |
+
bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, gt_inter_contacts, reduction='none',
|
634 |
+
pos_weight=logits.new_tensor([pos_class_weight]))
|
635 |
+
masked_loss = bce_loss * inter_pair_mask
|
636 |
+
final_loss = masked_loss.sum() / inter_pair_mask.sum()
|
637 |
+
|
638 |
+
return final_loss
|
639 |
+
|
640 |
+
|
641 |
+
def affinity_loss(
|
642 |
+
logits,
|
643 |
+
affinity,
|
644 |
+
affinity_loss_factor,
|
645 |
+
min_bin=0,
|
646 |
+
max_bin=15,
|
647 |
+
no_bins=32,
|
648 |
+
**kwargs,
|
649 |
+
):
|
650 |
+
boundaries = torch.linspace(
|
651 |
+
min_bin,
|
652 |
+
max_bin,
|
653 |
+
no_bins - 1,
|
654 |
+
device=logits.device,
|
655 |
+
)
|
656 |
+
|
657 |
+
true_bins = torch.sum(affinity > boundaries, dim=-1)
|
658 |
+
errors = softmax_cross_entropy(
|
659 |
+
logits,
|
660 |
+
torch.nn.functional.one_hot(true_bins, no_bins),
|
661 |
+
)
|
662 |
+
|
663 |
+
# print("errors dim", errors.shape, affinity_loss_factor.shape, errors)
|
664 |
+
after_factor = errors * affinity_loss_factor.squeeze()
|
665 |
+
if affinity_loss_factor.sum() > 0.1:
|
666 |
+
mean_val = after_factor.sum() / affinity_loss_factor.sum()
|
667 |
+
else:
|
668 |
+
# If no affinity in batch - get a very small loss. the factor also makes the loss small
|
669 |
+
mean_val = after_factor.sum() * 1e-3
|
670 |
+
# print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
|
671 |
+
return mean_val
|
672 |
+
|
673 |
+
|
674 |
+
def positions_inter_distogram_loss(
|
675 |
+
out,
|
676 |
+
aatype: torch.Tensor,
|
677 |
+
inter_pair_mask: torch.Tensor,
|
678 |
+
gt_pseudo_beta_with_lig: torch.Tensor,
|
679 |
+
max_dist=20.,
|
680 |
+
length_scale=10.,
|
681 |
+
eps: float = 1e-10,
|
682 |
+
**kwargs,
|
683 |
+
):
|
684 |
+
|
685 |
+
predicted_atoms = pseudo_beta_fn(aatype, out["final_atom_positions"], None)
|
686 |
+
pred_dists = torch.sum(
|
687 |
+
(predicted_atoms[..., None, :] - predicted_atoms[..., None, :, :]) ** 2,
|
688 |
+
dim=-1,
|
689 |
+
keepdims=True,
|
690 |
+
)
|
691 |
+
|
692 |
+
gt_dists = torch.sum(
|
693 |
+
(gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2,
|
694 |
+
dim=-1,
|
695 |
+
keepdims=True,
|
696 |
+
)
|
697 |
+
|
698 |
+
pred_dists = pred_dists.clamp(max=max_dist ** 2)
|
699 |
+
gt_dists = gt_dists.clamp(max=max_dist ** 2)
|
700 |
+
|
701 |
+
dists_diff = torch.abs(pred_dists - gt_dists) / (length_scale ** 2)
|
702 |
+
dists_diff = dists_diff * inter_pair_mask.unsqueeze(-1)
|
703 |
+
|
704 |
+
dists_diff_sum_per_batch = torch.sum(torch.sqrt(eps + dists_diff), dim=(-1, -2, -3))
|
705 |
+
mask_size_per_batch = torch.sum(inter_pair_mask, dim=(-1, -2))
|
706 |
+
inter_loss = torch.mean(dists_diff_sum_per_batch / (eps + mask_size_per_batch))
|
707 |
+
|
708 |
+
return inter_loss
|
709 |
+
|
710 |
+
|
711 |
+
def positions_intra_ligand_distogram_loss(
|
712 |
+
out,
|
713 |
+
aatype: torch.Tensor,
|
714 |
+
ligand_mask: torch.Tensor,
|
715 |
+
gt_pseudo_beta_with_lig: torch.Tensor,
|
716 |
+
max_dist=20.,
|
717 |
+
length_scale=4., # similar to RosettaFoldAA
|
718 |
+
eps=1e-10,
|
719 |
+
**kwargs,
|
720 |
+
):
|
721 |
+
intra_ligand_pair_mask = ligand_mask[..., None] * ligand_mask[..., None, :]
|
722 |
+
predicted_atoms = pseudo_beta_fn(aatype, out["final_atom_positions"], None)
|
723 |
+
pred_dists = torch.sum(
|
724 |
+
(predicted_atoms[..., None, :] - predicted_atoms[..., None, :, :]) ** 2,
|
725 |
+
dim=-1,
|
726 |
+
keepdims=True,
|
727 |
+
)
|
728 |
+
|
729 |
+
gt_dists = torch.sum(
|
730 |
+
(gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2,
|
731 |
+
dim=-1,
|
732 |
+
keepdims=True,
|
733 |
+
)
|
734 |
+
|
735 |
+
pred_dists = torch.sqrt(eps + pred_dists.clamp(max=max_dist ** 2)) / length_scale
|
736 |
+
gt_dists = torch.sqrt(eps + gt_dists.clamp(max=max_dist ** 2)) / length_scale
|
737 |
+
|
738 |
+
# Apply L2 loss
|
739 |
+
dists_diff = (pred_dists - gt_dists) ** 2
|
740 |
+
|
741 |
+
dists_diff = dists_diff * intra_ligand_pair_mask.unsqueeze(-1)
|
742 |
+
|
743 |
+
dists_diff_sum_per_batch = torch.sum(dists_diff, dim=(-1, -2, -3))
|
744 |
+
mask_size_per_batch = torch.sum(intra_ligand_pair_mask, dim=(-1, -2))
|
745 |
+
intra_ligand_loss = torch.mean(dists_diff_sum_per_batch / (eps + mask_size_per_batch))
|
746 |
+
|
747 |
+
return intra_ligand_loss
|
748 |
+
|
749 |
+
|
750 |
+
def _calculate_bin_centers(boundaries: torch.Tensor):
|
751 |
+
step = boundaries[1] - boundaries[0]
|
752 |
+
bin_centers = boundaries + step / 2
|
753 |
+
bin_centers = torch.cat(
|
754 |
+
[bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
|
755 |
+
)
|
756 |
+
return bin_centers
|
757 |
+
|
758 |
+
|
759 |
+
def _calculate_expected_aligned_error(
|
760 |
+
alignment_confidence_breaks: torch.Tensor,
|
761 |
+
aligned_distance_error_probs: torch.Tensor,
|
762 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
763 |
+
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
|
764 |
+
return (
|
765 |
+
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
|
766 |
+
bin_centers[-1],
|
767 |
+
)
|
768 |
+
|
769 |
+
|
770 |
+
def compute_predicted_aligned_error(
|
771 |
+
logits: torch.Tensor,
|
772 |
+
max_bin: int = 31,
|
773 |
+
no_bins: int = 64,
|
774 |
+
**kwargs,
|
775 |
+
) -> Dict[str, torch.Tensor]:
|
776 |
+
"""Computes aligned confidence metrics from logits.
|
777 |
+
|
778 |
+
Args:
|
779 |
+
logits: [*, num_res, num_res, num_bins] the logits output from
|
780 |
+
PredictedAlignedErrorHead.
|
781 |
+
max_bin: Maximum bin value
|
782 |
+
no_bins: Number of bins
|
783 |
+
Returns:
|
784 |
+
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
|
785 |
+
aligned error probabilities over bins for each residue pair.
|
786 |
+
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
|
787 |
+
error for each pair of residues.
|
788 |
+
max_predicted_aligned_error: [*] the maximum predicted error possible.
|
789 |
+
"""
|
790 |
+
boundaries = torch.linspace(
|
791 |
+
0, max_bin, steps=(no_bins - 1), device=logits.device
|
792 |
+
)
|
793 |
+
|
794 |
+
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
|
795 |
+
(
|
796 |
+
predicted_aligned_error,
|
797 |
+
max_predicted_aligned_error,
|
798 |
+
) = _calculate_expected_aligned_error(
|
799 |
+
alignment_confidence_breaks=boundaries,
|
800 |
+
aligned_distance_error_probs=aligned_confidence_probs,
|
801 |
+
)
|
802 |
+
|
803 |
+
return {
|
804 |
+
"aligned_confidence_probs": aligned_confidence_probs,
|
805 |
+
"predicted_aligned_error": predicted_aligned_error,
|
806 |
+
"max_predicted_aligned_error": max_predicted_aligned_error,
|
807 |
+
}
|
808 |
+
|
809 |
+
|
810 |
+
def compute_tm(
|
811 |
+
logits: torch.Tensor,
|
812 |
+
residue_weights: Optional[torch.Tensor] = None,
|
813 |
+
asym_id: Optional[torch.Tensor] = None,
|
814 |
+
interface: bool = False,
|
815 |
+
max_bin: int = 31,
|
816 |
+
no_bins: int = 64,
|
817 |
+
eps: float = 1e-8,
|
818 |
+
**kwargs,
|
819 |
+
) -> torch.Tensor:
|
820 |
+
if residue_weights is None:
|
821 |
+
residue_weights = logits.new_ones(logits.shape[-2])
|
822 |
+
|
823 |
+
boundaries = torch.linspace(
|
824 |
+
0, max_bin, steps=(no_bins - 1), device=logits.device
|
825 |
+
)
|
826 |
+
|
827 |
+
bin_centers = _calculate_bin_centers(boundaries)
|
828 |
+
clipped_n = max(torch.sum(residue_weights), 19)
|
829 |
+
|
830 |
+
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
|
831 |
+
|
832 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
833 |
+
|
834 |
+
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
|
835 |
+
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
|
836 |
+
|
837 |
+
n = residue_weights.shape[-1]
|
838 |
+
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
|
839 |
+
if interface and (asym_id is not None):
|
840 |
+
if len(asym_id.shape) > 1:
|
841 |
+
assert len(asym_id.shape) <= 2
|
842 |
+
batch_size = asym_id.shape[0]
|
843 |
+
pair_mask = residue_weights.new_ones((batch_size, n, n), dtype=torch.int32)
|
844 |
+
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
|
845 |
+
|
846 |
+
predicted_tm_term *= pair_mask
|
847 |
+
|
848 |
+
pair_residue_weights = pair_mask * (
|
849 |
+
residue_weights[..., None, :] * residue_weights[..., :, None]
|
850 |
+
)
|
851 |
+
denom = eps + torch.sum(pair_residue_weights, dim=-1, keepdims=True)
|
852 |
+
normed_residue_mask = pair_residue_weights / denom
|
853 |
+
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
|
854 |
+
|
855 |
+
weighted = per_alignment * residue_weights
|
856 |
+
|
857 |
+
argmax = (weighted == torch.max(weighted)).nonzero()[0]
|
858 |
+
return per_alignment[tuple(argmax)]
|
859 |
+
|
860 |
+
|
861 |
+
def compute_renamed_ground_truth(
|
862 |
+
batch: Dict[str, torch.Tensor],
|
863 |
+
atom14_pred_positions: torch.Tensor,
|
864 |
+
eps=1e-10,
|
865 |
+
) -> Dict[str, torch.Tensor]:
|
866 |
+
"""
|
867 |
+
Find optimal renaming of ground truth based on the predicted positions.
|
868 |
+
|
869 |
+
Alg. 26 "renameSymmetricGroundTruthAtoms"
|
870 |
+
|
871 |
+
This renamed ground truth is then used for all losses,
|
872 |
+
such that each loss moves the atoms in the same direction.
|
873 |
+
|
874 |
+
Args:
|
875 |
+
batch: Dictionary containing:
|
876 |
+
* atom14_gt_positions: Ground truth positions.
|
877 |
+
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
|
878 |
+
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
|
879 |
+
renaming swaps.
|
880 |
+
* atom14_gt_exists: Mask for which atoms exist in ground truth.
|
881 |
+
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
|
882 |
+
after renaming.
|
883 |
+
* atom14_atom_exists: Mask for whether each atom is part of the given
|
884 |
+
amino acid type.
|
885 |
+
atom14_pred_positions: Array of atom positions in global frame with shape
|
886 |
+
Returns:
|
887 |
+
Dictionary containing:
|
888 |
+
alt_naming_is_better: Array with 1.0 where alternative swap is better.
|
889 |
+
renamed_atom14_gt_positions: Array of optimal ground truth positions
|
890 |
+
after renaming swaps are performed.
|
891 |
+
renamed_atom14_gt_exists: Mask after renaming swap is performed.
|
892 |
+
"""
|
893 |
+
|
894 |
+
pred_dists = torch.sqrt(
|
895 |
+
eps
|
896 |
+
+ torch.sum(
|
897 |
+
(
|
898 |
+
atom14_pred_positions[..., None, :, None, :]
|
899 |
+
- atom14_pred_positions[..., None, :, None, :, :]
|
900 |
+
)
|
901 |
+
** 2,
|
902 |
+
dim=-1,
|
903 |
+
)
|
904 |
+
)
|
905 |
+
|
906 |
+
atom14_gt_positions = batch["atom14_gt_positions"]
|
907 |
+
gt_dists = torch.sqrt(
|
908 |
+
eps
|
909 |
+
+ torch.sum(
|
910 |
+
(
|
911 |
+
atom14_gt_positions[..., None, :, None, :]
|
912 |
+
- atom14_gt_positions[..., None, :, None, :, :]
|
913 |
+
)
|
914 |
+
** 2,
|
915 |
+
dim=-1,
|
916 |
+
)
|
917 |
+
)
|
918 |
+
|
919 |
+
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
|
920 |
+
alt_gt_dists = torch.sqrt(
|
921 |
+
eps
|
922 |
+
+ torch.sum(
|
923 |
+
(
|
924 |
+
atom14_alt_gt_positions[..., None, :, None, :]
|
925 |
+
- atom14_alt_gt_positions[..., None, :, None, :, :]
|
926 |
+
)
|
927 |
+
** 2,
|
928 |
+
dim=-1,
|
929 |
+
)
|
930 |
+
)
|
931 |
+
|
932 |
+
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
|
933 |
+
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
|
934 |
+
|
935 |
+
atom14_gt_exists = batch["atom14_atom_exists_in_gt"]
|
936 |
+
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
|
937 |
+
mask = (
|
938 |
+
atom14_gt_exists[..., None, :, None]
|
939 |
+
* atom14_atom_is_ambiguous[..., None, :, None]
|
940 |
+
* atom14_gt_exists[..., None, :, None, :]
|
941 |
+
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
|
942 |
+
)
|
943 |
+
|
944 |
+
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
|
945 |
+
alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
|
946 |
+
|
947 |
+
fp_type = atom14_pred_positions.dtype
|
948 |
+
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
|
949 |
+
|
950 |
+
renamed_atom14_gt_positions = (
|
951 |
+
1.0 - alt_naming_is_better[..., None, None]
|
952 |
+
) * atom14_gt_positions + alt_naming_is_better[
|
953 |
+
..., None, None
|
954 |
+
] * atom14_alt_gt_positions
|
955 |
+
|
956 |
+
renamed_atom14_gt_mask = (
|
957 |
+
1.0 - alt_naming_is_better[..., None]
|
958 |
+
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
|
959 |
+
"atom14_alt_gt_exists"
|
960 |
+
]
|
961 |
+
|
962 |
+
return {
|
963 |
+
"alt_naming_is_better": alt_naming_is_better,
|
964 |
+
"renamed_atom14_gt_positions": renamed_atom14_gt_positions,
|
965 |
+
"renamed_atom14_gt_exists": renamed_atom14_gt_mask,
|
966 |
+
}
|
967 |
+
|
968 |
+
|
969 |
+
def binding_site_loss(
|
970 |
+
logits: torch.Tensor,
|
971 |
+
binding_site_mask: torch.Tensor,
|
972 |
+
protein_mask: torch.Tensor,
|
973 |
+
pos_class_weight: float,
|
974 |
+
**kwargs,
|
975 |
+
) -> torch.Tensor:
|
976 |
+
logits = logits.squeeze(-1)
|
977 |
+
bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, binding_site_mask, reduction='none',
|
978 |
+
pos_weight=logits.new_tensor([pos_class_weight]))
|
979 |
+
masked_loss = bce_loss * protein_mask
|
980 |
+
final_loss = masked_loss.sum() / protein_mask.sum()
|
981 |
+
|
982 |
+
return final_loss
|
983 |
+
|
984 |
+
|
985 |
+
def chain_center_of_mass_loss(
|
986 |
+
all_atom_pred_pos: torch.Tensor,
|
987 |
+
all_atom_positions: torch.Tensor,
|
988 |
+
all_atom_mask: torch.Tensor,
|
989 |
+
asym_id: torch.Tensor,
|
990 |
+
clamp_distance: float = -4.0,
|
991 |
+
weight: float = 0.05,
|
992 |
+
eps: float = 1e-10, **kwargs
|
993 |
+
) -> torch.Tensor:
|
994 |
+
"""
|
995 |
+
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
|
996 |
+
|
997 |
+
Args:
|
998 |
+
all_atom_pred_pos:
|
999 |
+
[*, N_pts, 37, 3] All-atom predicted atom positions
|
1000 |
+
all_atom_positions:
|
1001 |
+
[*, N_pts, 37, 3] Ground truth all-atom positions
|
1002 |
+
all_atom_mask:
|
1003 |
+
[*, N_pts, 37] All-atom positions mask
|
1004 |
+
asym_id:
|
1005 |
+
[*, N_pts] Chain asym IDs
|
1006 |
+
clamp_distance:
|
1007 |
+
Cutoff above which distance errors are disregarded
|
1008 |
+
weight:
|
1009 |
+
Weight for loss
|
1010 |
+
eps:
|
1011 |
+
Small value used to regularize denominators
|
1012 |
+
Returns:
|
1013 |
+
[*] loss tensor
|
1014 |
+
"""
|
1015 |
+
ca_pos = residue_constants.atom_order["CA"]
|
1016 |
+
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
1017 |
+
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
1018 |
+
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
|
1019 |
+
|
1020 |
+
one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
|
1021 |
+
one_hot = one_hot * all_atom_mask
|
1022 |
+
chain_pos_mask = one_hot.transpose(-2, -1)
|
1023 |
+
chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype)
|
1024 |
+
|
1025 |
+
def get_chain_center_of_mass(pos):
|
1026 |
+
center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
|
1027 |
+
centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
|
1028 |
+
return Vec3Array.from_array(centers)
|
1029 |
+
|
1030 |
+
pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3]
|
1031 |
+
true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3]
|
1032 |
+
|
1033 |
+
pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
|
1034 |
+
true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
|
1035 |
+
losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2
|
1036 |
+
loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]
|
1037 |
+
|
1038 |
+
loss = masked_mean(loss_mask, losses, dim=(-1, -2))
|
1039 |
+
return loss
|
1040 |
+
|
1041 |
+
|
1042 |
+
class AlphaFoldLoss(nn.Module):
|
1043 |
+
"""Aggregation of the various losses described in the supplement"""
|
1044 |
+
|
1045 |
+
def __init__(self, config):
|
1046 |
+
super(AlphaFoldLoss, self).__init__()
|
1047 |
+
self.config = config
|
1048 |
+
|
1049 |
+
def loss(self, out, batch, _return_breakdown=False):
|
1050 |
+
"""
|
1051 |
+
Rename previous forward() as loss()
|
1052 |
+
so that can be reused in the subclass
|
1053 |
+
"""
|
1054 |
+
if "renamed_atom14_gt_positions" not in out.keys():
|
1055 |
+
batch.update(
|
1056 |
+
compute_renamed_ground_truth(
|
1057 |
+
batch,
|
1058 |
+
out["sm"]["positions"][-1],
|
1059 |
+
)
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
loss_fns = {
|
1063 |
+
"distogram": lambda: distogram_loss(
|
1064 |
+
logits=out["distogram_logits"],
|
1065 |
+
**{**batch, **self.config.distogram},
|
1066 |
+
),
|
1067 |
+
"positions_inter_distogram": lambda: positions_inter_distogram_loss(
|
1068 |
+
out,
|
1069 |
+
**{**batch, **self.config.positions_inter_distogram},
|
1070 |
+
),
|
1071 |
+
"positions_intra_distogram": lambda: positions_intra_ligand_distogram_loss(
|
1072 |
+
out,
|
1073 |
+
**{**batch, **self.config.positions_intra_distogram},
|
1074 |
+
),
|
1075 |
+
|
1076 |
+
"affinity1d": lambda: affinity_loss(
|
1077 |
+
logits=out["affinity_1d_logits"],
|
1078 |
+
**{**batch, **self.config.affinity1d},
|
1079 |
+
),
|
1080 |
+
"affinity2d": lambda: affinity_loss(
|
1081 |
+
logits=out["affinity_2d_logits"],
|
1082 |
+
**{**batch, **self.config.affinity2d},
|
1083 |
+
),
|
1084 |
+
"affinity_cls": lambda: affinity_loss(
|
1085 |
+
logits=out["affinity_cls_logits"],
|
1086 |
+
**{**batch, **self.config.affinity_cls},
|
1087 |
+
),
|
1088 |
+
"binding_site": lambda: binding_site_loss(
|
1089 |
+
logits=out["binding_site_logits"],
|
1090 |
+
**{**batch, **self.config.binding_site},
|
1091 |
+
),
|
1092 |
+
"inter_contact": lambda: inter_contact_loss(
|
1093 |
+
logits=out["inter_contact_logits"],
|
1094 |
+
**{**batch, **self.config.inter_contact},
|
1095 |
+
),
|
1096 |
+
# backbone is based on frames so only works on protein
|
1097 |
+
"fape_backbone": lambda: fape_bb(
|
1098 |
+
out,
|
1099 |
+
batch,
|
1100 |
+
self.config.fape_backbone,
|
1101 |
+
),
|
1102 |
+
"fape_sidechain": lambda: fape_sidechain(
|
1103 |
+
out,
|
1104 |
+
batch,
|
1105 |
+
self.config.fape_sidechain,
|
1106 |
+
),
|
1107 |
+
"fape_interface": lambda: fape_interface(
|
1108 |
+
out,
|
1109 |
+
batch,
|
1110 |
+
self.config.fape_interface,
|
1111 |
+
),
|
1112 |
+
"plddt_loss": lambda: lddt_loss(
|
1113 |
+
logits=out["lddt_logits"],
|
1114 |
+
all_atom_pred_pos=out["final_atom_positions"],
|
1115 |
+
**{**batch, **self.config.plddt_loss},
|
1116 |
+
),
|
1117 |
+
"supervised_chi": lambda: supervised_chi_loss(
|
1118 |
+
out["sm"]["angles"],
|
1119 |
+
out["sm"]["unnormalized_angles"],
|
1120 |
+
**{**batch, **self.config.supervised_chi},
|
1121 |
+
),
|
1122 |
+
}
|
1123 |
+
|
1124 |
+
if self.config.chain_center_of_mass.enabled:
|
1125 |
+
loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
|
1126 |
+
all_atom_pred_pos=out["final_atom_positions"],
|
1127 |
+
**{**batch, **self.config.chain_center_of_mass},
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
cum_loss = 0.
|
1131 |
+
losses = {}
|
1132 |
+
loss_time_took = {}
|
1133 |
+
for loss_name, loss_fn in loss_fns.items():
|
1134 |
+
start_time = time.time()
|
1135 |
+
weight = self.config[loss_name].weight
|
1136 |
+
loss = loss_fn()
|
1137 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
1138 |
+
# for k,v in batch.items():
|
1139 |
+
# if torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)):
|
1140 |
+
# logging.warning(f"{k}: is nan")
|
1141 |
+
# logging.warning(f"{loss_name}: {loss}")
|
1142 |
+
logging.warning(f"{loss_name} loss is NaN. Skipping...")
|
1143 |
+
loss = loss.new_tensor(0., requires_grad=True)
|
1144 |
+
# else:
|
1145 |
+
cum_loss = cum_loss + weight * loss
|
1146 |
+
losses[loss_name] = loss.detach().clone()
|
1147 |
+
loss_time_took[loss_name] = time.time() - start_time
|
1148 |
+
losses["unscaled_loss"] = cum_loss.detach().clone()
|
1149 |
+
# print("loss took: ", round(time.time() % 10000, 3),
|
1150 |
+
# sorted(loss_time_took.items(), key=lambda x: x[1], reverse=True))
|
1151 |
+
|
1152 |
+
# Scale the loss by the square root of the minimum of the crop size and
|
1153 |
+
# the (average) sequence length. See subsection 1.9.
|
1154 |
+
seq_len = torch.mean(batch["seq_length"].float())
|
1155 |
+
crop_len = batch["aatype"].shape[-1]
|
1156 |
+
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
|
1157 |
+
|
1158 |
+
losses["loss"] = cum_loss.detach().clone()
|
1159 |
+
|
1160 |
+
if not _return_breakdown:
|
1161 |
+
return cum_loss
|
1162 |
+
|
1163 |
+
return cum_loss, losses
|
1164 |
+
|
1165 |
+
def forward(self, out, batch, _return_breakdown=False):
|
1166 |
+
if not _return_breakdown:
|
1167 |
+
cum_loss = self.loss(out, batch, _return_breakdown)
|
1168 |
+
return cum_loss
|
1169 |
+
else:
|
1170 |
+
cum_loss, losses = self.loss(out, batch, _return_breakdown)
|
1171 |
+
return cum_loss, losses
|
dockformer/utils/lr_schedulers.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
|
5 |
+
""" Implements the learning rate schedule defined in the AlphaFold 2
|
6 |
+
supplement. A linear warmup is followed by a plateau at the maximum
|
7 |
+
learning rate and then exponential decay.
|
8 |
+
|
9 |
+
Note that the initial learning rate of the optimizer in question is
|
10 |
+
ignored; use this class' base_lr parameter to specify the starting
|
11 |
+
point of the warmup.
|
12 |
+
"""
|
13 |
+
def __init__(self,
|
14 |
+
optimizer,
|
15 |
+
last_epoch: int = -1,
|
16 |
+
verbose: bool = False,
|
17 |
+
base_lr: float = 0.,
|
18 |
+
max_lr: float = 0.001,
|
19 |
+
warmup_no_steps: int = 1000,
|
20 |
+
start_decay_after_n_steps: int = 50000,
|
21 |
+
decay_every_n_steps: int = 50000,
|
22 |
+
decay_factor: float = 0.95,
|
23 |
+
):
|
24 |
+
step_counts = {
|
25 |
+
"warmup_no_steps": warmup_no_steps,
|
26 |
+
"start_decay_after_n_steps": start_decay_after_n_steps,
|
27 |
+
}
|
28 |
+
|
29 |
+
for k,v in step_counts.items():
|
30 |
+
if(v < 0):
|
31 |
+
raise ValueError(f"{k} must be nonnegative")
|
32 |
+
|
33 |
+
if(warmup_no_steps > start_decay_after_n_steps):
|
34 |
+
raise ValueError(
|
35 |
+
"warmup_no_steps must not exceed start_decay_after_n_steps"
|
36 |
+
)
|
37 |
+
|
38 |
+
self.optimizer = optimizer
|
39 |
+
self.last_epoch = last_epoch
|
40 |
+
self.verbose = verbose
|
41 |
+
self.base_lr = base_lr
|
42 |
+
self.max_lr = max_lr
|
43 |
+
self.warmup_no_steps = warmup_no_steps
|
44 |
+
self.start_decay_after_n_steps = start_decay_after_n_steps
|
45 |
+
self.decay_every_n_steps = decay_every_n_steps
|
46 |
+
self.decay_factor = decay_factor
|
47 |
+
|
48 |
+
super(AlphaFoldLRScheduler, self).__init__(
|
49 |
+
optimizer,
|
50 |
+
last_epoch=last_epoch,
|
51 |
+
verbose=verbose,
|
52 |
+
)
|
53 |
+
|
54 |
+
def state_dict(self):
|
55 |
+
state_dict = {
|
56 |
+
k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
|
57 |
+
}
|
58 |
+
|
59 |
+
return state_dict
|
60 |
+
|
61 |
+
def load_state_dict(self, state_dict):
|
62 |
+
self.__dict__.update(state_dict)
|
63 |
+
|
64 |
+
def get_lr(self):
|
65 |
+
if(not self._get_lr_called_within_step):
|
66 |
+
raise RuntimeError(
|
67 |
+
"To get the last learning rate computed by the scheduler, use "
|
68 |
+
"get_last_lr()"
|
69 |
+
)
|
70 |
+
|
71 |
+
step_no = self.last_epoch
|
72 |
+
|
73 |
+
if(step_no <= self.warmup_no_steps):
|
74 |
+
lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
|
75 |
+
elif(step_no > self.start_decay_after_n_steps):
|
76 |
+
steps_since_decay = step_no - self.start_decay_after_n_steps
|
77 |
+
exp = (steps_since_decay // self.decay_every_n_steps) + 1
|
78 |
+
lr = self.max_lr * (self.decay_factor ** exp)
|
79 |
+
else: # plateau
|
80 |
+
lr = self.max_lr
|
81 |
+
|
82 |
+
return [lr for group in self.optimizer.param_groups]
|
dockformer/utils/precision_utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 AlQuraishi Laboratory
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import importlib
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
def is_fp16_enabled():
|
19 |
+
# Autocast world
|
20 |
+
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
|
21 |
+
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
|
22 |
+
|
23 |
+
return fp16_enabled
|