bshor commited on
Commit
bca3a49
1 Parent(s): 4008e46
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +7 -5
  2. dockformer/__init__.py +6 -0
  3. dockformer/config.py +358 -0
  4. dockformer/data/data_modules.py +643 -0
  5. dockformer/data/data_pipeline.py +503 -0
  6. dockformer/data/data_transforms.py +731 -0
  7. dockformer/data/errors.py +22 -0
  8. dockformer/data/ligand_features.py +66 -0
  9. dockformer/data/parsers.py +53 -0
  10. dockformer/data/protein_features.py +71 -0
  11. dockformer/data/utils.py +54 -0
  12. dockformer/model/__init__.py +0 -0
  13. dockformer/model/dropout.py +69 -0
  14. dockformer/model/embedders.py +346 -0
  15. dockformer/model/evoformer.py +468 -0
  16. dockformer/model/heads.py +260 -0
  17. dockformer/model/model.py +318 -0
  18. dockformer/model/pair_transition.py +81 -0
  19. dockformer/model/primitives.py +598 -0
  20. dockformer/model/single_attention.py +184 -0
  21. dockformer/model/structure_module.py +837 -0
  22. dockformer/model/torchscript.py +171 -0
  23. dockformer/model/triangular_attention.py +104 -0
  24. dockformer/model/triangular_multiplicative_update.py +173 -0
  25. dockformer/resources/__init__.py +0 -0
  26. dockformer/resources/stereo_chemical_props.txt +345 -0
  27. dockformer/utils/__init__.py +0 -0
  28. dockformer/utils/callbacks.py +15 -0
  29. dockformer/utils/checkpointing.py +78 -0
  30. dockformer/utils/config_tools.py +32 -0
  31. dockformer/utils/consts.py +25 -0
  32. dockformer/utils/exponential_moving_average.py +71 -0
  33. dockformer/utils/feats.py +174 -0
  34. dockformer/utils/geometry/__init__.py +28 -0
  35. dockformer/utils/geometry/quat_rigid.py +38 -0
  36. dockformer/utils/geometry/rigid_matrix_vector.py +181 -0
  37. dockformer/utils/geometry/rotation_matrix.py +208 -0
  38. dockformer/utils/geometry/test_utils.py +97 -0
  39. dockformer/utils/geometry/utils.py +22 -0
  40. dockformer/utils/geometry/vector.py +261 -0
  41. dockformer/utils/kernel/__init__.py +0 -0
  42. dockformer/utils/kernel/attention_core.py +107 -0
  43. dockformer/utils/kernel/csrc/compat.h +11 -0
  44. dockformer/utils/kernel/csrc/softmax_cuda.cpp +44 -0
  45. dockformer/utils/kernel/csrc/softmax_cuda_kernel.cu +241 -0
  46. dockformer/utils/kernel/csrc/softmax_cuda_stub.cpp +36 -0
  47. dockformer/utils/logger.py +80 -0
  48. dockformer/utils/loss.py +1171 -0
  49. dockformer/utils/lr_schedulers.py +82 -0
  50. dockformer/utils/precision_utils.py +23 -0
Dockerfile CHANGED
@@ -6,17 +6,19 @@ WORKDIR /usr/src/app
6
  COPY --link --chown=1000 ./ /usr/src/app
7
  COPY . .
8
 
9
-
10
  # install dependcies
11
- RUN conda install -y pandas numpy scikit-learn
12
- RUN pip install --no-cache-dir -r requirements.txt
13
 
14
- #if you need to download executable and run them switch to the default non-root user
 
15
 
 
16
  USER user
17
 
18
  #do not modify below
19
  EXPOSE 7860
20
  ENV GRADIO_SERVER_NAME="0.0.0.0"
21
 
22
- CMD ["python", "inference_app.py"]
 
 
6
  COPY --link --chown=1000 ./ /usr/src/app
7
  COPY . .
8
 
 
9
  # install dependcies
10
+ # RUN conda install -y pandas numpy scikit-learn
11
+ # RUN pip install --no-cache-dir -r requirements.txt
12
 
13
+ # Create Conda environment from env.yaml
14
+ RUN conda env create -f env.yml
15
 
16
+ #if you need to download executable and run them switch to the default non-root user
17
  USER user
18
 
19
  #do not modify below
20
  EXPOSE 7860
21
  ENV GRADIO_SERVER_NAME="0.0.0.0"
22
 
23
+ # CMD ["python", "inference_app.py"]
24
+ CMD ["conda", "run", "-n", "dockformer-venv", "python", "inference_app.py"]
dockformer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import model
2
+ from . import utils
3
+ from . import data
4
+ from . import resources
5
+
6
+ __all__ = ["model", "utils", "data", "resources"]
dockformer/config.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import ml_collections as mlc
3
+
4
+ from dockformer.utils.config_tools import set_inf, enforce_config_constraints
5
+
6
+
7
+ def model_config(
8
+ name,
9
+ train=False,
10
+ low_prec=False,
11
+ long_sequence_inference=False
12
+ ):
13
+ c = copy.deepcopy(config)
14
+ # TRAINING PRESETS
15
+ if name == "initial_training":
16
+ # AF2 Suppl. Table 4, "initial training" setting
17
+
18
+ pass
19
+ elif name == "finetune_affinity":
20
+ c.loss.affinity2d.weight = 0.5
21
+ c.loss.affinity1d.weight = 0.5
22
+ c.loss.binding_site.weight = 0.5
23
+ c.loss.positions_inter_distogram.weight = 0.5 # this is not essential given fape?
24
+ else:
25
+ raise ValueError("Invalid model name")
26
+
27
+ c.globals.use_lma = False
28
+
29
+ if long_sequence_inference:
30
+ assert(not train)
31
+ c.globals.use_lma = True
32
+
33
+ if train:
34
+ c.globals.blocks_per_ckpt = 1
35
+ c.globals.use_lma = False
36
+
37
+ if low_prec:
38
+ c.globals.eps = 1e-4
39
+ # If we want exact numerical parity with the original, inf can't be
40
+ # a global constant
41
+ set_inf(c, 1e4)
42
+
43
+ enforce_config_constraints(c)
44
+
45
+ return c
46
+
47
+
48
+ c_z = mlc.FieldReference(128, field_type=int)
49
+ c_m = mlc.FieldReference(256, field_type=int)
50
+ c_t = mlc.FieldReference(64, field_type=int)
51
+ c_e = mlc.FieldReference(64, field_type=int)
52
+ c_s = mlc.FieldReference(384, field_type=int)
53
+
54
+
55
+ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
56
+ aux_distogram_bins = mlc.FieldReference(64, field_type=int)
57
+ aux_affinity_bins = mlc.FieldReference(32, field_type=int)
58
+ eps = mlc.FieldReference(1e-8, field_type=float)
59
+
60
+ NUM_RES = "num residues placeholder"
61
+ NUM_LIG_ATOMS = "num ligand atoms placeholder"
62
+ NUM_TOKEN = "num tokens placeholder"
63
+
64
+
65
+ config = mlc.ConfigDict(
66
+ {
67
+ "data": {
68
+ "common": {
69
+ "feat": {
70
+ "aatype": [NUM_TOKEN],
71
+ "all_atom_mask": [NUM_TOKEN, None],
72
+ "all_atom_positions": [NUM_TOKEN, None, None],
73
+ "atom14_alt_gt_exists": [NUM_TOKEN, None],
74
+ "atom14_alt_gt_positions": [NUM_TOKEN, None, None],
75
+ "atom14_atom_exists": [NUM_TOKEN, None],
76
+ "atom14_atom_is_ambiguous": [NUM_TOKEN, None],
77
+ "atom14_gt_exists": [NUM_TOKEN, None],
78
+ "atom14_gt_positions": [NUM_TOKEN, None, None],
79
+ "atom37_atom_exists": [NUM_TOKEN, None],
80
+ "backbone_rigid_mask": [NUM_TOKEN],
81
+ "backbone_rigid_tensor": [NUM_TOKEN, None, None],
82
+ "chi_angles_sin_cos": [NUM_TOKEN, None, None],
83
+ "chi_mask": [NUM_TOKEN, None],
84
+ "no_recycling_iters": [],
85
+ "pseudo_beta": [NUM_TOKEN, None],
86
+ "pseudo_beta_mask": [NUM_TOKEN],
87
+ "residue_index": [NUM_TOKEN],
88
+ "in_chain_residue_index": [NUM_TOKEN],
89
+ "chain_index": [NUM_TOKEN],
90
+ "residx_atom14_to_atom37": [NUM_TOKEN, None],
91
+ "residx_atom37_to_atom14": [NUM_TOKEN, None],
92
+ "resolution": [],
93
+ "rigidgroups_alt_gt_frames": [NUM_TOKEN, None, None, None],
94
+ "rigidgroups_group_exists": [NUM_TOKEN, None],
95
+ "rigidgroups_group_is_ambiguous": [NUM_TOKEN, None],
96
+ "rigidgroups_gt_exists": [NUM_TOKEN, None],
97
+ "rigidgroups_gt_frames": [NUM_TOKEN, None, None, None],
98
+ "seq_length": [],
99
+ "token_mask": [NUM_TOKEN],
100
+ "target_feat": [NUM_TOKEN, None],
101
+ "use_clamped_fape": [],
102
+ },
103
+ "max_recycling_iters": 1,
104
+ "unsupervised_features": [
105
+ "aatype",
106
+ "residue_index",
107
+ "in_chain_residue_index",
108
+ "chain_index",
109
+ "seq_length",
110
+ "no_recycling_iters",
111
+ "all_atom_mask",
112
+ "all_atom_positions",
113
+ ],
114
+ },
115
+ "supervised": {
116
+ "clamp_prob": 0.9,
117
+ "supervised_features": [
118
+ "resolution",
119
+ "use_clamped_fape",
120
+ ],
121
+ },
122
+ "predict": {
123
+ "fixed_size": True,
124
+ "crop": False,
125
+ "crop_size": None,
126
+ "supervised": False,
127
+ "uniform_recycling": False,
128
+ },
129
+ "eval": {
130
+ "fixed_size": True,
131
+ "crop": False,
132
+ "crop_size": None,
133
+ "supervised": True,
134
+ "uniform_recycling": False,
135
+ },
136
+ "train": {
137
+ "fixed_size": True,
138
+ "crop": True,
139
+ "crop_size": 355,
140
+ "supervised": True,
141
+ "clamp_prob": 0.9,
142
+ "uniform_recycling": True,
143
+ "protein_distogram_mask_prob": 0.1,
144
+ },
145
+ "data_module": {
146
+ "data_loaders": {
147
+ "batch_size": 1,
148
+ # "batch_size": 2,
149
+ "num_workers": 16,
150
+ "pin_memory": True,
151
+ "should_verify": False,
152
+ },
153
+ },
154
+ },
155
+ # Recurring FieldReferences that can be changed globally here
156
+ "globals": {
157
+ "blocks_per_ckpt": blocks_per_ckpt,
158
+ # Use Staats & Rabe's low-memory attention algorithm.
159
+ "use_lma": False,
160
+ "max_lr": 1e-3,
161
+ "c_z": c_z,
162
+ "c_m": c_m,
163
+ "c_t": c_t,
164
+ "c_e": c_e,
165
+ "c_s": c_s,
166
+ "eps": eps,
167
+ },
168
+ "model": {
169
+ "_mask_trans": False,
170
+ "structure_input_embedder": {
171
+ "protein_tf_dim": 20,
172
+ # len(POSSIBLE_ATOM_TYPES) + len(POSSIBLE_CHARGES) + len(POSSIBLE_CHIRALITIES)
173
+ "ligand_tf_dim": 34,
174
+ "additional_tf_dim": 3, # number of classes (prot, lig, aff)
175
+ "ligand_bond_dim": 6,
176
+ "c_z": c_z,
177
+ "c_m": c_m,
178
+ "relpos_k": 32,
179
+ "prot_min_bin": 3.25,
180
+ "prot_max_bin": 20.75,
181
+ "prot_no_bins": 15,
182
+ "lig_min_bin": 0.75,
183
+ "lig_max_bin": 9.75,
184
+ "lig_no_bins": 10,
185
+ "inf": 1e8,
186
+ },
187
+ "recycling_embedder": {
188
+ "c_z": c_z,
189
+ "c_m": c_m,
190
+ "min_bin": 3.25,
191
+ "max_bin": 20.75,
192
+ "no_bins": 15,
193
+ "inf": 1e8,
194
+ },
195
+ "evoformer_stack": {
196
+ "c_m": c_m,
197
+ "c_z": c_z,
198
+ "c_hidden_single_att": 32,
199
+ "c_hidden_mul": 128,
200
+ "c_hidden_pair_att": 32,
201
+ "c_s": c_s,
202
+ "no_heads_single": 8,
203
+ "no_heads_pair": 4,
204
+ # "no_blocks": 48,
205
+ "no_blocks": 2,
206
+ "transition_n": 4,
207
+ "single_dropout": 0.15,
208
+ "pair_dropout": 0.25,
209
+ "blocks_per_ckpt": blocks_per_ckpt,
210
+ "clear_cache_between_blocks": False,
211
+ "inf": 1e9,
212
+ "eps": eps, # 1e-10,
213
+ },
214
+ "structure_module": {
215
+ "c_s": c_s,
216
+ "c_z": c_z,
217
+ "c_ipa": 16,
218
+ "c_resnet": 128,
219
+ "no_heads_ipa": 12,
220
+ "no_qk_points": 4,
221
+ "no_v_points": 8,
222
+ "dropout_rate": 0.1,
223
+ "no_blocks": 8,
224
+ "no_transition_layers": 1,
225
+ "no_resnet_blocks": 2,
226
+ "no_angles": 7,
227
+ "trans_scale_factor": 10,
228
+ "epsilon": eps, # 1e-12,
229
+ "inf": 1e5,
230
+ },
231
+ "heads": {
232
+ "lddt": {
233
+ "no_bins": 50,
234
+ "c_in": c_s,
235
+ "c_hidden": 128,
236
+ },
237
+ "distogram": {
238
+ "c_z": c_z,
239
+ "no_bins": aux_distogram_bins,
240
+ },
241
+ "affinity_2d": {
242
+ "c_z": c_z,
243
+ "num_bins": aux_affinity_bins,
244
+ },
245
+ "affinity_1d": {
246
+ "c_s": c_s,
247
+ "num_bins": aux_affinity_bins,
248
+ },
249
+ "affinity_cls": {
250
+ "c_s": c_s,
251
+ "num_bins": aux_affinity_bins,
252
+ },
253
+ "binding_site": {
254
+ "c_s": c_s,
255
+ "c_out": 1,
256
+ },
257
+ "inter_contact": {
258
+ "c_s": c_s,
259
+ "c_z": c_z,
260
+ "c_out": 1,
261
+ },
262
+ },
263
+ # A negative value indicates that no early stopping will occur, i.e.
264
+ # the model will always run `max_recycling_iters` number of recycling
265
+ # iterations. A positive value will enable early stopping if the
266
+ # difference in pairwise distances is less than the tolerance between
267
+ # recycling steps.
268
+ "recycle_early_stop_tolerance": -1.
269
+ },
270
+ "relax": {
271
+ "max_iterations": 0, # no max
272
+ "tolerance": 2.39,
273
+ "stiffness": 10.0,
274
+ "max_outer_iterations": 20,
275
+ "exclude_residues": [],
276
+ },
277
+ "loss": {
278
+ "distogram": {
279
+ "min_bin": 2.3125,
280
+ "max_bin": 21.6875,
281
+ "no_bins": 64,
282
+ "eps": eps, # 1e-6,
283
+ "weight": 0.3,
284
+ },
285
+ "positions_inter_distogram": {
286
+ "max_dist": 20.0,
287
+ "weight": 0.0,
288
+ },
289
+ "positions_intra_distogram": {
290
+ "max_dist": 10.0,
291
+ "weight": 0.0,
292
+ },
293
+ "binding_site": {
294
+ "weight": 0.0,
295
+ "pos_class_weight": 20.0,
296
+ },
297
+ "inter_contact": {
298
+ "weight": 0.0,
299
+ "pos_class_weight": 200.0,
300
+ },
301
+ "affinity2d": {
302
+ "min_bin": 0,
303
+ "max_bin": 15,
304
+ "no_bins": aux_affinity_bins,
305
+ "weight": 0.0,
306
+ },
307
+ "affinity1d": {
308
+ "min_bin": 0,
309
+ "max_bin": 15,
310
+ "no_bins": aux_affinity_bins,
311
+ "weight": 0.0,
312
+ },
313
+ "affinity_cls": {
314
+ "min_bin": 0,
315
+ "max_bin": 15,
316
+ "no_bins": aux_affinity_bins,
317
+ "weight": 0.0,
318
+ },
319
+ "fape_backbone": {
320
+ "clamp_distance": 10.0,
321
+ "loss_unit_distance": 10.0,
322
+ "weight": 0.5,
323
+ },
324
+ "fape_sidechain": {
325
+ "clamp_distance": 10.0,
326
+ "length_scale": 10.0,
327
+ "weight": 0.5,
328
+ },
329
+ "fape_interface": {
330
+ "clamp_distance": 10.0,
331
+ "length_scale": 10.0,
332
+ "weight": 0.0,
333
+ },
334
+ "plddt_loss": {
335
+ "min_resolution": 0.1,
336
+ "max_resolution": 3.0,
337
+ "cutoff": 15.0,
338
+ "no_bins": 50,
339
+ "eps": eps, # 1e-10,
340
+ "weight": 0.01,
341
+ },
342
+ "supervised_chi": {
343
+ "chi_weight": 0.5,
344
+ "angle_norm_weight": 0.01,
345
+ "eps": eps, # 1e-6,
346
+ "weight": 1.0,
347
+ },
348
+ "chain_center_of_mass": {
349
+ "clamp_distance": -4.0,
350
+ "weight": 0.,
351
+ "eps": eps,
352
+ "enabled": False,
353
+ },
354
+ "eps": eps,
355
+ },
356
+ "ema": {"decay": 0.999},
357
+ }
358
+ )
dockformer/data/data_modules.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import time
4
+ import traceback
5
+ from collections import Counter
6
+ from functools import partial
7
+ import json
8
+ import os
9
+ import pickle
10
+ from typing import Optional, Sequence, Any
11
+
12
+ import ml_collections as mlc
13
+ import lightning as L
14
+ import torch
15
+ from torch.utils.data import RandomSampler
16
+
17
+ from dockformer.data.data_pipeline import parse_input_json
18
+ from dockformer.data import data_pipeline
19
+ from dockformer.utils.tensor_utils import dict_multimap
20
+ from dockformer.utils.tensor_utils import (
21
+ tensor_tree_map,
22
+ )
23
+
24
+
25
+ class OpenFoldSingleDataset(torch.utils.data.Dataset):
26
+ def __init__(self,
27
+ data_dir: str,
28
+ config: mlc.ConfigDict,
29
+ mode: str = "train",
30
+ ):
31
+ """
32
+ Args:
33
+ data_dir:
34
+ A path to a directory containing mmCIF files (in train
35
+ mode) or FASTA files (in inference mode).
36
+ config:
37
+ A dataset config object. See openfold.config
38
+ mode:
39
+ "train", "val", or "predict"
40
+ """
41
+ super(OpenFoldSingleDataset, self).__init__()
42
+ self.data_dir = data_dir
43
+
44
+ self.config = config
45
+ self.mode = mode
46
+
47
+ valid_modes = ["train", "eval", "predict"]
48
+ if mode not in valid_modes:
49
+ raise ValueError(f'mode must be one of {valid_modes}')
50
+
51
+ self._all_input_files = [i for i in os.listdir(data_dir) if i.endswith(".json")]
52
+ if self.config.data_module.data_loaders.should_verify:
53
+ self._all_input_files = [i for i in self._all_input_files if self._verify_json_input_file(i)]
54
+
55
+ self.data_pipeline = data_pipeline.DataPipeline(config, mode)
56
+
57
+ def _verify_json_input_file(self, file_name: str) -> bool:
58
+ with open(os.path.join(self.data_dir, file_name), "r") as f:
59
+ try:
60
+ loaded = json.load(f)
61
+ for i in ["input_structure"]:
62
+ if i not in loaded:
63
+ return False
64
+ if self.mode != "predict":
65
+ for i in ["gt_structure", "resolution"]:
66
+ if i not in loaded:
67
+ return False
68
+ except json.JSONDecodeError:
69
+ return False
70
+ return True
71
+
72
+ def get_metadata_for_idx(self, idx: int) -> dict:
73
+ input_path = os.path.join(self.data_dir, self._all_input_files[idx])
74
+ input_data = json.load(open(input_path, "r"))
75
+ metadata = {
76
+ "resolution": input_data.get("resolution", 99.0),
77
+ "input_path": input_path,
78
+ "input_name": os.path.basename(input_path).split(".json")[0],
79
+ }
80
+ return metadata
81
+
82
+ def __getitem__(self, idx):
83
+ return parse_input_json(
84
+ input_path=os.path.join(self.data_dir, self._all_input_files[idx]),
85
+ mode=self.mode,
86
+ config=self.config,
87
+ data_pipeline=self.data_pipeline,
88
+ data_dir=os.path.dirname(self.data_dir),
89
+ idx=idx,
90
+ )
91
+
92
+ def __len__(self):
93
+ return len(self._all_input_files)
94
+
95
+
96
+ def resolution_filter(resolution: int, max_resolution: float) -> bool:
97
+ """Check that the resolution is <= max_resolution permitted"""
98
+ return resolution is not None and resolution <= max_resolution
99
+
100
+
101
+ def all_seq_len_filter(seqs: list, minimum_number_of_residues: int) -> bool:
102
+ """Check if the total combined sequence lengths are >= minimum_numer_of_residues"""
103
+ total_len = sum([len(i) for i in seqs])
104
+ return total_len >= minimum_number_of_residues
105
+
106
+
107
+ class OpenFoldDataset(torch.utils.data.Dataset):
108
+ """
109
+ Implements the stochastic filters applied during AlphaFold's training.
110
+ Because samples are selected from constituent datasets randomly, the
111
+ length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
112
+ and filtered once at initialization.
113
+ """
114
+
115
+ def __init__(self,
116
+ datasets: Sequence[OpenFoldSingleDataset],
117
+ probabilities: Sequence[float],
118
+ epoch_len: int,
119
+ generator: torch.Generator = None,
120
+ _roll_at_init: bool = True,
121
+ ):
122
+ self.datasets = datasets
123
+ self.probabilities = probabilities
124
+ self.epoch_len = epoch_len
125
+ self.generator = generator
126
+
127
+ self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
128
+ if _roll_at_init:
129
+ self.reroll()
130
+
131
+ @staticmethod
132
+ def deterministic_train_filter(
133
+ cache_entry: Any,
134
+ max_resolution: float = 9.,
135
+ max_single_aa_prop: float = 0.8,
136
+ *args, **kwargs
137
+ ) -> bool:
138
+ # Hard filters
139
+ resolution = cache_entry["resolution"]
140
+
141
+ return all([
142
+ resolution_filter(resolution=resolution,
143
+ max_resolution=max_resolution)
144
+ ])
145
+
146
+ @staticmethod
147
+ def get_stochastic_train_filter_prob(
148
+ cache_entry: Any,
149
+ *args, **kwargs
150
+ ) -> float:
151
+ # Stochastic filters
152
+ probabilities = []
153
+
154
+ cluster_size = cache_entry.get("cluster_size", None)
155
+ if cluster_size is not None and cluster_size > 0:
156
+ probabilities.append(1 / cluster_size)
157
+
158
+ # Risk of underflow here?
159
+ out = 1
160
+ for p in probabilities:
161
+ out *= p
162
+
163
+ return out
164
+
165
+ def looped_shuffled_dataset_idx(self, dataset_len):
166
+ while True:
167
+ # Uniformly shuffle each dataset's indices
168
+ weights = [1. for _ in range(dataset_len)]
169
+ shuf = torch.multinomial(
170
+ torch.tensor(weights),
171
+ num_samples=dataset_len,
172
+ replacement=False,
173
+ generator=self.generator,
174
+ )
175
+ for idx in shuf:
176
+ yield idx
177
+
178
+ def looped_samples(self, dataset_idx):
179
+ max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
180
+ dataset = self.datasets[dataset_idx]
181
+ idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
182
+ while True:
183
+ weights = []
184
+ idx = []
185
+ for _ in range(max_cache_len):
186
+ candidate_idx = next(idx_iter)
187
+ # chain_id = dataset.idx_to_chain_id(candidate_idx)
188
+ # chain_data_cache_entry = chain_data_cache[chain_id]
189
+ # data_entry = dataset[candidate_idx.item()]
190
+ entry_metadata_for_filter = dataset.get_metadata_for_idx(candidate_idx.item())
191
+ if not self.deterministic_train_filter(entry_metadata_for_filter):
192
+ continue
193
+
194
+ p = self.get_stochastic_train_filter_prob(
195
+ entry_metadata_for_filter,
196
+ )
197
+ weights.append([1. - p, p])
198
+ idx.append(candidate_idx)
199
+
200
+ samples = torch.multinomial(
201
+ torch.tensor(weights),
202
+ num_samples=1,
203
+ generator=self.generator,
204
+ )
205
+ samples = samples.squeeze()
206
+
207
+ cache = [i for i, s in zip(idx, samples) if s]
208
+
209
+ for datapoint_idx in cache:
210
+ yield datapoint_idx
211
+
212
+ def __getitem__(self, idx):
213
+ dataset_idx, datapoint_idx = self.datapoints[idx]
214
+ return self.datasets[dataset_idx][datapoint_idx]
215
+
216
+ def __len__(self):
217
+ return self.epoch_len
218
+
219
+ def reroll(self):
220
+ # TODO bshor: I have removed support for filters (currently done in preprocess) and to weighting clusters
221
+ # now it is much faster, because it doesn't call looped_samples
222
+ dataset_choices = torch.multinomial(
223
+ torch.tensor(self.probabilities),
224
+ num_samples=self.epoch_len,
225
+ replacement=True,
226
+ generator=self.generator,
227
+ )
228
+ self.datapoints = []
229
+ counter_datasets = Counter(dataset_choices.tolist())
230
+ for dataset_idx, num_samples in counter_datasets.items():
231
+ dataset = self.datasets[dataset_idx]
232
+ sample_choices = torch.randint(0, len(dataset), (num_samples,), generator=self.generator)
233
+ for datapoint_idx in sample_choices:
234
+ self.datapoints.append((dataset_idx, datapoint_idx))
235
+
236
+
237
+ class OpenFoldBatchCollator:
238
+ def __call__(self, prots):
239
+ stack_fn = partial(torch.stack, dim=0)
240
+ return dict_multimap(stack_fn, prots)
241
+
242
+
243
+ class OpenFoldDataLoader(torch.utils.data.DataLoader):
244
+ def __init__(self, *args, config, stage="train", generator=None, **kwargs):
245
+ super().__init__(*args, **kwargs)
246
+ self.config = config
247
+ self.stage = stage
248
+ self.generator = generator
249
+ self._prep_batch_properties_probs()
250
+
251
+ def _prep_batch_properties_probs(self):
252
+ keyed_probs = []
253
+ stage_cfg = self.config[self.stage]
254
+
255
+ max_iters = self.config.common.max_recycling_iters
256
+
257
+ if stage_cfg.uniform_recycling:
258
+ recycling_probs = [
259
+ 1. / (max_iters + 1) for _ in range(max_iters + 1)
260
+ ]
261
+ else:
262
+ recycling_probs = [
263
+ 0. for _ in range(max_iters + 1)
264
+ ]
265
+ recycling_probs[-1] = 1.
266
+
267
+ keyed_probs.append(
268
+ ("no_recycling_iters", recycling_probs)
269
+ )
270
+
271
+ keys, probs = zip(*keyed_probs)
272
+ max_len = max([len(p) for p in probs])
273
+ padding = [[0.] * (max_len - len(p)) for p in probs]
274
+
275
+ self.prop_keys = keys
276
+ self.prop_probs_tensor = torch.tensor(
277
+ [p + pad for p, pad in zip(probs, padding)],
278
+ dtype=torch.float32,
279
+ )
280
+
281
+ def _add_batch_properties(self, batch):
282
+ # gt_features = batch.pop('gt_features', None)
283
+ samples = torch.multinomial(
284
+ self.prop_probs_tensor,
285
+ num_samples=1, # 1 per row
286
+ replacement=True,
287
+ generator=self.generator
288
+ )
289
+
290
+ aatype = batch["aatype"]
291
+ batch_dims = aatype.shape[:-2]
292
+ recycling_dim = aatype.shape[-1]
293
+ no_recycling = recycling_dim
294
+ for i, key in enumerate(self.prop_keys):
295
+ sample = int(samples[i][0])
296
+ sample_tensor = torch.tensor(
297
+ sample,
298
+ device=aatype.device,
299
+ requires_grad=False
300
+ )
301
+ orig_shape = sample_tensor.shape
302
+ sample_tensor = sample_tensor.view(
303
+ (1,) * len(batch_dims) + sample_tensor.shape + (1,)
304
+ )
305
+ sample_tensor = sample_tensor.expand(
306
+ batch_dims + orig_shape + (recycling_dim,)
307
+ )
308
+ batch[key] = sample_tensor
309
+
310
+ if key == "no_recycling_iters":
311
+ no_recycling = sample
312
+
313
+ resample_recycling = lambda t: t[..., :no_recycling + 1]
314
+ batch = tensor_tree_map(resample_recycling, batch)
315
+ # batch['gt_features'] = gt_features
316
+
317
+ return batch
318
+
319
+ def __iter__(self):
320
+ it = super().__iter__()
321
+
322
+ def _batch_prop_gen(iterator):
323
+ for batch in iterator:
324
+ yield self._add_batch_properties(batch)
325
+
326
+ return _batch_prop_gen(it)
327
+
328
+
329
+ class OpenFoldDataModule(L.LightningDataModule):
330
+ def __init__(self,
331
+ config: mlc.ConfigDict,
332
+ train_data_dir: Optional[str] = None,
333
+ val_data_dir: Optional[str] = None,
334
+ predict_data_dir: Optional[str] = None,
335
+ batch_seed: Optional[int] = None,
336
+ train_epoch_len: int = 50000,
337
+ **kwargs
338
+ ):
339
+ super(OpenFoldDataModule, self).__init__()
340
+
341
+ self.config = config
342
+ self.train_data_dir = train_data_dir
343
+ self.val_data_dir = val_data_dir
344
+ self.predict_data_dir = predict_data_dir
345
+ self.batch_seed = batch_seed
346
+ self.train_epoch_len = train_epoch_len
347
+
348
+ if self.train_data_dir is None and self.predict_data_dir is None:
349
+ raise ValueError(
350
+ 'At least one of train_data_dir or predict_data_dir must be '
351
+ 'specified'
352
+ )
353
+
354
+ self.training_mode = self.train_data_dir is not None
355
+
356
+ # if not self.training_mode and predict_alignment_dir is None:
357
+ # raise ValueError(
358
+ # 'In inference mode, predict_alignment_dir must be specified'
359
+ # )
360
+ # elif val_data_dir is not None and val_alignment_dir is None:
361
+ # raise ValueError(
362
+ # 'If val_data_dir is specified, val_alignment_dir must '
363
+ # 'be specified as well'
364
+ # )
365
+
366
+ def setup(self, stage):
367
+ # Most of the arguments are the same for the three datasets
368
+ dataset_gen = partial(OpenFoldSingleDataset,
369
+ config=self.config)
370
+
371
+ if self.training_mode:
372
+ train_dataset = dataset_gen(
373
+ data_dir=self.train_data_dir,
374
+ mode="train",
375
+ )
376
+
377
+ datasets = [train_dataset]
378
+ probabilities = [1.]
379
+
380
+ generator = None
381
+ if self.batch_seed is not None:
382
+ generator = torch.Generator()
383
+ generator = generator.manual_seed(self.batch_seed + 1)
384
+
385
+ self.train_dataset = OpenFoldDataset(
386
+ datasets=datasets,
387
+ probabilities=probabilities,
388
+ epoch_len=self.train_epoch_len,
389
+ generator=generator,
390
+ _roll_at_init=False,
391
+ )
392
+
393
+ if self.val_data_dir is not None:
394
+ self.eval_dataset = dataset_gen(
395
+ data_dir=self.val_data_dir,
396
+ mode="eval",
397
+ )
398
+ else:
399
+ self.eval_dataset = None
400
+ else:
401
+ self.predict_dataset = dataset_gen(
402
+ data_dir=self.predict_data_dir,
403
+ mode="predict",
404
+ )
405
+
406
+ def _gen_dataloader(self, stage):
407
+ generator = None
408
+ if self.batch_seed is not None:
409
+ generator = torch.Generator()
410
+ generator = generator.manual_seed(self.batch_seed)
411
+
412
+ if stage == "train":
413
+ dataset = self.train_dataset
414
+ # Filter the dataset, if necessary
415
+ dataset.reroll()
416
+ elif stage == "eval":
417
+ dataset = self.eval_dataset
418
+ elif stage == "predict":
419
+ dataset = self.predict_dataset
420
+ else:
421
+ raise ValueError("Invalid stage")
422
+
423
+ batch_collator = OpenFoldBatchCollator()
424
+
425
+ dl = OpenFoldDataLoader(
426
+ dataset,
427
+ config=self.config,
428
+ stage=stage,
429
+ generator=generator,
430
+ batch_size=self.config.data_module.data_loaders.batch_size,
431
+ # num_workers=self.config.data_module.data_loaders.num_workers,
432
+ num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator
433
+ collate_fn=batch_collator,
434
+ )
435
+
436
+ return dl
437
+
438
+ def train_dataloader(self):
439
+ return self._gen_dataloader("train")
440
+
441
+ def val_dataloader(self):
442
+ if self.eval_dataset is not None:
443
+ return self._gen_dataloader("eval")
444
+ return None
445
+
446
+ def predict_dataloader(self):
447
+ return self._gen_dataloader("predict")
448
+
449
+
450
+ class DummyDataset(torch.utils.data.Dataset):
451
+ def __init__(self, batch_path):
452
+ with open(batch_path, "rb") as f:
453
+ self.batch = pickle.load(f)
454
+
455
+ def __getitem__(self, idx):
456
+ return copy.deepcopy(self.batch)
457
+
458
+ def __len__(self):
459
+ return 1000
460
+
461
+
462
+ class DummyDataLoader(L.LightningDataModule):
463
+ def __init__(self, batch_path):
464
+ super().__init__()
465
+ self.dataset = DummyDataset(batch_path)
466
+
467
+ def train_dataloader(self):
468
+ return torch.utils.data.DataLoader(self.dataset)
469
+
470
+
471
+ class DockFormerSimpleDataset(torch.utils.data.Dataset):
472
+ def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train"):
473
+ clusters = json.load(open(clusters_json, "r"))
474
+ self.config = config
475
+ self.mode = mode
476
+ self._data_dir = os.path.dirname(clusters_json)
477
+ print("Data dir", self._data_dir)
478
+ self._clusters = clusters
479
+ self._all_input_files = sum(clusters.values(), [])
480
+ self.data_pipeline = data_pipeline.DataPipeline(config, mode)
481
+
482
+ def __getitem__(self, idx):
483
+ return parse_input_json(
484
+ input_path=os.path.join(self._data_dir, self._all_input_files[idx]),
485
+ mode=self.mode,
486
+ config=self.config,
487
+ data_pipeline=self.data_pipeline,
488
+ data_dir=self._data_dir,
489
+ idx=idx,
490
+ )
491
+
492
+ def __len__(self):
493
+ return len(self._all_input_files)
494
+
495
+
496
+ class DockFormerClusteredDataset(torch.utils.data.Dataset):
497
+ def __init__(self, clusters_json: str, config: mlc.ConfigDict, mode: str = "train", generator=None):
498
+ clusters = json.load(open(clusters_json, "r"))
499
+ self.config = config
500
+ self.mode = mode
501
+ self._data_dir = os.path.dirname(clusters_json)
502
+ self._clusters = list(clusters.values())
503
+ self.data_pipeline = data_pipeline.DataPipeline(config, mode)
504
+ self._generator = generator
505
+
506
+ def __getitem__(self, idx):
507
+ try:
508
+ cluster = self._clusters[idx]
509
+ # choose random from cluster
510
+ input_file = cluster[torch.randint(0, len(cluster), (1,), generator=self._generator).item()]
511
+
512
+ return parse_input_json(
513
+ input_path=os.path.join(self._data_dir, input_file),
514
+ mode=self.mode,
515
+ config=self.config,
516
+ data_pipeline=self.data_pipeline,
517
+ data_dir=self._data_dir,
518
+ idx=idx,
519
+ )
520
+ except Exception as e:
521
+ print("ERROR in loading", e)
522
+ traceback.print_exc()
523
+ return parse_input_json(
524
+ input_path=os.path.join(self._data_dir, self._clusters[0][0]),
525
+ mode=self.mode,
526
+ config=self.config,
527
+ data_pipeline=self.data_pipeline,
528
+ data_dir=self._data_dir,
529
+ idx=idx,
530
+ )
531
+
532
+
533
+ def __len__(self):
534
+ return len(self._clusters)
535
+
536
+
537
+ class DockFormerDataLoader(torch.utils.data.DataLoader):
538
+ def __init__(self, *args, config, stage="train", generator=None, **kwargs):
539
+ super().__init__(*args, **kwargs)
540
+ self.config = config
541
+ self.stage = stage
542
+ # self.generator = generator
543
+
544
+ def _add_batch_properties(self, batch):
545
+ if self.config[self.stage].uniform_recycling:
546
+ aatype = batch["aatype"]
547
+ max_recycling_dim = aatype.shape[-1]
548
+
549
+ # num_recycles = torch.randint(0, max_recycling_dim, (1,), generator=self.generator)
550
+ num_recycles = torch.randint(0, max_recycling_dim, (1,)).item()
551
+
552
+ resample_recycling = lambda t: t[..., :num_recycles + 1]
553
+ batch = tensor_tree_map(resample_recycling, batch)
554
+
555
+ return batch
556
+
557
+ def __iter__(self):
558
+ it = super().__iter__()
559
+
560
+ def _batch_prop_gen(iterator):
561
+ for batch in iterator:
562
+ yield self._add_batch_properties(batch)
563
+
564
+ return _batch_prop_gen(it)
565
+
566
+
567
+ class DockFormerDataModule(L.LightningDataModule):
568
+ def __init__(self,
569
+ config: mlc.ConfigDict,
570
+ train_data_file: Optional[str] = None,
571
+ val_data_file: Optional[str] = None,
572
+ batch_seed: Optional[int] = None,
573
+ **kwargs
574
+ ):
575
+ super(DockFormerDataModule, self).__init__()
576
+
577
+ self.config = config
578
+ self.train_data_file = train_data_file
579
+ self.val_data_file = val_data_file
580
+ self.batch_seed = batch_seed
581
+
582
+ assert self.train_data_file is not None, "train_data_file must be specified"
583
+ assert self.val_data_file is not None, "val_data_file must be specified"
584
+
585
+ self.train_dataset = None
586
+ self.val_dataset = None
587
+
588
+ def setup(self, stage):
589
+ generator = None
590
+ if self.batch_seed is not None:
591
+ generator = torch.Generator()
592
+ generator = generator.manual_seed(self.batch_seed + 1)
593
+
594
+ self.train_dataset = DockFormerClusteredDataset(
595
+ clusters_json=self.train_data_file,
596
+ config=self.config,
597
+ mode="train",
598
+ generator=generator,
599
+ )
600
+
601
+ self.val_dataset = DockFormerSimpleDataset(
602
+ clusters_json=self.val_data_file,
603
+ config=self.config,
604
+ mode="eval",
605
+ )
606
+
607
+ def _gen_dataloader(self, stage):
608
+ generator = None
609
+ if self.batch_seed is not None:
610
+ generator = torch.Generator()
611
+ generator = generator.manual_seed(self.batch_seed)
612
+
613
+ should_shuffle = stage == "train"
614
+ if stage == "train":
615
+ dataset = self.train_dataset
616
+ elif stage == "eval":
617
+ dataset = self.val_dataset
618
+ else:
619
+ raise ValueError("Invalid stage")
620
+
621
+ batch_collator = OpenFoldBatchCollator()
622
+
623
+ dl = DockFormerDataLoader(
624
+ dataset,
625
+ config=self.config,
626
+ stage=stage,
627
+ # generator=generator,
628
+ batch_size=self.config.data_module.data_loaders.batch_size,
629
+ # num_workers=self.config.data_module.data_loaders.num_workers,
630
+ num_workers=0, # TODO bshor: solve generator pickling issue and then bring back num_workers, or just remove generator
631
+ collate_fn=batch_collator,
632
+ shuffle=should_shuffle,
633
+ )
634
+
635
+ return dl
636
+
637
+ def train_dataloader(self):
638
+ return self._gen_dataloader("train")
639
+
640
+ def val_dataloader(self):
641
+ if self.val_dataset is not None:
642
+ return self._gen_dataloader("eval")
643
+ return None
dockformer/data/data_pipeline.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import json
16
+ import os
17
+ import time
18
+ from typing import List
19
+
20
+ import numpy as np
21
+ import torch
22
+ import ml_collections as mlc
23
+ from rdkit import Chem
24
+
25
+ from dockformer.data import data_transforms
26
+ from dockformer.data.data_transforms import get_restype_atom37_mask, get_restypes
27
+ from dockformer.data.ligand_features import make_ligand_features
28
+ from dockformer.data.protein_features import make_protein_features
29
+ from dockformer.data.utils import FeatureTensorDict, FeatureDict
30
+ from dockformer.utils import protein
31
+
32
+
33
+ def _np_filter_and_to_tensor_dict(np_example: FeatureDict, features_to_keep: List[str]) -> FeatureTensorDict:
34
+ """Creates dict of tensors from a dict of NumPy arrays.
35
+
36
+ Args:
37
+ np_example: A dict of NumPy feature arrays.
38
+ features: A list of strings of feature names to be returned in the dataset.
39
+
40
+ Returns:
41
+ A dictionary of features mapping feature names to features. Only the given
42
+ features are returned, all other ones are filtered out.
43
+ """
44
+ # torch generates warnings if feature is already a torch Tensor
45
+ to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
46
+ tensor_dict = {
47
+ k: to_tensor(v) for k, v in np_example.items() if k in features_to_keep
48
+ }
49
+
50
+ return tensor_dict
51
+
52
+
53
+ def _add_protein_probablistic_features(features: FeatureDict, cfg: mlc.ConfigDict, mode: str) -> FeatureDict:
54
+ if mode == "train":
55
+ p = torch.rand(1).item()
56
+ use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
57
+ features["use_clamped_fape"] = np.float32(use_clamped_fape_value)
58
+ else:
59
+ features["use_clamped_fape"] = np.float32(0.0)
60
+ return features
61
+
62
+
63
+ @data_transforms.curry1
64
+ def compose(x, fs):
65
+ for f in fs:
66
+ x = f(x)
67
+ return x
68
+
69
+
70
+ def _apply_protein_transforms(tensors: FeatureTensorDict) -> FeatureTensorDict:
71
+ transforms = [
72
+ data_transforms.cast_to_64bit_ints,
73
+ data_transforms.squeeze_features,
74
+ data_transforms.make_atom14_masks,
75
+ data_transforms.make_atom14_positions,
76
+ data_transforms.atom37_to_frames,
77
+ data_transforms.atom37_to_torsion_angles(""),
78
+ data_transforms.make_pseudo_beta(),
79
+ data_transforms.get_backbone_frames,
80
+ data_transforms.get_chi_angles,
81
+ ]
82
+
83
+ tensors = compose(transforms)(tensors)
84
+
85
+ return tensors
86
+
87
+
88
+ def _apply_protein_probablistic_transforms(tensors: FeatureTensorDict, cfg: mlc.ConfigDict, mode: str) \
89
+ -> FeatureTensorDict:
90
+ transforms = [data_transforms.make_target_feat()]
91
+
92
+ crop_feats = dict(cfg.common.feat)
93
+
94
+ if cfg[mode].fixed_size:
95
+ transforms.append(data_transforms.select_feat(list(crop_feats)))
96
+ # TODO bshor: restore transforms for training on cropped proteins, need to handle pocket somehow
97
+ # if so, look for random_crop_to_size and make_fixed_size in data_transforms.py
98
+
99
+ compose(transforms)(tensors)
100
+
101
+ return tensors
102
+
103
+
104
+ class DataPipeline:
105
+ """Assembles input features."""
106
+ def __init__(self, config: mlc.ConfigDict, mode: str):
107
+ self.config = config
108
+ self.mode = mode
109
+
110
+ self.feature_names = config.common.unsupervised_features
111
+ if config[mode].supervised:
112
+ self.feature_names += config.supervised.supervised_features
113
+
114
+ def process_pdb(self, pdb_path: str) -> FeatureTensorDict:
115
+ """
116
+ Assembles features for a protein in a PDB file.
117
+ """
118
+ with open(pdb_path, 'r') as f:
119
+ pdb_str = f.read()
120
+
121
+ protein_object = protein.from_pdb_string(pdb_str)
122
+ description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
123
+ pdb_feats = make_protein_features(protein_object, description)
124
+ pdb_feats = _add_protein_probablistic_features(pdb_feats, self.config, self.mode)
125
+
126
+ tensor_feats = _np_filter_and_to_tensor_dict(pdb_feats, self.feature_names)
127
+
128
+ tensor_feats = _apply_protein_transforms(tensor_feats)
129
+ tensor_feats = _apply_protein_probablistic_transforms(tensor_feats, self.config, self.mode)
130
+
131
+ return tensor_feats
132
+
133
+ def process_smiles(self, smiles: str) -> FeatureTensorDict:
134
+ ligand = Chem.MolFromSmiles(smiles)
135
+ return make_ligand_features(ligand)
136
+
137
+ def process_mol2(self, mol2_path: str) -> FeatureTensorDict:
138
+ """
139
+ Assembles features for a ligand in a mol2 file.
140
+ """
141
+ ligand = Chem.MolFromMol2File(mol2_path)
142
+ assert ligand is not None, f"Failed to parse ligand from {mol2_path}"
143
+
144
+ conf = ligand.GetConformer()
145
+ positions = torch.tensor(conf.GetPositions())
146
+
147
+ return {
148
+ **make_ligand_features(ligand),
149
+ "gt_ligand_positions": positions.float()
150
+ }
151
+
152
+ def process_sdf(self, sdf_path: str) -> FeatureTensorDict:
153
+ """
154
+ Assembles features for a ligand in a mol2 file.
155
+ """
156
+ ligand = Chem.MolFromMolFile(sdf_path)
157
+ assert ligand is not None, f"Failed to parse ligand from {sdf_path}"
158
+
159
+ conf = ligand.GetConformer(0)
160
+ positions = torch.tensor(conf.GetPositions())
161
+
162
+ return {
163
+ **make_ligand_features(ligand),
164
+ "ligand_positions": positions.float()
165
+ }
166
+
167
+ def process_sdf_list(self, sdf_path_list: List[str]) -> FeatureTensorDict:
168
+ all_sdf_feats = [self.process_sdf(sdf_path) for sdf_path in sdf_path_list]
169
+
170
+ all_sizes = [sdf_feats["ligand_target_feat"].shape[0] for sdf_feats in all_sdf_feats]
171
+
172
+ joined_ligand_feats = {}
173
+ for k in all_sdf_feats[0].keys():
174
+ if k == "ligand_positions":
175
+ joined_positions = all_sdf_feats[0][k]
176
+ prev_offset = joined_positions.max(dim=0).values + 100
177
+
178
+ for i, sdf_feats in enumerate(all_sdf_feats[1:]):
179
+ offset = prev_offset - sdf_feats[k].min(dim=0).values
180
+ joined_positions = torch.cat([joined_positions, sdf_feats[k] + offset], dim=0)
181
+ prev_offset = joined_positions.max(dim=0).values + 100
182
+ joined_ligand_feats[k] = joined_positions
183
+ elif k in ["ligand_target_feat", "ligand_atype", "ligand_charge", "ligand_chirality", "ligand_bonds"]:
184
+ joined_ligand_feats[k] = torch.cat([sdf_feats[k] for sdf_feats in all_sdf_feats], dim=0)
185
+ if k == "ligand_target_feat":
186
+ joined_ligand_feats["ligand_idx"] = torch.cat([torch.full((sdf_feats[k].shape[0],), i)
187
+ for i, sdf_feats in enumerate(all_sdf_feats)], dim=0)
188
+ elif k == "ligand_bonds":
189
+ joined_ligand_feats["ligand_bonds_idx"] = torch.cat([torch.full((sdf_feats[k].shape[0],), i)
190
+ for i, sdf_feats in enumerate(all_sdf_feats)],
191
+ dim=0)
192
+ elif k == "ligand_bonds_feat":
193
+ joined_feature = torch.zeros((sum(all_sizes), sum(all_sizes), all_sdf_feats[0][k].shape[2]))
194
+ for i, sdf_feats in enumerate(all_sdf_feats):
195
+ start_idx = sum(all_sizes[:i])
196
+ end_idx = sum(all_sizes[:i + 1])
197
+ joined_feature[start_idx:end_idx, start_idx:end_idx, :] = sdf_feats[k]
198
+ joined_ligand_feats[k] = joined_feature
199
+ else:
200
+ raise ValueError(f"Unknown key in sdf list features {k}")
201
+ return joined_ligand_feats
202
+
203
+ def get_matching_positions_list(self, ref_path_list: List[str], gt_path_list: List[str]):
204
+ joined_gt_positions = []
205
+
206
+ for ref_ligand_path, gt_ligand_path in zip(ref_path_list, gt_path_list):
207
+ ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
208
+ gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
209
+
210
+ gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
211
+
212
+ gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
213
+
214
+ joined_gt_positions.extend(gt_positions)
215
+
216
+ return torch.tensor(np.array(joined_gt_positions)).float()
217
+
218
+ def get_matching_positions(self, ref_ligand_path: str, gt_ligand_path: str):
219
+ ref_ligand = Chem.MolFromMolFile(ref_ligand_path)
220
+ gt_ligand = Chem.MolFromMolFile(gt_ligand_path)
221
+
222
+ gt_original_positions = gt_ligand.GetConformer(0).GetPositions()
223
+
224
+ gt_positions = [gt_original_positions[idx] for idx in gt_ligand.GetSubstructMatch(ref_ligand)]
225
+
226
+ # ref_positions = ref_ligand.GetConformer(0).GetPositions()
227
+ # for i in range(len(ref_positions)):
228
+ # for j in range(i + 1, len(ref_positions)):
229
+ # dist_ref = np.linalg.norm(ref_positions[i] - ref_positions[j])
230
+ # dist_gt = np.linalg.norm(gt_positions[i] - gt_positions[j])
231
+ # dist_gt = np.linalg.norm(gt_original_positions[i] - gt_original_positions[j])
232
+ # if abs(dist_ref - dist_gt) > 1.0:
233
+ # print(f"Distance mismatch {i} {j} {dist_ref} {dist_gt}")
234
+
235
+ return torch.tensor(np.array(gt_positions)) .float()
236
+
237
+
238
+ def _prepare_recycles(feat: torch.Tensor, num_recycles: int) -> torch.Tensor:
239
+ return feat.unsqueeze(-1).repeat(*([1] * len(feat.shape)), num_recycles)
240
+
241
+
242
+ def _fit_to_crop(target_tensor: torch.Tensor, crop_size: int, start_ind: int) -> torch.Tensor:
243
+ if len(target_tensor.shape) == 1:
244
+ ret = torch.zeros((crop_size, ), dtype=target_tensor.dtype)
245
+ ret[start_ind:start_ind + target_tensor.shape[0]] = target_tensor
246
+ return ret
247
+ elif len(target_tensor.shape) == 2:
248
+ ret = torch.zeros((crop_size, target_tensor.shape[-1]), dtype=target_tensor.dtype)
249
+ ret[start_ind:start_ind + target_tensor.shape[0], :] = target_tensor
250
+ return ret
251
+ else:
252
+ ret = torch.zeros((crop_size, *target_tensor.shape[1:]), dtype=target_tensor.dtype)
253
+ ret[start_ind:start_ind + target_tensor.shape[0], ...] = target_tensor
254
+ return ret
255
+
256
+
257
+ def parse_input_json(input_path: str, mode: str, config: mlc.ConfigDict, data_pipeline: DataPipeline,
258
+ data_dir: str, idx: int) -> FeatureTensorDict:
259
+ start_load_time = time.time()
260
+ input_data = json.load(open(input_path, "r"))
261
+ if mode == "train" or mode == "eval":
262
+ print("loading", input_data["pdb_id"], end=" ")
263
+
264
+ num_recycles = config.common.max_recycling_iters + 1
265
+
266
+ input_pdb_path = os.path.join(data_dir, input_data["input_structure"])
267
+ input_protein_feats = data_pipeline.process_pdb(pdb_path=input_pdb_path)
268
+
269
+ # load ref sdf
270
+ if "ref_sdf" in input_data:
271
+ ref_sdf_path = os.path.join(data_dir, input_data["ref_sdf"])
272
+ ref_ligand_feats = data_pipeline.process_sdf(sdf_path=ref_sdf_path)
273
+ ref_ligand_feats["ligand_idx"] = torch.zeros((ref_ligand_feats["ligand_target_feat"].shape[0],))
274
+ ref_ligand_feats["ligand_bonds_idx"] = torch.zeros((ref_ligand_feats["ligand_bonds"].shape[0],))
275
+ elif "ref_sdf_list" in input_data:
276
+ sdf_path_list = [os.path.join(data_dir, i) for i in input_data["ref_sdf_list"]]
277
+ ref_ligand_feats = data_pipeline.process_sdf_list(sdf_path_list=sdf_path_list)
278
+ else:
279
+ raise ValueError("ref_sdf or ref_sdf_list must be in input_data")
280
+
281
+ n_res = input_protein_feats["protein_target_feat"].shape[0]
282
+ n_lig = ref_ligand_feats["ligand_target_feat"].shape[0]
283
+ n_affinity = 1
284
+
285
+ # add 1 for affinity token
286
+ crop_size = n_res + n_lig + n_affinity
287
+ if (mode == "train" or mode == "eval") and config.train.fixed_size:
288
+ crop_size = config.train.crop_size
289
+
290
+ assert crop_size >= n_res + n_lig + n_affinity, f"crop_size: {crop_size}, n_res: {n_res}, n_lig: {n_lig}"
291
+
292
+ token_mask = torch.zeros((crop_size,), dtype=torch.float32)
293
+ token_mask[:n_res + n_lig + n_affinity] = 1
294
+
295
+ protein_mask = torch.zeros((crop_size,), dtype=torch.float32)
296
+ protein_mask[:n_res] = 1
297
+
298
+ ligand_mask = torch.zeros((crop_size,), dtype=torch.float32)
299
+ ligand_mask[n_res:n_res + n_lig] = 1
300
+
301
+ affinity_mask = torch.zeros((crop_size,), dtype=torch.float32)
302
+ affinity_mask[n_res + n_lig] = 1
303
+
304
+ structural_mask = torch.zeros((crop_size,), dtype=torch.float32)
305
+ structural_mask[:n_res + n_lig] = 1
306
+
307
+ inter_pair_mask = torch.zeros((crop_size, crop_size), dtype=torch.float32)
308
+ inter_pair_mask[:n_res, n_res:n_res + n_lig] = 1
309
+ inter_pair_mask[n_res:n_res + n_lig, :n_res] = 1
310
+
311
+ protein_tf_dim = input_protein_feats["protein_target_feat"].shape[-1]
312
+ ligand_tf_dim = ref_ligand_feats["ligand_target_feat"].shape[-1]
313
+ joined_tf_dim = protein_tf_dim + ligand_tf_dim
314
+
315
+ target_feat = torch.zeros((crop_size, joined_tf_dim + 3), dtype=torch.float32)
316
+ target_feat[:n_res, :protein_tf_dim] = input_protein_feats["protein_target_feat"]
317
+ target_feat[n_res:n_res + n_lig, protein_tf_dim:joined_tf_dim] = ref_ligand_feats["ligand_target_feat"]
318
+
319
+ target_feat[:n_res, joined_tf_dim] = 1 # Set "is_protein" flag for protein rows
320
+ target_feat[n_res:n_res + n_lig, joined_tf_dim + 1] = 1 # Set "is_ligand" flag for ligand rows
321
+ target_feat[n_res + n_lig, joined_tf_dim + 2] = 1 # Set "is_affinity" flag for affinity row
322
+
323
+ ligand_bonds_feat = torch.zeros((crop_size, crop_size, ref_ligand_feats["ligand_bonds_feat"].shape[-1]),
324
+ dtype=torch.float32)
325
+ ligand_bonds_feat[n_res:n_res + n_lig, n_res:n_res + n_lig] = ref_ligand_feats["ligand_bonds_feat"]
326
+
327
+ input_positions = torch.zeros((crop_size, 3), dtype=torch.float32)
328
+ input_positions[:n_res] = input_protein_feats["pseudo_beta"]
329
+ input_positions[n_res:n_res + n_lig] = ref_ligand_feats["ligand_positions"]
330
+
331
+ protein_distogram_mask = torch.zeros(crop_size)
332
+ if mode == "train":
333
+ ones_indices = torch.randperm(n_res)[:int(n_res * config.train.protein_distogram_mask_prob)]
334
+ # print(ones_indices)
335
+ protein_distogram_mask[ones_indices] = 1
336
+ input_positions = input_positions * (1 - protein_distogram_mask).unsqueeze(-1)
337
+ elif mode == "predict":
338
+ # ignore all positions where pseudo_beta is 0, 0, 0
339
+ protein_distogram_mask = (input_positions == 0).all(dim=-1).float()
340
+ # print("Ignoring residues", torch.nonzero(distogram_mask).flatten())
341
+
342
+ # Implement ligand as amino acid type 20
343
+ ligand_aatype = 20 * torch.ones((n_lig,), dtype=input_protein_feats["aatype"].dtype)
344
+ aatype = torch.cat([input_protein_feats["aatype"], ligand_aatype], dim=0)
345
+
346
+ restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(target_feat.device)
347
+ lig_residx_atom37_to_atom14 = restype_atom37_to_atom14[20].repeat(n_lig, 1)
348
+ residx_atom37_to_atom14 = torch.cat([input_protein_feats["residx_atom37_to_atom14"], lig_residx_atom37_to_atom14],
349
+ dim=0)
350
+
351
+ restype_atom37_mask = get_restype_atom37_mask(target_feat.device)
352
+ lig_atom37_atom_exists = restype_atom37_mask[20].repeat(n_lig, 1)
353
+ atom37_atom_exists = torch.cat([input_protein_feats["atom37_atom_exists"], lig_atom37_atom_exists], dim=0)
354
+
355
+ feats = {
356
+ "token_mask": token_mask,
357
+ "protein_mask": protein_mask,
358
+ "ligand_mask": ligand_mask,
359
+ "affinity_mask": affinity_mask,
360
+ "structural_mask": structural_mask,
361
+ "inter_pair_mask": inter_pair_mask,
362
+
363
+ "target_feat": target_feat,
364
+ "ligand_bonds_feat": ligand_bonds_feat,
365
+ "input_positions": input_positions,
366
+ "protein_distogram_mask": protein_distogram_mask,
367
+ "protein_residue_index": _fit_to_crop(input_protein_feats["residue_index"], crop_size, 0),
368
+ "aatype": _fit_to_crop(aatype, crop_size, 0),
369
+ "residx_atom37_to_atom14": _fit_to_crop(residx_atom37_to_atom14, crop_size, 0),
370
+ "atom37_atom_exists": _fit_to_crop(atom37_atom_exists, crop_size, 0),
371
+ }
372
+
373
+ if mode == "predict":
374
+ feats.update({
375
+ "in_chain_residue_index": input_protein_feats["in_chain_residue_index"],
376
+ "chain_index": input_protein_feats["chain_index"],
377
+ "ligand_atype": ref_ligand_feats["ligand_atype"],
378
+ "ligand_chirality": ref_ligand_feats["ligand_chirality"],
379
+ "ligand_charge": ref_ligand_feats["ligand_charge"],
380
+ "ligand_bonds": ref_ligand_feats["ligand_bonds"],
381
+ "ligand_idx": ref_ligand_feats["ligand_idx"],
382
+ "ligand_bonds_idx": ref_ligand_feats["ligand_bonds_idx"],
383
+ })
384
+
385
+ if mode == 'train' or mode == 'eval':
386
+ gt_pdb_path = os.path.join(data_dir, input_data["gt_structure"])
387
+ gt_protein_feats = data_pipeline.process_pdb(pdb_path=gt_pdb_path)
388
+
389
+ if "gt_sdf" in input_data:
390
+ gt_ligand_positions = data_pipeline.get_matching_positions(
391
+ os.path.join(data_dir, input_data["ref_sdf"]),
392
+ os.path.join(data_dir, input_data["gt_sdf"]),
393
+ )
394
+ elif "gt_sdf_list" in input_data:
395
+ gt_ligand_positions = data_pipeline.get_matching_positions_list(
396
+ [os.path.join(data_dir, i) for i in input_data["ref_sdf_list"]],
397
+ [os.path.join(data_dir, i) for i in input_data["gt_sdf_list"]],
398
+ )
399
+ else:
400
+ raise ValueError("gt_sdf or gt_sdf_list must be in input_data")
401
+
402
+ affinity_loss_factor = torch.tensor([1.0], dtype=torch.float32)
403
+ if input_data["affinity"] is None:
404
+ eps = 1e-6
405
+ affinity_loss_factor = torch.tensor([eps], dtype=torch.float32)
406
+ affinity = torch.tensor([0.0], dtype=torch.float32)
407
+ else:
408
+ affinity = torch.tensor([input_data["affinity"]], dtype=torch.float32)
409
+
410
+ resolution = torch.tensor(input_data["resolution"], dtype=torch.float32)
411
+
412
+ # prepare inter_contacts
413
+ expanded_prot_pos = gt_protein_feats["pseudo_beta"].unsqueeze(1) # Shape: (N_prot, 1, 3)
414
+ expanded_lig_pos = gt_ligand_positions.unsqueeze(0) # Shape: (1, N_lig, 3)
415
+ distances = torch.sqrt(torch.sum((expanded_prot_pos - expanded_lig_pos) ** 2, dim=-1))
416
+ inter_contact = (distances < 5.0).float()
417
+ binding_site_mask = inter_contact.any(dim=1).float()
418
+
419
+ inter_contact_reshaped_to_crop = torch.zeros((crop_size, crop_size), dtype=torch.float32)
420
+ inter_contact_reshaped_to_crop[:n_res, n_res:n_res + n_lig] = inter_contact
421
+ inter_contact_reshaped_to_crop[n_res:n_res + n_lig, :n_res] = inter_contact.T
422
+
423
+ # Use CA positions only
424
+ lig_single_res_atom37_mask = torch.zeros((37,), dtype=torch.float32)
425
+ lig_single_res_atom37_mask[1] = 1
426
+ lig_atom37_mask = lig_single_res_atom37_mask.unsqueeze(0).expand(n_lig, -1)
427
+ lig_single_res_atom14_mask = torch.zeros((14,), dtype=torch.float32)
428
+ lig_single_res_atom14_mask[1] = 1
429
+ lig_atom14_mask = lig_single_res_atom14_mask.unsqueeze(0).expand(n_lig, -1)
430
+
431
+ lig_atom37_positions = gt_ligand_positions.unsqueeze(1).expand(-1, 37, -1)
432
+ lig_atom37_positions = lig_atom37_positions * lig_single_res_atom37_mask.view(1, 37, 1).expand(n_lig, -1, 3)
433
+
434
+ lig_atom14_positions = gt_ligand_positions.unsqueeze(1).expand(-1, 14, -1)
435
+ lig_atom14_positions = lig_atom14_positions * lig_single_res_atom14_mask.view(1, 14, 1).expand(n_lig, -1, 3)
436
+
437
+ atom37_gt_positions = torch.cat([gt_protein_feats["all_atom_positions"], lig_atom37_positions], dim=0)
438
+ atom37_atom_exists_in_res = torch.cat([gt_protein_feats["atom37_atom_exists"], lig_atom37_mask], dim=0)
439
+ atom37_atom_exists_in_gt = torch.cat([gt_protein_feats["all_atom_mask"], lig_atom37_mask], dim=0)
440
+
441
+ atom14_gt_positions = torch.cat([gt_protein_feats["atom14_gt_positions"], lig_atom14_positions], dim=0)
442
+ atom14_atom_exists_in_res = torch.cat([gt_protein_feats["atom14_atom_exists"], lig_atom14_mask], dim=0)
443
+ atom14_atom_exists_in_gt = torch.cat([gt_protein_feats["atom14_gt_exists"], lig_atom14_mask], dim=0)
444
+
445
+ gt_pseudo_beta_with_lig = torch.cat([gt_protein_feats["pseudo_beta"], gt_ligand_positions], dim=0)
446
+ gt_pseudo_beta_with_lig_mask = torch.cat(
447
+ [gt_protein_feats["pseudo_beta_mask"],
448
+ torch.ones((n_lig,), dtype=gt_protein_feats["pseudo_beta_mask"].dtype)],
449
+ dim=0)
450
+
451
+ # IGNORES: residx_atom14_to_atom37, rigidgroups_group_exists,
452
+ # rigidgroups_group_is_ambiguous, pseudo_beta_mask, backbone_rigid_mask, protein_target_feat
453
+ gt_protein_feats = {
454
+ "atom37_gt_positions": atom37_gt_positions, # torch.Size([n_struct, 37, 3])
455
+ "atom37_atom_exists_in_res": atom37_atom_exists_in_res, # torch.Size([n_struct, 37])
456
+ "atom37_atom_exists_in_gt": atom37_atom_exists_in_gt, # torch.Size([n_struct, 37])
457
+
458
+ "atom14_gt_positions": atom14_gt_positions, # torch.Size([n_struct, 14, 3])
459
+ "atom14_atom_exists_in_res": atom14_atom_exists_in_res, # torch.Size([n_struct, 14])
460
+ "atom14_atom_exists_in_gt": atom14_atom_exists_in_gt, # torch.Size([n_struct, 14])
461
+
462
+ "gt_pseudo_beta_with_lig": gt_pseudo_beta_with_lig, # torch.Size([n_struct, 3])
463
+ "gt_pseudo_beta_with_lig_mask": gt_pseudo_beta_with_lig_mask, # torch.Size([n_struct])
464
+
465
+ # These we don't need to add the ligand to, because padding is sufficient (everything should be 0)
466
+ "atom14_alt_gt_positions": gt_protein_feats["atom14_alt_gt_positions"], # torch.Size([n_res, 14, 3])
467
+ "atom14_alt_gt_exists": gt_protein_feats["atom14_alt_gt_exists"], # torch.Size([n_res, 14])
468
+ "atom14_atom_is_ambiguous": gt_protein_feats["atom14_atom_is_ambiguous"], # torch.Size([n_res, 14])
469
+ "rigidgroups_gt_frames": gt_protein_feats["rigidgroups_gt_frames"], # torch.Size([n_res, 8, 4, 4])
470
+ "rigidgroups_gt_exists": gt_protein_feats["rigidgroups_gt_exists"], # torch.Size([n_res, 8])
471
+ "rigidgroups_alt_gt_frames": gt_protein_feats["rigidgroups_alt_gt_frames"], # torch.Size([n_res, 8, 4, 4])
472
+ "backbone_rigid_tensor": gt_protein_feats["backbone_rigid_tensor"], # torch.Size([n_res, 4, 4])
473
+ "backbone_rigid_mask": gt_protein_feats["backbone_rigid_mask"], # torch.Size([n_res])
474
+ "chi_angles_sin_cos": gt_protein_feats["chi_angles_sin_cos"],
475
+ "chi_mask": gt_protein_feats["chi_mask"],
476
+ }
477
+
478
+ for k, v in gt_protein_feats.items():
479
+ gt_protein_feats[k] = _fit_to_crop(v, crop_size, 0)
480
+
481
+ feats = {
482
+ **feats,
483
+ **gt_protein_feats,
484
+ "gt_ligand_positions": _fit_to_crop(gt_ligand_positions, crop_size, n_res),
485
+ "resolution": resolution,
486
+ "affinity": affinity,
487
+ "affinity_loss_factor": affinity_loss_factor,
488
+ "seq_length": torch.tensor(n_res + n_lig),
489
+ "binding_site_mask": _fit_to_crop(binding_site_mask, crop_size, 0),
490
+ "gt_inter_contacts": inter_contact_reshaped_to_crop,
491
+ }
492
+
493
+ for k, v in feats.items():
494
+ # print(k, v.shape)
495
+ feats[k] = _prepare_recycles(v, num_recycles)
496
+
497
+ feats["batch_idx"] = torch.tensor(
498
+ [idx for _ in range(crop_size)], dtype=torch.int64, device=feats["aatype"].device
499
+ )
500
+
501
+ print("load time", round(time.time() - start_load_time, 4))
502
+
503
+ return feats
dockformer/data/data_transforms.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import itertools
17
+ from functools import reduce, wraps
18
+ from operator import add
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from dockformer.config import NUM_RES
24
+ from dockformer.utils import residue_constants as rc
25
+ from dockformer.utils.residue_constants import restypes
26
+ from dockformer.utils.rigid_utils import Rotation, Rigid
27
+ from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array
28
+ from dockformer.utils.geometry.rotation_matrix import Rot3Array
29
+ from dockformer.utils.geometry.vector import Vec3Array
30
+ from dockformer.utils.tensor_utils import (
31
+ tree_map,
32
+ tensor_tree_map,
33
+ batched_gather,
34
+ )
35
+
36
+
37
+ def cast_to_64bit_ints(protein):
38
+ # We keep all ints as int64
39
+ for k, v in protein.items():
40
+ if v.dtype == torch.int32:
41
+ protein[k] = v.type(torch.int64)
42
+
43
+ return protein
44
+
45
+
46
+ def make_one_hot(x, num_classes):
47
+ x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
48
+ x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
49
+ return x_one_hot
50
+
51
+
52
+ def curry1(f):
53
+ """Supply all arguments but the first."""
54
+ @wraps(f)
55
+ def fc(*args, **kwargs):
56
+ return lambda x: f(x, *args, **kwargs)
57
+
58
+ return fc
59
+
60
+
61
+ def squeeze_features(protein):
62
+ """Remove singleton and repeated dimensions in protein features."""
63
+ protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
64
+ for k in [
65
+ "domain_name",
66
+ "seq_length",
67
+ "sequence",
68
+ "resolution",
69
+ "residue_index",
70
+ ]:
71
+ if k in protein:
72
+ final_dim = protein[k].shape[-1]
73
+ if isinstance(final_dim, int) and final_dim == 1:
74
+ if torch.is_tensor(protein[k]):
75
+ protein[k] = torch.squeeze(protein[k], dim=-1)
76
+ else:
77
+ protein[k] = np.squeeze(protein[k], axis=-1)
78
+
79
+ for k in ["seq_length"]:
80
+ if k in protein:
81
+ protein[k] = protein[k][0]
82
+
83
+ return protein
84
+
85
+
86
+ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
87
+ """Create pseudo beta features."""
88
+ is_gly = torch.eq(aatype, rc.restype_order["G"])
89
+ ca_idx = rc.atom_order["CA"]
90
+ cb_idx = rc.atom_order["CB"]
91
+ pseudo_beta = torch.where(
92
+ torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
93
+ all_atom_positions[..., ca_idx, :],
94
+ all_atom_positions[..., cb_idx, :],
95
+ )
96
+
97
+ if all_atom_mask is not None:
98
+ pseudo_beta_mask = torch.where(
99
+ is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
100
+ )
101
+ return pseudo_beta, pseudo_beta_mask
102
+ else:
103
+ return pseudo_beta
104
+
105
+
106
+ @curry1
107
+ def make_pseudo_beta(protein):
108
+ """Create pseudo-beta (alpha for glycine) position and mask."""
109
+ (protein["pseudo_beta"], protein["pseudo_beta_mask"]) = pseudo_beta_fn(
110
+ protein["aatype"],
111
+ protein["all_atom_positions"],
112
+ protein["all_atom_mask"],
113
+ )
114
+ return protein
115
+
116
+
117
+ @curry1
118
+ def make_target_feat(protein):
119
+ """Create and concatenate protein features."""
120
+ # Whether there is a domain break. Always zero for chains, but keeping for
121
+ # compatibility with domain datasets.
122
+ aatype_1hot = make_one_hot(protein["aatype"], 20)
123
+
124
+ protein["protein_target_feat"] = aatype_1hot
125
+
126
+ return protein
127
+
128
+
129
+
130
+ @curry1
131
+ def select_feat(protein, feature_list):
132
+ return {k: v for k, v in protein.items() if k in feature_list}
133
+
134
+
135
+ def get_restypes(device):
136
+ restype_atom14_to_atom37 = []
137
+ restype_atom37_to_atom14 = []
138
+ restype_atom14_mask = []
139
+
140
+ for rt in rc.restypes:
141
+ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
142
+ restype_atom14_to_atom37.append(
143
+ [(rc.atom_order[name] if name else 0) for name in atom_names]
144
+ )
145
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
146
+ restype_atom37_to_atom14.append(
147
+ [
148
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
149
+ for name in rc.atom_types
150
+ ]
151
+ )
152
+
153
+ restype_atom14_mask.append(
154
+ [(1.0 if name else 0.0) for name in atom_names]
155
+ )
156
+
157
+ # Add dummy mapping for restype 'UNK'
158
+ restype_atom14_to_atom37.append([0] * 14)
159
+ restype_atom37_to_atom14.append([0] * 37)
160
+ restype_atom14_mask.append([0.0] * 14)
161
+
162
+ restype_atom14_to_atom37 = torch.tensor(
163
+ restype_atom14_to_atom37,
164
+ dtype=torch.int32,
165
+ device=device,
166
+ )
167
+ restype_atom37_to_atom14 = torch.tensor(
168
+ restype_atom37_to_atom14,
169
+ dtype=torch.int32,
170
+ device=device,
171
+ )
172
+ restype_atom14_mask = torch.tensor(
173
+ restype_atom14_mask,
174
+ dtype=torch.float32,
175
+ device=device,
176
+ )
177
+
178
+ return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask
179
+
180
+
181
+ def get_restype_atom37_mask(device):
182
+ # create the corresponding mask
183
+ restype_atom37_mask = torch.zeros(
184
+ [len(restypes) + 1, 37], dtype=torch.float32, device=device
185
+ )
186
+ for restype, restype_letter in enumerate(rc.restypes):
187
+ restype_name = rc.restype_1to3[restype_letter]
188
+ atom_names = rc.residue_atoms[restype_name]
189
+ for atom_name in atom_names:
190
+ atom_type = rc.atom_order[atom_name]
191
+ restype_atom37_mask[restype, atom_type] = 1
192
+ return restype_atom37_mask
193
+
194
+
195
+ def make_atom14_masks(protein):
196
+ """Construct denser atom positions (14 dimensions instead of 37)."""
197
+ restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask = get_restypes(protein["aatype"].device)
198
+
199
+ protein_aatype = protein['aatype'].to(torch.long)
200
+
201
+ # create the mapping for (residx, atom14) --> atom37, i.e. an array
202
+ # with shape (num_res, 14) containing the atom37 indices for this protein
203
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
204
+ residx_atom14_mask = restype_atom14_mask[protein_aatype]
205
+
206
+ protein["atom14_atom_exists"] = residx_atom14_mask
207
+ protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
208
+
209
+ # create the gather indices for mapping back
210
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
211
+ protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
212
+
213
+ restype_atom37_mask = get_restype_atom37_mask(protein["aatype"].device)
214
+
215
+ residx_atom37_mask = restype_atom37_mask[protein_aatype]
216
+ protein["atom37_atom_exists"] = residx_atom37_mask
217
+
218
+ return protein
219
+
220
+
221
+ def make_atom14_positions(protein):
222
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
223
+ residx_atom14_mask = protein["atom14_atom_exists"]
224
+ residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
225
+
226
+ # Create a mask for known ground truth positions.
227
+ residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
228
+ protein["all_atom_mask"],
229
+ residx_atom14_to_atom37,
230
+ dim=-1,
231
+ no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
232
+ )
233
+
234
+ # Gather the ground truth positions.
235
+ residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
236
+ batched_gather(
237
+ protein["all_atom_positions"],
238
+ residx_atom14_to_atom37,
239
+ dim=-2,
240
+ no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
241
+ )
242
+ )
243
+
244
+ protein["atom14_atom_exists"] = residx_atom14_mask
245
+ protein["atom14_gt_exists"] = residx_atom14_gt_mask
246
+ protein["atom14_gt_positions"] = residx_atom14_gt_positions
247
+
248
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
249
+ # alternative ground truth coordinates where the naming is swapped
250
+ restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
251
+ restype_3 += ["UNK"]
252
+
253
+ # Matrices for renaming ambiguous atoms.
254
+ all_matrices = {
255
+ res: torch.eye(
256
+ 14,
257
+ dtype=protein["all_atom_mask"].dtype,
258
+ device=protein["all_atom_mask"].device,
259
+ )
260
+ for res in restype_3
261
+ }
262
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
263
+ correspondences = torch.arange(
264
+ 14, device=protein["all_atom_mask"].device
265
+ )
266
+ for source_atom_swap, target_atom_swap in swap.items():
267
+ source_index = rc.restype_name_to_atom14_names[resname].index(
268
+ source_atom_swap
269
+ )
270
+ target_index = rc.restype_name_to_atom14_names[resname].index(
271
+ target_atom_swap
272
+ )
273
+ correspondences[source_index] = target_index
274
+ correspondences[target_index] = source_index
275
+ renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
276
+ for index, correspondence in enumerate(correspondences):
277
+ renaming_matrix[index, correspondence] = 1.0
278
+ all_matrices[resname] = renaming_matrix
279
+
280
+ renaming_matrices = torch.stack(
281
+ [all_matrices[restype] for restype in restype_3]
282
+ )
283
+
284
+ # Pick the transformation matrices for the given residue sequence
285
+ # shape (num_res, 14, 14).
286
+ renaming_transform = renaming_matrices[protein["aatype"]]
287
+
288
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
289
+ alternative_gt_positions = torch.einsum(
290
+ "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
291
+ )
292
+ protein["atom14_alt_gt_positions"] = alternative_gt_positions
293
+
294
+ # Create the mask for the alternative ground truth (differs from the
295
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
296
+ # ground truth position).
297
+ alternative_gt_mask = torch.einsum(
298
+ "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
299
+ )
300
+ protein["atom14_alt_gt_exists"] = alternative_gt_mask
301
+
302
+ # Create an ambiguous atoms mask. shape: (21, 14).
303
+ restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
304
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
305
+ for atom_name1, atom_name2 in swap.items():
306
+ restype = rc.restype_order[rc.restype_3to1[resname]]
307
+ atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
308
+ atom_name1
309
+ )
310
+ atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
311
+ atom_name2
312
+ )
313
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
314
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
315
+
316
+ # From this create an ambiguous_mask for the given sequence.
317
+ protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
318
+ protein["aatype"]
319
+ ]
320
+
321
+ return protein
322
+
323
+
324
+ def atom37_to_frames(protein, eps=1e-8):
325
+ aatype = protein["aatype"]
326
+ all_atom_positions = protein["all_atom_positions"]
327
+ all_atom_mask = protein["all_atom_mask"]
328
+
329
+ batch_dims = len(aatype.shape[:-1])
330
+
331
+ restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
332
+ restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
333
+ restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
334
+
335
+ for restype, restype_letter in enumerate(rc.restypes):
336
+ resname = rc.restype_1to3[restype_letter]
337
+ for chi_idx in range(4):
338
+ if rc.chi_angles_mask[restype][chi_idx]:
339
+ names = rc.chi_angles_atoms[resname][chi_idx]
340
+ restype_rigidgroup_base_atom_names[
341
+ restype, chi_idx + 4, :
342
+ ] = names[1:]
343
+
344
+ restype_rigidgroup_mask = all_atom_mask.new_zeros(
345
+ (*aatype.shape[:-1], 21, 8),
346
+ )
347
+ restype_rigidgroup_mask[..., 0] = 1
348
+ restype_rigidgroup_mask[..., 3] = 1
349
+ restype_rigidgroup_mask[..., :len(restypes), 4:] = all_atom_mask.new_tensor(
350
+ rc.chi_angles_mask
351
+ )
352
+
353
+ lookuptable = rc.atom_order.copy()
354
+ lookuptable[""] = 0
355
+ lookup = np.vectorize(lambda x: lookuptable[x])
356
+ restype_rigidgroup_base_atom37_idx = lookup(
357
+ restype_rigidgroup_base_atom_names,
358
+ )
359
+ restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
360
+ restype_rigidgroup_base_atom37_idx,
361
+ )
362
+ restype_rigidgroup_base_atom37_idx = (
363
+ restype_rigidgroup_base_atom37_idx.view(
364
+ *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
365
+ )
366
+ )
367
+
368
+ residx_rigidgroup_base_atom37_idx = batched_gather(
369
+ restype_rigidgroup_base_atom37_idx,
370
+ aatype,
371
+ dim=-3,
372
+ no_batch_dims=batch_dims,
373
+ )
374
+
375
+ base_atom_pos = batched_gather(
376
+ all_atom_positions,
377
+ residx_rigidgroup_base_atom37_idx,
378
+ dim=-2,
379
+ no_batch_dims=len(all_atom_positions.shape[:-2]),
380
+ )
381
+
382
+ gt_frames = Rigid.from_3_points(
383
+ p_neg_x_axis=base_atom_pos[..., 0, :],
384
+ origin=base_atom_pos[..., 1, :],
385
+ p_xy_plane=base_atom_pos[..., 2, :],
386
+ eps=eps,
387
+ )
388
+
389
+ group_exists = batched_gather(
390
+ restype_rigidgroup_mask,
391
+ aatype,
392
+ dim=-2,
393
+ no_batch_dims=batch_dims,
394
+ )
395
+
396
+ gt_atoms_exist = batched_gather(
397
+ all_atom_mask,
398
+ residx_rigidgroup_base_atom37_idx,
399
+ dim=-1,
400
+ no_batch_dims=len(all_atom_mask.shape[:-1]),
401
+ )
402
+ gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
403
+
404
+ rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
405
+ rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
406
+ rots[..., 0, 0, 0] = -1
407
+ rots[..., 0, 2, 2] = -1
408
+
409
+ rots = Rotation(rot_mats=rots)
410
+ gt_frames = gt_frames.compose(Rigid(rots, None))
411
+
412
+ restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
413
+ *((1,) * batch_dims), 21, 8
414
+ )
415
+ restype_rigidgroup_rots = torch.eye(
416
+ 3, dtype=all_atom_mask.dtype, device=aatype.device
417
+ )
418
+ restype_rigidgroup_rots = torch.tile(
419
+ restype_rigidgroup_rots,
420
+ (*((1,) * batch_dims), 21, 8, 1, 1),
421
+ )
422
+
423
+ for resname, _ in rc.residue_atom_renaming_swaps.items():
424
+ restype = rc.restype_order[rc.restype_3to1[resname]]
425
+ chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
426
+ restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
427
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
428
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
429
+
430
+ residx_rigidgroup_is_ambiguous = batched_gather(
431
+ restype_rigidgroup_is_ambiguous,
432
+ aatype,
433
+ dim=-2,
434
+ no_batch_dims=batch_dims,
435
+ )
436
+
437
+ residx_rigidgroup_ambiguity_rot = batched_gather(
438
+ restype_rigidgroup_rots,
439
+ aatype,
440
+ dim=-4,
441
+ no_batch_dims=batch_dims,
442
+ )
443
+
444
+ residx_rigidgroup_ambiguity_rot = Rotation(
445
+ rot_mats=residx_rigidgroup_ambiguity_rot
446
+ )
447
+ alt_gt_frames = gt_frames.compose(
448
+ Rigid(residx_rigidgroup_ambiguity_rot, None)
449
+ )
450
+
451
+ gt_frames_tensor = gt_frames.to_tensor_4x4()
452
+ alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
453
+
454
+ protein["rigidgroups_gt_frames"] = gt_frames_tensor
455
+ protein["rigidgroups_gt_exists"] = gt_exists
456
+ protein["rigidgroups_group_exists"] = group_exists
457
+ protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
458
+ protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
459
+
460
+ return protein
461
+
462
+
463
+ def get_chi_atom_indices():
464
+ """Returns atom indices needed to compute chi angles for all residue types.
465
+
466
+ Returns:
467
+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
468
+ in the order specified in rc.restypes + unknown residue type
469
+ at the end. For chi angles which are not defined on the residue, the
470
+ positions indices are by default set to 0.
471
+ """
472
+ chi_atom_indices = []
473
+ for residue_name in rc.restypes:
474
+ residue_name = rc.restype_1to3[residue_name]
475
+ residue_chi_angles = rc.chi_angles_atoms[residue_name]
476
+ atom_indices = []
477
+ for chi_angle in residue_chi_angles:
478
+ atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
479
+ for _ in range(4 - len(atom_indices)):
480
+ atom_indices.append(
481
+ [0, 0, 0, 0]
482
+ ) # For chi angles not defined on the AA.
483
+ chi_atom_indices.append(atom_indices)
484
+
485
+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
486
+
487
+ return chi_atom_indices
488
+
489
+
490
+ @curry1
491
+ def atom37_to_torsion_angles(
492
+ protein,
493
+ prefix="",
494
+ ):
495
+ """
496
+ Convert coordinates to torsion angles.
497
+
498
+ This function is extremely sensitive to floating point imprecisions
499
+ and should be run with double precision whenever possible.
500
+
501
+ Args:
502
+ Dict containing:
503
+ * (prefix)aatype:
504
+ [*, N_res] residue indices
505
+ * (prefix)all_atom_positions:
506
+ [*, N_res, 37, 3] atom positions (in atom37
507
+ format)
508
+ * (prefix)all_atom_mask:
509
+ [*, N_res, 37] atom position mask
510
+ Returns:
511
+ The same dictionary updated with the following features:
512
+
513
+ "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
514
+ Torsion angles
515
+ "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
516
+ Alternate torsion angles (accounting for 180-degree symmetry)
517
+ "(prefix)torsion_angles_mask" ([*, N_res, 7])
518
+ Torsion angles mask
519
+ """
520
+ aatype = protein[prefix + "aatype"]
521
+ all_atom_positions = protein[prefix + "all_atom_positions"]
522
+ all_atom_mask = protein[prefix + "all_atom_mask"]
523
+
524
+ aatype = torch.clamp(aatype, max=20)
525
+
526
+ pad = all_atom_positions.new_zeros(
527
+ [*all_atom_positions.shape[:-3], 1, 37, 3]
528
+ )
529
+ prev_all_atom_positions = torch.cat(
530
+ [pad, all_atom_positions[..., :-1, :, :]], dim=-3
531
+ )
532
+
533
+ pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
534
+ prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
535
+
536
+ pre_omega_atom_pos = torch.cat(
537
+ [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
538
+ dim=-2,
539
+ )
540
+ phi_atom_pos = torch.cat(
541
+ [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
542
+ dim=-2,
543
+ )
544
+ psi_atom_pos = torch.cat(
545
+ [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
546
+ dim=-2,
547
+ )
548
+
549
+ pre_omega_mask = torch.prod(
550
+ prev_all_atom_mask[..., 1:3], dim=-1
551
+ ) * torch.prod(all_atom_mask[..., :2], dim=-1)
552
+ phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
553
+ all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
554
+ )
555
+ psi_mask = (
556
+ torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
557
+ * all_atom_mask[..., 4]
558
+ )
559
+
560
+ chi_atom_indices = torch.as_tensor(
561
+ get_chi_atom_indices(), device=aatype.device
562
+ )
563
+
564
+ atom_indices = chi_atom_indices[..., aatype, :, :]
565
+ chis_atom_pos = batched_gather(
566
+ all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
567
+ )
568
+
569
+ chi_angles_mask = list(rc.chi_angles_mask)
570
+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
571
+ chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
572
+
573
+ chis_mask = chi_angles_mask[aatype, :]
574
+
575
+ chi_angle_atoms_mask = batched_gather(
576
+ all_atom_mask,
577
+ atom_indices,
578
+ dim=-1,
579
+ no_batch_dims=len(atom_indices.shape[:-2]),
580
+ )
581
+ chi_angle_atoms_mask = torch.prod(
582
+ chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
583
+ )
584
+ chis_mask = chis_mask * chi_angle_atoms_mask
585
+
586
+ torsions_atom_pos = torch.cat(
587
+ [
588
+ pre_omega_atom_pos[..., None, :, :],
589
+ phi_atom_pos[..., None, :, :],
590
+ psi_atom_pos[..., None, :, :],
591
+ chis_atom_pos,
592
+ ],
593
+ dim=-3,
594
+ )
595
+
596
+ torsion_angles_mask = torch.cat(
597
+ [
598
+ pre_omega_mask[..., None],
599
+ phi_mask[..., None],
600
+ psi_mask[..., None],
601
+ chis_mask,
602
+ ],
603
+ dim=-1,
604
+ )
605
+
606
+ torsion_frames = Rigid.from_3_points(
607
+ torsions_atom_pos[..., 1, :],
608
+ torsions_atom_pos[..., 2, :],
609
+ torsions_atom_pos[..., 0, :],
610
+ eps=1e-8,
611
+ )
612
+
613
+ fourth_atom_rel_pos = torsion_frames.invert().apply(
614
+ torsions_atom_pos[..., 3, :]
615
+ )
616
+
617
+ torsion_angles_sin_cos = torch.stack(
618
+ [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
619
+ )
620
+
621
+ denom = torch.sqrt(
622
+ torch.sum(
623
+ torch.square(torsion_angles_sin_cos),
624
+ dim=-1,
625
+ dtype=torsion_angles_sin_cos.dtype,
626
+ keepdims=True,
627
+ )
628
+ + 1e-8
629
+ )
630
+ torsion_angles_sin_cos = torsion_angles_sin_cos / denom
631
+
632
+ torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
633
+ [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
634
+ )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
635
+
636
+ chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
637
+ rc.chi_pi_periodic,
638
+ )[aatype, ...]
639
+
640
+ mirror_torsion_angles = torch.cat(
641
+ [
642
+ all_atom_mask.new_ones(*aatype.shape, 3),
643
+ 1.0 - 2.0 * chi_is_ambiguous,
644
+ ],
645
+ dim=-1,
646
+ )
647
+
648
+ alt_torsion_angles_sin_cos = (
649
+ torsion_angles_sin_cos * mirror_torsion_angles[..., None]
650
+ )
651
+
652
+ protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
653
+ protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
654
+ protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
655
+
656
+ return protein
657
+
658
+
659
+ def get_backbone_frames(protein):
660
+ # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
661
+ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
662
+ ..., 0, :, :
663
+ ]
664
+ protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
665
+
666
+ return protein
667
+
668
+
669
+ def get_chi_angles(protein):
670
+ dtype = protein["all_atom_mask"].dtype
671
+ protein["chi_angles_sin_cos"] = (
672
+ protein["torsion_angles_sin_cos"][..., 3:, :]
673
+ ).to(dtype)
674
+ protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
675
+
676
+ return protein
677
+
678
+
679
+ @curry1
680
+ def random_crop_to_size(
681
+ protein,
682
+ crop_size,
683
+ shape_schema,
684
+ seed=None,
685
+ ):
686
+ """Crop randomly to `crop_size`, or keep as is if shorter than that."""
687
+ # We want each ensemble to be cropped the same way
688
+
689
+ g = None
690
+ if seed is not None:
691
+ g = torch.Generator(device=protein["seq_length"].device)
692
+ g.manual_seed(seed)
693
+
694
+ seq_length = protein["seq_length"]
695
+
696
+ num_res_crop_size = min(int(seq_length), crop_size)
697
+
698
+ def _randint(lower, upper):
699
+ return int(torch.randint(
700
+ lower,
701
+ upper + 1,
702
+ (1,),
703
+ device=protein["seq_length"].device,
704
+ generator=g,
705
+ )[0])
706
+
707
+ n = seq_length - num_res_crop_size
708
+ if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
709
+ right_anchor = n
710
+ else:
711
+ x = _randint(0, n)
712
+ right_anchor = n - x
713
+
714
+ num_res_crop_start = _randint(0, right_anchor)
715
+
716
+ for k, v in protein.items():
717
+ if k not in shape_schema or (NUM_RES not in shape_schema[k]):
718
+ continue
719
+
720
+ slices = []
721
+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
722
+ is_num_res = dim_size == NUM_RES
723
+ crop_start = num_res_crop_start if is_num_res else 0
724
+ crop_size = num_res_crop_size if is_num_res else dim
725
+ slices.append(slice(crop_start, crop_start + crop_size))
726
+ protein[k] = v[slices]
727
+
728
+ protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
729
+
730
+ return protein
731
+
dockformer/data/errors.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """General-purpose errors used throughout the data pipeline"""
17
+ class Error(Exception):
18
+ """Base class for exceptions."""
19
+
20
+
21
+ class MultipleChainsError(Error):
22
+ """An error indicating that multiple chains were found for a given ID."""
dockformer/data/ligand_features.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from rdkit import Chem
6
+
7
+ from dockformer.data.utils import FeatureTensorDict
8
+ from dockformer.utils.consts import POSSIBLE_BOND_TYPES, POSSIBLE_ATOM_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES
9
+
10
+
11
+ def get_atom_features(atom: Chem.Atom):
12
+ # TODO: this is temporary, we need to add more features, for example for Zn
13
+ if atom.GetSymbol() not in POSSIBLE_ATOM_TYPES:
14
+ print(f"********Unknown atom type {atom.GetSymbol()}")
15
+ atom_type = POSSIBLE_ATOM_TYPES.index("Ni")
16
+ else:
17
+ atom_type = POSSIBLE_ATOM_TYPES.index(atom.GetSymbol())
18
+ atom_charge = POSSIBLE_CHARGES.index(max(min(atom.GetFormalCharge(), 1), -1))
19
+ atom_chirality = POSSIBLE_CHIRALITIES.index(atom.GetChiralTag())
20
+
21
+ return {"atom_type": atom_type, "atom_charge": atom_charge, "atom_chirality": atom_chirality}
22
+
23
+
24
+ def get_bond_features(bond: Chem.Bond):
25
+ bond_type = POSSIBLE_BOND_TYPES.index(bond.GetBondType())
26
+ return {"bond_type": bond_type}
27
+
28
+
29
+ def make_ligand_features(ligand: Chem.Mol) -> FeatureTensorDict:
30
+ atoms_features = []
31
+ atom_idx_to_atom_pos_idx = {}
32
+ for atom in ligand.GetAtoms():
33
+ atom_idx_to_atom_pos_idx[atom.GetIdx()] = len(atoms_features)
34
+ atoms_features.append(get_atom_features(atom))
35
+
36
+ atom_types = torch.tensor(np.array([atom["atom_type"] for atom in atoms_features], dtype=np.int64))
37
+ atom_types_one_hot = nn.functional.one_hot(atom_types, num_classes=len(POSSIBLE_ATOM_TYPES), )
38
+ atom_charges = torch.tensor(np.array([atom["atom_charge"] for atom in atoms_features], dtype=np.int64))
39
+ atom_charges_one_hot = nn.functional.one_hot(atom_charges, num_classes=len(POSSIBLE_CHARGES))
40
+ atom_chiralities = torch.tensor(np.array([atom["atom_chirality"] for atom in atoms_features], dtype=np.int64))
41
+ atom_chiralities_one_hot = nn.functional.one_hot(atom_chiralities, num_classes=len(POSSIBLE_CHIRALITIES))
42
+
43
+ ligand_target_feat = torch.cat([atom_types_one_hot.float(), atom_charges_one_hot.float(),
44
+ atom_chiralities_one_hot.float()], dim=1)
45
+
46
+ # create one-hot matrix encoding for bonds
47
+ ligand_bonds_feat = torch.zeros((len(atoms_features), len(atoms_features), len(POSSIBLE_BOND_TYPES)))
48
+ ligand_bonds = []
49
+ for bond in ligand.GetBonds():
50
+ atom1_idx = atom_idx_to_atom_pos_idx[bond.GetBeginAtomIdx()]
51
+ atom2_idx = atom_idx_to_atom_pos_idx[bond.GetEndAtomIdx()]
52
+ bond_features = get_bond_features(bond)
53
+ ligand_bonds.append((atom1_idx, atom2_idx, bond_features["bond_type"]))
54
+ ligand_bonds_feat[atom1_idx, atom2_idx, bond_features["bond_type"]] = 1
55
+
56
+ return {
57
+ # These are used for reconstruction at the end of the pipeline
58
+ "ligand_atype": atom_types,
59
+ "ligand_charge": atom_charges,
60
+ "ligand_chirality": atom_chiralities,
61
+ "ligand_bonds": torch.tensor(ligand_bonds, dtype=torch.int64),
62
+ # these are the actual features
63
+ "ligand_target_feat": ligand_target_feat.float(),
64
+ "ligand_bonds_feat": ligand_bonds_feat.float(),
65
+ }
66
+
dockformer/data/parsers.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Functions for parsing various file formats."""
17
+ import collections
18
+ import dataclasses
19
+ import itertools
20
+ import re
21
+ import string
22
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
23
+
24
+
25
+ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
26
+ """Parses FASTA string and returns list of strings with amino-acid sequences.
27
+
28
+ Arguments:
29
+ fasta_string: The string contents of a FASTA file.
30
+
31
+ Returns:
32
+ A tuple of two lists:
33
+ * A list of sequences.
34
+ * A list of sequence descriptions taken from the comment lines. In the
35
+ same order as the sequences.
36
+ """
37
+ sequences = []
38
+ descriptions = []
39
+ index = -1
40
+ for line in fasta_string.splitlines():
41
+ line = line.strip()
42
+ if line.startswith(">"):
43
+ index += 1
44
+ descriptions.append(line[1:]) # Remove the '>' at the beginning.
45
+ sequences.append("")
46
+ continue
47
+ elif line.startswith("#"):
48
+ continue
49
+ elif not line:
50
+ continue # Skip blank lines.
51
+ sequences[index] += line
52
+
53
+ return sequences, descriptions
dockformer/data/protein_features.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from dockformer.data.utils import FeatureDict
4
+ from dockformer.utils import residue_constants, protein
5
+
6
+
7
+ def _make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict:
8
+ """Construct a feature dict of sequence features."""
9
+ features = {}
10
+ features["aatype"] = residue_constants.sequence_to_onehot(
11
+ sequence=sequence,
12
+ mapping=residue_constants.restype_order_with_x,
13
+ map_unknown_to_x=True,
14
+ )
15
+ features["domain_name"] = np.array(
16
+ [description.encode("utf-8")], dtype=object
17
+ )
18
+ # features["residue_index"] = np.array(range(num_res), dtype=np.int32)
19
+ features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
20
+ features["sequence"] = np.array(
21
+ [sequence.encode("utf-8")], dtype=object
22
+ )
23
+ return features
24
+
25
+
26
+ def _aatype_to_str_sequence(aatype):
27
+ return ''.join([
28
+ residue_constants.restypes_with_x[aatype[i]]
29
+ for i in range(len(aatype))
30
+ ])
31
+
32
+
33
+ def _make_protein_structure_features(protein_object: protein.Protein) -> FeatureDict:
34
+ pdb_feats = {}
35
+
36
+ all_atom_positions = protein_object.atom_positions
37
+ all_atom_mask = protein_object.atom_mask
38
+
39
+ pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
40
+ pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
41
+ pdb_feats["in_chain_residue_index"] = protein_object.residue_index.astype(np.int32)
42
+
43
+ gapped_res_indexes = []
44
+ prev_chain_index = protein_object.chain_index[0]
45
+ chain_start_res_ind = 0
46
+ for relative_res_ind, chain_index in zip(protein_object.residue_index, protein_object.chain_index):
47
+ if chain_index != prev_chain_index:
48
+ chain_start_res_ind = gapped_res_indexes[-1] + 50
49
+ prev_chain_index = chain_index
50
+ gapped_res_indexes.append(relative_res_ind + chain_start_res_ind)
51
+
52
+ pdb_feats["residue_index"] = np.array(gapped_res_indexes).astype(np.int32)
53
+ pdb_feats["chain_index"] = np.array(protein_object.chain_index).astype(np.int32)
54
+ pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
55
+
56
+ return pdb_feats
57
+
58
+
59
+ def make_protein_features(protein_object: protein.Protein, description: str) -> FeatureDict:
60
+ feats = {}
61
+ aatype = protein_object.aatype
62
+ sequence = _aatype_to_str_sequence(aatype)
63
+ feats.update(
64
+ _make_sequence_features(sequence=sequence, description=description, num_res=len(protein_object.aatype))
65
+ )
66
+
67
+ feats.update(
68
+ _make_protein_structure_features(protein_object=protein_object)
69
+ )
70
+
71
+ return feats
dockformer/data/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Common utilities for data pipeline tools."""
17
+ import contextlib
18
+ import datetime
19
+ import logging
20
+ import shutil
21
+ import tempfile
22
+ import time
23
+ from typing import Optional, Mapping, Dict
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ FeatureDict = Dict[str, np.ndarray]
29
+ FeatureTensorDict = Dict[str, torch.Tensor]
30
+
31
+
32
+ @contextlib.contextmanager
33
+ def tmpdir_manager(base_dir: Optional[str] = None):
34
+ """Context manager that deletes a temporary directory on exit."""
35
+ tmpdir = tempfile.mkdtemp(dir=base_dir)
36
+ try:
37
+ yield tmpdir
38
+ finally:
39
+ shutil.rmtree(tmpdir, ignore_errors=True)
40
+
41
+
42
+ @contextlib.contextmanager
43
+ def timing(msg: str):
44
+ logging.info("Started %s", msg)
45
+ tic = time.perf_counter()
46
+ yield
47
+ toc = time.perf_counter()
48
+ logging.info("Finished %s in %.3f seconds", msg, toc - tic)
49
+
50
+
51
+ def to_date(s: str):
52
+ return datetime.datetime(
53
+ year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
54
+ )
dockformer/model/__init__.py ADDED
File without changes
dockformer/model/dropout.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from functools import partialmethod
19
+ from typing import Union, List
20
+
21
+
22
+ class Dropout(nn.Module):
23
+ """
24
+ Implementation of dropout with the ability to share the dropout mask
25
+ along a particular dimension.
26
+
27
+ If not in training mode, this module computes the identity function.
28
+ """
29
+
30
+ def __init__(self, r: float, batch_dim: Union[int, List[int]]):
31
+ """
32
+ Args:
33
+ r:
34
+ Dropout rate
35
+ batch_dim:
36
+ Dimension(s) along which the dropout mask is shared
37
+ """
38
+ super(Dropout, self).__init__()
39
+
40
+ self.r = r
41
+ if type(batch_dim) == int:
42
+ batch_dim = [batch_dim]
43
+ self.batch_dim = batch_dim
44
+ self.dropout = nn.Dropout(self.r)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Args:
49
+ x:
50
+ Tensor to which dropout is applied. Can have any shape
51
+ compatible with self.batch_dim
52
+ """
53
+ shape = list(x.shape)
54
+ if self.batch_dim is not None:
55
+ for bd in self.batch_dim:
56
+ shape[bd] = 1
57
+ mask = x.new_ones(shape)
58
+ mask = self.dropout(mask)
59
+ x *= mask
60
+ return x
61
+
62
+
63
+ class DropoutRowwise(Dropout):
64
+ """
65
+ Convenience class for rowwise dropout as described in subsection
66
+ 1.11.6.
67
+ """
68
+
69
+ __init__ = partialmethod(Dropout.__init__, batch_dim=-3)
dockformer/model/embedders.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from typing import Tuple, Optional
21
+
22
+ from dockformer.model.primitives import Linear, LayerNorm
23
+ from dockformer.utils.tensor_utils import add
24
+
25
+
26
+ class StructureInputEmbedder(nn.Module):
27
+ """
28
+ Embeds a subset of the input features.
29
+
30
+ Implements a merge of Algorithms 3 and Algorithm 32.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ protein_tf_dim: int,
36
+ ligand_tf_dim: int,
37
+ additional_tf_dim: int,
38
+ ligand_bond_dim: int,
39
+ c_z: int,
40
+ c_m: int,
41
+ relpos_k: int,
42
+ prot_min_bin: float,
43
+ prot_max_bin: float,
44
+ prot_no_bins: int,
45
+ lig_min_bin: float,
46
+ lig_max_bin: float,
47
+ lig_no_bins: int,
48
+ inf: float = 1e8,
49
+ **kwargs,
50
+ ):
51
+ """
52
+ Args:
53
+ tf_dim:
54
+ Final dimension of the target features
55
+ c_z:
56
+ Pair embedding dimension
57
+ c_m:
58
+ Single embedding dimension
59
+ relpos_k:
60
+ Window size used in relative positional encoding
61
+ """
62
+ super(StructureInputEmbedder, self).__init__()
63
+
64
+ self.tf_dim = protein_tf_dim + ligand_tf_dim + additional_tf_dim
65
+ self.pair_tf_dim = ligand_bond_dim
66
+
67
+ self.c_z = c_z
68
+ self.c_m = c_m
69
+
70
+ self.linear_tf_z_i = Linear(self.tf_dim, c_z)
71
+ self.linear_tf_z_j = Linear(self.tf_dim, c_z)
72
+ self.linear_tf_m = Linear(self.tf_dim, c_m)
73
+
74
+ self.ligand_linear_bond_z = Linear(ligand_bond_dim, c_z)
75
+
76
+ # RPE stuff
77
+ self.relpos_k = relpos_k
78
+ self.no_bins = 2 * relpos_k + 1
79
+ self.linear_relpos = Linear(self.no_bins, c_z)
80
+
81
+ # Recycling stuff
82
+ self.prot_min_bin = prot_min_bin
83
+ self.prot_max_bin = prot_max_bin
84
+ self.prot_no_bins = prot_no_bins
85
+ self.lig_min_bin = lig_min_bin
86
+ self.lig_max_bin = lig_max_bin
87
+ self.lig_no_bins = lig_no_bins
88
+ self.inf = inf
89
+
90
+ self.prot_recycling_linear = Linear(self.prot_no_bins + 1, self.c_z)
91
+ self.lig_recycling_linear = Linear(self.lig_no_bins, self.c_z)
92
+ self.layer_norm_m = LayerNorm(self.c_m)
93
+ self.layer_norm_z = LayerNorm(self.c_z)
94
+
95
+ def relpos(self, ri: torch.Tensor):
96
+ """
97
+ Computes relative positional encodings
98
+
99
+ Implements Algorithm 4.
100
+
101
+ Args:
102
+ ri:
103
+ "residue_index" features of shape [*, N]
104
+ """
105
+ d = ri[..., None] - ri[..., None, :]
106
+ boundaries = torch.arange(
107
+ start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
108
+ )
109
+ reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
110
+ d = d[..., None] - reshaped_bins
111
+ d = torch.abs(d)
112
+ d = torch.argmin(d, dim=-1)
113
+ d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
114
+ d = d.to(ri.dtype)
115
+ return self.linear_relpos(d)
116
+
117
+ def _get_binned_distogram(self, x, min_bin, max_bin, no_bins, recycling_linear, prot_distogram_mask=None):
118
+ # This squared method might become problematic in FP16 mode.
119
+ bins = torch.linspace(
120
+ min_bin,
121
+ max_bin,
122
+ no_bins,
123
+ dtype=x.dtype,
124
+ device=x.device,
125
+ requires_grad=False,
126
+ )
127
+ squared_bins = bins ** 2
128
+ upper = torch.cat(
129
+ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
130
+ )
131
+ d = torch.sum((x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True)
132
+
133
+ # [*, N, N, no_bins]
134
+ d = ((d > squared_bins) * (d < upper)).type(x.dtype)
135
+ # print("d shape", d.shape, d[0][0][:10])
136
+
137
+ if prot_distogram_mask is not None:
138
+ expanded_d = torch.cat([d, torch.zeros(*d.shape[:-1], 1, device=d.device)], dim=-1)
139
+
140
+ # Step 2: Create a mask where `input_positions_masked` is 0
141
+ # Use broadcasting and tensor operations directly without additional variables
142
+ input_positions_mask = (prot_distogram_mask == 1).float() # Shape [N, crop_size]
143
+ mask_i = input_positions_mask.unsqueeze(2) # Shape [N, crop_size, 1]
144
+ mask_j = input_positions_mask.unsqueeze(1) # Shape [N, 1, crop_size]
145
+
146
+ # Step 3: Combine masks for both [N, :, i, :] and [N, i, :, :]
147
+ combined_mask = mask_i + mask_j # Shape [N, crop_size, crop_size]
148
+ combined_mask = combined_mask.clamp(max=1) # Ensure binary mask
149
+
150
+ # Step 4: Apply the mask
151
+ # a. Set all but the last position in the `no_bins + 1` dimension to 0 where the mask is 1
152
+ expanded_d[..., :-1] *= (1 - combined_mask).unsqueeze(-1) # Shape [N, crop_size, crop_size, no_bins]
153
+
154
+ # print("expanded_d shape1", expanded_d.shape, expanded_d[0][0][:10])
155
+
156
+ # b. Set the last position in the `no_bins + 1` dimension to 1 where the mask is 1
157
+ expanded_d[..., -1] += combined_mask # Shape [N, crop_size, crop_size, 1]
158
+ d = expanded_d
159
+ # print("expanded_d shape2", d.shape, d[0][0][:10])
160
+
161
+ return recycling_linear(d)
162
+
163
+ def forward(
164
+ self,
165
+ token_mask: torch.Tensor,
166
+ protein_mask: torch.Tensor,
167
+ ligand_mask: torch.Tensor,
168
+ target_feat: torch.Tensor,
169
+ ligand_bonds_feat: torch.Tensor,
170
+ input_positions: torch.Tensor,
171
+ protein_residue_index: torch.Tensor,
172
+ protein_distogram_mask: torch.Tensor,
173
+ inplace_safe: bool = False,
174
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
175
+ """
176
+ Args:
177
+ batch: Dict containing
178
+ "protein_target_feat":
179
+ Features of shape [*, N_res + N_lig_atoms, tf_dim]
180
+ "residue_index":
181
+ Features of shape [*, N_res]
182
+ input_protein_coords:
183
+ [*, N_res, 3] AF predicted C_beta coordinates supplied as input
184
+ ligand_bonds_feat:
185
+ [*, N_lig_atoms, N_lig_atoms, tf_dim] ligand bonds features
186
+ Returns:
187
+ single_emb:
188
+ [*, N_res + N_lig_atoms, C_m] single embedding
189
+ pair_emb:
190
+ [*, N_res + N_lig_atoms, N_res + N_lig_atoms, C_z] pair embedding
191
+
192
+ """
193
+ device = token_mask.device
194
+ pair_protein_mask = protein_mask[..., None] * protein_mask[..., None, :]
195
+ pair_ligand_mask = ligand_mask[..., None] * ligand_mask[..., None, :]
196
+
197
+ # Single representation embedding - Algorithm 3
198
+ tf_m = self.linear_tf_m(target_feat)
199
+ tf_m = self.layer_norm_m(tf_m) # previously this happend in the do_recycle function
200
+
201
+ # Pair representation
202
+ # protein pair embedding - Algorithm 3
203
+ # [*, N_res, c_z]
204
+ tf_emb_i = self.linear_tf_z_i(target_feat)
205
+ tf_emb_j = self.linear_tf_z_j(target_feat)
206
+
207
+ pair_emb = torch.zeros(*pair_protein_mask.shape, self.c_z, device=device)
208
+ pair_emb = add(pair_emb, tf_emb_i[..., None, :], inplace=inplace_safe)
209
+ pair_emb = add(pair_emb, tf_emb_j[..., None, :, :], inplace=inplace_safe)
210
+
211
+ # Apply relpos
212
+ relpos = self.relpos(protein_residue_index.type(tf_emb_i.dtype))
213
+ pair_emb += pair_protein_mask[..., None] * relpos
214
+
215
+ del relpos
216
+
217
+ # apply ligand bonds
218
+ ligand_bonds = self.ligand_linear_bond_z(ligand_bonds_feat)
219
+ pair_emb += pair_ligand_mask[..., None] * ligand_bonds
220
+
221
+ del ligand_bonds
222
+
223
+ # before recycles, do z_norm, this previously was a part of the recycles
224
+ pair_emb = self.layer_norm_z(pair_emb)
225
+
226
+ # apply protein recycle
227
+ prot_distogram_embed = self._get_binned_distogram(input_positions, self.prot_min_bin, self.prot_max_bin,
228
+ self.prot_no_bins, self.prot_recycling_linear,
229
+ protein_distogram_mask)
230
+
231
+
232
+ pair_emb = add(pair_emb, prot_distogram_embed * pair_protein_mask.unsqueeze(-1), inplace_safe)
233
+
234
+ del prot_distogram_embed
235
+
236
+ # apply ligand recycle
237
+ lig_distogram_embed = self._get_binned_distogram(input_positions, self.lig_min_bin, self.lig_max_bin,
238
+ self.lig_no_bins, self.lig_recycling_linear)
239
+ pair_emb = add(pair_emb, lig_distogram_embed * pair_ligand_mask.unsqueeze(-1), inplace_safe)
240
+
241
+ del lig_distogram_embed
242
+
243
+ return tf_m, pair_emb
244
+
245
+
246
+ class RecyclingEmbedder(nn.Module):
247
+ """
248
+ Embeds the output of an iteration of the model for recycling.
249
+
250
+ Implements Algorithm 32.
251
+ """
252
+ def __init__(
253
+ self,
254
+ c_m: int,
255
+ c_z: int,
256
+ min_bin: float,
257
+ max_bin: float,
258
+ no_bins: int,
259
+ inf: float = 1e8,
260
+ **kwargs,
261
+ ):
262
+ """
263
+ Args:
264
+ c_m:
265
+ Single channel dimension
266
+ c_z:
267
+ Pair embedding channel dimension
268
+ min_bin:
269
+ Smallest distogram bin (Angstroms)
270
+ max_bin:
271
+ Largest distogram bin (Angstroms)
272
+ no_bins:
273
+ Number of distogram bins
274
+ """
275
+ super(RecyclingEmbedder, self).__init__()
276
+
277
+ self.c_m = c_m
278
+ self.c_z = c_z
279
+ self.min_bin = min_bin
280
+ self.max_bin = max_bin
281
+ self.no_bins = no_bins
282
+ self.inf = inf
283
+
284
+ self.linear = Linear(self.no_bins, self.c_z)
285
+ self.layer_norm_m = LayerNorm(self.c_m)
286
+ self.layer_norm_z = LayerNorm(self.c_z)
287
+
288
+ def forward(
289
+ self,
290
+ m: torch.Tensor,
291
+ z: torch.Tensor,
292
+ x: torch.Tensor,
293
+ inplace_safe: bool = False,
294
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
295
+ """
296
+ Args:
297
+ m:
298
+ First row of the single embedding. [*, N_res, C_m]
299
+ z:
300
+ [*, N_res, N_res, C_z] pair embedding
301
+ x:
302
+ [*, N_res, 3] predicted C_beta coordinates
303
+ Returns:
304
+ m:
305
+ [*, N_res, C_m] single embedding update
306
+ z:
307
+ [*, N_res, N_res, C_z] pair embedding update
308
+ """
309
+ # [*, N, C_m]
310
+ m_update = self.layer_norm_m(m)
311
+ if(inplace_safe):
312
+ m.copy_(m_update)
313
+ m_update = m
314
+
315
+ # [*, N, N, C_z]
316
+ z_update = self.layer_norm_z(z)
317
+ if(inplace_safe):
318
+ z.copy_(z_update)
319
+ z_update = z
320
+
321
+ # This squared method might become problematic in FP16 mode.
322
+ bins = torch.linspace(
323
+ self.min_bin,
324
+ self.max_bin,
325
+ self.no_bins,
326
+ dtype=x.dtype,
327
+ device=x.device,
328
+ requires_grad=False,
329
+ )
330
+ squared_bins = bins ** 2
331
+ upper = torch.cat(
332
+ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
333
+ )
334
+ d = torch.sum(
335
+ (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
336
+ )
337
+
338
+ # [*, N, N, no_bins]
339
+ d = ((d > squared_bins) * (d < upper)).type(x.dtype)
340
+
341
+ # [*, N, N, C_z]
342
+ d = self.linear(d)
343
+ z_update = add(z_update, d, inplace_safe)
344
+
345
+ return m_update, z_update
346
+
dockformer/model/evoformer.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import math
16
+ import sys
17
+ import torch
18
+ import torch.nn as nn
19
+ from typing import Tuple, Sequence, Optional
20
+ from functools import partial
21
+ from abc import ABC, abstractmethod
22
+
23
+ from dockformer.model.primitives import Linear, LayerNorm
24
+ from dockformer.model.dropout import DropoutRowwise
25
+ from dockformer.model.single_attention import SingleRowAttentionWithPairBias
26
+
27
+ from dockformer.model.pair_transition import PairTransition
28
+ from dockformer.model.triangular_attention import (
29
+ TriangleAttention,
30
+ )
31
+ from dockformer.model.triangular_multiplicative_update import (
32
+ TriangleMultiplicationOutgoing,
33
+ TriangleMultiplicationIncoming,
34
+ )
35
+ from dockformer.utils.checkpointing import checkpoint_blocks
36
+ from dockformer.utils.tensor_utils import add
37
+
38
+
39
+ class SingleRepTransition(nn.Module):
40
+ """
41
+ Feed-forward network applied to single representation activations after attention.
42
+
43
+ Implements Algorithm 9
44
+ """
45
+ def __init__(self, c_m, n):
46
+ """
47
+ Args:
48
+ c_m:
49
+ channel dimension
50
+ n:
51
+ Factor multiplied to c_m to obtain the hidden channel dimension
52
+ """
53
+ super(SingleRepTransition, self).__init__()
54
+
55
+ self.c_m = c_m
56
+ self.n = n
57
+
58
+ self.layer_norm = LayerNorm(self.c_m)
59
+ self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
60
+ self.relu = nn.ReLU()
61
+ self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
62
+
63
+ def _transition(self, m, mask):
64
+ m = self.layer_norm(m)
65
+ m = self.linear_1(m)
66
+ m = self.relu(m)
67
+ m = self.linear_2(m) * mask
68
+ return m
69
+
70
+ def forward(
71
+ self,
72
+ m: torch.Tensor,
73
+ mask: Optional[torch.Tensor] = None,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Args:
77
+ m:
78
+ [*, N_res, C_m] activation after attention
79
+ mask:
80
+ [*, N_res, C_m] mask
81
+ Returns:
82
+ m:
83
+ [*, N_res, C_m] activation update
84
+ """
85
+ # DISCREPANCY: DeepMind forgets to apply the mask here.
86
+ if mask is None:
87
+ mask = m.new_ones(m.shape[:-1])
88
+
89
+ mask = mask.unsqueeze(-1)
90
+
91
+ m = self._transition(m, mask)
92
+
93
+ return m
94
+
95
+
96
+ class PairStack(nn.Module):
97
+ def __init__(
98
+ self,
99
+ c_z: int,
100
+ c_hidden_mul: int,
101
+ c_hidden_pair_att: int,
102
+ no_heads_pair: int,
103
+ transition_n: int,
104
+ pair_dropout: float,
105
+ inf: float,
106
+ eps: float
107
+ ):
108
+ super(PairStack, self).__init__()
109
+
110
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
111
+ c_z,
112
+ c_hidden_mul,
113
+ )
114
+ self.tri_mul_in = TriangleMultiplicationIncoming(
115
+ c_z,
116
+ c_hidden_mul,
117
+ )
118
+
119
+ self.tri_att_start = TriangleAttention(
120
+ c_z,
121
+ c_hidden_pair_att,
122
+ no_heads_pair,
123
+ inf=inf,
124
+ )
125
+ self.tri_att_end = TriangleAttention(
126
+ c_z,
127
+ c_hidden_pair_att,
128
+ no_heads_pair,
129
+ inf=inf,
130
+ )
131
+
132
+ self.pair_transition = PairTransition(
133
+ c_z,
134
+ transition_n,
135
+ )
136
+
137
+ self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
138
+
139
+ def forward(self,
140
+ z: torch.Tensor,
141
+ pair_mask: torch.Tensor,
142
+ use_lma: bool = False,
143
+ inplace_safe: bool = False,
144
+ _mask_trans: bool = True,
145
+ ) -> torch.Tensor:
146
+ # DeepMind doesn't mask these transitions in the source, so _mask_trans
147
+ # should be disabled to better approximate the exact activations of
148
+ # the original.
149
+ pair_trans_mask = pair_mask if _mask_trans else None
150
+
151
+ tmu_update = self.tri_mul_out(
152
+ z,
153
+ mask=pair_mask,
154
+ inplace_safe=inplace_safe,
155
+ _add_with_inplace=True,
156
+ )
157
+ if (not inplace_safe):
158
+ z = z + self.ps_dropout_row_layer(tmu_update)
159
+ else:
160
+ z = tmu_update
161
+
162
+ del tmu_update
163
+
164
+ tmu_update = self.tri_mul_in(
165
+ z,
166
+ mask=pair_mask,
167
+ inplace_safe=inplace_safe,
168
+ _add_with_inplace=True,
169
+ )
170
+ if (not inplace_safe):
171
+ z = z + self.ps_dropout_row_layer(tmu_update)
172
+ else:
173
+ z = tmu_update
174
+
175
+ del tmu_update
176
+
177
+ z = add(z,
178
+ self.ps_dropout_row_layer(
179
+ self.tri_att_start(
180
+ z,
181
+ mask=pair_mask,
182
+ use_memory_efficient_kernel=False,
183
+ use_lma=use_lma,
184
+ )
185
+ ),
186
+ inplace=inplace_safe,
187
+ )
188
+
189
+ z = z.transpose(-2, -3)
190
+ if (inplace_safe):
191
+ z = z.contiguous()
192
+
193
+ z = add(z,
194
+ self.ps_dropout_row_layer(
195
+ self.tri_att_end(
196
+ z,
197
+ mask=pair_mask.transpose(-1, -2),
198
+ use_memory_efficient_kernel=False,
199
+ use_lma=use_lma,
200
+ )
201
+ ),
202
+ inplace=inplace_safe,
203
+ )
204
+
205
+ z = z.transpose(-2, -3)
206
+ if (inplace_safe):
207
+ z = z.contiguous()
208
+
209
+ z = add(z,
210
+ self.pair_transition(
211
+ z, mask=pair_trans_mask,
212
+ ),
213
+ inplace=inplace_safe,
214
+ )
215
+
216
+ return z
217
+
218
+
219
+ class EvoformerBlock(nn.Module, ABC):
220
+ def __init__(self,
221
+ c_m: int,
222
+ c_z: int,
223
+ c_hidden_single_att: int,
224
+ c_hidden_mul: int,
225
+ c_hidden_pair_att: int,
226
+ no_heads_single: int,
227
+ no_heads_pair: int,
228
+ transition_n: int,
229
+ single_dropout: float,
230
+ pair_dropout: float,
231
+ inf: float,
232
+ eps: float,
233
+ ):
234
+ super(EvoformerBlock, self).__init__()
235
+
236
+ self.single_att_row = SingleRowAttentionWithPairBias(
237
+ c_m=c_m,
238
+ c_z=c_z,
239
+ c_hidden=c_hidden_single_att,
240
+ no_heads=no_heads_single,
241
+ inf=inf,
242
+ )
243
+
244
+ self.single_dropout_layer = DropoutRowwise(single_dropout)
245
+
246
+ self.single_transition = SingleRepTransition(
247
+ c_m=c_m,
248
+ n=transition_n,
249
+ )
250
+
251
+ self.pair_stack = PairStack(
252
+ c_z=c_z,
253
+ c_hidden_mul=c_hidden_mul,
254
+ c_hidden_pair_att=c_hidden_pair_att,
255
+ no_heads_pair=no_heads_pair,
256
+ transition_n=transition_n,
257
+ pair_dropout=pair_dropout,
258
+ inf=inf,
259
+ eps=eps
260
+ )
261
+
262
+ def forward(self,
263
+ m: Optional[torch.Tensor],
264
+ z: Optional[torch.Tensor],
265
+ single_mask: torch.Tensor,
266
+ pair_mask: torch.Tensor,
267
+ use_lma: bool = False,
268
+ inplace_safe: bool = False,
269
+ _mask_trans: bool = True,
270
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
271
+
272
+ single_trans_mask = single_mask if _mask_trans else None
273
+
274
+ input_tensors = [m, z]
275
+
276
+ m, z = input_tensors
277
+
278
+ z = self.pair_stack(
279
+ z=z,
280
+ pair_mask=pair_mask,
281
+ use_lma=use_lma,
282
+ inplace_safe=inplace_safe,
283
+ _mask_trans=_mask_trans,
284
+ )
285
+
286
+ m = add(m,
287
+ self.single_dropout_layer(
288
+ self.single_att_row(
289
+ m,
290
+ z=z,
291
+ mask=single_mask,
292
+ use_memory_efficient_kernel=False,
293
+ use_lma=use_lma,
294
+ )
295
+ ),
296
+ inplace=inplace_safe,
297
+ )
298
+
299
+ m = add(m, self.single_transition(m, mask=single_mask), inplace=inplace_safe)
300
+
301
+ return m, z
302
+
303
+
304
+ class EvoformerStack(nn.Module):
305
+ """
306
+ Main Evoformer trunk.
307
+
308
+ Implements Algorithm 6.
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ c_m: int,
314
+ c_z: int,
315
+ c_hidden_single_att: int,
316
+ c_hidden_mul: int,
317
+ c_hidden_pair_att: int,
318
+ c_s: int,
319
+ no_heads_single: int,
320
+ no_heads_pair: int,
321
+ no_blocks: int,
322
+ transition_n: int,
323
+ single_dropout: float,
324
+ pair_dropout: float,
325
+ blocks_per_ckpt: int,
326
+ inf: float,
327
+ eps: float,
328
+ clear_cache_between_blocks: bool = False,
329
+ **kwargs,
330
+ ):
331
+ """
332
+ Args:
333
+ c_m:
334
+ single channel dimension
335
+ c_z:
336
+ Pair channel dimension
337
+ c_hidden_single_att:
338
+ Hidden dimension in single representation attention
339
+ c_hidden_mul:
340
+ Hidden dimension in multiplicative updates
341
+ c_hidden_pair_att:
342
+ Hidden dimension in triangular attention
343
+ c_s:
344
+ Channel dimension of the output "single" embedding
345
+ no_heads_single:
346
+ Number of heads used for single attention
347
+ no_heads_pair:
348
+ Number of heads used for pair attention
349
+ no_blocks:
350
+ Number of Evoformer blocks in the stack
351
+ transition_n:
352
+ Factor by which to multiply c_m to obtain the SingleTransition
353
+ hidden dimension
354
+ single_dropout:
355
+ Dropout rate for single activations
356
+ pair_dropout:
357
+ Dropout used for pair activations
358
+ blocks_per_ckpt:
359
+ Number of Evoformer blocks in each activation checkpoint
360
+ clear_cache_between_blocks:
361
+ Whether to clear CUDA's GPU memory cache between blocks of the
362
+ stack. Slows down each block but can reduce fragmentation
363
+ """
364
+ super(EvoformerStack, self).__init__()
365
+
366
+ self.blocks_per_ckpt = blocks_per_ckpt
367
+ self.clear_cache_between_blocks = clear_cache_between_blocks
368
+
369
+ self.blocks = nn.ModuleList()
370
+
371
+ for _ in range(no_blocks):
372
+ block = EvoformerBlock(
373
+ c_m=c_m,
374
+ c_z=c_z,
375
+ c_hidden_single_att=c_hidden_single_att,
376
+ c_hidden_mul=c_hidden_mul,
377
+ c_hidden_pair_att=c_hidden_pair_att,
378
+ no_heads_single=no_heads_single,
379
+ no_heads_pair=no_heads_pair,
380
+ transition_n=transition_n,
381
+ single_dropout=single_dropout,
382
+ pair_dropout=pair_dropout,
383
+ inf=inf,
384
+ eps=eps,
385
+ )
386
+ self.blocks.append(block)
387
+
388
+ self.linear = Linear(c_m, c_s)
389
+
390
+ def _prep_blocks(self,
391
+ use_lma: bool,
392
+ single_mask: Optional[torch.Tensor],
393
+ pair_mask: Optional[torch.Tensor],
394
+ inplace_safe: bool,
395
+ _mask_trans: bool,
396
+ ):
397
+ blocks = [
398
+ partial(
399
+ b,
400
+ single_mask=single_mask,
401
+ pair_mask=pair_mask,
402
+ use_lma=use_lma,
403
+ inplace_safe=inplace_safe,
404
+ _mask_trans=_mask_trans,
405
+ )
406
+ for b in self.blocks
407
+ ]
408
+
409
+ if self.clear_cache_between_blocks:
410
+ def block_with_cache_clear(block, *args, **kwargs):
411
+ torch.cuda.empty_cache()
412
+ return block(*args, **kwargs)
413
+
414
+ blocks = [partial(block_with_cache_clear, b) for b in blocks]
415
+
416
+ return blocks
417
+
418
+ def forward(self,
419
+ m: torch.Tensor,
420
+ z: torch.Tensor,
421
+ single_mask: torch.Tensor,
422
+ pair_mask: torch.Tensor,
423
+ use_lma: bool = False,
424
+ inplace_safe: bool = False,
425
+ _mask_trans: bool = True,
426
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
427
+ """
428
+ Args:
429
+ m:
430
+ [*, N_res, C_m] single embedding
431
+ z:
432
+ [*, N_res, N_res, C_z] pair embedding
433
+ single_mask:
434
+ [*, N_res] single mask
435
+ pair_mask:
436
+ [*, N_res, N_res] pair mask
437
+ use_lma:
438
+ Whether to use low-memory attention during inference.
439
+
440
+ Returns:
441
+ m:
442
+ [*, N_res, C_m] single embedding
443
+ z:
444
+ [*, N_res, N_res, C_z] pair embedding
445
+ s:
446
+ [*, N_res, C_s] single embedding after linear layer
447
+ """
448
+ blocks = self._prep_blocks(
449
+ use_lma=use_lma,
450
+ single_mask=single_mask,
451
+ pair_mask=pair_mask,
452
+ inplace_safe=inplace_safe,
453
+ _mask_trans=_mask_trans,
454
+ )
455
+
456
+ blocks_per_ckpt = self.blocks_per_ckpt
457
+ if(not torch.is_grad_enabled()):
458
+ blocks_per_ckpt = None
459
+
460
+ m, z = checkpoint_blocks(
461
+ blocks,
462
+ args=(m, z),
463
+ blocks_per_ckpt=blocks_per_ckpt,
464
+ )
465
+
466
+ s = self.linear(m)
467
+
468
+ return m, z, s
dockformer/model/heads.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn import Parameter
19
+
20
+ from dockformer.model.primitives import Linear, LayerNorm
21
+ from dockformer.utils.loss import (
22
+ compute_plddt,
23
+ compute_tm,
24
+ compute_predicted_aligned_error,
25
+ )
26
+ from dockformer.utils.precision_utils import is_fp16_enabled
27
+
28
+
29
+ class AuxiliaryHeads(nn.Module):
30
+ def __init__(self, config):
31
+ super(AuxiliaryHeads, self).__init__()
32
+
33
+ self.plddt = PerResidueLDDTCaPredictor(
34
+ **config["lddt"],
35
+ )
36
+
37
+ self.distogram = DistogramHead(
38
+ **config["distogram"],
39
+ )
40
+
41
+ self.affinity_2d = Affinity2DPredictor(
42
+ **config["affinity_2d"],
43
+ )
44
+
45
+ self.affinity_1d = Affinity1DPredictor(
46
+ **config["affinity_1d"],
47
+ )
48
+
49
+ self.affinity_cls = AffinityClsTokenPredictor(
50
+ **config["affinity_cls"],
51
+ )
52
+
53
+ self.binding_site = BindingSitePredictor(
54
+ **config["binding_site"],
55
+ )
56
+
57
+ self.inter_contact = InterContactHead(
58
+ **config["inter_contact"],
59
+ )
60
+
61
+ self.config = config
62
+
63
+ def forward(self, outputs, inter_mask, affinity_mask):
64
+ aux_out = {}
65
+ lddt_logits = self.plddt(outputs["sm"]["single"])
66
+ aux_out["lddt_logits"] = lddt_logits
67
+
68
+ # Required for relaxation later on
69
+ aux_out["plddt"] = compute_plddt(lddt_logits)
70
+
71
+ distogram_logits = self.distogram(outputs["pair"])
72
+ aux_out["distogram_logits"] = distogram_logits
73
+
74
+ aux_out["inter_contact_logits"] = self.inter_contact(outputs["single"], outputs["pair"])
75
+
76
+ aux_out["affinity_2d_logits"] = self.affinity_2d(outputs["pair"], aux_out["inter_contact_logits"], inter_mask)
77
+
78
+ aux_out["affinity_1d_logits"] = self.affinity_1d(outputs["single"])
79
+
80
+ aux_out["affinity_cls_logits"] = self.affinity_cls(outputs["single"], affinity_mask)
81
+
82
+ aux_out["binding_site_logits"] = self.binding_site(outputs["single"])
83
+
84
+ return aux_out
85
+
86
+
87
+ class Affinity2DPredictor(nn.Module):
88
+ def __init__(self, c_z, num_bins):
89
+ super(Affinity2DPredictor, self).__init__()
90
+
91
+ self.c_z = c_z
92
+
93
+ self.weight_linear = Linear(self.c_z + 1, 1)
94
+ self.embed_linear = Linear(self.c_z, self.c_z)
95
+ self.bins_linear = Linear(self.c_z, num_bins)
96
+
97
+ def forward(self, z, inter_contacts_logits, inter_pair_mask):
98
+ z_with_inter_contacts = torch.cat((z, inter_contacts_logits), dim=-1) # [*, N, N, c_z + 1]
99
+ weights = self.weight_linear(z_with_inter_contacts) # [*, N, N, 1]
100
+
101
+ x = self.embed_linear(z) # [*, N, N, c_z]
102
+ batch_size, N, M, _ = x.shape
103
+
104
+ flat_weights = weights.reshape(batch_size, N*M, -1) # [*, N*M, 1]
105
+ flat_x = x.reshape(batch_size, N*M, -1) # [*, N*M, c_z]
106
+ flat_inter_pair_mask = inter_pair_mask.reshape(batch_size, N*M, 1)
107
+
108
+ flat_weights = flat_weights.masked_fill(~(flat_inter_pair_mask.bool()), float('-inf')) # [*, N*N, 1]
109
+ flat_weights = torch.nn.functional.softmax(flat_weights, dim=1) # [*, N*N, 1]
110
+ flat_weights = torch.nan_to_num(flat_weights, nan=0.0) # [*, N*N, 1]
111
+ weighted_sum = torch.sum((flat_weights * flat_x).reshape(batch_size, N*M, -1), dim=1) # [*, c_z]
112
+
113
+ return self.bins_linear(weighted_sum)
114
+
115
+
116
+ class Affinity1DPredictor(nn.Module):
117
+ def __init__(self, c_s, num_bins, **kwargs):
118
+ super(Affinity1DPredictor, self).__init__()
119
+
120
+ self.c_s = c_s
121
+
122
+ self.linear1 = Linear(self.c_s, self.c_s, init="final")
123
+
124
+ self.linear2 = Linear(self.c_s, num_bins, init="final")
125
+
126
+ def forward(self, s):
127
+ # [*, N, C_out]
128
+ s = self.linear1(s)
129
+
130
+ # get an average over the sequence
131
+ s = torch.mean(s, dim=1)
132
+
133
+ logits = self.linear2(s)
134
+ return logits
135
+
136
+
137
+ class AffinityClsTokenPredictor(nn.Module):
138
+ def __init__(self, c_s, num_bins, **kwargs):
139
+ super(AffinityClsTokenPredictor, self).__init__()
140
+
141
+ self.c_s = c_s
142
+ self.linear = Linear(self.c_s, num_bins, init="final")
143
+
144
+ def forward(self, s, affinity_mask):
145
+ affinity_tokens = (s * affinity_mask.unsqueeze(-1)).sum(dim=1)
146
+ return self.linear(affinity_tokens)
147
+
148
+
149
+ class BindingSitePredictor(nn.Module):
150
+ def __init__(self, c_s, c_out, **kwargs):
151
+ super(BindingSitePredictor, self).__init__()
152
+
153
+ self.c_s = c_s
154
+ self.c_out = c_out
155
+
156
+ self.linear = Linear(self.c_s, self.c_out, init="final")
157
+
158
+ def forward(self, s):
159
+ # [*, N, C_out]
160
+ return self.linear(s)
161
+
162
+
163
+ class InterContactHead(nn.Module):
164
+ def __init__(self, c_s, c_z, c_out, **kwargs):
165
+ """
166
+ Args:
167
+ c_z:
168
+ Input channel dimension
169
+ c_out:
170
+ Number of bins, but since boolean should be 1
171
+ """
172
+ super(InterContactHead, self).__init__()
173
+
174
+ self.c_s = c_s
175
+ self.c_z = c_z
176
+ self.c_out = c_out
177
+
178
+ self.linear = Linear(2 * self.c_s + self.c_z, self.c_out, init="final")
179
+
180
+ def forward(self, s, z): # [*, N, N, C_z]
181
+ # [*, N, N, no_bins]
182
+ batch_size, n, s_dim = s.shape
183
+
184
+ s_i = s.unsqueeze(2).expand(batch_size, n, n, s_dim)
185
+ s_j = s.unsqueeze(1).expand(batch_size, n, n, s_dim)
186
+ joined = torch.cat((s_i, s_j, z), dim=-1)
187
+
188
+ logits = self.linear(joined)
189
+
190
+ return logits
191
+
192
+
193
+ class PerResidueLDDTCaPredictor(nn.Module):
194
+ def __init__(self, no_bins, c_in, c_hidden):
195
+ super(PerResidueLDDTCaPredictor, self).__init__()
196
+
197
+ self.no_bins = no_bins
198
+ self.c_in = c_in
199
+ self.c_hidden = c_hidden
200
+
201
+ self.layer_norm = LayerNorm(self.c_in)
202
+
203
+ self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
204
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
205
+ self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
206
+
207
+ self.relu = nn.ReLU()
208
+
209
+ def forward(self, s):
210
+ s = self.layer_norm(s)
211
+ s = self.linear_1(s)
212
+ s = self.relu(s)
213
+ s = self.linear_2(s)
214
+ s = self.relu(s)
215
+ s = self.linear_3(s)
216
+
217
+ return s
218
+
219
+
220
+ class DistogramHead(nn.Module):
221
+ """
222
+ Computes a distogram probability distribution.
223
+
224
+ For use in computation of distogram loss, subsection 1.9.8
225
+ """
226
+
227
+ def __init__(self, c_z, no_bins, **kwargs):
228
+ """
229
+ Args:
230
+ c_z:
231
+ Input channel dimension
232
+ no_bins:
233
+ Number of distogram bins
234
+ """
235
+ super(DistogramHead, self).__init__()
236
+
237
+ self.c_z = c_z
238
+ self.no_bins = no_bins
239
+
240
+ self.linear = Linear(self.c_z, self.no_bins, init="final")
241
+
242
+ def _forward(self, z): # [*, N, N, C_z]
243
+ """
244
+ Args:
245
+ z:
246
+ [*, N_res, N_res, C_z] pair embedding
247
+ Returns:
248
+ [*, N, N, no_bins] distogram probability distribution
249
+ """
250
+ # [*, N, N, no_bins]
251
+ logits = self.linear(z)
252
+ logits = logits + logits.transpose(-2, -3)
253
+ return logits
254
+
255
+ def forward(self, z):
256
+ if(is_fp16_enabled()):
257
+ with torch.cuda.amp.autocast(enabled=False):
258
+ return self._forward(z.float())
259
+ else:
260
+ return self._forward(z)
dockformer/model/model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import partial
16
+ import weakref
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from dockformer.utils.tensor_utils import masked_mean
22
+ from dockformer.model.embedders import (
23
+ StructureInputEmbedder,
24
+ RecyclingEmbedder,
25
+ )
26
+ from dockformer.model.evoformer import EvoformerStack
27
+ from dockformer.model.heads import AuxiliaryHeads
28
+ from dockformer.model.structure_module import StructureModule
29
+ import dockformer.utils.residue_constants as residue_constants
30
+ from dockformer.utils.feats import (
31
+ pseudo_beta_fn,
32
+ atom14_to_atom37,
33
+ )
34
+ from dockformer.utils.tensor_utils import (
35
+ add,
36
+ tensor_tree_map,
37
+ )
38
+
39
+
40
+ class AlphaFold(nn.Module):
41
+ """
42
+ Alphafold 2.
43
+
44
+ Implements Algorithm 2 (but with training).
45
+ """
46
+
47
+ def __init__(self, config):
48
+ """
49
+ Args:
50
+ config:
51
+ A dict-like config object (like the one in config.py)
52
+ """
53
+ super(AlphaFold, self).__init__()
54
+
55
+ self.globals = config.globals
56
+ self.config = config.model
57
+
58
+ # Main trunk + structure module
59
+ self.input_embedder = StructureInputEmbedder(
60
+ **self.config["structure_input_embedder"],
61
+ )
62
+
63
+ self.recycling_embedder = RecyclingEmbedder(
64
+ **self.config["recycling_embedder"],
65
+ )
66
+
67
+ self.evoformer = EvoformerStack(
68
+ **self.config["evoformer_stack"],
69
+ )
70
+
71
+ self.structure_module = StructureModule(
72
+ **self.config["structure_module"],
73
+ )
74
+ self.aux_heads = AuxiliaryHeads(
75
+ self.config["heads"],
76
+ )
77
+
78
+ def tolerance_reached(self, prev_pos, next_pos, mask, eps=1e-8) -> bool:
79
+ """
80
+ Early stopping criteria based on criteria used in
81
+ AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
82
+ Args:
83
+ prev_pos: Previous atom positions in atom37/14 representation
84
+ next_pos: Current atom positions in atom37/14 representation
85
+ mask: 1-D sequence mask
86
+ eps: Epsilon used in square root calculation
87
+ Returns:
88
+ Whether to stop recycling early based on the desired tolerance.
89
+ """
90
+
91
+ def distances(points):
92
+ """Compute all pairwise distances for a set of points."""
93
+ d = points[..., None, :] - points[..., None, :, :]
94
+ return torch.sqrt(torch.sum(d ** 2, dim=-1))
95
+
96
+ if self.config.recycle_early_stop_tolerance < 0:
97
+ return False
98
+
99
+ ca_idx = residue_constants.atom_order['CA']
100
+ sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
101
+ mask = mask[..., None] * mask[..., None, :]
102
+ sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape))))
103
+ diff = torch.sqrt(sq_diff + eps).item()
104
+ return diff <= self.config.recycle_early_stop_tolerance
105
+
106
+ def iteration(self, feats, prevs, _recycle=True):
107
+ # Primary output dictionary
108
+ outputs = {}
109
+
110
+ # This needs to be done manually for DeepSpeed's sake
111
+ dtype = next(self.parameters()).dtype
112
+ for k in feats:
113
+ if feats[k].dtype == torch.float32:
114
+ feats[k] = feats[k].to(dtype=dtype)
115
+
116
+ # Grab some data about the input
117
+ batch_dims, n_total = feats["token_mask"].shape
118
+ device = feats["token_mask"].device
119
+
120
+ print("doing sample of size", feats["token_mask"].shape,
121
+ feats["protein_mask"].sum(dim=1), feats["ligand_mask"].sum(dim=1))
122
+
123
+ # Controls whether the model uses in-place operations throughout
124
+ # The dual condition accounts for activation checkpoints
125
+ # inplace_safe = not (self.training or torch.is_grad_enabled())
126
+ inplace_safe = False # so we don't need attn_core_inplace_cuda
127
+
128
+ # Prep some features
129
+ token_mask = feats["token_mask"]
130
+ pair_mask = token_mask[..., None] * token_mask[..., None, :]
131
+
132
+ # Initialize the single and pair representations
133
+ # m: [*, 1, n_total, C_m]
134
+ # z: [*, n_total, n_total, C_z]
135
+ m, z = self.input_embedder(
136
+ feats["token_mask"],
137
+ feats["protein_mask"],
138
+ feats["ligand_mask"],
139
+ feats["target_feat"],
140
+ feats["ligand_bonds_feat"],
141
+ feats["input_positions"],
142
+ feats["protein_residue_index"],
143
+ feats["protein_distogram_mask"],
144
+ inplace_safe=inplace_safe,
145
+ )
146
+
147
+ # Unpack the recycling embeddings. Removing them from the list allows
148
+ # them to be freed further down in this function, saving memory
149
+ m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
150
+
151
+ # Initialize the recycling embeddings, if needs be
152
+ if None in [m_1_prev, z_prev, x_prev]:
153
+ # [*, N, C_m]
154
+ m_1_prev = m.new_zeros(
155
+ (batch_dims, n_total, self.config.structure_input_embedder.c_m),
156
+ requires_grad=False,
157
+ )
158
+
159
+ # [*, N, N, C_z]
160
+ z_prev = z.new_zeros(
161
+ (batch_dims, n_total, n_total, self.config.structure_input_embedder.c_z),
162
+ requires_grad=False,
163
+ )
164
+
165
+ # [*, N, 3]
166
+ x_prev = z.new_zeros(
167
+ (batch_dims, n_total, residue_constants.atom_type_num, 3),
168
+ requires_grad=False,
169
+ )
170
+
171
+ # shape == [1, n_total, 37, 3]
172
+ pseudo_beta_or_lig_x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None).to(dtype=z.dtype)
173
+
174
+ # m_1_prev_emb: [*, N, C_m]
175
+ # z_prev_emb: [*, N, N, C_z]
176
+ m_1_prev_emb, z_prev_emb = self.recycling_embedder(
177
+ m_1_prev,
178
+ z_prev,
179
+ pseudo_beta_or_lig_x_prev,
180
+ inplace_safe=inplace_safe,
181
+ )
182
+
183
+ del pseudo_beta_or_lig_x_prev
184
+
185
+ # [*, S_c, N, C_m]
186
+ m += m_1_prev_emb
187
+
188
+ # [*, N, N, C_z]
189
+ z = add(z, z_prev_emb, inplace=inplace_safe)
190
+
191
+ # Deletions like these become significant for inference with large N,
192
+ # where they free unused tensors and remove references to others such
193
+ # that they can be offloaded later
194
+ del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb
195
+
196
+ # Run single + pair embeddings through the trunk of the network
197
+ # m: [*, N, C_m]
198
+ # z: [*, N, N, C_z]
199
+ # s: [*, N, C_s]
200
+ m, z, s = self.evoformer(
201
+ m,
202
+ z,
203
+ single_mask=token_mask.to(dtype=m.dtype),
204
+ pair_mask=pair_mask.to(dtype=z.dtype),
205
+ use_lma=self.globals.use_lma,
206
+ inplace_safe=inplace_safe,
207
+ _mask_trans=self.config._mask_trans,
208
+ )
209
+
210
+ outputs["pair"] = z
211
+ outputs["single"] = s
212
+
213
+ del z
214
+
215
+ # Predict 3D structure
216
+ outputs["sm"] = self.structure_module(
217
+ outputs,
218
+ feats["aatype"],
219
+ mask=token_mask.to(dtype=s.dtype),
220
+ inplace_safe=inplace_safe,
221
+ )
222
+ outputs["final_atom_positions"] = atom14_to_atom37(
223
+ outputs["sm"]["positions"][-1], feats
224
+ )
225
+ outputs["final_atom_mask"] = feats["atom37_atom_exists"]
226
+
227
+ # Save embeddings for use during the next recycling iteration
228
+
229
+ # [*, N, C_m]
230
+ m_1_prev = m[..., 0, :, :]
231
+
232
+ # [*, N, N, C_z]
233
+ z_prev = outputs["pair"]
234
+
235
+ # TODO bshor: early stop depends on is_multimer, but I don't think it must
236
+ early_stop = False
237
+ # if self.globals.is_multimer:
238
+ # early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask)
239
+
240
+ del x_prev
241
+
242
+ # [*, N, 3]
243
+ x_prev = outputs["final_atom_positions"]
244
+
245
+ return outputs, m_1_prev, z_prev, x_prev, early_stop
246
+
247
+ def forward(self, batch):
248
+ """
249
+ Args:
250
+ batch:
251
+ Dictionary of arguments outlined in Algorithm 2. Keys must
252
+ include the official names of the features in the
253
+ supplement subsection 1.2.9.
254
+
255
+ The final dimension of each input must have length equal to
256
+ the number of recycling iterations.
257
+
258
+ Features (without the recycling dimension):
259
+
260
+ "aatype" ([*, N_res]):
261
+ Contrary to the supplement, this tensor of residue
262
+ indices is not one-hot.
263
+ "protein_target_feat" ([*, N_res, C_tf])
264
+ One-hot encoding of the target sequence. C_tf is
265
+ config.model.input_embedder.tf_dim.
266
+ "residue_index" ([*, N_res])
267
+ Tensor whose final dimension consists of
268
+ consecutive indices from 0 to N_res.
269
+ "token_mask" ([*, N_token])
270
+ 1-D token mask
271
+ "pair_mask" ([*, N_token, N_token])
272
+ 2-D pair mask
273
+ """
274
+ # Initialize recycling embeddings
275
+ m_1_prev, z_prev, x_prev = None, None, None
276
+ prevs = [m_1_prev, z_prev, x_prev]
277
+
278
+ is_grad_enabled = torch.is_grad_enabled()
279
+
280
+ # Main recycling loop
281
+ num_iters = batch["aatype"].shape[-1]
282
+ early_stop = False
283
+ num_recycles = 0
284
+ for cycle_no in range(num_iters):
285
+ # Select the features for the current recycling cycle
286
+ fetch_cur_batch = lambda t: t[..., cycle_no]
287
+ feats = tensor_tree_map(fetch_cur_batch, batch)
288
+
289
+ # Enable grad iff we're training and it's the final recycling layer
290
+ is_final_iter = cycle_no == (num_iters - 1) or early_stop
291
+ with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
292
+ if is_final_iter:
293
+ # Sidestep AMP bug (PyTorch issue #65766)
294
+ if torch.is_autocast_enabled():
295
+ torch.clear_autocast_cache()
296
+
297
+ # Run the next iteration of the model
298
+ outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
299
+ feats,
300
+ prevs,
301
+ _recycle=(num_iters > 1)
302
+ )
303
+
304
+ num_recycles += 1
305
+
306
+ if not is_final_iter:
307
+ del outputs
308
+ prevs = [m_1_prev, z_prev, x_prev]
309
+ del m_1_prev, z_prev, x_prev
310
+ else:
311
+ break
312
+
313
+ outputs["num_recycles"] = torch.tensor(num_recycles, device=feats["aatype"].device)
314
+
315
+ # Run auxiliary heads, remove the recycling dimension batch properties
316
+ outputs.update(self.aux_heads(outputs, batch["inter_pair_mask"][..., 0], batch["affinity_mask"][..., 0]))
317
+
318
+ return outputs
dockformer/model/pair_transition.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from dockformer.model.primitives import Linear, LayerNorm
21
+
22
+
23
+ class PairTransition(nn.Module):
24
+ """
25
+ Implements Algorithm 15.
26
+ """
27
+
28
+ def __init__(self, c_z, n):
29
+ """
30
+ Args:
31
+ c_z:
32
+ Pair transition channel dimension
33
+ n:
34
+ Factor by which c_z is multiplied to obtain hidden channel
35
+ dimension
36
+ """
37
+ super(PairTransition, self).__init__()
38
+
39
+ self.c_z = c_z
40
+ self.n = n
41
+
42
+ self.layer_norm = LayerNorm(self.c_z)
43
+ self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
44
+ self.relu = nn.ReLU()
45
+ self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
46
+
47
+ def _transition(self, z, mask):
48
+ # [*, N_res, N_res, C_z]
49
+ z = self.layer_norm(z)
50
+
51
+ # [*, N_res, N_res, C_hidden]
52
+ z = self.linear_1(z)
53
+ z = self.relu(z)
54
+
55
+ # [*, N_res, N_res, C_z]
56
+ z = self.linear_2(z)
57
+ z = z * mask
58
+
59
+ return z
60
+
61
+ def forward(self,
62
+ z: torch.Tensor,
63
+ mask: Optional[torch.Tensor] = None,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Args:
67
+ z:
68
+ [*, N_res, N_res, C_z] pair embedding
69
+ Returns:
70
+ [*, N_res, N_res, C_z] pair embedding update
71
+ """
72
+ # DISCREPANCY: DeepMind forgets to apply the mask in this module.
73
+ if mask is None:
74
+ mask = z.new_ones(z.shape[:-1])
75
+
76
+ # [*, N_res, N_res, 1]
77
+ mask = mask.unsqueeze(-1)
78
+
79
+ z = self._transition(z=z, mask=mask)
80
+
81
+ return z
dockformer/model/primitives.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import importlib
16
+ import math
17
+ from typing import Optional, Callable, List, Tuple
18
+ import numpy as np
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils.checkpoint
23
+ from scipy.stats import truncnorm
24
+
25
+ from dockformer.utils.kernel.attention_core import attention_core
26
+ from dockformer.utils.precision_utils import is_fp16_enabled
27
+ from dockformer.utils.tensor_utils import (
28
+ permute_final_dims,
29
+ flatten_final_dims,
30
+ )
31
+
32
+
33
+ # Suited for 40gb GPU
34
+ # DEFAULT_LMA_Q_CHUNK_SIZE = 1024
35
+ # DEFAULT_LMA_KV_CHUNK_SIZE = 4096
36
+ # Suited for 10gb GPU
37
+ DEFAULT_LMA_Q_CHUNK_SIZE = 64
38
+ DEFAULT_LMA_KV_CHUNK_SIZE = 256
39
+
40
+
41
+ def _prod(nums):
42
+ out = 1
43
+ for n in nums:
44
+ out = out * n
45
+ return out
46
+
47
+
48
+ def _calculate_fan(linear_weight_shape, fan="fan_in"):
49
+ fan_out, fan_in = linear_weight_shape
50
+
51
+ if fan == "fan_in":
52
+ f = fan_in
53
+ elif fan == "fan_out":
54
+ f = fan_out
55
+ elif fan == "fan_avg":
56
+ f = (fan_in + fan_out) / 2
57
+ else:
58
+ raise ValueError("Invalid fan option")
59
+
60
+ return f
61
+
62
+
63
+ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
64
+ shape = weights.shape
65
+ f = _calculate_fan(shape, fan)
66
+ scale = scale / max(1, f)
67
+ a = -2
68
+ b = 2
69
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
70
+ size = _prod(shape)
71
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
72
+ samples = np.reshape(samples, shape)
73
+ with torch.no_grad():
74
+ weights.copy_(torch.tensor(samples, device=weights.device))
75
+
76
+
77
+ def lecun_normal_init_(weights):
78
+ trunc_normal_init_(weights, scale=1.0)
79
+
80
+
81
+ def he_normal_init_(weights):
82
+ trunc_normal_init_(weights, scale=2.0)
83
+
84
+
85
+ def glorot_uniform_init_(weights):
86
+ nn.init.xavier_uniform_(weights, gain=1)
87
+
88
+
89
+ def final_init_(weights):
90
+ with torch.no_grad():
91
+ weights.fill_(0.0)
92
+
93
+
94
+ def gating_init_(weights):
95
+ with torch.no_grad():
96
+ weights.fill_(0.0)
97
+
98
+
99
+ def normal_init_(weights):
100
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
101
+
102
+
103
+ def ipa_point_weights_init_(weights):
104
+ with torch.no_grad():
105
+ softplus_inverse_1 = 0.541324854612918
106
+ weights.fill_(softplus_inverse_1)
107
+
108
+
109
+ class Linear(nn.Linear):
110
+ """
111
+ A Linear layer with built-in nonstandard initializations. Called just
112
+ like torch.nn.Linear.
113
+
114
+ Implements the initializers in 1.11.4, plus some additional ones found
115
+ in the code.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ in_dim: int,
121
+ out_dim: int,
122
+ bias: bool = True,
123
+ init: str = "default",
124
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
125
+ precision=None
126
+ ):
127
+ """
128
+ Args:
129
+ in_dim:
130
+ The final dimension of inputs to the layer
131
+ out_dim:
132
+ The final dimension of layer outputs
133
+ bias:
134
+ Whether to learn an additive bias. True by default
135
+ init:
136
+ The initializer to use. Choose from:
137
+
138
+ "default": LeCun fan-in truncated normal initialization
139
+ "relu": He initialization w/ truncated normal distribution
140
+ "glorot": Fan-average Glorot uniform initialization
141
+ "gating": Weights=0, Bias=1
142
+ "normal": Normal initialization with std=1/sqrt(fan_in)
143
+ "final": Weights=0, Bias=0
144
+
145
+ Overridden by init_fn if the latter is not None.
146
+ init_fn:
147
+ A custom initializer taking weight and bias as inputs.
148
+ Overrides init if not None.
149
+ """
150
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
151
+
152
+ if bias:
153
+ with torch.no_grad():
154
+ self.bias.fill_(0)
155
+
156
+ with torch.no_grad():
157
+ if init_fn is not None:
158
+ init_fn(self.weight, self.bias)
159
+ else:
160
+ if init == "default":
161
+ lecun_normal_init_(self.weight)
162
+ elif init == "relu":
163
+ he_normal_init_(self.weight)
164
+ elif init == "glorot":
165
+ glorot_uniform_init_(self.weight)
166
+ elif init == "gating":
167
+ gating_init_(self.weight)
168
+ if bias:
169
+ self.bias.fill_(1.0)
170
+ elif init == "normal":
171
+ normal_init_(self.weight)
172
+ elif init == "final":
173
+ final_init_(self.weight)
174
+ else:
175
+ raise ValueError("Invalid init string.")
176
+
177
+ self.precision = precision
178
+
179
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
180
+ d = input.dtype
181
+ if self.precision is not None:
182
+ with torch.cuda.amp.autocast(enabled=False):
183
+ bias = self.bias.to(dtype=self.precision) if self.bias is not None else None
184
+ return nn.functional.linear(input.to(dtype=self.precision),
185
+ self.weight.to(dtype=self.precision),
186
+ bias).to(dtype=d)
187
+
188
+ if d is torch.bfloat16:
189
+ with torch.cuda.amp.autocast(enabled=False):
190
+ bias = self.bias.to(dtype=d) if self.bias is not None else None
191
+ return nn.functional.linear(input, self.weight.to(dtype=d), bias)
192
+
193
+ return nn.functional.linear(input, self.weight, self.bias)
194
+
195
+
196
+ class LayerNorm(nn.Module):
197
+ def __init__(self, c_in, eps=1e-5):
198
+ super(LayerNorm, self).__init__()
199
+
200
+ self.c_in = (c_in,)
201
+ self.eps = eps
202
+
203
+ self.weight = nn.Parameter(torch.ones(c_in))
204
+ self.bias = nn.Parameter(torch.zeros(c_in))
205
+
206
+ def forward(self, x):
207
+ d = x.dtype
208
+ if d is torch.bfloat16:
209
+ with torch.cuda.amp.autocast(enabled=False):
210
+ out = nn.functional.layer_norm(
211
+ x,
212
+ self.c_in,
213
+ self.weight.to(dtype=d),
214
+ self.bias.to(dtype=d),
215
+ self.eps
216
+ )
217
+ else:
218
+ out = nn.functional.layer_norm(
219
+ x,
220
+ self.c_in,
221
+ self.weight,
222
+ self.bias,
223
+ self.eps,
224
+ )
225
+
226
+ return out
227
+
228
+
229
+ @torch.jit.ignore
230
+ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
231
+ """
232
+ Softmax, but without automatic casting to fp32 when the input is of
233
+ type bfloat16
234
+ """
235
+ d = t.dtype
236
+ if d is torch.bfloat16:
237
+ with torch.cuda.amp.autocast(enabled=False):
238
+ s = torch.nn.functional.softmax(t, dim=dim)
239
+ else:
240
+ s = torch.nn.functional.softmax(t, dim=dim)
241
+
242
+ return s
243
+
244
+
245
+ #@torch.jit.script
246
+ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, biases: List[torch.Tensor]) -> torch.Tensor:
247
+ # [*, H, C_hidden, K]
248
+ key = permute_final_dims(key, (1, 0))
249
+
250
+ # [*, H, Q, K]
251
+ a = torch.matmul(query, key)
252
+
253
+ for b in biases:
254
+ a += b
255
+
256
+ a = softmax_no_cast(a, -1)
257
+
258
+ # [*, H, Q, C_hidden]
259
+ a = torch.matmul(a, value)
260
+
261
+ return a
262
+
263
+
264
+ class Attention(nn.Module):
265
+ """
266
+ Standard multi-head attention using AlphaFold's default layer
267
+ initialization. Allows multiple bias vectors.
268
+ """
269
+ def __init__(
270
+ self,
271
+ c_q: int,
272
+ c_k: int,
273
+ c_v: int,
274
+ c_hidden: int,
275
+ no_heads: int,
276
+ gating: bool = True,
277
+ ):
278
+ """
279
+ Args:
280
+ c_q:
281
+ Input dimension of query data
282
+ c_k:
283
+ Input dimension of key data
284
+ c_v:
285
+ Input dimension of value data
286
+ c_hidden:
287
+ Per-head hidden dimension
288
+ no_heads:
289
+ Number of attention heads
290
+ gating:
291
+ Whether the output should be gated using query data
292
+ """
293
+ super(Attention, self).__init__()
294
+
295
+ self.c_q = c_q
296
+ self.c_k = c_k
297
+ self.c_v = c_v
298
+ self.c_hidden = c_hidden
299
+ self.no_heads = no_heads
300
+ self.gating = gating
301
+
302
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
303
+ # stated in the supplement, but the overall channel dimension.
304
+
305
+ self.linear_q = Linear(
306
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
307
+ )
308
+ self.linear_k = Linear(
309
+ self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
310
+ )
311
+ self.linear_v = Linear(
312
+ self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
313
+ )
314
+ self.linear_o = Linear(
315
+ self.c_hidden * self.no_heads, self.c_q, init="final"
316
+ )
317
+
318
+ self.linear_g = None
319
+ if self.gating:
320
+ self.linear_g = Linear(
321
+ self.c_q, self.c_hidden * self.no_heads, init="gating"
322
+ )
323
+
324
+ self.sigmoid = nn.Sigmoid()
325
+
326
+ def _prep_qkv(self,
327
+ q_x: torch.Tensor,
328
+ kv_x: torch.Tensor,
329
+ apply_scale: bool = True
330
+ ) -> Tuple[
331
+ torch.Tensor, torch.Tensor, torch.Tensor
332
+ ]:
333
+ # [*, Q/K/V, H * C_hidden]
334
+ q = self.linear_q(q_x)
335
+ k = self.linear_k(kv_x)
336
+ v = self.linear_v(kv_x)
337
+
338
+ # [*, Q/K, H, C_hidden]
339
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
340
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
341
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
342
+
343
+ # [*, H, Q/K, C_hidden]
344
+ q = q.transpose(-2, -3)
345
+ k = k.transpose(-2, -3)
346
+ v = v.transpose(-2, -3)
347
+
348
+ if apply_scale:
349
+ q /= math.sqrt(self.c_hidden)
350
+
351
+ return q, k, v
352
+
353
+ def _wrap_up(self,
354
+ o: torch.Tensor,
355
+ q_x: torch.Tensor
356
+ ) -> torch.Tensor:
357
+ if self.linear_g is not None:
358
+ g = self.sigmoid(self.linear_g(q_x))
359
+
360
+ # [*, Q, H, C_hidden]
361
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
362
+ o = o * g
363
+
364
+ # [*, Q, H * C_hidden]
365
+ o = flatten_final_dims(o, 2)
366
+
367
+ # [*, Q, C_q]
368
+ o = self.linear_o(o)
369
+
370
+ return o
371
+
372
+ def forward(
373
+ self,
374
+ q_x: torch.Tensor,
375
+ kv_x: torch.Tensor,
376
+ biases: Optional[List[torch.Tensor]] = None,
377
+ use_memory_efficient_kernel: bool = False,
378
+ use_lma: bool = False,
379
+ lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
380
+ lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
381
+ ) -> torch.Tensor:
382
+ """
383
+ Args:
384
+ q_x:
385
+ [*, Q, C_q] query data
386
+ kv_x:
387
+ [*, K, C_k] key data
388
+ biases:
389
+ List of biases that broadcast to [*, H, Q, K]
390
+ use_memory_efficient_kernel:
391
+ Whether to use a custom memory-efficient attention kernel.
392
+ This should be the default choice for most. If none of the
393
+ "use_<...>" flags are True, a stock PyTorch implementation
394
+ is used instead
395
+ use_lma:
396
+ Whether to use low-memory attention (Staats & Rabe 2021). If
397
+ none of the "use_<...>" flags are True, a stock PyTorch
398
+ implementation is used instead
399
+ lma_q_chunk_size:
400
+ Query chunk size (for LMA)
401
+ lma_kv_chunk_size:
402
+ Key/Value chunk size (for LMA)
403
+ Returns
404
+ [*, Q, C_q] attention update
405
+ """
406
+ if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
407
+ raise ValueError(
408
+ "If use_lma is specified, lma_q_chunk_size and "
409
+ "lma_kv_chunk_size must be provided"
410
+ )
411
+
412
+ attn_options = [use_memory_efficient_kernel, use_lma]
413
+ if sum(attn_options) > 1:
414
+ raise ValueError(
415
+ "Choose at most one alternative attention algorithm"
416
+ )
417
+
418
+ if biases is None:
419
+ biases = []
420
+
421
+ q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=True)
422
+
423
+ if is_fp16_enabled():
424
+ use_memory_efficient_kernel = False
425
+
426
+ if use_memory_efficient_kernel:
427
+ if len(biases) > 2:
428
+ raise ValueError(
429
+ "If use_memory_efficient_kernel is True, you may only "
430
+ "provide up to two bias terms"
431
+ )
432
+ o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
433
+ o = o.transpose(-2, -3)
434
+ elif use_lma:
435
+ biases = [
436
+ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
437
+ for b in biases
438
+ ]
439
+ o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
440
+ o = o.transpose(-2, -3)
441
+ else:
442
+ o = _attention(q, k, v, biases)
443
+ o = o.transpose(-2, -3)
444
+
445
+ o = self._wrap_up(o, q_x)
446
+
447
+ return o
448
+
449
+
450
+ class GlobalAttention(nn.Module):
451
+ def __init__(self, c_in, c_hidden, no_heads, inf, eps):
452
+ super(GlobalAttention, self).__init__()
453
+
454
+ self.c_in = c_in
455
+ self.c_hidden = c_hidden
456
+ self.no_heads = no_heads
457
+ self.inf = inf
458
+ self.eps = eps
459
+
460
+ self.linear_q = Linear(
461
+ c_in, c_hidden * no_heads, bias=False, init="glorot"
462
+ )
463
+
464
+ self.linear_k = Linear(
465
+ c_in, c_hidden, bias=False, init="glorot",
466
+ )
467
+ self.linear_v = Linear(
468
+ c_in, c_hidden, bias=False, init="glorot",
469
+ )
470
+ self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
471
+ self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
472
+
473
+ self.sigmoid = nn.Sigmoid()
474
+
475
+ def forward(self,
476
+ m: torch.Tensor,
477
+ mask: torch.Tensor,
478
+ use_lma: bool = False,
479
+ ) -> torch.Tensor:
480
+ # [*, N_res, C_in]
481
+ q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
482
+ torch.sum(mask, dim=-1)[..., None] + self.eps
483
+ )
484
+
485
+ # [*, N_res, H * C_hidden]
486
+ q = self.linear_q(q)
487
+ q *= (self.c_hidden ** (-0.5))
488
+
489
+ # [*, N_res, H, C_hidden]
490
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
491
+
492
+ # [*, N_res, C_hidden]
493
+ k = self.linear_k(m)
494
+ v = self.linear_v(m)
495
+
496
+ bias = (self.inf * (mask - 1))[..., :, None, :]
497
+ if not use_lma:
498
+ # [*, N_res, H, N_seq]
499
+ a = torch.matmul(
500
+ q,
501
+ k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
502
+ )
503
+ a += bias
504
+ a = softmax_no_cast(a)
505
+
506
+ # [*, N_res, H, C_hidden]
507
+ o = torch.matmul(
508
+ a,
509
+ v,
510
+ )
511
+ else:
512
+ o = _lma(
513
+ q,
514
+ k,
515
+ v,
516
+ [bias],
517
+ DEFAULT_LMA_Q_CHUNK_SIZE,
518
+ DEFAULT_LMA_KV_CHUNK_SIZE
519
+ )
520
+
521
+ # [*, N_res, C_hidden]
522
+ g = self.sigmoid(self.linear_g(m))
523
+
524
+ # [*, N_res, H, C_hidden]
525
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
526
+
527
+ # [*, N_res, H, C_hidden]
528
+ o = o.unsqueeze(-3) * g
529
+
530
+ # [*, N_res, H * C_hidden]
531
+ o = o.reshape(o.shape[:-2] + (-1,))
532
+
533
+ # [*, N_res, C_in]
534
+ m = self.linear_o(o)
535
+
536
+ return m
537
+
538
+
539
+ def _lma(
540
+ q: torch.Tensor,
541
+ k: torch.Tensor,
542
+ v: torch.Tensor,
543
+ biases: List[torch.Tensor],
544
+ q_chunk_size: int,
545
+ kv_chunk_size: int,
546
+ ):
547
+ no_q, no_kv = q.shape[-2], k.shape[-2]
548
+
549
+ # [*, H, Q, C_hidden]
550
+ o = q.new_zeros(q.shape)
551
+ for q_s in range(0, no_q, q_chunk_size):
552
+ q_chunk = q[..., q_s: q_s + q_chunk_size, :]
553
+ large_bias_chunks = [
554
+ b[..., q_s: q_s + q_chunk_size, :] for b in biases
555
+ ]
556
+
557
+ maxes = []
558
+ weights = []
559
+ values = []
560
+ for kv_s in range(0, no_kv, kv_chunk_size):
561
+ k_chunk = k[..., kv_s: kv_s + kv_chunk_size, :]
562
+ v_chunk = v[..., kv_s: kv_s + kv_chunk_size, :]
563
+ small_bias_chunks = [
564
+ b[..., kv_s: kv_s + kv_chunk_size] for b in large_bias_chunks
565
+ ]
566
+
567
+ a = torch.einsum(
568
+ "...hqd,...hkd->...hqk", q_chunk, k_chunk,
569
+ )
570
+
571
+ for b in small_bias_chunks:
572
+ a += b
573
+
574
+ max_a = torch.max(a, dim=-1, keepdim=True)[0]
575
+ exp_a = torch.exp(a - max_a)
576
+ exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
577
+
578
+ maxes.append(max_a.detach().squeeze(-1))
579
+ weights.append(torch.sum(exp_a, dim=-1))
580
+ values.append(exp_v)
581
+
582
+ chunk_max = torch.stack(maxes, dim=-3)
583
+ chunk_weights = torch.stack(weights, dim=-3)
584
+ chunk_values = torch.stack(values, dim=-4)
585
+
586
+ global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
587
+ max_diffs = torch.exp(chunk_max - global_max)
588
+ chunk_values = chunk_values * max_diffs.unsqueeze(-1)
589
+ chunk_weights = chunk_weights * max_diffs
590
+
591
+ all_values = torch.sum(chunk_values, dim=-4)
592
+ all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
593
+
594
+ q_chunk_out = all_values / all_weights
595
+
596
+ o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out
597
+
598
+ return o
dockformer/model/single_attention.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import partial
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+ from typing import Optional, List, Tuple
21
+
22
+ from dockformer.model.primitives import (
23
+ Linear,
24
+ LayerNorm,
25
+ Attention,
26
+ )
27
+ from dockformer.utils.tensor_utils import permute_final_dims
28
+
29
+
30
+ class SingleAttention(nn.Module):
31
+ def __init__(
32
+ self,
33
+ c_in,
34
+ c_hidden,
35
+ no_heads,
36
+ pair_bias=False,
37
+ c_z=None,
38
+ inf=1e9,
39
+ ):
40
+ """
41
+ Args:
42
+ c_in:
43
+ Input channel dimension
44
+ c_hidden:
45
+ Per-head hidden channel dimension
46
+ no_heads:
47
+ Number of attention heads
48
+ pair_bias:
49
+ Whether to use pair embedding bias
50
+ c_z:
51
+ Pair embedding channel dimension. Ignored unless pair_bias
52
+ is true
53
+ inf:
54
+ A large number to be used in computing the attention mask
55
+ """
56
+ super(SingleAttention, self).__init__()
57
+
58
+ self.c_in = c_in
59
+ self.c_hidden = c_hidden
60
+ self.no_heads = no_heads
61
+ self.pair_bias = pair_bias
62
+ self.c_z = c_z
63
+ self.inf = inf
64
+
65
+ self.layer_norm_m = LayerNorm(self.c_in)
66
+
67
+ self.layer_norm_z = None
68
+ self.linear_z = None
69
+ if self.pair_bias:
70
+ self.layer_norm_z = LayerNorm(self.c_z)
71
+ self.linear_z = Linear(
72
+ self.c_z, self.no_heads, bias=False, init="normal"
73
+ )
74
+
75
+ self.mha = Attention(
76
+ self.c_in,
77
+ self.c_in,
78
+ self.c_in,
79
+ self.c_hidden,
80
+ self.no_heads,
81
+ )
82
+
83
+ def _prep_inputs(self,
84
+ m: torch.Tensor,
85
+ z: Optional[torch.Tensor],
86
+ mask: Optional[torch.Tensor],
87
+ inplace_safe: bool = False,
88
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
89
+ if mask is None:
90
+ # [*, N_res]
91
+ mask = m.new_ones(m.shape[:-1])
92
+
93
+ # [*, 1, 1, N_res]
94
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
95
+
96
+ if (self.pair_bias and
97
+ z is not None and # For the
98
+ self.layer_norm_z is not None and # benefit of
99
+ self.linear_z is not None # TorchScript
100
+ ):
101
+ chunks = []
102
+
103
+ for i in range(0, z.shape[-3], 256):
104
+ z_chunk = z[..., i: i + 256, :, :]
105
+
106
+ # [*, N_res, N_res, C_z]
107
+ z_chunk = self.layer_norm_z(z_chunk)
108
+
109
+ # [*, N_res, N_res, no_heads]
110
+ z_chunk = self.linear_z(z_chunk)
111
+
112
+ chunks.append(z_chunk)
113
+
114
+ z = torch.cat(chunks, dim=-3)
115
+
116
+ # [*, no_heads, N_res, N_res]
117
+ z = permute_final_dims(z, (2, 0, 1))
118
+
119
+ return m, mask_bias, z
120
+
121
+ def forward(self,
122
+ m: torch.Tensor,
123
+ z: Optional[torch.Tensor] = None,
124
+ mask: Optional[torch.Tensor] = None,
125
+ use_memory_efficient_kernel: bool = False,
126
+ use_lma: bool = False,
127
+ inplace_safe: bool = False,
128
+ ) -> torch.Tensor:
129
+ """
130
+ Args:
131
+ m:
132
+ [*, N_res, C_m] single embedding
133
+ z:
134
+ [*, N_res, N_res, C_z] pair embedding. Required only if pair_bias is True
135
+ mask:
136
+ [*, N_res] single mask
137
+ """
138
+ m, mask_bias, z = self._prep_inputs(
139
+ m, z, mask, inplace_safe=inplace_safe
140
+ )
141
+
142
+ biases = [mask_bias]
143
+ if(z is not None):
144
+ biases.append(z)
145
+
146
+ m = self.layer_norm_m(m)
147
+ m = self.mha(
148
+ q_x=m,
149
+ kv_x=m,
150
+ biases=biases,
151
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
152
+ use_lma=use_lma,
153
+ )
154
+
155
+ return m
156
+
157
+
158
+ class SingleRowAttentionWithPairBias(SingleAttention):
159
+ """
160
+ Implements Algorithm 7.
161
+ """
162
+
163
+ def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
164
+ """
165
+ Args:
166
+ c_m:
167
+ Input channel dimension
168
+ c_z:
169
+ Pair embedding channel dimension
170
+ c_hidden:
171
+ Per-head hidden channel dimension
172
+ no_heads:
173
+ Number of attention heads
174
+ inf:
175
+ Large number used to construct attention masks
176
+ """
177
+ super(SingleRowAttentionWithPairBias, self).__init__(
178
+ c_m,
179
+ c_hidden,
180
+ no_heads,
181
+ pair_bias=True,
182
+ c_z=c_z,
183
+ inf=inf,
184
+ )
dockformer/model/structure_module.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from functools import reduce
16
+ import importlib
17
+ import math
18
+ import sys
19
+ from operator import mul
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from typing import Optional, Tuple, Sequence, Union
24
+
25
+ from dockformer.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
26
+ from dockformer.utils.residue_constants import (
27
+ restype_rigid_group_default_frame,
28
+ restype_atom14_to_rigid_group,
29
+ restype_atom14_mask,
30
+ restype_atom14_rigid_group_positions,
31
+ )
32
+ from dockformer.utils.geometry.quat_rigid import QuatRigid
33
+ from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array
34
+ from dockformer.utils.geometry.vector import Vec3Array, square_euclidean_distance
35
+ from dockformer.utils.feats import (
36
+ frames_and_literature_positions_to_atom14_pos,
37
+ torsion_angles_to_frames,
38
+ )
39
+ from dockformer.utils.precision_utils import is_fp16_enabled
40
+ from dockformer.utils.rigid_utils import Rotation, Rigid
41
+ from dockformer.utils.tensor_utils import (
42
+ dict_multimap,
43
+ permute_final_dims,
44
+ flatten_final_dims,
45
+ )
46
+
47
+ import importlib.util
48
+ attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None
49
+ attn_core_inplace_cuda = None
50
+ if attn_core_is_installed:
51
+ attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
52
+
53
+
54
+ class AngleResnetBlock(nn.Module):
55
+ def __init__(self, c_hidden):
56
+ """
57
+ Args:
58
+ c_hidden:
59
+ Hidden channel dimension
60
+ """
61
+ super(AngleResnetBlock, self).__init__()
62
+
63
+ self.c_hidden = c_hidden
64
+
65
+ self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
66
+ self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
67
+
68
+ self.relu = nn.ReLU()
69
+
70
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
71
+
72
+ s_initial = a
73
+
74
+ a = self.relu(a)
75
+ a = self.linear_1(a)
76
+ a = self.relu(a)
77
+ a = self.linear_2(a)
78
+
79
+ return a + s_initial
80
+
81
+
82
+ class AngleResnet(nn.Module):
83
+ """
84
+ Implements Algorithm 20, lines 11-14
85
+ """
86
+
87
+ def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
88
+ """
89
+ Args:
90
+ c_in:
91
+ Input channel dimension
92
+ c_hidden:
93
+ Hidden channel dimension
94
+ no_blocks:
95
+ Number of resnet blocks
96
+ no_angles:
97
+ Number of torsion angles to generate
98
+ epsilon:
99
+ Small constant for normalization
100
+ """
101
+ super(AngleResnet, self).__init__()
102
+
103
+ self.c_in = c_in
104
+ self.c_hidden = c_hidden
105
+ self.no_blocks = no_blocks
106
+ self.no_angles = no_angles
107
+ self.eps = epsilon
108
+
109
+ self.linear_in = Linear(self.c_in, self.c_hidden)
110
+ self.linear_initial = Linear(self.c_in, self.c_hidden)
111
+
112
+ self.layers = nn.ModuleList()
113
+ for _ in range(self.no_blocks):
114
+ layer = AngleResnetBlock(c_hidden=self.c_hidden)
115
+ self.layers.append(layer)
116
+
117
+ self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
118
+
119
+ self.relu = nn.ReLU()
120
+
121
+ def forward(
122
+ self, s: torch.Tensor, s_initial: torch.Tensor
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ """
125
+ Args:
126
+ s:
127
+ [*, C_hidden] single embedding
128
+ s_initial:
129
+ [*, C_hidden] single embedding as of the start of the
130
+ StructureModule
131
+ Returns:
132
+ [*, no_angles, 2] predicted angles
133
+ """
134
+ # NOTE: The ReLU's applied to the inputs are absent from the supplement
135
+ # pseudocode but present in the source. For maximal compatibility with
136
+ # the pretrained weights, I'm going with the source.
137
+
138
+ # [*, C_hidden]
139
+ s_initial = self.relu(s_initial)
140
+ s_initial = self.linear_initial(s_initial)
141
+ s = self.relu(s)
142
+ s = self.linear_in(s)
143
+ s = s + s_initial
144
+
145
+ for l in self.layers:
146
+ s = l(s)
147
+
148
+ s = self.relu(s)
149
+
150
+ # [*, no_angles * 2]
151
+ s = self.linear_out(s)
152
+
153
+ # [*, no_angles, 2]
154
+ s = s.view(s.shape[:-1] + (-1, 2))
155
+
156
+ unnormalized_s = s
157
+ norm_denom = torch.sqrt(
158
+ torch.clamp(
159
+ torch.sum(s ** 2, dim=-1, keepdim=True),
160
+ min=self.eps,
161
+ )
162
+ )
163
+ s = s / norm_denom
164
+
165
+ return unnormalized_s, s
166
+
167
+
168
+ class PointProjection(nn.Module):
169
+ def __init__(self,
170
+ c_hidden: int,
171
+ num_points: int,
172
+ no_heads: int,
173
+ return_local_points: bool = False,
174
+ ):
175
+ super().__init__()
176
+ self.return_local_points = return_local_points
177
+ self.no_heads = no_heads
178
+ self.num_points = num_points
179
+
180
+ # Multimer requires this to be run with fp32 precision during training
181
+ precision = None
182
+ self.linear = Linear(c_hidden, no_heads * 3 * num_points, precision=precision)
183
+
184
+ def forward(self,
185
+ activations: torch.Tensor,
186
+ rigids: Union[Rigid, Rigid3Array],
187
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
188
+ # TODO: Needs to run in high precision during training
189
+ points_local = self.linear(activations)
190
+ out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
191
+
192
+ points_local = torch.split(
193
+ points_local, points_local.shape[-1] // 3, dim=-1
194
+ )
195
+
196
+ points_local = torch.stack(points_local, dim=-1).view(out_shape)
197
+
198
+ points_global = rigids[..., None, None].apply(points_local)
199
+
200
+ if(self.return_local_points):
201
+ return points_global, points_local
202
+
203
+ return points_global
204
+
205
+
206
+ class InvariantPointAttention(nn.Module):
207
+ """
208
+ Implements Algorithm 22.
209
+ """
210
+ def __init__(
211
+ self,
212
+ c_s: int,
213
+ c_z: int,
214
+ c_hidden: int,
215
+ no_heads: int,
216
+ no_qk_points: int,
217
+ no_v_points: int,
218
+ inf: float = 1e5,
219
+ eps: float = 1e-8,
220
+ ):
221
+ """
222
+ Args:
223
+ c_s:
224
+ Single representation channel dimension
225
+ c_z:
226
+ Pair representation channel dimension
227
+ c_hidden:
228
+ Hidden channel dimension
229
+ no_heads:
230
+ Number of attention heads
231
+ no_qk_points:
232
+ Number of query/key points to generate
233
+ no_v_points:
234
+ Number of value points to generate
235
+ """
236
+ super(InvariantPointAttention, self).__init__()
237
+
238
+ self.c_s = c_s
239
+ self.c_z = c_z
240
+ self.c_hidden = c_hidden
241
+ self.no_heads = no_heads
242
+ self.no_qk_points = no_qk_points
243
+ self.no_v_points = no_v_points
244
+ self.inf = inf
245
+ self.eps = eps
246
+
247
+ # These linear layers differ from their specifications in the
248
+ # supplement. There, they lack bias and use Glorot initialization.
249
+ # Here as in the official source, they have bias and use the default
250
+ # Lecun initialization.
251
+ hc = self.c_hidden * self.no_heads
252
+ self.linear_q = Linear(self.c_s, hc, bias=True)
253
+
254
+ self.linear_q_points = PointProjection(
255
+ self.c_s,
256
+ self.no_qk_points,
257
+ self.no_heads,
258
+ )
259
+
260
+
261
+ self.linear_kv = Linear(self.c_s, 2 * hc)
262
+ self.linear_kv_points = PointProjection(
263
+ self.c_s,
264
+ self.no_qk_points + self.no_v_points,
265
+ self.no_heads,
266
+ )
267
+
268
+ self.linear_b = Linear(self.c_z, self.no_heads)
269
+
270
+ self.head_weights = nn.Parameter(torch.zeros((no_heads)))
271
+ ipa_point_weights_init_(self.head_weights)
272
+
273
+ concat_out_dim = self.no_heads * (
274
+ self.c_z + self.c_hidden + self.no_v_points * 4
275
+ )
276
+ self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
277
+
278
+ self.softmax = nn.Softmax(dim=-1)
279
+ self.softplus = nn.Softplus()
280
+
281
+ def forward(
282
+ self,
283
+ s: torch.Tensor,
284
+ z: torch.Tensor,
285
+ r: Union[Rigid, Rigid3Array],
286
+ mask: torch.Tensor,
287
+ inplace_safe: bool = False,
288
+ ) -> torch.Tensor:
289
+ """
290
+ Args:
291
+ s:
292
+ [*, N_res, C_s] single representation
293
+ z:
294
+ [*, N_res, N_res, C_z] pair representation
295
+ r:
296
+ [*, N_res] transformation object
297
+ mask:
298
+ [*, N_res] mask
299
+ Returns:
300
+ [*, N_res, C_s] single representation update
301
+ """
302
+ z = [z]
303
+
304
+ #######################################
305
+ # Generate scalar and point activations
306
+ #######################################
307
+ # [*, N_res, H * C_hidden]
308
+ q = self.linear_q(s)
309
+
310
+ # [*, N_res, H, C_hidden]
311
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
312
+
313
+ # [*, N_res, H, P_qk]
314
+ q_pts = self.linear_q_points(s, r)
315
+
316
+ # The following two blocks are equivalent
317
+ # They're separated only to preserve compatibility with old AF weights
318
+
319
+ # [*, N_res, H * 2 * C_hidden]
320
+ kv = self.linear_kv(s)
321
+
322
+ # [*, N_res, H, 2 * C_hidden]
323
+ kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
324
+
325
+ # [*, N_res, H, C_hidden]
326
+ k, v = torch.split(kv, self.c_hidden, dim=-1)
327
+
328
+ kv_pts = self.linear_kv_points(s, r)
329
+
330
+ # [*, N_res, H, P_q/P_v, 3]
331
+ k_pts, v_pts = torch.split(
332
+ kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
333
+ )
334
+
335
+ ##########################
336
+ # Compute attention scores
337
+ ##########################
338
+ # [*, N_res, N_res, H]
339
+ b = self.linear_b(z[0])
340
+
341
+ # [*, H, N_res, N_res]
342
+ if (is_fp16_enabled()):
343
+ with torch.cuda.amp.autocast(enabled=False):
344
+ a = torch.matmul(
345
+ permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
346
+ permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
347
+ )
348
+ else:
349
+ a = torch.matmul(
350
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
351
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
352
+ )
353
+
354
+ a *= math.sqrt(1.0 / (3 * self.c_hidden))
355
+ a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
356
+
357
+ # [*, N_res, N_res, H, P_q, 3]
358
+ pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
359
+
360
+ if (inplace_safe):
361
+ pt_att *= pt_att
362
+ else:
363
+ pt_att = pt_att ** 2
364
+
365
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
366
+
367
+ head_weights = self.softplus(self.head_weights).view(
368
+ *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
369
+ )
370
+ head_weights = head_weights * math.sqrt(
371
+ 1.0 / (3 * (self.no_qk_points * 9.0 / 2))
372
+ )
373
+
374
+ if (inplace_safe):
375
+ pt_att *= head_weights
376
+ else:
377
+ pt_att = pt_att * head_weights
378
+
379
+ # [*, N_res, N_res, H]
380
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
381
+
382
+ # [*, N_res, N_res]
383
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
384
+ square_mask = self.inf * (square_mask - 1)
385
+
386
+ # [*, H, N_res, N_res]
387
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
388
+
389
+ if (inplace_safe):
390
+ a += pt_att
391
+ del pt_att
392
+ a += square_mask.unsqueeze(-3)
393
+ # in-place softmax
394
+ attn_core_inplace_cuda.forward_(
395
+ a,
396
+ reduce(mul, a.shape[:-1]),
397
+ a.shape[-1],
398
+ )
399
+ else:
400
+ a = a + pt_att
401
+ a = a + square_mask.unsqueeze(-3)
402
+ a = self.softmax(a)
403
+
404
+ ################
405
+ # Compute output
406
+ ################
407
+ # [*, N_res, H, C_hidden]
408
+ o = torch.matmul(
409
+ a, v.transpose(-2, -3).to(dtype=a.dtype)
410
+ ).transpose(-2, -3)
411
+
412
+ # [*, N_res, H * C_hidden]
413
+ o = flatten_final_dims(o, 2)
414
+
415
+ # [*, H, 3, N_res, P_v]
416
+ if (inplace_safe):
417
+ v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
418
+ o_pt = [
419
+ torch.matmul(a, v.to(a.dtype))
420
+ for v in torch.unbind(v_pts, dim=-3)
421
+ ]
422
+ o_pt = torch.stack(o_pt, dim=-3)
423
+ else:
424
+ o_pt = torch.sum(
425
+ (
426
+ a[..., None, :, :, None]
427
+ * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
428
+ ),
429
+ dim=-2,
430
+ )
431
+
432
+ # [*, N_res, H, P_v, 3]
433
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
434
+ o_pt = r[..., None, None].invert_apply(o_pt)
435
+
436
+ # [*, N_res, H * P_v]
437
+ o_pt_norm = flatten_final_dims(
438
+ torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
439
+ )
440
+
441
+ # [*, N_res, H * P_v, 3]
442
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
443
+ o_pt = torch.unbind(o_pt, dim=-1)
444
+
445
+ # [*, N_res, H, C_z]
446
+ o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
447
+
448
+ # [*, N_res, H * C_z]
449
+ o_pair = flatten_final_dims(o_pair, 2)
450
+
451
+ # [*, N_res, C_s]
452
+ s = self.linear_out(
453
+ torch.cat(
454
+ (o, *o_pt, o_pt_norm, o_pair), dim=-1
455
+ ).to(dtype=z[0].dtype)
456
+ )
457
+
458
+ return s
459
+
460
+
461
+ class BackboneUpdate(nn.Module):
462
+ """
463
+ Implements part of Algorithm 23.
464
+ """
465
+
466
+ def __init__(self, c_s):
467
+ """
468
+ Args:
469
+ c_s:
470
+ Single representation channel dimension
471
+ """
472
+ super(BackboneUpdate, self).__init__()
473
+
474
+ self.c_s = c_s
475
+
476
+ self.linear = Linear(self.c_s, 6, init="final")
477
+
478
+ def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
479
+ """
480
+ Args:
481
+ [*, N_res, C_s] single representation
482
+ Returns:
483
+ [*, N_res, 6] update vector
484
+ """
485
+ # [*, 6]
486
+ update = self.linear(s)
487
+
488
+ return update
489
+
490
+
491
+ class StructureModuleTransitionLayer(nn.Module):
492
+ def __init__(self, c):
493
+ super(StructureModuleTransitionLayer, self).__init__()
494
+
495
+ self.c = c
496
+
497
+ self.linear_1 = Linear(self.c, self.c, init="relu")
498
+ self.linear_2 = Linear(self.c, self.c, init="relu")
499
+ self.linear_3 = Linear(self.c, self.c, init="final")
500
+
501
+ self.relu = nn.ReLU()
502
+
503
+ def forward(self, s):
504
+ s_initial = s
505
+ s = self.linear_1(s)
506
+ s = self.relu(s)
507
+ s = self.linear_2(s)
508
+ s = self.relu(s)
509
+ s = self.linear_3(s)
510
+
511
+ s = s + s_initial
512
+
513
+ return s
514
+
515
+
516
+ class StructureModuleTransition(nn.Module):
517
+ def __init__(self, c, num_layers, dropout_rate):
518
+ super(StructureModuleTransition, self).__init__()
519
+
520
+ self.c = c
521
+ self.num_layers = num_layers
522
+ self.dropout_rate = dropout_rate
523
+
524
+ self.layers = nn.ModuleList()
525
+ for _ in range(self.num_layers):
526
+ l = StructureModuleTransitionLayer(self.c)
527
+ self.layers.append(l)
528
+
529
+ self.dropout = nn.Dropout(self.dropout_rate)
530
+ self.layer_norm = LayerNorm(self.c)
531
+
532
+ def forward(self, s):
533
+ for l in self.layers:
534
+ s = l(s)
535
+
536
+ s = self.dropout(s)
537
+ s = self.layer_norm(s)
538
+
539
+ return s
540
+
541
+
542
+ class StructureModule(nn.Module):
543
+ def __init__(
544
+ self,
545
+ c_s,
546
+ c_z,
547
+ c_ipa,
548
+ c_resnet,
549
+ no_heads_ipa,
550
+ no_qk_points,
551
+ no_v_points,
552
+ dropout_rate,
553
+ no_blocks,
554
+ no_transition_layers,
555
+ no_resnet_blocks,
556
+ no_angles,
557
+ trans_scale_factor,
558
+ epsilon,
559
+ inf,
560
+ **kwargs,
561
+ ):
562
+ """
563
+ Args:
564
+ c_s:
565
+ Single representation channel dimension
566
+ c_z:
567
+ Pair representation channel dimension
568
+ c_ipa:
569
+ IPA hidden channel dimension
570
+ c_resnet:
571
+ Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
572
+ no_heads_ipa:
573
+ Number of IPA heads
574
+ no_qk_points:
575
+ Number of query/key points to generate during IPA
576
+ no_v_points:
577
+ Number of value points to generate during IPA
578
+ dropout_rate:
579
+ Dropout rate used throughout the layer
580
+ no_blocks:
581
+ Number of structure module blocks
582
+ no_transition_layers:
583
+ Number of layers in the single representation transition
584
+ (Alg. 23 lines 8-9)
585
+ no_resnet_blocks:
586
+ Number of blocks in the angle resnet
587
+ no_angles:
588
+ Number of angles to generate in the angle resnet
589
+ trans_scale_factor:
590
+ Scale of single representation transition hidden dimension
591
+ epsilon:
592
+ Small number used in angle resnet normalization
593
+ inf:
594
+ Large number used for attention masking
595
+ """
596
+ super(StructureModule, self).__init__()
597
+
598
+ self.c_s = c_s
599
+ self.c_z = c_z
600
+ self.c_ipa = c_ipa
601
+ self.c_resnet = c_resnet
602
+ self.no_heads_ipa = no_heads_ipa
603
+ self.no_qk_points = no_qk_points
604
+ self.no_v_points = no_v_points
605
+ self.dropout_rate = dropout_rate
606
+ self.no_blocks = no_blocks
607
+ self.no_transition_layers = no_transition_layers
608
+ self.no_resnet_blocks = no_resnet_blocks
609
+ self.no_angles = no_angles
610
+ self.trans_scale_factor = trans_scale_factor
611
+ self.epsilon = epsilon
612
+ self.inf = inf
613
+
614
+ # Buffers to be lazily initialized later
615
+ # self.default_frames
616
+ # self.group_idx
617
+ # self.atom_mask
618
+ # self.lit_positions
619
+
620
+ self.layer_norm_s = LayerNorm(self.c_s)
621
+ self.layer_norm_z = LayerNorm(self.c_z)
622
+
623
+ self.linear_in = Linear(self.c_s, self.c_s)
624
+
625
+ self.ipa = InvariantPointAttention(
626
+ self.c_s,
627
+ self.c_z,
628
+ self.c_ipa,
629
+ self.no_heads_ipa,
630
+ self.no_qk_points,
631
+ self.no_v_points,
632
+ inf=self.inf,
633
+ eps=self.epsilon,
634
+ )
635
+
636
+ self.ipa_dropout = nn.Dropout(self.dropout_rate)
637
+ self.layer_norm_ipa = LayerNorm(self.c_s)
638
+
639
+ self.transition = StructureModuleTransition(
640
+ self.c_s,
641
+ self.no_transition_layers,
642
+ self.dropout_rate,
643
+ )
644
+
645
+ self.bb_update = BackboneUpdate(self.c_s)
646
+
647
+ self.angle_resnet = AngleResnet(
648
+ self.c_s,
649
+ self.c_resnet,
650
+ self.no_resnet_blocks,
651
+ self.no_angles,
652
+ self.epsilon,
653
+ )
654
+
655
+ def forward(
656
+ self,
657
+ evoformer_output_dict,
658
+ aatype,
659
+ mask=None,
660
+ inplace_safe=False,
661
+ ):
662
+ """
663
+ Args:
664
+ evoformer_output_dict:
665
+ Dictionary containing:
666
+ "single":
667
+ [*, N_res, C_s] single representation
668
+ "pair":
669
+ [*, N_res, N_res, C_z] pair representation
670
+ aatype:
671
+ [*, N_res] amino acid indices
672
+ mask:
673
+ Optional [*, N_res] sequence mask
674
+ Returns:
675
+ A dictionary of outputs
676
+ """
677
+ s = evoformer_output_dict["single"]
678
+
679
+ if mask is None:
680
+ # [*, N]
681
+ mask = s.new_ones(s.shape[:-1])
682
+
683
+ # [*, N, C_s]
684
+ s = self.layer_norm_s(s)
685
+
686
+ # [*, N, N, C_z]
687
+ z = self.layer_norm_z(evoformer_output_dict["pair"])
688
+
689
+ # [*, N, C_s]
690
+ s_initial = s
691
+ s = self.linear_in(s)
692
+
693
+ # [*, N]
694
+ rigids = Rigid.identity(
695
+ s.shape[:-1],
696
+ s.dtype,
697
+ s.device,
698
+ self.training,
699
+ fmt="quat",
700
+ )
701
+ outputs = []
702
+ for i in range(self.no_blocks):
703
+ # [*, N, C_s]
704
+ s = s + self.ipa(
705
+ s,
706
+ z,
707
+ rigids,
708
+ mask,
709
+ inplace_safe=inplace_safe,
710
+ )
711
+ s = self.ipa_dropout(s)
712
+ s = self.layer_norm_ipa(s)
713
+ s = self.transition(s)
714
+
715
+ # [*, N]
716
+
717
+ # [*, N_res, 6] vector of translations and rotations
718
+ bb_update_output = self.bb_update(s)
719
+
720
+ rigids = rigids.compose_q_update_vec(bb_update_output)
721
+
722
+
723
+ # To hew as closely as possible to AlphaFold, we convert our
724
+ # quaternion-based transformations to rotation-matrix ones
725
+ # here
726
+ backb_to_global = Rigid(
727
+ Rotation(
728
+ rot_mats=rigids.get_rots().get_rot_mats(),
729
+ quats=None
730
+ ),
731
+ rigids.get_trans(),
732
+ )
733
+
734
+ backb_to_global = backb_to_global.scale_translation(
735
+ self.trans_scale_factor
736
+ )
737
+
738
+ # [*, N, 7, 2]
739
+ unnormalized_angles, angles = self.angle_resnet(s, s_initial)
740
+
741
+ all_frames_to_global = self.torsion_angles_to_frames(
742
+ backb_to_global,
743
+ angles,
744
+ aatype,
745
+ )
746
+
747
+ pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
748
+ all_frames_to_global,
749
+ aatype,
750
+ )
751
+
752
+ scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
753
+
754
+ preds = {
755
+ "frames": scaled_rigids.to_tensor_7(),
756
+ "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
757
+ "unnormalized_angles": unnormalized_angles,
758
+ "angles": angles,
759
+ "positions": pred_xyz,
760
+ "states": s,
761
+ }
762
+
763
+ outputs.append(preds)
764
+
765
+ rigids = rigids.stop_rot_gradient()
766
+
767
+ del z
768
+
769
+ outputs = dict_multimap(torch.stack, outputs)
770
+ outputs["single"] = s
771
+
772
+ return outputs
773
+
774
+ def _init_residue_constants(self, float_dtype, device):
775
+ if not hasattr(self, "default_frames"):
776
+ self.register_buffer(
777
+ "default_frames",
778
+ torch.tensor(
779
+ restype_rigid_group_default_frame,
780
+ dtype=float_dtype,
781
+ device=device,
782
+ requires_grad=False,
783
+ ),
784
+ persistent=False,
785
+ )
786
+ if not hasattr(self, "group_idx"):
787
+ self.register_buffer(
788
+ "group_idx",
789
+ torch.tensor(
790
+ restype_atom14_to_rigid_group,
791
+ device=device,
792
+ requires_grad=False,
793
+ ),
794
+ persistent=False,
795
+ )
796
+ if not hasattr(self, "atom_mask"):
797
+ self.register_buffer(
798
+ "atom_mask",
799
+ torch.tensor(
800
+ restype_atom14_mask,
801
+ dtype=float_dtype,
802
+ device=device,
803
+ requires_grad=False,
804
+ ),
805
+ persistent=False,
806
+ )
807
+ if not hasattr(self, "lit_positions"):
808
+ self.register_buffer(
809
+ "lit_positions",
810
+ torch.tensor(
811
+ restype_atom14_rigid_group_positions,
812
+ dtype=float_dtype,
813
+ device=device,
814
+ requires_grad=False,
815
+ ),
816
+ persistent=False,
817
+ )
818
+
819
+ def torsion_angles_to_frames(self, r, alpha, f):
820
+ # Lazily initialize the residue constants on the correct device
821
+ self._init_residue_constants(alpha.dtype, alpha.device)
822
+ # Separated purely to make testing less annoying
823
+ return torsion_angles_to_frames(r, alpha, f, self.default_frames)
824
+
825
+ def frames_and_literature_positions_to_atom14_pos(
826
+ self, r, f # [*, N, 8] # [*, N]
827
+ ):
828
+ # Lazily initialize the residue constants on the correct device
829
+ self._init_residue_constants(r.dtype, r.device)
830
+ return frames_and_literature_positions_to_atom14_pos(
831
+ r,
832
+ f,
833
+ self.default_frames,
834
+ self.group_idx,
835
+ self.atom_mask,
836
+ self.lit_positions,
837
+ )
dockformer/model/torchscript.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Sequence, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from dockformer.model.evoformer import (
21
+ EvoformerBlock,
22
+ EvoformerStack,
23
+ )
24
+ from dockformer.model.single_attention import SingleRowAttentionWithPairBias
25
+ from dockformer.model.primitives import Attention, GlobalAttention
26
+
27
+
28
+ def script_preset_(model: torch.nn.Module):
29
+ """
30
+ TorchScript a handful of low-level but frequently used submodule types
31
+ that are known to be scriptable.
32
+
33
+ Args:
34
+ model:
35
+ A torch.nn.Module. It should contain at least some modules from
36
+ this repository, or this function won't do anything.
37
+ """
38
+ script_submodules_(
39
+ model,
40
+ [
41
+ nn.Dropout,
42
+ Attention,
43
+ GlobalAttention,
44
+ EvoformerBlock,
45
+ ],
46
+ attempt_trace=False,
47
+ batch_dims=None,
48
+ )
49
+
50
+
51
+ def _get_module_device(module: torch.nn.Module) -> torch.device:
52
+ """
53
+ Fetches the device of a module, assuming that all of the module's
54
+ parameters reside on a single device
55
+
56
+ Args:
57
+ module: A torch.nn.Module
58
+ Returns:
59
+ The module's device
60
+ """
61
+ return next(module.parameters()).device
62
+
63
+
64
+ def _trace_module(module, batch_dims=None):
65
+ if(batch_dims is None):
66
+ batch_dims = ()
67
+
68
+ # Stand-in values
69
+ n_seq = 10
70
+ n_res = 10
71
+
72
+ device = _get_module_device(module)
73
+
74
+ def msa(channel_dim):
75
+ return torch.rand(
76
+ (*batch_dims, n_seq, n_res, channel_dim),
77
+ device=device,
78
+ )
79
+
80
+ def pair(channel_dim):
81
+ return torch.rand(
82
+ (*batch_dims, n_res, n_res, channel_dim),
83
+ device=device,
84
+ )
85
+
86
+ if(isinstance(module, SingleRowAttentionWithPairBias)):
87
+ inputs = {
88
+ "forward": (
89
+ msa(module.c_in), # m
90
+ pair(module.c_z), # z
91
+ torch.randint(
92
+ 0, 2,
93
+ (*batch_dims, n_seq, n_res)
94
+ ), # mask
95
+ ),
96
+ }
97
+ else:
98
+ raise TypeError(
99
+ f"tracing is not supported for modules of type {type(module)}"
100
+ )
101
+
102
+ return torch.jit.trace_module(module, inputs)
103
+
104
+
105
+ def _script_submodules_helper_(
106
+ model,
107
+ types,
108
+ attempt_trace,
109
+ to_trace,
110
+ ):
111
+ for name, child in model.named_children():
112
+ if(types is None or any(isinstance(child, t) for t in types)):
113
+ try:
114
+ scripted = torch.jit.script(child)
115
+ setattr(model, name, scripted)
116
+ continue
117
+ except (RuntimeError, torch.jit.frontend.NotSupportedError) as e:
118
+ if(attempt_trace):
119
+ to_trace.add(type(child))
120
+ else:
121
+ raise e
122
+
123
+ _script_submodules_helper_(child, types, attempt_trace, to_trace)
124
+
125
+
126
+ def _trace_submodules_(
127
+ model,
128
+ types,
129
+ batch_dims=None,
130
+ ):
131
+ for name, child in model.named_children():
132
+ if(any(isinstance(child, t) for t in types)):
133
+ traced = _trace_module(child, batch_dims=batch_dims)
134
+ setattr(model, name, traced)
135
+ else:
136
+ _trace_submodules_(child, types, batch_dims=batch_dims)
137
+
138
+
139
+ def script_submodules_(
140
+ model: nn.Module,
141
+ types: Optional[Sequence[type]] = None,
142
+ attempt_trace: Optional[bool] = True,
143
+ batch_dims: Optional[Tuple[int]] = None,
144
+ ):
145
+ """
146
+ Convert all submodules whose types match one of those in the input
147
+ list to recursively scripted equivalents in place. To script the entire
148
+ model, just call torch.jit.script on it directly.
149
+
150
+ When types is None, all submodules are scripted.
151
+
152
+ Args:
153
+ model:
154
+ A torch.nn.Module
155
+ types:
156
+ A list of types of submodules to script
157
+ attempt_trace:
158
+ Whether to attempt to trace specified modules if scripting
159
+ fails. Recall that tracing eliminates all conditional
160
+ logic---with great tracing comes the mild responsibility of
161
+ having to remember to ensure that the modules in question
162
+ perform the same computations no matter what.
163
+ """
164
+ to_trace = set()
165
+
166
+ # Aggressively script as much as possible first...
167
+ _script_submodules_helper_(model, types, attempt_trace, to_trace)
168
+
169
+ # ... and then trace stragglers.
170
+ if(attempt_trace and len(to_trace) > 0):
171
+ _trace_submodules_(model, to_trace, batch_dims=batch_dims)
dockformer/model/triangular_attention.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partialmethod, partial
17
+ import math
18
+ from typing import Optional, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from dockformer.model.primitives import Linear, LayerNorm, Attention
24
+ from dockformer.utils.tensor_utils import permute_final_dims
25
+
26
+
27
+ class TriangleAttention(nn.Module):
28
+ def __init__(
29
+ self, c_in, c_hidden, no_heads, starting=True, inf=1e9
30
+ ):
31
+ """
32
+ Args:
33
+ c_in:
34
+ Input channel dimension
35
+ c_hidden:
36
+ Overall hidden channel dimension (not per-head)
37
+ no_heads:
38
+ Number of attention heads
39
+ """
40
+ super(TriangleAttention, self).__init__()
41
+
42
+ self.c_in = c_in
43
+ self.c_hidden = c_hidden
44
+ self.no_heads = no_heads
45
+ self.starting = starting
46
+ self.inf = inf
47
+
48
+ self.layer_norm = LayerNorm(self.c_in)
49
+
50
+ self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
51
+
52
+ self.mha = Attention(
53
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
54
+ )
55
+
56
+ def forward(self,
57
+ x: torch.Tensor,
58
+ mask: Optional[torch.Tensor] = None,
59
+ use_memory_efficient_kernel: bool = False,
60
+ use_lma: bool = False,
61
+ ) -> torch.Tensor:
62
+ """
63
+ Args:
64
+ x:
65
+ [*, I, J, C_in] input tensor (e.g. the pair representation)
66
+ Returns:
67
+ [*, I, J, C_in] output tensor
68
+ """
69
+ if mask is None:
70
+ # [*, I, J]
71
+ mask = x.new_ones(
72
+ x.shape[:-1],
73
+ )
74
+
75
+ if(not self.starting):
76
+ x = x.transpose(-2, -3)
77
+ mask = mask.transpose(-1, -2)
78
+
79
+ # [*, I, J, C_in]
80
+ x = self.layer_norm(x)
81
+
82
+ # [*, I, 1, 1, J]
83
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
84
+
85
+ # [*, H, I, J]
86
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
87
+
88
+ # [*, 1, H, I, J]
89
+ triangle_bias = triangle_bias.unsqueeze(-4)
90
+
91
+ biases = [mask_bias, triangle_bias]
92
+
93
+ x = self.mha(
94
+ q_x=x,
95
+ kv_x=x,
96
+ biases=biases,
97
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
98
+ use_lma=use_lma
99
+ )
100
+
101
+ if(not self.starting):
102
+ x = x.transpose(-2, -3)
103
+
104
+ return x
dockformer/model/triangular_multiplicative_update.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partialmethod
17
+ from typing import Optional
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from dockformer.model.primitives import Linear, LayerNorm
24
+ from dockformer.utils.precision_utils import is_fp16_enabled
25
+ from dockformer.utils.tensor_utils import permute_final_dims
26
+
27
+
28
+ class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
29
+ """
30
+ Implements Algorithms 11 and 12.
31
+ """
32
+ @abstractmethod
33
+ def __init__(self, c_z, c_hidden, _outgoing):
34
+ """
35
+ Args:
36
+ c_z:
37
+ Input channel dimension
38
+ c:
39
+ Hidden channel dimension
40
+ """
41
+ super(BaseTriangleMultiplicativeUpdate, self).__init__()
42
+ self.c_z = c_z
43
+ self.c_hidden = c_hidden
44
+ self._outgoing = _outgoing
45
+
46
+ self.linear_g = Linear(self.c_z, self.c_z, init="gating")
47
+ self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
48
+
49
+ self.layer_norm_in = LayerNorm(self.c_z)
50
+ self.layer_norm_out = LayerNorm(self.c_hidden)
51
+
52
+ self.sigmoid = nn.Sigmoid()
53
+
54
+ def _combine_projections(self,
55
+ a: torch.Tensor,
56
+ b: torch.Tensor,
57
+ ) -> torch.Tensor:
58
+ if(self._outgoing):
59
+ a = permute_final_dims(a, (2, 0, 1))
60
+ b = permute_final_dims(b, (2, 1, 0))
61
+ else:
62
+ a = permute_final_dims(a, (2, 1, 0))
63
+ b = permute_final_dims(b, (2, 0, 1))
64
+
65
+ p = torch.matmul(a, b)
66
+
67
+ return permute_final_dims(p, (1, 2, 0))
68
+
69
+ @abstractmethod
70
+ def forward(self,
71
+ z: torch.Tensor,
72
+ mask: Optional[torch.Tensor] = None,
73
+ inplace_safe: bool = False,
74
+ _add_with_inplace: bool = False
75
+ ) -> torch.Tensor:
76
+ """
77
+ Args:
78
+ x:
79
+ [*, N_res, N_res, C_z] input tensor
80
+ mask:
81
+ [*, N_res, N_res] input mask
82
+ Returns:
83
+ [*, N_res, N_res, C_z] output tensor
84
+ """
85
+ pass
86
+
87
+
88
+ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
89
+ """
90
+ Implements Algorithms 11 and 12.
91
+ """
92
+ def __init__(self, c_z, c_hidden, _outgoing=True):
93
+ """
94
+ Args:
95
+ c_z:
96
+ Input channel dimension
97
+ c:
98
+ Hidden channel dimension
99
+ """
100
+ super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
101
+ c_hidden=c_hidden,
102
+ _outgoing=_outgoing)
103
+
104
+ self.linear_a_p = Linear(self.c_z, self.c_hidden)
105
+ self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
106
+ self.linear_b_p = Linear(self.c_z, self.c_hidden)
107
+ self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
108
+
109
+ def forward(self,
110
+ z: torch.Tensor,
111
+ mask: Optional[torch.Tensor] = None,
112
+ inplace_safe: bool = False,
113
+ _add_with_inplace: bool = False,
114
+ ) -> torch.Tensor:
115
+ """
116
+ Args:
117
+ x:
118
+ [*, N_res, N_res, C_z] input tensor
119
+ mask:
120
+ [*, N_res, N_res] input mask
121
+ Returns:
122
+ [*, N_res, N_res, C_z] output tensor
123
+ """
124
+
125
+ if mask is None:
126
+ mask = z.new_ones(z.shape[:-1])
127
+
128
+ mask = mask.unsqueeze(-1)
129
+
130
+ z = self.layer_norm_in(z)
131
+ a = mask
132
+ a = a * self.sigmoid(self.linear_a_g(z))
133
+ a = a * self.linear_a_p(z)
134
+ b = mask
135
+ b = b * self.sigmoid(self.linear_b_g(z))
136
+ b = b * self.linear_b_p(z)
137
+
138
+ # Prevents overflow of torch.matmul in combine projections in
139
+ # reduced-precision modes
140
+ a_std = a.std()
141
+ b_std = b.std()
142
+ if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
143
+ a = a / a.std()
144
+ b = b / b.std()
145
+
146
+ if(is_fp16_enabled()):
147
+ with torch.cuda.amp.autocast(enabled=False):
148
+ x = self._combine_projections(a.float(), b.float())
149
+ else:
150
+ x = self._combine_projections(a, b)
151
+
152
+ del a, b
153
+ x = self.layer_norm_out(x)
154
+ x = self.linear_z(x)
155
+ g = self.sigmoid(self.linear_g(z))
156
+ x = x * g
157
+
158
+ return x
159
+
160
+
161
+ class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
162
+ """
163
+ Implements Algorithm 11.
164
+ """
165
+ __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=True)
166
+
167
+
168
+ class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
169
+ """
170
+ Implements Algorithm 12.
171
+ """
172
+ __init__ = partialmethod(TriangleMultiplicativeUpdate.__init__, _outgoing=False)
173
+
dockformer/resources/__init__.py ADDED
File without changes
dockformer/resources/stereo_chemical_props.txt ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Bond Residue Mean StdDev
2
+ CA-CB ALA 1.520 0.021
3
+ N-CA ALA 1.459 0.020
4
+ CA-C ALA 1.525 0.026
5
+ C-O ALA 1.229 0.019
6
+ CA-CB ARG 1.535 0.022
7
+ CB-CG ARG 1.521 0.027
8
+ CG-CD ARG 1.515 0.025
9
+ CD-NE ARG 1.460 0.017
10
+ NE-CZ ARG 1.326 0.013
11
+ CZ-NH1 ARG 1.326 0.013
12
+ CZ-NH2 ARG 1.326 0.013
13
+ N-CA ARG 1.459 0.020
14
+ CA-C ARG 1.525 0.026
15
+ C-O ARG 1.229 0.019
16
+ CA-CB ASN 1.527 0.026
17
+ CB-CG ASN 1.506 0.023
18
+ CG-OD1 ASN 1.235 0.022
19
+ CG-ND2 ASN 1.324 0.025
20
+ N-CA ASN 1.459 0.020
21
+ CA-C ASN 1.525 0.026
22
+ C-O ASN 1.229 0.019
23
+ CA-CB ASP 1.535 0.022
24
+ CB-CG ASP 1.513 0.021
25
+ CG-OD1 ASP 1.249 0.023
26
+ CG-OD2 ASP 1.249 0.023
27
+ N-CA ASP 1.459 0.020
28
+ CA-C ASP 1.525 0.026
29
+ C-O ASP 1.229 0.019
30
+ CA-CB CYS 1.526 0.013
31
+ CB-SG CYS 1.812 0.016
32
+ N-CA CYS 1.459 0.020
33
+ CA-C CYS 1.525 0.026
34
+ C-O CYS 1.229 0.019
35
+ CA-CB GLU 1.535 0.022
36
+ CB-CG GLU 1.517 0.019
37
+ CG-CD GLU 1.515 0.015
38
+ CD-OE1 GLU 1.252 0.011
39
+ CD-OE2 GLU 1.252 0.011
40
+ N-CA GLU 1.459 0.020
41
+ CA-C GLU 1.525 0.026
42
+ C-O GLU 1.229 0.019
43
+ CA-CB GLN 1.535 0.022
44
+ CB-CG GLN 1.521 0.027
45
+ CG-CD GLN 1.506 0.023
46
+ CD-OE1 GLN 1.235 0.022
47
+ CD-NE2 GLN 1.324 0.025
48
+ N-CA GLN 1.459 0.020
49
+ CA-C GLN 1.525 0.026
50
+ C-O GLN 1.229 0.019
51
+ N-CA GLY 1.456 0.015
52
+ CA-C GLY 1.514 0.016
53
+ C-O GLY 1.232 0.016
54
+ CA-CB HIS 1.535 0.022
55
+ CB-CG HIS 1.492 0.016
56
+ CG-ND1 HIS 1.369 0.015
57
+ CG-CD2 HIS 1.353 0.017
58
+ ND1-CE1 HIS 1.343 0.025
59
+ CD2-NE2 HIS 1.415 0.021
60
+ CE1-NE2 HIS 1.322 0.023
61
+ N-CA HIS 1.459 0.020
62
+ CA-C HIS 1.525 0.026
63
+ C-O HIS 1.229 0.019
64
+ CA-CB ILE 1.544 0.023
65
+ CB-CG1 ILE 1.536 0.028
66
+ CB-CG2 ILE 1.524 0.031
67
+ CG1-CD1 ILE 1.500 0.069
68
+ N-CA ILE 1.459 0.020
69
+ CA-C ILE 1.525 0.026
70
+ C-O ILE 1.229 0.019
71
+ CA-CB LEU 1.533 0.023
72
+ CB-CG LEU 1.521 0.029
73
+ CG-CD1 LEU 1.514 0.037
74
+ CG-CD2 LEU 1.514 0.037
75
+ N-CA LEU 1.459 0.020
76
+ CA-C LEU 1.525 0.026
77
+ C-O LEU 1.229 0.019
78
+ CA-CB LYS 1.535 0.022
79
+ CB-CG LYS 1.521 0.027
80
+ CG-CD LYS 1.520 0.034
81
+ CD-CE LYS 1.508 0.025
82
+ CE-NZ LYS 1.486 0.025
83
+ N-CA LYS 1.459 0.020
84
+ CA-C LYS 1.525 0.026
85
+ C-O LYS 1.229 0.019
86
+ CA-CB MET 1.535 0.022
87
+ CB-CG MET 1.509 0.032
88
+ CG-SD MET 1.807 0.026
89
+ SD-CE MET 1.774 0.056
90
+ N-CA MET 1.459 0.020
91
+ CA-C MET 1.525 0.026
92
+ C-O MET 1.229 0.019
93
+ CA-CB PHE 1.535 0.022
94
+ CB-CG PHE 1.509 0.017
95
+ CG-CD1 PHE 1.383 0.015
96
+ CG-CD2 PHE 1.383 0.015
97
+ CD1-CE1 PHE 1.388 0.020
98
+ CD2-CE2 PHE 1.388 0.020
99
+ CE1-CZ PHE 1.369 0.019
100
+ CE2-CZ PHE 1.369 0.019
101
+ N-CA PHE 1.459 0.020
102
+ CA-C PHE 1.525 0.026
103
+ C-O PHE 1.229 0.019
104
+ CA-CB PRO 1.531 0.020
105
+ CB-CG PRO 1.495 0.050
106
+ CG-CD PRO 1.502 0.033
107
+ CD-N PRO 1.474 0.014
108
+ N-CA PRO 1.468 0.017
109
+ CA-C PRO 1.524 0.020
110
+ C-O PRO 1.228 0.020
111
+ CA-CB SER 1.525 0.015
112
+ CB-OG SER 1.418 0.013
113
+ N-CA SER 1.459 0.020
114
+ CA-C SER 1.525 0.026
115
+ C-O SER 1.229 0.019
116
+ CA-CB THR 1.529 0.026
117
+ CB-OG1 THR 1.428 0.020
118
+ CB-CG2 THR 1.519 0.033
119
+ N-CA THR 1.459 0.020
120
+ CA-C THR 1.525 0.026
121
+ C-O THR 1.229 0.019
122
+ CA-CB TRP 1.535 0.022
123
+ CB-CG TRP 1.498 0.018
124
+ CG-CD1 TRP 1.363 0.014
125
+ CG-CD2 TRP 1.432 0.017
126
+ CD1-NE1 TRP 1.375 0.017
127
+ NE1-CE2 TRP 1.371 0.013
128
+ CD2-CE2 TRP 1.409 0.012
129
+ CD2-CE3 TRP 1.399 0.015
130
+ CE2-CZ2 TRP 1.393 0.017
131
+ CE3-CZ3 TRP 1.380 0.017
132
+ CZ2-CH2 TRP 1.369 0.019
133
+ CZ3-CH2 TRP 1.396 0.016
134
+ N-CA TRP 1.459 0.020
135
+ CA-C TRP 1.525 0.026
136
+ C-O TRP 1.229 0.019
137
+ CA-CB TYR 1.535 0.022
138
+ CB-CG TYR 1.512 0.015
139
+ CG-CD1 TYR 1.387 0.013
140
+ CG-CD2 TYR 1.387 0.013
141
+ CD1-CE1 TYR 1.389 0.015
142
+ CD2-CE2 TYR 1.389 0.015
143
+ CE1-CZ TYR 1.381 0.013
144
+ CE2-CZ TYR 1.381 0.013
145
+ CZ-OH TYR 1.374 0.017
146
+ N-CA TYR 1.459 0.020
147
+ CA-C TYR 1.525 0.026
148
+ C-O TYR 1.229 0.019
149
+ CA-CB VAL 1.543 0.021
150
+ CB-CG1 VAL 1.524 0.021
151
+ CB-CG2 VAL 1.524 0.021
152
+ N-CA VAL 1.459 0.020
153
+ CA-C VAL 1.525 0.026
154
+ C-O VAL 1.229 0.019
155
+ -
156
+
157
+ Angle Residue Mean StdDev
158
+ N-CA-CB ALA 110.1 1.4
159
+ CB-CA-C ALA 110.1 1.5
160
+ N-CA-C ALA 111.0 2.7
161
+ CA-C-O ALA 120.1 2.1
162
+ N-CA-CB ARG 110.6 1.8
163
+ CB-CA-C ARG 110.4 2.0
164
+ CA-CB-CG ARG 113.4 2.2
165
+ CB-CG-CD ARG 111.6 2.6
166
+ CG-CD-NE ARG 111.8 2.1
167
+ CD-NE-CZ ARG 123.6 1.4
168
+ NE-CZ-NH1 ARG 120.3 0.5
169
+ NE-CZ-NH2 ARG 120.3 0.5
170
+ NH1-CZ-NH2 ARG 119.4 1.1
171
+ N-CA-C ARG 111.0 2.7
172
+ CA-C-O ARG 120.1 2.1
173
+ N-CA-CB ASN 110.6 1.8
174
+ CB-CA-C ASN 110.4 2.0
175
+ CA-CB-CG ASN 113.4 2.2
176
+ CB-CG-ND2 ASN 116.7 2.4
177
+ CB-CG-OD1 ASN 121.6 2.0
178
+ ND2-CG-OD1 ASN 121.9 2.3
179
+ N-CA-C ASN 111.0 2.7
180
+ CA-C-O ASN 120.1 2.1
181
+ N-CA-CB ASP 110.6 1.8
182
+ CB-CA-C ASP 110.4 2.0
183
+ CA-CB-CG ASP 113.4 2.2
184
+ CB-CG-OD1 ASP 118.3 0.9
185
+ CB-CG-OD2 ASP 118.3 0.9
186
+ OD1-CG-OD2 ASP 123.3 1.9
187
+ N-CA-C ASP 111.0 2.7
188
+ CA-C-O ASP 120.1 2.1
189
+ N-CA-CB CYS 110.8 1.5
190
+ CB-CA-C CYS 111.5 1.2
191
+ CA-CB-SG CYS 114.2 1.1
192
+ N-CA-C CYS 111.0 2.7
193
+ CA-C-O CYS 120.1 2.1
194
+ N-CA-CB GLU 110.6 1.8
195
+ CB-CA-C GLU 110.4 2.0
196
+ CA-CB-CG GLU 113.4 2.2
197
+ CB-CG-CD GLU 114.2 2.7
198
+ CG-CD-OE1 GLU 118.3 2.0
199
+ CG-CD-OE2 GLU 118.3 2.0
200
+ OE1-CD-OE2 GLU 123.3 1.2
201
+ N-CA-C GLU 111.0 2.7
202
+ CA-C-O GLU 120.1 2.1
203
+ N-CA-CB GLN 110.6 1.8
204
+ CB-CA-C GLN 110.4 2.0
205
+ CA-CB-CG GLN 113.4 2.2
206
+ CB-CG-CD GLN 111.6 2.6
207
+ CG-CD-OE1 GLN 121.6 2.0
208
+ CG-CD-NE2 GLN 116.7 2.4
209
+ OE1-CD-NE2 GLN 121.9 2.3
210
+ N-CA-C GLN 111.0 2.7
211
+ CA-C-O GLN 120.1 2.1
212
+ N-CA-C GLY 113.1 2.5
213
+ CA-C-O GLY 120.6 1.8
214
+ N-CA-CB HIS 110.6 1.8
215
+ CB-CA-C HIS 110.4 2.0
216
+ CA-CB-CG HIS 113.6 1.7
217
+ CB-CG-ND1 HIS 123.2 2.5
218
+ CB-CG-CD2 HIS 130.8 3.1
219
+ CG-ND1-CE1 HIS 108.2 1.4
220
+ ND1-CE1-NE2 HIS 109.9 2.2
221
+ CE1-NE2-CD2 HIS 106.6 2.5
222
+ NE2-CD2-CG HIS 109.2 1.9
223
+ CD2-CG-ND1 HIS 106.0 1.4
224
+ N-CA-C HIS 111.0 2.7
225
+ CA-C-O HIS 120.1 2.1
226
+ N-CA-CB ILE 110.8 2.3
227
+ CB-CA-C ILE 111.6 2.0
228
+ CA-CB-CG1 ILE 111.0 1.9
229
+ CB-CG1-CD1 ILE 113.9 2.8
230
+ CA-CB-CG2 ILE 110.9 2.0
231
+ CG1-CB-CG2 ILE 111.4 2.2
232
+ N-CA-C ILE 111.0 2.7
233
+ CA-C-O ILE 120.1 2.1
234
+ N-CA-CB LEU 110.4 2.0
235
+ CB-CA-C LEU 110.2 1.9
236
+ CA-CB-CG LEU 115.3 2.3
237
+ CB-CG-CD1 LEU 111.0 1.7
238
+ CB-CG-CD2 LEU 111.0 1.7
239
+ CD1-CG-CD2 LEU 110.5 3.0
240
+ N-CA-C LEU 111.0 2.7
241
+ CA-C-O LEU 120.1 2.1
242
+ N-CA-CB LYS 110.6 1.8
243
+ CB-CA-C LYS 110.4 2.0
244
+ CA-CB-CG LYS 113.4 2.2
245
+ CB-CG-CD LYS 111.6 2.6
246
+ CG-CD-CE LYS 111.9 3.0
247
+ CD-CE-NZ LYS 111.7 2.3
248
+ N-CA-C LYS 111.0 2.7
249
+ CA-C-O LYS 120.1 2.1
250
+ N-CA-CB MET 110.6 1.8
251
+ CB-CA-C MET 110.4 2.0
252
+ CA-CB-CG MET 113.3 1.7
253
+ CB-CG-SD MET 112.4 3.0
254
+ CG-SD-CE MET 100.2 1.6
255
+ N-CA-C MET 111.0 2.7
256
+ CA-C-O MET 120.1 2.1
257
+ N-CA-CB PHE 110.6 1.8
258
+ CB-CA-C PHE 110.4 2.0
259
+ CA-CB-CG PHE 113.9 2.4
260
+ CB-CG-CD1 PHE 120.8 0.7
261
+ CB-CG-CD2 PHE 120.8 0.7
262
+ CD1-CG-CD2 PHE 118.3 1.3
263
+ CG-CD1-CE1 PHE 120.8 1.1
264
+ CG-CD2-CE2 PHE 120.8 1.1
265
+ CD1-CE1-CZ PHE 120.1 1.2
266
+ CD2-CE2-CZ PHE 120.1 1.2
267
+ CE1-CZ-CE2 PHE 120.0 1.8
268
+ N-CA-C PHE 111.0 2.7
269
+ CA-C-O PHE 120.1 2.1
270
+ N-CA-CB PRO 103.3 1.2
271
+ CB-CA-C PRO 111.7 2.1
272
+ CA-CB-CG PRO 104.8 1.9
273
+ CB-CG-CD PRO 106.5 3.9
274
+ CG-CD-N PRO 103.2 1.5
275
+ CA-N-CD PRO 111.7 1.4
276
+ N-CA-C PRO 112.1 2.6
277
+ CA-C-O PRO 120.2 2.4
278
+ N-CA-CB SER 110.5 1.5
279
+ CB-CA-C SER 110.1 1.9
280
+ CA-CB-OG SER 111.2 2.7
281
+ N-CA-C SER 111.0 2.7
282
+ CA-C-O SER 120.1 2.1
283
+ N-CA-CB THR 110.3 1.9
284
+ CB-CA-C THR 111.6 2.7
285
+ CA-CB-OG1 THR 109.0 2.1
286
+ CA-CB-CG2 THR 112.4 1.4
287
+ OG1-CB-CG2 THR 110.0 2.3
288
+ N-CA-C THR 111.0 2.7
289
+ CA-C-O THR 120.1 2.1
290
+ N-CA-CB TRP 110.6 1.8
291
+ CB-CA-C TRP 110.4 2.0
292
+ CA-CB-CG TRP 113.7 1.9
293
+ CB-CG-CD1 TRP 127.0 1.3
294
+ CB-CG-CD2 TRP 126.6 1.3
295
+ CD1-CG-CD2 TRP 106.3 0.8
296
+ CG-CD1-NE1 TRP 110.1 1.0
297
+ CD1-NE1-CE2 TRP 109.0 0.9
298
+ NE1-CE2-CD2 TRP 107.3 1.0
299
+ CE2-CD2-CG TRP 107.3 0.8
300
+ CG-CD2-CE3 TRP 133.9 0.9
301
+ NE1-CE2-CZ2 TRP 130.4 1.1
302
+ CE3-CD2-CE2 TRP 118.7 1.2
303
+ CD2-CE2-CZ2 TRP 122.3 1.2
304
+ CE2-CZ2-CH2 TRP 117.4 1.0
305
+ CZ2-CH2-CZ3 TRP 121.6 1.2
306
+ CH2-CZ3-CE3 TRP 121.2 1.1
307
+ CZ3-CE3-CD2 TRP 118.8 1.3
308
+ N-CA-C TRP 111.0 2.7
309
+ CA-C-O TRP 120.1 2.1
310
+ N-CA-CB TYR 110.6 1.8
311
+ CB-CA-C TYR 110.4 2.0
312
+ CA-CB-CG TYR 113.4 1.9
313
+ CB-CG-CD1 TYR 121.0 0.6
314
+ CB-CG-CD2 TYR 121.0 0.6
315
+ CD1-CG-CD2 TYR 117.9 1.1
316
+ CG-CD1-CE1 TYR 121.3 0.8
317
+ CG-CD2-CE2 TYR 121.3 0.8
318
+ CD1-CE1-CZ TYR 119.8 0.9
319
+ CD2-CE2-CZ TYR 119.8 0.9
320
+ CE1-CZ-CE2 TYR 119.8 1.6
321
+ CE1-CZ-OH TYR 120.1 2.7
322
+ CE2-CZ-OH TYR 120.1 2.7
323
+ N-CA-C TYR 111.0 2.7
324
+ CA-C-O TYR 120.1 2.1
325
+ N-CA-CB VAL 111.5 2.2
326
+ CB-CA-C VAL 111.4 1.9
327
+ CA-CB-CG1 VAL 110.9 1.5
328
+ CA-CB-CG2 VAL 110.9 1.5
329
+ CG1-CB-CG2 VAL 110.9 1.6
330
+ N-CA-C VAL 111.0 2.7
331
+ CA-C-O VAL 120.1 2.1
332
+ -
333
+
334
+ Non-bonded distance Minimum Dist Tolerance
335
+ C-C 3.4 1.5
336
+ C-N 3.25 1.5
337
+ C-S 3.5 1.5
338
+ C-O 3.22 1.5
339
+ N-N 3.1 1.5
340
+ N-S 3.35 1.5
341
+ N-O 3.07 1.5
342
+ O-S 3.32 1.5
343
+ O-O 3.04 1.5
344
+ S-S 2.03 1.0
345
+ -
dockformer/utils/__init__.py ADDED
File without changes
dockformer/utils/callbacks.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch.callbacks import EarlyStopping
2
+ from lightning_utilities.core.rank_zero import rank_zero_info
3
+
4
+
5
+ class EarlyStoppingVerbose(EarlyStopping):
6
+ """
7
+ The default EarlyStopping callback's verbose mode is too verbose.
8
+ This class outputs a message only when it's getting ready to stop.
9
+ """
10
+ def _evalute_stopping_criteria(self, *args, **kwargs):
11
+ should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs)
12
+ if(should_stop):
13
+ rank_zero_info(f"{reason}\n")
14
+
15
+ return should_stop, reason
dockformer/utils/checkpointing.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ from typing import Any, Tuple, List, Callable, Optional
16
+
17
+
18
+ import torch
19
+ import torch.utils.checkpoint
20
+
21
+
22
+ BLOCK_ARG = Any
23
+ BLOCK_ARGS = List[BLOCK_ARG]
24
+
25
+
26
+ @torch.jit.ignore
27
+ def checkpoint_blocks(
28
+ blocks: List[Callable],
29
+ args: BLOCK_ARGS,
30
+ blocks_per_ckpt: Optional[int],
31
+ ) -> BLOCK_ARGS:
32
+ """
33
+ Chunk a list of blocks and run each chunk with activation
34
+ checkpointing. We define a "block" as a callable whose only inputs are
35
+ the outputs of the previous block.
36
+
37
+ Implements Subsection 1.11.8
38
+
39
+ Args:
40
+ blocks:
41
+ List of blocks
42
+ args:
43
+ Tuple of arguments for the first block.
44
+ blocks_per_ckpt:
45
+ Size of each chunk. A higher value corresponds to fewer
46
+ checkpoints, and trades memory for speed. If None, no checkpointing
47
+ is performed.
48
+ Returns:
49
+ The output of the final block
50
+ """
51
+ def wrap(a):
52
+ return (a,) if type(a) is not tuple else a
53
+
54
+ def exec(b, a):
55
+ for block in b:
56
+ a = wrap(block(*a))
57
+ return a
58
+
59
+ def chunker(s, e):
60
+ def exec_sliced(*a):
61
+ return exec(blocks[s:e], a)
62
+
63
+ return exec_sliced
64
+
65
+ # Avoids mishaps when the blocks take just one argument
66
+ args = wrap(args)
67
+
68
+ if blocks_per_ckpt is None or not torch.is_grad_enabled():
69
+ return exec(blocks, args)
70
+ elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
71
+ raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
72
+
73
+ for s in range(0, len(blocks), blocks_per_ckpt):
74
+ e = s + blocks_per_ckpt
75
+ args = torch.utils.checkpoint.checkpoint(chunker(s, e), *args, use_reentrant=True)
76
+ args = wrap(args)
77
+
78
+ return args
dockformer/utils/config_tools.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import ml_collections as mlc
4
+
5
+
6
+ def set_inf(c, inf):
7
+ for k, v in c.items():
8
+ if isinstance(v, mlc.ConfigDict):
9
+ set_inf(v, inf)
10
+ elif k == "inf":
11
+ c[k] = inf
12
+
13
+
14
+ def enforce_config_constraints(config):
15
+ def string_to_setting(s):
16
+ path = s.split('.')
17
+ setting = config
18
+ for p in path:
19
+ setting = setting.get(p)
20
+
21
+ return setting
22
+
23
+ mutually_exclusive_bools = [
24
+ (
25
+ "globals.use_lma",
26
+ ),
27
+ ]
28
+
29
+ for options in mutually_exclusive_bools:
30
+ option_settings = [string_to_setting(o) for o in options]
31
+ if sum(option_settings) > 1:
32
+ raise ValueError(f"Only one of {', '.join(options)} may be set at a time")
dockformer/utils/consts.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit.Chem.rdchem import ChiralType, BondType
2
+
3
+ # Survey of atom types in the PDBBind
4
+ # {'C': 403253, 'O': 101283, 'N': 81325, 'S': 6262, 'F': 5256, 'P': 3378, 'Cl': 2920, 'Br': 552, 'B': 237, 'I': 185,
5
+ # 'H': 181, 'Fe': 19, 'Se': 15, 'Ru': 10, 'Si': 5, 'Co': 4, 'Ir': 4, 'As': 2, 'Pt': 2, 'V': 1, 'Mg': 1, 'Be': 1,
6
+ # 'Rh': 1, 'Cu': 1, 'Re': 1}
7
+ # I have changed the uncommon types to common ions for the plinder dataset
8
+ # {'As': "Zn", 'Pt': "Mn", 'V': "Ca", 'Mg': "Mg", 'Be': "Na", 'Rh': "Al", 'Cu': "K", 'Re': "Ni"}
9
+
10
+ POSSIBLE_ATOM_TYPES = ['C', 'O', 'N', 'S', 'F', 'P', 'Cl', 'Br', 'B', 'I', 'H', 'Fe', 'Se', 'Ru', 'Si', 'Co', 'Ir',
11
+ 'Zn', 'Mn', 'Ca', 'Mg', 'Na', 'Al', 'K', 'Ni']
12
+
13
+ # bonds Counter({BondType.SINGLE: 366857, BondType.AROMATIC: 214238, BondType.DOUBLE: 59725, BondType.TRIPLE: 866,
14
+ # BondType.UNSPECIFIED: 18, BondType.DATIVE: 8})
15
+ POSSIBLE_BOND_TYPES = [BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE, BondType.AROMATIC, BondType.UNSPECIFIED,
16
+ BondType.DATIVE]
17
+
18
+ # {0: 580061, 1: 13273, -1: 11473, 2: 44, 7: 17, -2: 8, 9: 7, 10: 7, 5: 3, 3: 3, 4: 1, 6: 1, 8: 1}
19
+ POSSIBLE_CHARGES = [-1, 0, 1]
20
+
21
+ # {ChiralType.CHI_UNSPECIFIED: 551374, ChiralType.CHI_TETRAHEDRAL_CCW: 27328, ChiralType.CHI_TETRAHEDRAL_CW: 26178,
22
+ # ChiralType.CHI_OCTAHEDRAL: 13, ChiralType.CHI_SQUAREPLANAR: 3, ChiralType.CHI_TRIGONALBIPYRAMIDAL: 3}
23
+ POSSIBLE_CHIRALITIES = [ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CCW, ChiralType.CHI_TETRAHEDRAL_CW,
24
+ ChiralType.CHI_OCTAHEDRAL, ChiralType.CHI_SQUAREPLANAR, ChiralType.CHI_TRIGONALBIPYRAMIDAL]
25
+
dockformer/utils/exponential_moving_average.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import copy
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from dockformer.utils.tensor_utils import tensor_tree_map
7
+
8
+
9
+ class ExponentialMovingAverage:
10
+ """
11
+ Maintains moving averages of parameters with exponential decay
12
+
13
+ At each step, the stored copy `copy` of each parameter `param` is
14
+ updated as follows:
15
+
16
+ `copy = decay * copy + (1 - decay) * param`
17
+
18
+ where `decay` is an attribute of the ExponentialMovingAverage object.
19
+ """
20
+
21
+ def __init__(self, model: nn.Module, decay: float):
22
+ """
23
+ Args:
24
+ model:
25
+ A torch.nn.Module whose parameters are to be tracked
26
+ decay:
27
+ A value (usually close to 1.) by which updates are
28
+ weighted as part of the above formula
29
+ """
30
+ super(ExponentialMovingAverage, self).__init__()
31
+
32
+ clone_param = lambda t: t.clone().detach()
33
+ self.params = tensor_tree_map(clone_param, model.state_dict())
34
+ self.decay = decay
35
+ self.device = next(model.parameters()).device
36
+
37
+ def to(self, device):
38
+ self.params = tensor_tree_map(lambda t: t.to(device), self.params)
39
+ self.device = device
40
+
41
+ def _update_state_dict_(self, update, state_dict):
42
+ with torch.no_grad():
43
+ for k, v in update.items():
44
+ stored = state_dict[k]
45
+ if not isinstance(v, torch.Tensor):
46
+ self._update_state_dict_(v, stored)
47
+ else:
48
+ diff = stored - v
49
+ diff *= 1 - self.decay
50
+ stored -= diff
51
+
52
+ def update(self, model: torch.nn.Module) -> None:
53
+ """
54
+ Updates the stored parameters using the state dict of the provided
55
+ module. The module should have the same structure as that used to
56
+ initialize the ExponentialMovingAverage object.
57
+ """
58
+ self._update_state_dict_(model.state_dict(), self.params)
59
+
60
+ def load_state_dict(self, state_dict: OrderedDict) -> None:
61
+ for k in state_dict["params"].keys():
62
+ self.params[k] = state_dict["params"][k].clone()
63
+ self.decay = state_dict["decay"]
64
+
65
+ def state_dict(self) -> OrderedDict:
66
+ return OrderedDict(
67
+ {
68
+ "params": self.params,
69
+ "decay": self.decay,
70
+ }
71
+ )
dockformer/utils/feats.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from typing import Dict, Union
22
+
23
+ from dockformer.utils import protein
24
+ import dockformer.utils.residue_constants as rc
25
+ from dockformer.utils.geometry import rigid_matrix_vector, rotation_matrix, vector
26
+ from dockformer.utils.rigid_utils import Rotation, Rigid
27
+ from dockformer.utils.tensor_utils import (
28
+ batched_gather,
29
+ one_hot,
30
+ tree_map,
31
+ tensor_tree_map,
32
+ )
33
+
34
+
35
+ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
36
+ # rc.restype_order["X"] defines a ligand, and the atom position used is the CA
37
+ is_gly_or_lig = (aatype == rc.restype_order["G"]) | (aatype == rc.restype_order["Z"])
38
+ ca_idx = rc.atom_order["CA"]
39
+ cb_idx = rc.atom_order["CB"]
40
+ pseudo_beta = torch.where(
41
+ is_gly_or_lig[..., None].expand(*((-1,) * len(is_gly_or_lig.shape)), 3),
42
+ all_atom_positions[..., ca_idx, :],
43
+ all_atom_positions[..., cb_idx, :],
44
+ )
45
+
46
+ if all_atom_masks is not None:
47
+ pseudo_beta_mask = torch.where(
48
+ is_gly_or_lig,
49
+ all_atom_masks[..., ca_idx],
50
+ all_atom_masks[..., cb_idx],
51
+ )
52
+ return pseudo_beta, pseudo_beta_mask
53
+ else:
54
+ return pseudo_beta
55
+
56
+
57
+ def atom14_to_atom37(atom14, batch):
58
+ atom37_data = batched_gather(
59
+ atom14,
60
+ batch["residx_atom37_to_atom14"],
61
+ dim=-2,
62
+ no_batch_dims=len(atom14.shape[:-2]),
63
+ )
64
+
65
+ atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
66
+
67
+ return atom37_data
68
+
69
+
70
+ def torsion_angles_to_frames(
71
+ r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
72
+ alpha: torch.Tensor,
73
+ aatype: torch.Tensor,
74
+ rrgdf: torch.Tensor,
75
+ ):
76
+
77
+ rigid_type = type(r)
78
+
79
+ # [*, N, 8, 4, 4]
80
+ default_4x4 = rrgdf[aatype, ...]
81
+
82
+ # [*, N, 8] transformations, i.e.
83
+ # One [*, N, 8, 3, 3] rotation matrix and
84
+ # One [*, N, 8, 3] translation matrix
85
+ default_r = rigid_type.from_tensor_4x4(default_4x4)
86
+
87
+ bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
88
+ bb_rot[..., 1] = 1
89
+
90
+ # [*, N, 8, 2]
91
+ alpha = torch.cat(
92
+ [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
93
+ )
94
+
95
+ # [*, N, 8, 3, 3]
96
+ # Produces rotation matrices of the form:
97
+ # [
98
+ # [1, 0 , 0 ],
99
+ # [0, a_2,-a_1],
100
+ # [0, a_1, a_2]
101
+ # ]
102
+ # This follows the original code rather than the supplement, which uses
103
+ # different indices.
104
+
105
+ all_rots = alpha.new_zeros(default_r.shape + (4, 4))
106
+ all_rots[..., 0, 0] = 1
107
+ all_rots[..., 1, 1] = alpha[..., 1]
108
+ all_rots[..., 1, 2] = -alpha[..., 0]
109
+ all_rots[..., 2, 1:3] = alpha
110
+
111
+ all_rots = rigid_type.from_tensor_4x4(all_rots)
112
+ all_frames = default_r.compose(all_rots)
113
+
114
+ chi2_frame_to_frame = all_frames[..., 5]
115
+ chi3_frame_to_frame = all_frames[..., 6]
116
+ chi4_frame_to_frame = all_frames[..., 7]
117
+
118
+ chi1_frame_to_bb = all_frames[..., 4]
119
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
120
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
121
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
122
+
123
+ all_frames_to_bb = rigid_type.cat(
124
+ [
125
+ all_frames[..., :5],
126
+ chi2_frame_to_bb.unsqueeze(-1),
127
+ chi3_frame_to_bb.unsqueeze(-1),
128
+ chi4_frame_to_bb.unsqueeze(-1),
129
+ ],
130
+ dim=-1,
131
+ )
132
+
133
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb)
134
+
135
+ return all_frames_to_global
136
+
137
+
138
+ def frames_and_literature_positions_to_atom14_pos(
139
+ r: Union[Rigid, rigid_matrix_vector.Rigid3Array],
140
+ aatype: torch.Tensor,
141
+ default_frames,
142
+ group_idx,
143
+ atom_mask,
144
+ lit_positions,
145
+ ):
146
+ # [*, N, 14, 4, 4]
147
+ default_4x4 = default_frames[aatype, ...]
148
+
149
+ # [*, N, 14]
150
+ group_mask = group_idx[aatype, ...]
151
+
152
+ # [*, N, 14, 8]
153
+ group_mask = nn.functional.one_hot(
154
+ group_mask,
155
+ num_classes=default_frames.shape[-3],
156
+ )
157
+
158
+ # [*, N, 14, 8]
159
+ t_atoms_to_global = r[..., None, :] * group_mask
160
+
161
+ # [*, N, 14]
162
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
163
+ lambda x: torch.sum(x, dim=-1)
164
+ )
165
+
166
+ # [*, N, 14]
167
+ atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
168
+
169
+ # [*, N, 14, 3]
170
+ lit_positions = lit_positions[aatype, ...]
171
+ pred_positions = t_atoms_to_global.apply(lit_positions)
172
+ pred_positions = pred_positions * atom_mask
173
+
174
+ return pred_positions
dockformer/utils/geometry/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Geometry Module."""
15
+
16
+ from dockformer.utils.geometry import rigid_matrix_vector
17
+ from dockformer.utils.geometry import rotation_matrix
18
+ from dockformer.utils.geometry import vector
19
+
20
+ Rot3Array = rotation_matrix.Rot3Array
21
+ Rigid3Array = rigid_matrix_vector.Rigid3Array
22
+
23
+ Vec3Array = vector.Vec3Array
24
+ square_euclidean_distance = vector.square_euclidean_distance
25
+ euclidean_distance = vector.euclidean_distance
26
+ dihedral_angle = vector.dihedral_angle
27
+ dot = vector.dot
28
+ cross = vector.cross
dockformer/utils/geometry/quat_rigid.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from dockformer.model.primitives import Linear
5
+ from dockformer.utils.geometry.rigid_matrix_vector import Rigid3Array
6
+ from dockformer.utils.geometry.rotation_matrix import Rot3Array
7
+ from dockformer.utils.geometry.vector import Vec3Array
8
+
9
+
10
+ class QuatRigid(nn.Module):
11
+ def __init__(self, c_hidden, full_quat):
12
+ super().__init__()
13
+ self.full_quat = full_quat
14
+ if self.full_quat:
15
+ rigid_dim = 7
16
+ else:
17
+ rigid_dim = 6
18
+
19
+ self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32)
20
+
21
+ def forward(self, activations: torch.Tensor) -> Rigid3Array:
22
+ # NOTE: During training, this needs to be run in higher precision
23
+ rigid_flat = self.linear(activations)
24
+
25
+ rigid_flat = torch.unbind(rigid_flat, dim=-1)
26
+ if(self.full_quat):
27
+ qw, qx, qy, qz = rigid_flat[:4]
28
+ translation = rigid_flat[4:]
29
+ else:
30
+ qx, qy, qz = rigid_flat[:3]
31
+ qw = torch.ones_like(qx)
32
+ translation = rigid_flat[3:]
33
+
34
+ rotation = Rot3Array.from_quaternion(
35
+ qw, qx, qy, qz, normalize=True,
36
+ )
37
+ translation = Vec3Array(*translation)
38
+ return Rigid3Array(rotation, translation)
dockformer/utils/geometry/rigid_matrix_vector.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Rigid3Array Transformations represented by a Matrix and a Vector."""
15
+
16
+ from __future__ import annotations
17
+ import dataclasses
18
+ from typing import Union, List
19
+
20
+ import torch
21
+
22
+ from dockformer.utils.geometry import rotation_matrix
23
+ from dockformer.utils.geometry import vector
24
+
25
+
26
+ Float = Union[float, torch.Tensor]
27
+
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class Rigid3Array:
31
+ """Rigid Transformation, i.e. element of special euclidean group."""
32
+
33
+ rotation: rotation_matrix.Rot3Array
34
+ translation: vector.Vec3Array
35
+
36
+ def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
37
+ new_rotation = self.rotation @ other.rotation # __matmul__
38
+ new_translation = self.apply_to_point(other.translation)
39
+ return Rigid3Array(new_rotation, new_translation)
40
+
41
+ def __getitem__(self, index) -> Rigid3Array:
42
+ return Rigid3Array(
43
+ self.rotation[index],
44
+ self.translation[index],
45
+ )
46
+
47
+ def __mul__(self, other: torch.Tensor) -> Rigid3Array:
48
+ return Rigid3Array(
49
+ self.rotation * other,
50
+ self.translation * other,
51
+ )
52
+
53
+ def map_tensor_fn(self, fn) -> Rigid3Array:
54
+ return Rigid3Array(
55
+ self.rotation.map_tensor_fn(fn),
56
+ self.translation.map_tensor_fn(fn),
57
+ )
58
+
59
+ def inverse(self) -> Rigid3Array:
60
+ """Return Rigid3Array corresponding to inverse transform."""
61
+ inv_rotation = self.rotation.inverse()
62
+ inv_translation = inv_rotation.apply_to_point(-self.translation)
63
+ return Rigid3Array(inv_rotation, inv_translation)
64
+
65
+ def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
66
+ """Apply Rigid3Array transform to point."""
67
+ return self.rotation.apply_to_point(point) + self.translation
68
+
69
+ def apply(self, point: torch.Tensor) -> torch.Tensor:
70
+ return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor()
71
+
72
+ def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
73
+ """Apply inverse Rigid3Array transform to point."""
74
+ new_point = point - self.translation
75
+ return self.rotation.apply_inverse_to_point(new_point)
76
+
77
+ def invert_apply(self, point: torch.Tensor) -> torch.Tensor:
78
+ return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor()
79
+
80
+ def compose_rotation(self, other_rotation):
81
+ rot = self.rotation @ other_rotation
82
+ return Rigid3Array(rot, self.translation.clone())
83
+
84
+ def compose(self, other_rigid):
85
+ return self @ other_rigid
86
+
87
+ def unsqueeze(self, dim: int):
88
+ return Rigid3Array(
89
+ self.rotation.unsqueeze(dim),
90
+ self.translation.unsqueeze(dim),
91
+ )
92
+
93
+ @property
94
+ def shape(self) -> torch.Size:
95
+ return self.rotation.xx.shape
96
+
97
+ @property
98
+ def dtype(self) -> torch.dtype:
99
+ return self.rotation.xx.dtype
100
+
101
+ @property
102
+ def device(self) -> torch.device:
103
+ return self.rotation.xx.device
104
+
105
+ @classmethod
106
+ def identity(cls, shape, device) -> Rigid3Array:
107
+ """Return identity Rigid3Array of given shape."""
108
+ return cls(
109
+ rotation_matrix.Rot3Array.identity(shape, device),
110
+ vector.Vec3Array.zeros(shape, device)
111
+ )
112
+
113
+ @classmethod
114
+ def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
115
+ return cls(
116
+ rotation_matrix.Rot3Array.cat(
117
+ [r.rotation for r in rigids], dim=dim
118
+ ),
119
+ vector.Vec3Array.cat(
120
+ [r.translation for r in rigids], dim=dim
121
+ ),
122
+ )
123
+
124
+ def scale_translation(self, factor: Float) -> Rigid3Array:
125
+ """Scale translation in Rigid3Array by 'factor'."""
126
+ return Rigid3Array(self.rotation, self.translation * factor)
127
+
128
+ def to_tensor(self) -> torch.Tensor:
129
+ rot_array = self.rotation.to_tensor()
130
+ vec_array = self.translation.to_tensor()
131
+ array = torch.zeros(
132
+ rot_array.shape[:-2] + (4, 4),
133
+ device=rot_array.device,
134
+ dtype=rot_array.dtype
135
+ )
136
+ array[..., :3, :3] = rot_array
137
+ array[..., :3, 3] = vec_array
138
+ array[..., 3, 3] = 1.
139
+ return array
140
+
141
+ def to_tensor_4x4(self) -> torch.Tensor:
142
+ return self.to_tensor()
143
+
144
+ def reshape(self, new_shape) -> Rigid3Array:
145
+ rots = self.rotation.reshape(new_shape)
146
+ trans = self.translation.reshape(new_shape)
147
+ return Rigid3Array(rots, trans)
148
+
149
+ def stop_rot_gradient(self) -> Rigid3Array:
150
+ return Rigid3Array(
151
+ self.rotation.stop_gradient(),
152
+ self.translation,
153
+ )
154
+
155
+ @classmethod
156
+ def from_array(cls, array):
157
+ rot = rotation_matrix.Rot3Array.from_array(
158
+ array[..., :3, :3],
159
+ )
160
+ vec = vector.Vec3Array.from_array(array[..., :3, 3])
161
+ return cls(rot, vec)
162
+
163
+ @classmethod
164
+ def from_tensor_4x4(cls, array):
165
+ return cls.from_array(array)
166
+
167
+ @classmethod
168
+ def from_array4x4(cls, array: torch.tensor) -> Rigid3Array:
169
+ """Construct Rigid3Array from homogeneous 4x4 array."""
170
+ rotation = rotation_matrix.Rot3Array(
171
+ array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
172
+ array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
173
+ array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
174
+ )
175
+ translation = vector.Vec3Array(
176
+ array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
177
+ )
178
+ return cls(rotation, translation)
179
+
180
+ def cuda(self) -> Rigid3Array:
181
+ return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda())
dockformer/utils/geometry/rotation_matrix.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Rot3Array Matrix Class."""
15
+
16
+ from __future__ import annotations
17
+ import dataclasses
18
+ from typing import List
19
+
20
+ import torch
21
+
22
+ from dockformer.utils.geometry import utils
23
+ from dockformer.utils.geometry import vector
24
+ from dockformer.utils.tensor_utils import tensor_tree_map
25
+
26
+
27
+ COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class Rot3Array:
31
+ """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
32
+ xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
33
+ xy: torch.Tensor
34
+ xz: torch.Tensor
35
+ yx: torch.Tensor
36
+ yy: torch.Tensor
37
+ yz: torch.Tensor
38
+ zx: torch.Tensor
39
+ zy: torch.Tensor
40
+ zz: torch.Tensor
41
+
42
+ __array_ufunc__ = None
43
+
44
+ def __getitem__(self, index):
45
+ field_names = utils.get_field_names(Rot3Array)
46
+ return Rot3Array(
47
+ **{
48
+ name: getattr(self, name)[index]
49
+ for name in field_names
50
+ }
51
+ )
52
+
53
+ def __mul__(self, other: torch.Tensor):
54
+ field_names = utils.get_field_names(Rot3Array)
55
+ return Rot3Array(
56
+ **{
57
+ name: getattr(self, name) * other
58
+ for name in field_names
59
+ }
60
+ )
61
+
62
+ def __matmul__(self, other: Rot3Array) -> Rot3Array:
63
+ """Composes two Rot3Arrays."""
64
+ c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
65
+ c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
66
+ c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
67
+ return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
68
+
69
+ def map_tensor_fn(self, fn) -> Rot3Array:
70
+ field_names = utils.get_field_names(Rot3Array)
71
+ return Rot3Array(
72
+ **{
73
+ name: fn(getattr(self, name))
74
+ for name in field_names
75
+ }
76
+ )
77
+
78
+ def inverse(self) -> Rot3Array:
79
+ """Returns inverse of Rot3Array."""
80
+ return Rot3Array(
81
+ self.xx, self.yx, self.zx,
82
+ self.xy, self.yy, self.zy,
83
+ self.xz, self.yz, self.zz
84
+ )
85
+
86
+ def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
87
+ """Applies Rot3Array to point."""
88
+ return vector.Vec3Array(
89
+ self.xx * point.x + self.xy * point.y + self.xz * point.z,
90
+ self.yx * point.x + self.yy * point.y + self.yz * point.z,
91
+ self.zx * point.x + self.zy * point.y + self.zz * point.z
92
+ )
93
+
94
+ def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
95
+ """Applies inverse Rot3Array to point."""
96
+ return self.inverse().apply_to_point(point)
97
+
98
+
99
+ def unsqueeze(self, dim: int):
100
+ return Rot3Array(
101
+ *tensor_tree_map(
102
+ lambda t: t.unsqueeze(dim),
103
+ [getattr(self, c) for c in COMPONENTS]
104
+ )
105
+ )
106
+
107
+ def stop_gradient(self) -> Rot3Array:
108
+ return Rot3Array(
109
+ *[getattr(self, c).detach() for c in COMPONENTS]
110
+ )
111
+
112
+ @classmethod
113
+ def identity(cls, shape, device) -> Rot3Array:
114
+ """Returns identity of given shape."""
115
+ ones = torch.ones(shape, dtype=torch.float32, device=device)
116
+ zeros = torch.zeros(shape, dtype=torch.float32, device=device)
117
+ return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)
118
+
119
+ @classmethod
120
+ def from_two_vectors(
121
+ cls, e0: vector.Vec3Array,
122
+ e1: vector.Vec3Array
123
+ ) -> Rot3Array:
124
+ """Construct Rot3Array from two Vectors.
125
+
126
+ Rot3Array is constructed such that in the corresponding frame 'e0' lies on
127
+ the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
128
+
129
+ Args:
130
+ e0: Vector
131
+ e1: Vector
132
+ Returns:
133
+ Rot3Array
134
+ """
135
+ # Normalize the unit vector for the x-axis, e0.
136
+ e0 = e0.normalized()
137
+ # make e1 perpendicular to e0.
138
+ c = e1.dot(e0)
139
+ e1 = (e1 - c * e0).normalized()
140
+ # Compute e2 as cross product of e0 and e1.
141
+ e2 = e0.cross(e1)
142
+ return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)
143
+
144
+ @classmethod
145
+ def from_array(cls, array: torch.Tensor) -> Rot3Array:
146
+ """Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
147
+ rows = torch.unbind(array, dim=-2)
148
+ rc = [torch.unbind(e, dim=-1) for e in rows]
149
+ return cls(*[e for row in rc for e in row])
150
+
151
+ def to_tensor(self) -> torch.Tensor:
152
+ """Convert Rot3Array to array of shape [..., 3, 3]."""
153
+ return torch.stack(
154
+ [
155
+ torch.stack([self.xx, self.xy, self.xz], dim=-1),
156
+ torch.stack([self.yx, self.yy, self.yz], dim=-1),
157
+ torch.stack([self.zx, self.zy, self.zz], dim=-1)
158
+ ],
159
+ dim=-2
160
+ )
161
+
162
+ @classmethod
163
+ def from_quaternion(cls,
164
+ w: torch.Tensor,
165
+ x: torch.Tensor,
166
+ y: torch.Tensor,
167
+ z: torch.Tensor,
168
+ normalize: bool = True,
169
+ eps: float = 1e-6
170
+ ) -> Rot3Array:
171
+ """Construct Rot3Array from components of quaternion."""
172
+ if normalize:
173
+ inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
174
+ w = w * inv_norm
175
+ x = x * inv_norm
176
+ y = y * inv_norm
177
+ z = z * inv_norm
178
+ xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
179
+ xy = 2.0 * (x * y - w * z)
180
+ xz = 2.0 * (x * z + w * y)
181
+ yx = 2.0 * (x * y + w * z)
182
+ yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
183
+ yz = 2.0 * (y * z - w * x)
184
+ zx = 2.0 * (x * z - w * y)
185
+ zy = 2.0 * (y * z + w * x)
186
+ zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
187
+ return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
188
+
189
+ def reshape(self, new_shape):
190
+ field_names = utils.get_field_names(Rot3Array)
191
+ reshape_fn = lambda t: t.reshape(new_shape)
192
+ return Rot3Array(
193
+ **{
194
+ name: reshape_fn(getattr(self, name))
195
+ for name in field_names
196
+ }
197
+ )
198
+
199
+ @classmethod
200
+ def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array:
201
+ field_names = utils.get_field_names(Rot3Array)
202
+ cat_fn = lambda l: torch.cat(l, dim=dim)
203
+ return cls(
204
+ **{
205
+ name: cat_fn([getattr(r, name) for r in rots])
206
+ for name in field_names
207
+ }
208
+ )
dockformer/utils/geometry/test_utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Shared utils for tests."""
15
+
16
+ import dataclasses
17
+ import torch
18
+
19
+ from dockformer.utils.geometry import rigid_matrix_vector
20
+ from dockformer.utils.geometry import rotation_matrix
21
+ from dockformer.utils.geometry import vector
22
+
23
+
24
+ def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
25
+ matrix2: rotation_matrix.Rot3Array):
26
+ for field in dataclasses.fields(rotation_matrix.Rot3Array):
27
+ field = field.name
28
+ assert torch.equal(
29
+ getattr(matrix1, field), getattr(matrix2, field))
30
+
31
+
32
+ def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
33
+ mat2: rotation_matrix.Rot3Array):
34
+ assert torch.allclose(mat1.to_tensor(), mat2.to_tensor(), atol=1e-6)
35
+
36
+
37
+ def assert_array_equal_to_rotation_matrix(array: torch.Tensor,
38
+ matrix: rotation_matrix.Rot3Array):
39
+ """Check that array and Matrix match."""
40
+ assert torch.equal(matrix.xx, array[..., 0, 0])
41
+ assert torch.equal(matrix.xy, array[..., 0, 1])
42
+ assert torch.equal(matrix.xz, array[..., 0, 2])
43
+ assert torch.equal(matrix.yx, array[..., 1, 0])
44
+ assert torch.equal(matrix.yy, array[..., 1, 1])
45
+ assert torch.equal(matrix.yz, array[..., 1, 2])
46
+ assert torch.equal(matrix.zx, array[..., 2, 0])
47
+ assert torch.equal(matrix.zy, array[..., 2, 1])
48
+ assert torch.equal(matrix.zz, array[..., 2, 2])
49
+
50
+
51
+ def assert_array_close_to_rotation_matrix(array: torch.Tensor,
52
+ matrix: rotation_matrix.Rot3Array):
53
+ assert torch.allclose(matrix.to_tensor(), array, atol=1e-6)
54
+
55
+
56
+ def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
57
+ assert torch.equal(vec1.x, vec2.x)
58
+ assert torch.equal(vec1.y, vec2.y)
59
+ assert torch.equal(vec1.z, vec2.z)
60
+
61
+
62
+ def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
63
+ assert torch.allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
64
+ assert torch.allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
65
+ assert torch.allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
66
+
67
+
68
+ def assert_array_close_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
69
+ assert torch.allclose(vec.to_tensor(), array, atol=1e-6, rtol=0.)
70
+
71
+
72
+ def assert_array_equal_to_vector(array: torch.Tensor, vec: vector.Vec3Array):
73
+ assert torch.equal(vec.to_tensor(), array)
74
+
75
+
76
+ def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
77
+ rigid2: rigid_matrix_vector.Rigid3Array):
78
+ assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
79
+
80
+
81
+ def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
82
+ rigid2: rigid_matrix_vector.Rigid3Array):
83
+ assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
84
+
85
+
86
+ def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
87
+ trans: vector.Vec3Array,
88
+ rigid: rigid_matrix_vector.Rigid3Array):
89
+ assert_rotation_matrix_equal(rot, rigid.rotation)
90
+ assert_vectors_equal(trans, rigid.translation)
91
+
92
+
93
+ def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
94
+ trans: vector.Vec3Array,
95
+ rigid: rigid_matrix_vector.Rigid3Array):
96
+ assert_rotation_matrix_close(rot, rigid.rotation)
97
+ assert_vectors_close(trans, rigid.translation)
dockformer/utils/geometry/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utils for geometry library."""
15
+
16
+ import dataclasses
17
+
18
+
19
+ def get_field_names(cls):
20
+ fields = dataclasses.fields(cls)
21
+ field_names = [f.name for f in fields]
22
+ return field_names
dockformer/utils/geometry/vector.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Vec3Array Class."""
15
+
16
+ from __future__ import annotations
17
+ import dataclasses
18
+ from typing import Union, List
19
+
20
+ import torch
21
+
22
+ Float = Union[float, torch.Tensor]
23
+
24
+ @dataclasses.dataclass(frozen=True)
25
+ class Vec3Array:
26
+ x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
27
+ y: torch.Tensor
28
+ z: torch.Tensor
29
+
30
+ def __post_init__(self):
31
+ if hasattr(self.x, 'dtype'):
32
+ assert self.x.dtype == self.y.dtype
33
+ assert self.x.dtype == self.z.dtype
34
+ assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
35
+ assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
36
+
37
+ def __add__(self, other: Vec3Array) -> Vec3Array:
38
+ return Vec3Array(
39
+ self.x + other.x,
40
+ self.y + other.y,
41
+ self.z + other.z,
42
+ )
43
+
44
+ def __sub__(self, other: Vec3Array) -> Vec3Array:
45
+ return Vec3Array(
46
+ self.x - other.x,
47
+ self.y - other.y,
48
+ self.z - other.z,
49
+ )
50
+
51
+ def __mul__(self, other: Float) -> Vec3Array:
52
+ return Vec3Array(
53
+ self.x * other,
54
+ self.y * other,
55
+ self.z * other,
56
+ )
57
+
58
+ def __rmul__(self, other: Float) -> Vec3Array:
59
+ return self * other
60
+
61
+ def __truediv__(self, other: Float) -> Vec3Array:
62
+ return Vec3Array(
63
+ self.x / other,
64
+ self.y / other,
65
+ self.z / other,
66
+ )
67
+
68
+ def __neg__(self) -> Vec3Array:
69
+ return self * -1
70
+
71
+ def __pos__(self) -> Vec3Array:
72
+ return self * 1
73
+
74
+ def __getitem__(self, index) -> Vec3Array:
75
+ return Vec3Array(
76
+ self.x[index],
77
+ self.y[index],
78
+ self.z[index],
79
+ )
80
+
81
+ def __iter__(self):
82
+ return iter((self.x, self.y, self.z))
83
+
84
+ @property
85
+ def shape(self):
86
+ return self.x.shape
87
+
88
+ def map_tensor_fn(self, fn) -> Vec3Array:
89
+ return Vec3Array(
90
+ fn(self.x),
91
+ fn(self.y),
92
+ fn(self.z),
93
+ )
94
+
95
+ def cross(self, other: Vec3Array) -> Vec3Array:
96
+ """Compute cross product between 'self' and 'other'."""
97
+ new_x = self.y * other.z - self.z * other.y
98
+ new_y = self.z * other.x - self.x * other.z
99
+ new_z = self.x * other.y - self.y * other.x
100
+ return Vec3Array(new_x, new_y, new_z)
101
+
102
+ def dot(self, other: Vec3Array) -> Float:
103
+ """Compute dot product between 'self' and 'other'."""
104
+ return self.x * other.x + self.y * other.y + self.z * other.z
105
+
106
+ def norm(self, epsilon: float = 1e-6) -> Float:
107
+ """Compute Norm of Vec3Array, clipped to epsilon."""
108
+ # To avoid NaN on the backward pass, we must use maximum before the sqrt
109
+ norm2 = self.dot(self)
110
+ if epsilon:
111
+ norm2 = torch.clamp(norm2, min=epsilon**2)
112
+ return torch.sqrt(norm2)
113
+
114
+ def norm2(self):
115
+ return self.dot(self)
116
+
117
+ def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
118
+ """Return unit vector with optional clipping."""
119
+ return self / self.norm(epsilon)
120
+
121
+ def clone(self) -> Vec3Array:
122
+ return Vec3Array(
123
+ self.x.clone(),
124
+ self.y.clone(),
125
+ self.z.clone(),
126
+ )
127
+
128
+ def reshape(self, new_shape) -> Vec3Array:
129
+ x = self.x.reshape(new_shape)
130
+ y = self.y.reshape(new_shape)
131
+ z = self.z.reshape(new_shape)
132
+
133
+ return Vec3Array(x, y, z)
134
+
135
+ def sum(self, dim: int) -> Vec3Array:
136
+ return Vec3Array(
137
+ torch.sum(self.x, dim=dim),
138
+ torch.sum(self.y, dim=dim),
139
+ torch.sum(self.z, dim=dim),
140
+ )
141
+
142
+ def unsqueeze(self, dim: int):
143
+ return Vec3Array(
144
+ self.x.unsqueeze(dim),
145
+ self.y.unsqueeze(dim),
146
+ self.z.unsqueeze(dim),
147
+ )
148
+
149
+ @classmethod
150
+ def zeros(cls, shape, device="cpu"):
151
+ """Return Vec3Array corresponding to zeros of given shape."""
152
+ return cls(
153
+ torch.zeros(shape, dtype=torch.float32, device=device),
154
+ torch.zeros(shape, dtype=torch.float32, device=device),
155
+ torch.zeros(shape, dtype=torch.float32, device=device)
156
+ )
157
+
158
+ def to_tensor(self) -> torch.Tensor:
159
+ return torch.stack([self.x, self.y, self.z], dim=-1)
160
+
161
+ @classmethod
162
+ def from_array(cls, tensor):
163
+ return cls(*torch.unbind(tensor, dim=-1))
164
+
165
+ @classmethod
166
+ def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
167
+ return cls(
168
+ torch.cat([v.x for v in vecs], dim=dim),
169
+ torch.cat([v.y for v in vecs], dim=dim),
170
+ torch.cat([v.z for v in vecs], dim=dim),
171
+ )
172
+
173
+
174
+ def square_euclidean_distance(
175
+ vec1: Vec3Array,
176
+ vec2: Vec3Array,
177
+ epsilon: float = 1e-6
178
+ ) -> Float:
179
+ """Computes square of euclidean distance between 'vec1' and 'vec2'.
180
+
181
+ Args:
182
+ vec1: Vec3Array to compute distance to
183
+ vec2: Vec3Array to compute distance from, should be
184
+ broadcast compatible with 'vec1'
185
+ epsilon: distance is clipped from below to be at least epsilon
186
+
187
+ Returns:
188
+ Array of square euclidean distances;
189
+ shape will be result of broadcasting 'vec1' and 'vec2'
190
+ """
191
+ difference = vec1 - vec2
192
+ distance = difference.dot(difference)
193
+ if epsilon:
194
+ distance = torch.clamp(distance, min=epsilon)
195
+ return distance
196
+
197
+
198
+ def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
199
+ return vector1.dot(vector2)
200
+
201
+
202
+ def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
203
+ return vector1.cross(vector2)
204
+
205
+
206
+ def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
207
+ return vector.norm(epsilon)
208
+
209
+
210
+ def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
211
+ return vector.normalized(epsilon)
212
+
213
+
214
+ def euclidean_distance(
215
+ vec1: Vec3Array,
216
+ vec2: Vec3Array,
217
+ epsilon: float = 1e-6
218
+ ) -> Float:
219
+ """Computes euclidean distance between 'vec1' and 'vec2'.
220
+
221
+ Args:
222
+ vec1: Vec3Array to compute euclidean distance to
223
+ vec2: Vec3Array to compute euclidean distance from, should be
224
+ broadcast compatible with 'vec1'
225
+ epsilon: distance is clipped from below to be at least epsilon
226
+
227
+ Returns:
228
+ Array of euclidean distances;
229
+ shape will be result of broadcasting 'vec1' and 'vec2'
230
+ """
231
+ distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
232
+ distance = torch.sqrt(distance_sq)
233
+ return distance
234
+
235
+
236
+ def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
237
+ d: Vec3Array) -> Float:
238
+ """Computes torsion angle for a quadruple of points.
239
+
240
+ For points (a, b, c, d), this is the angle between the planes defined by
241
+ points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
242
+
243
+ Arguments:
244
+ a: A Vec3Array of coordinates.
245
+ b: A Vec3Array of coordinates.
246
+ c: A Vec3Array of coordinates.
247
+ d: A Vec3Array of coordinates.
248
+
249
+ Returns:
250
+ A tensor of angles in radians: [-pi, pi].
251
+ """
252
+ v1 = a - b
253
+ v2 = b - c
254
+ v3 = d - c
255
+
256
+ c1 = v1.cross(v2)
257
+ c2 = v3.cross(v2)
258
+ c3 = c2.cross(c1)
259
+
260
+ v2_mag = v2.norm()
261
+ return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2))
dockformer/utils/kernel/__init__.py ADDED
File without changes
dockformer/utils/kernel/attention_core.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ from functools import reduce
16
+ from operator import mul
17
+
18
+ import torch
19
+
20
+ # TODO bshor: solve attn_core_is_installed in mac
21
+ attn_core_is_installed = importlib.util.find_spec("attn_core_inplace_cuda") is not None
22
+ attn_core_inplace_cuda = None
23
+ if attn_core_is_installed:
24
+ attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
25
+
26
+
27
+ SUPPORTED_DTYPES = [torch.float32, torch.bfloat16]
28
+
29
+
30
+ class AttentionCoreFunction(torch.autograd.Function):
31
+ @staticmethod
32
+ def forward(ctx, q, k, v, bias_1=None, bias_2=None):
33
+ if(bias_1 is None and bias_2 is not None):
34
+ raise ValueError("bias_1 must be specified before bias_2")
35
+ if(q.dtype not in SUPPORTED_DTYPES):
36
+ raise ValueError("Unsupported datatype")
37
+
38
+ q = q.contiguous()
39
+ k = k.contiguous()
40
+
41
+ # [*, H, Q, K]
42
+ attention_logits = torch.matmul(
43
+ q, k.transpose(-1, -2),
44
+ )
45
+
46
+ if(bias_1 is not None):
47
+ attention_logits += bias_1
48
+ if(bias_2 is not None):
49
+ attention_logits += bias_2
50
+
51
+ attn_core_inplace_cuda.forward_(
52
+ attention_logits,
53
+ reduce(mul, attention_logits.shape[:-1]),
54
+ attention_logits.shape[-1],
55
+ )
56
+
57
+ o = torch.matmul(attention_logits, v)
58
+
59
+ ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None
60
+ ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None
61
+ ctx.save_for_backward(q, k, v, attention_logits)
62
+
63
+ return o
64
+
65
+ @staticmethod
66
+ def backward(ctx, grad_output):
67
+ q, k, v, attention_logits = ctx.saved_tensors
68
+ grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None
69
+
70
+ grad_v = torch.matmul(
71
+ attention_logits.transpose(-1, -2),
72
+ grad_output
73
+ )
74
+
75
+ attn_core_inplace_cuda.backward_(
76
+ attention_logits,
77
+ grad_output.contiguous(),
78
+ v.contiguous(), # v is implicitly transposed in the kernel
79
+ reduce(mul, attention_logits.shape[:-1]),
80
+ attention_logits.shape[-1],
81
+ grad_output.shape[-1],
82
+ )
83
+
84
+ if(ctx.bias_1_shape is not None):
85
+ grad_bias_1 = torch.sum(
86
+ attention_logits,
87
+ dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1),
88
+ keepdim=True,
89
+ )
90
+
91
+ if(ctx.bias_2_shape is not None):
92
+ grad_bias_2 = torch.sum(
93
+ attention_logits,
94
+ dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1),
95
+ keepdim=True,
96
+ )
97
+
98
+ grad_q = torch.matmul(
99
+ attention_logits, k
100
+ )
101
+ grad_k = torch.matmul(
102
+ q.transpose(-1, -2), attention_logits,
103
+ ).transpose(-1, -2)
104
+
105
+ return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2
106
+
107
+ attention_core = AttentionCoreFunction.apply
dockformer/utils/kernel/csrc/compat.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
2
+
3
+ #ifndef TORCH_CHECK
4
+ #define TORCH_CHECK AT_CHECK
5
+ #endif
6
+
7
+ #ifdef VERSION_GE_1_3
8
+ #define DATA_PTR data_ptr
9
+ #else
10
+ #define DATA_PTR data
11
+ #endif
dockformer/utils/kernel/csrc/softmax_cuda.cpp ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 AlQuraishi Laboratory
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
16
+
17
+ #include <torch/extension.h>
18
+
19
+ void attn_softmax_inplace_forward_(
20
+ at::Tensor input,
21
+ long long rows, int cols
22
+ );
23
+ void attn_softmax_inplace_backward_(
24
+ at::Tensor output,
25
+ at::Tensor d_ov,
26
+ at::Tensor values,
27
+ long long rows,
28
+ int cols_output,
29
+ int cols_values
30
+ );
31
+
32
+
33
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
34
+ m.def(
35
+ "forward_",
36
+ &attn_softmax_inplace_forward_,
37
+ "Softmax forward (CUDA)"
38
+ );
39
+ m.def(
40
+ "backward_",
41
+ &attn_softmax_inplace_backward_,
42
+ "Softmax backward (CUDA)"
43
+ );
44
+ }
dockformer/utils/kernel/csrc/softmax_cuda_kernel.cu ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 AlQuraishi Laboratory
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
16
+
17
+ #include <math_constants.h>
18
+ #include <torch/extension.h>
19
+ #include <c10/cuda/CUDAGuard.h>
20
+
21
+ #include <iostream>
22
+
23
+ #include "ATen/ATen.h"
24
+ #include "ATen/cuda/CUDAContext.h"
25
+ #include "compat.h"
26
+
27
+ #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
28
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
29
+ #define CHECK_INPUT(x) \
30
+ CHECK_CUDA(x); \
31
+ CHECK_CONTIGUOUS(x)
32
+
33
+ __inline__ __device__ float WarpAllReduceMax(float val) {
34
+ for (int mask = 1; mask < 32; mask *= 2) {
35
+ val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
36
+ }
37
+ return val;
38
+ }
39
+
40
+ __inline__ __device__ float WarpAllReduceSum(float val) {
41
+ for (int mask = 1; mask < 32; mask *= 2) {
42
+ val += __shfl_xor_sync(0xffffffff, val, mask);
43
+ }
44
+ return val;
45
+ }
46
+
47
+
48
+ template<typename T>
49
+ __global__ void attn_softmax_inplace_(
50
+ T *input,
51
+ long long rows, int cols
52
+ ) {
53
+ int threadidx_x = threadIdx.x / 32;
54
+ int threadidx_y = threadIdx.x % 32;
55
+ long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
56
+ int cols_per_thread = (cols + 31) / 32;
57
+ int cols_this_thread = cols_per_thread;
58
+
59
+ int last_y = (cols / cols_per_thread);
60
+
61
+ if (threadidx_y == last_y) {
62
+ cols_this_thread = cols - cols_per_thread * last_y;
63
+ }
64
+ else if (threadidx_y > last_y) {
65
+ cols_this_thread = 0;
66
+ }
67
+
68
+ float buf[32];
69
+
70
+ int lane_id = threadidx_y;
71
+
72
+ if (row_offset < rows) {
73
+ T *row_input = input + row_offset * cols;
74
+ T *row_output = row_input;
75
+
76
+ #pragma unroll
77
+ for (int i = 0; i < cols_this_thread; i++) {
78
+ int idx = lane_id * cols_per_thread + i;
79
+ buf[i] = static_cast<float>(row_input[idx]);
80
+ }
81
+
82
+ float thread_max = -1 * CUDART_INF_F;
83
+ #pragma unroll
84
+ for (int i = 0; i < cols_this_thread; i++) {
85
+ thread_max = max(thread_max, buf[i]);
86
+ }
87
+
88
+ float warp_max = WarpAllReduceMax(thread_max);
89
+
90
+ float thread_sum = 0.f;
91
+ #pragma unroll
92
+ for (int i = 0; i < cols_this_thread; i++) {
93
+ buf[i] = __expf(buf[i] - warp_max);
94
+ thread_sum += buf[i];
95
+ }
96
+
97
+ float warp_sum = WarpAllReduceSum(thread_sum);
98
+ #pragma unroll
99
+ for (int i = 0; i < cols_this_thread; i++) {
100
+ row_output[lane_id * cols_per_thread + i] =
101
+ static_cast<T>(__fdividef(buf[i], warp_sum));
102
+ }
103
+ }
104
+ }
105
+
106
+
107
+ void attn_softmax_inplace_forward_(
108
+ at::Tensor input,
109
+ long long rows, int cols
110
+ ) {
111
+ CHECK_INPUT(input);
112
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
113
+
114
+ int grid = (rows + 3) / 4;
115
+ dim3 block(128);
116
+
117
+ if (input.dtype() == torch::kFloat32) {
118
+ attn_softmax_inplace_<float><<<grid, block>>>(
119
+ (float *)input.data_ptr(),
120
+ rows, cols
121
+ );
122
+ }
123
+ else {
124
+ attn_softmax_inplace_<at::BFloat16><<<grid, block>>>(
125
+ (at::BFloat16 *)input.data_ptr(),
126
+ rows, cols
127
+ );
128
+ }
129
+ }
130
+
131
+
132
+ template<typename T>
133
+ __global__ void attn_softmax_inplace_grad_(
134
+ T *output,
135
+ T *d_ov,
136
+ T *values,
137
+ long long rows,
138
+ int cols_output,
139
+ int cols_values
140
+ ) {
141
+ int threadidx_x = threadIdx.x / 32;
142
+ int threadidx_y = threadIdx.x % 32;
143
+ long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
144
+ int cols_per_thread = (cols_output + 31) / 32;
145
+ int cols_this_thread = cols_per_thread;
146
+ int rows_values = cols_output;
147
+ // values are set to the beginning of the current
148
+ // rows_values x cols_values leaf matrix
149
+ long long value_row_offset = row_offset - row_offset % rows_values;
150
+ int last_y = (cols_output / cols_per_thread);
151
+
152
+ if (threadidx_y == last_y) {
153
+ cols_this_thread = cols_output - cols_per_thread * last_y;
154
+ }
155
+ else if (threadidx_y > last_y) {
156
+ cols_this_thread = 0;
157
+ }
158
+
159
+ float y_buf[32];
160
+ float dy_buf[32];
161
+
162
+ int lane_id = threadidx_y;
163
+
164
+ if (row_offset < rows) {
165
+ T *row_output = output + row_offset * cols_output;
166
+ T *row_d_ov = d_ov + row_offset * cols_values;
167
+ T *row_values = values + value_row_offset * cols_values;
168
+
169
+ float thread_max = -1 * CUDART_INF_F;
170
+
171
+ // Compute a chunk of the output gradient on the fly
172
+ int value_row_idx = 0;
173
+ int value_idx = 0;
174
+ #pragma unroll
175
+ for (int i = 0; i < cols_this_thread; i++) {
176
+ T sum = 0.;
177
+ #pragma unroll
178
+ for (int j = 0; j < cols_values; j++) {
179
+ value_row_idx = ((lane_id * cols_per_thread) + i);
180
+ value_idx = value_row_idx * cols_values + j;
181
+ sum += row_d_ov[j] * row_values[value_idx];
182
+ }
183
+ dy_buf[i] = static_cast<float>(sum);
184
+ }
185
+
186
+ #pragma unroll
187
+ for (int i = 0; i < cols_this_thread; i++) {
188
+ y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
189
+ }
190
+
191
+ float thread_sum = 0.;
192
+
193
+ #pragma unroll
194
+ for (int i = 0; i < cols_this_thread; i++) {
195
+ thread_sum += y_buf[i] * dy_buf[i];
196
+ }
197
+
198
+ float warp_sum = WarpAllReduceSum(thread_sum);
199
+
200
+ #pragma unroll
201
+ for (int i = 0; i < cols_this_thread; i++) {
202
+ row_output[lane_id * cols_per_thread + i] = static_cast<T>(
203
+ (dy_buf[i] - warp_sum) * y_buf[i]
204
+ );
205
+ }
206
+ }
207
+ }
208
+
209
+
210
+ void attn_softmax_inplace_backward_(
211
+ at::Tensor output,
212
+ at::Tensor d_ov,
213
+ at::Tensor values,
214
+ long long rows,
215
+ int cols_output,
216
+ int cols_values
217
+ ) {
218
+ CHECK_INPUT(output);
219
+ CHECK_INPUT(d_ov);
220
+ CHECK_INPUT(values);
221
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
222
+
223
+ int grid = (rows + 3) / 4;
224
+ dim3 block(128);
225
+
226
+ if (output.dtype() == torch::kFloat32) {
227
+ attn_softmax_inplace_grad_<float><<<grid, block>>>(
228
+ (float *)output.data_ptr(),
229
+ (float *)d_ov.data_ptr(),
230
+ (float *)values.data_ptr(),
231
+ rows, cols_output, cols_values
232
+ );
233
+ } else {
234
+ attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>(
235
+ (at::BFloat16 *)output.data_ptr(),
236
+ (at::BFloat16 *)d_ov.data_ptr(),
237
+ (at::BFloat16 *)values.data_ptr(),
238
+ rows, cols_output, cols_values
239
+ );
240
+ }
241
+ }
dockformer/utils/kernel/csrc/softmax_cuda_stub.cpp ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 AlQuraishi Laboratory
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
16
+
17
+ #include <torch/extension.h>
18
+
19
+ void attn_softmax_inplace_forward_(
20
+ at::Tensor input,
21
+ long long rows, int cols
22
+ )
23
+ {
24
+ throw std::runtime_error("attn_softmax_inplace_forward_ not implemented on CPU");
25
+ };
26
+ void attn_softmax_inplace_backward_(
27
+ at::Tensor output,
28
+ at::Tensor d_ov,
29
+ at::Tensor values,
30
+ long long rows,
31
+ int cols_output,
32
+ int cols_values
33
+ )
34
+ {
35
+ throw std::runtime_error("attn_softmax_inplace_backward_ not implemented on CPU");
36
+ };
dockformer/utils/logger.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import operator
16
+ import time
17
+
18
+ import dllogger as logger
19
+ from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
20
+ import numpy as np
21
+ from lightning import Callback
22
+ import torch.cuda.profiler as profiler
23
+
24
+
25
+ def is_main_process():
26
+ return int(os.getenv("LOCAL_RANK", "0")) == 0
27
+
28
+
29
+ class PerformanceLoggingCallback(Callback):
30
+ def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False):
31
+ logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)])
32
+ self.warmup_steps = warmup_steps
33
+ self.global_batch_size = global_batch_size
34
+ self.step = 0
35
+ self.profile = profile
36
+ self.timestamps = []
37
+
38
+ def do_step(self):
39
+ self.step += 1
40
+ if self.profile and self.step == self.warmup_steps:
41
+ profiler.start()
42
+ if self.step > self.warmup_steps:
43
+ self.timestamps.append(time.time())
44
+
45
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
46
+ self.do_step()
47
+
48
+ def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx: int = 0):
49
+ self.do_step()
50
+
51
+ def process_performance_stats(self, deltas):
52
+ def _round3(val):
53
+ return round(val, 3)
54
+
55
+ throughput_imgps = _round3(self.global_batch_size / np.mean(deltas))
56
+ timestamps_ms = 1000 * deltas
57
+ stats = {
58
+ f"throughput": throughput_imgps,
59
+ f"latency_mean": _round3(timestamps_ms.mean()),
60
+ }
61
+ for level in [90, 95, 99]:
62
+ stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))})
63
+
64
+ return stats
65
+
66
+ def _log(self):
67
+ if is_main_process():
68
+ diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1]))
69
+ deltas = np.array(diffs)
70
+ stats = self.process_performance_stats(deltas)
71
+ logger.log(step=(), data=stats)
72
+ logger.flush()
73
+
74
+ def on_train_end(self, trainer, pl_module):
75
+ if self.profile:
76
+ profiler.stop()
77
+ self._log()
78
+
79
+ def on_epoch_end(self, trainer, pl_module):
80
+ self._log()
dockformer/utils/loss.py ADDED
@@ -0,0 +1,1171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import time
16
+
17
+ import ml_collections
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from typing import Dict, Optional, Tuple
22
+
23
+ from dockformer.utils import residue_constants
24
+ from dockformer.utils.feats import pseudo_beta_fn
25
+ from dockformer.utils.rigid_utils import Rotation, Rigid
26
+ from dockformer.utils.geometry.vector import Vec3Array, euclidean_distance
27
+ from dockformer.utils.tensor_utils import (
28
+ tree_map,
29
+ masked_mean,
30
+ permute_final_dims,
31
+ )
32
+ import logging
33
+ from dockformer.utils.tensor_utils import tensor_tree_map
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def softmax_cross_entropy(logits, labels):
39
+ loss = -1 * torch.sum(
40
+ labels * torch.nn.functional.log_softmax(logits, dim=-1),
41
+ dim=-1,
42
+ )
43
+ return loss
44
+
45
+
46
+ def sigmoid_cross_entropy(logits, labels):
47
+ logits_dtype = logits.dtype
48
+ try:
49
+ logits = logits.double()
50
+ labels = labels.double()
51
+ except:
52
+ logits = logits.to(dtype=torch.float32)
53
+ labels = labels.to(dtype=torch.float32)
54
+
55
+ log_p = torch.nn.functional.logsigmoid(logits)
56
+ # log_p = torch.log(torch.sigmoid(logits))
57
+ log_not_p = torch.nn.functional.logsigmoid(-1 * logits)
58
+ # log_not_p = torch.log(torch.sigmoid(-logits))
59
+ loss = (-1. * labels) * log_p - (1. - labels) * log_not_p
60
+ loss = loss.to(dtype=logits_dtype)
61
+ return loss
62
+
63
+
64
+ def torsion_angle_loss(
65
+ a, # [*, N, 7, 2]
66
+ a_gt, # [*, N, 7, 2]
67
+ a_alt_gt, # [*, N, 7, 2]
68
+ ):
69
+ # [*, N, 7]
70
+ norm = torch.norm(a, dim=-1)
71
+
72
+ # [*, N, 7, 2]
73
+ a = a / norm.unsqueeze(-1)
74
+
75
+ # [*, N, 7]
76
+ diff_norm_gt = torch.norm(a - a_gt, dim=-1)
77
+ diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
78
+ min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
79
+
80
+ # [*]
81
+ l_torsion = torch.mean(min_diff, dim=(-1, -2))
82
+ l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
83
+
84
+ an_weight = 0.02
85
+ return l_torsion + an_weight * l_angle_norm
86
+
87
+
88
+ def compute_fape(
89
+ pred_frames: Rigid,
90
+ target_frames: Rigid,
91
+ frames_mask: torch.Tensor,
92
+ pred_positions: torch.Tensor,
93
+ target_positions: torch.Tensor,
94
+ positions_mask: torch.Tensor,
95
+ length_scale: float,
96
+ pair_mask: Optional[torch.Tensor] = None,
97
+ l1_clamp_distance: Optional[float] = None,
98
+ eps=1e-8,
99
+ ) -> torch.Tensor:
100
+ """
101
+ Computes FAPE loss.
102
+
103
+ Args:
104
+ pred_frames:
105
+ [*, N_frames] Rigid object of predicted frames
106
+ target_frames:
107
+ [*, N_frames] Rigid object of ground truth frames
108
+ frames_mask:
109
+ [*, N_frames] binary mask for the frames
110
+ pred_positions:
111
+ [*, N_pts, 3] predicted atom positions
112
+ target_positions:
113
+ [*, N_pts, 3] ground truth positions
114
+ positions_mask:
115
+ [*, N_pts] positions mask
116
+ length_scale:
117
+ Length scale by which the loss is divided
118
+ pair_mask:
119
+ [*, N_frames, N_pts] mask to use for
120
+ separating intra- from inter-chain losses.
121
+ l1_clamp_distance:
122
+ Cutoff above which distance errors are disregarded
123
+ eps:
124
+ Small value used to regularize denominators
125
+ Returns:
126
+ [*] loss tensor
127
+ """
128
+ # [*, N_frames, N_pts, 3]
129
+ local_pred_pos = pred_frames.invert()[..., None].apply(
130
+ pred_positions[..., None, :, :],
131
+ )
132
+ local_target_pos = target_frames.invert()[..., None].apply(
133
+ target_positions[..., None, :, :],
134
+ )
135
+
136
+ error_dist = torch.sqrt(
137
+ torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
138
+ )
139
+
140
+ if l1_clamp_distance is not None:
141
+ error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
142
+
143
+ normed_error = error_dist / length_scale
144
+ normed_error = normed_error * frames_mask[..., None]
145
+ normed_error = normed_error * positions_mask[..., None, :]
146
+
147
+ if pair_mask is not None:
148
+ normed_error = normed_error * pair_mask
149
+ normed_error = torch.sum(normed_error, dim=(-1, -2))
150
+
151
+ mask = frames_mask[..., None] * positions_mask[..., None, :] * pair_mask
152
+ norm_factor = torch.sum(mask, dim=(-2, -1))
153
+
154
+ normed_error = normed_error / (eps + norm_factor)
155
+ else:
156
+ # FP16-friendly averaging. Roughly equivalent to:
157
+ #
158
+ # norm_factor = (
159
+ # torch.sum(frames_mask, dim=-1) *
160
+ # torch.sum(positions_mask, dim=-1)
161
+ # )
162
+ # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
163
+ #
164
+ # ("roughly" because eps is necessarily duplicated in the latter)
165
+ normed_error = torch.sum(normed_error, dim=-1)
166
+ normed_error = (
167
+ normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
168
+ )
169
+ normed_error = torch.sum(normed_error, dim=-1)
170
+ normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
171
+
172
+ return normed_error
173
+
174
+
175
+ def backbone_loss(
176
+ backbone_rigid_tensor: torch.Tensor,
177
+ backbone_rigid_mask: torch.Tensor,
178
+ traj: torch.Tensor,
179
+ pair_mask: Optional[torch.Tensor] = None,
180
+ use_clamped_fape: Optional[torch.Tensor] = None,
181
+ clamp_distance: float = 10.0,
182
+ loss_unit_distance: float = 10.0,
183
+ eps: float = 1e-4,
184
+ **kwargs,
185
+ ) -> torch.Tensor:
186
+ ### need to check if the traj belongs to 4*4 matrix or a tensor_7
187
+ if traj.shape[-1] == 7:
188
+ pred_aff = Rigid.from_tensor_7(traj)
189
+ elif traj.shape[-1] == 4:
190
+ pred_aff = Rigid.from_tensor_4x4(traj)
191
+
192
+ pred_aff = Rigid(
193
+ Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
194
+ pred_aff.get_trans(),
195
+ )
196
+
197
+ # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
198
+ # backbone tensor, normalizes it, and then turns it back to a rotation
199
+ # matrix. To avoid a potentially numerically unstable rotation matrix
200
+ # to quaternion conversion, we just use the original rotation matrix
201
+ # outright. This one hasn't been composed a bunch of times, though, so
202
+ # it might be fine.
203
+ gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
204
+
205
+ fape_loss = compute_fape(
206
+ pred_aff,
207
+ gt_aff[None],
208
+ backbone_rigid_mask[None],
209
+ pred_aff.get_trans(),
210
+ gt_aff[None].get_trans(),
211
+ backbone_rigid_mask[None],
212
+ pair_mask=pair_mask,
213
+ l1_clamp_distance=clamp_distance,
214
+ length_scale=loss_unit_distance,
215
+ eps=eps,
216
+ )
217
+ if use_clamped_fape is not None:
218
+ unclamped_fape_loss = compute_fape(
219
+ pred_aff,
220
+ gt_aff[None],
221
+ backbone_rigid_mask[None],
222
+ pred_aff.get_trans(),
223
+ gt_aff[None].get_trans(),
224
+ backbone_rigid_mask[None],
225
+ pair_mask=pair_mask,
226
+ l1_clamp_distance=None,
227
+ length_scale=loss_unit_distance,
228
+ eps=eps,
229
+ )
230
+
231
+ fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
232
+ 1 - use_clamped_fape
233
+ )
234
+
235
+ # Average over the batch dimension
236
+ fape_loss = torch.mean(fape_loss)
237
+
238
+ return fape_loss
239
+
240
+
241
+ def sidechain_loss(
242
+ pred_sidechain_frames: torch.Tensor,
243
+ pred_sidechain_atom_pos: torch.Tensor,
244
+ rigidgroups_gt_frames: torch.Tensor,
245
+ rigidgroups_alt_gt_frames: torch.Tensor,
246
+ rigidgroups_gt_exists: torch.Tensor,
247
+ renamed_atom14_gt_positions: torch.Tensor,
248
+ renamed_atom14_gt_exists: torch.Tensor,
249
+ alt_naming_is_better: torch.Tensor,
250
+ ligand_mask: torch.Tensor,
251
+ clamp_distance: float = 10.0,
252
+ length_scale: float = 10.0,
253
+ eps: float = 1e-4,
254
+ only_include_ligand_atoms: bool = False,
255
+ **kwargs,
256
+ ) -> torch.Tensor:
257
+ renamed_gt_frames = (
258
+ 1.0 - alt_naming_is_better[..., None, None, None]
259
+ ) * rigidgroups_gt_frames + alt_naming_is_better[
260
+ ..., None, None, None
261
+ ] * rigidgroups_alt_gt_frames
262
+
263
+ # Steamroll the inputs
264
+ pred_sidechain_frames = pred_sidechain_frames[-1] # get only the last layer of the strcuture module
265
+ batch_dims = pred_sidechain_frames.shape[:-4]
266
+ pred_sidechain_frames = pred_sidechain_frames.view(*batch_dims, -1, 4, 4)
267
+ pred_sidechain_frames = Rigid.from_tensor_4x4(pred_sidechain_frames)
268
+ renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
269
+ renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
270
+ rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
271
+ pred_sidechain_atom_pos = pred_sidechain_atom_pos[-1]
272
+ pred_sidechain_atom_pos = pred_sidechain_atom_pos.view(*batch_dims, -1, 3)
273
+ renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
274
+ *batch_dims, -1, 3
275
+ )
276
+ renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
277
+
278
+ atom_mask_to_apply = renamed_atom14_gt_exists
279
+ if only_include_ligand_atoms:
280
+ ligand_atom14_mask = torch.repeat_interleave(ligand_mask, 14, dim=-1)
281
+ atom_mask_to_apply = atom_mask_to_apply * ligand_atom14_mask
282
+
283
+ fape = compute_fape(
284
+ pred_sidechain_frames,
285
+ renamed_gt_frames,
286
+ rigidgroups_gt_exists,
287
+ pred_sidechain_atom_pos,
288
+ renamed_atom14_gt_positions,
289
+ atom_mask_to_apply,
290
+ pair_mask=None,
291
+ l1_clamp_distance=clamp_distance,
292
+ length_scale=length_scale,
293
+ eps=eps,
294
+ )
295
+
296
+ return fape
297
+
298
+
299
+ def fape_bb(
300
+ out: Dict[str, torch.Tensor],
301
+ batch: Dict[str, torch.Tensor],
302
+ config: ml_collections.ConfigDict,
303
+ ) -> torch.Tensor:
304
+ traj = out["sm"]["frames"]
305
+ bb_loss = backbone_loss(
306
+ traj=traj,
307
+ **{**batch, **config},
308
+ )
309
+ loss = torch.mean(bb_loss)
310
+ return loss
311
+
312
+
313
+ def fape_sidechain(
314
+ out: Dict[str, torch.Tensor],
315
+ batch: Dict[str, torch.Tensor],
316
+ config: ml_collections.ConfigDict,
317
+ ) -> torch.Tensor:
318
+ sc_loss = sidechain_loss(
319
+ out["sm"]["sidechain_frames"],
320
+ out["sm"]["positions"],
321
+ **{**batch, **config},
322
+ )
323
+ loss = torch.mean(sc_loss)
324
+ return loss
325
+
326
+
327
+ def fape_interface(
328
+ out: Dict[str, torch.Tensor],
329
+ batch: Dict[str, torch.Tensor],
330
+ config: ml_collections.ConfigDict,
331
+ ) -> torch.Tensor:
332
+ sc_loss = sidechain_loss(
333
+ out["sm"]["sidechain_frames"],
334
+ out["sm"]["positions"],
335
+ only_include_ligand_atoms=True,
336
+ **{**batch, **config},
337
+ )
338
+ loss = torch.mean(sc_loss)
339
+ return loss
340
+
341
+
342
+ def supervised_chi_loss(
343
+ angles_sin_cos: torch.Tensor,
344
+ unnormalized_angles_sin_cos: torch.Tensor,
345
+ aatype: torch.Tensor,
346
+ protein_mask: torch.Tensor,
347
+ chi_mask: torch.Tensor,
348
+ chi_angles_sin_cos: torch.Tensor,
349
+ chi_weight: float,
350
+ angle_norm_weight: float,
351
+ eps=1e-6,
352
+ **kwargs,
353
+ ) -> torch.Tensor:
354
+ """
355
+ Implements Algorithm 27 (torsionAngleLoss)
356
+
357
+ Args:
358
+ angles_sin_cos:
359
+ [*, N, 7, 2] predicted angles
360
+ unnormalized_angles_sin_cos:
361
+ The same angles, but unnormalized
362
+ aatype:
363
+ [*, N] residue indices
364
+ protein_mask:
365
+ [*, N] protein mask
366
+ chi_mask:
367
+ [*, N, 7] angle mask
368
+ chi_angles_sin_cos:
369
+ [*, N, 7, 2] ground truth angles
370
+ chi_weight:
371
+ Weight for the angle component of the loss
372
+ angle_norm_weight:
373
+ Weight for the normalization component of the loss
374
+ Returns:
375
+ [*] loss tensor
376
+ """
377
+ pred_angles = angles_sin_cos[..., 3:, :]
378
+ residue_type_one_hot = torch.nn.functional.one_hot(
379
+ aatype,
380
+ residue_constants.restype_num + 1,
381
+ )
382
+ chi_pi_periodic = torch.einsum(
383
+ "...ij,jk->ik",
384
+ residue_type_one_hot.type(angles_sin_cos.dtype),
385
+ angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
386
+ )
387
+
388
+ true_chi = chi_angles_sin_cos[None]
389
+
390
+ shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
391
+ true_chi_shifted = shifted_mask * true_chi
392
+ sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
393
+ sq_chi_error_shifted = torch.sum(
394
+ (true_chi_shifted - pred_angles) ** 2, dim=-1
395
+ )
396
+ sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
397
+
398
+ # The ol' switcheroo
399
+ sq_chi_error = sq_chi_error.permute(
400
+ *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
401
+ )
402
+
403
+ sq_chi_loss = masked_mean(
404
+ chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
405
+ )
406
+
407
+ loss = chi_weight * sq_chi_loss
408
+
409
+ angle_norm = torch.sqrt(
410
+ torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
411
+ )
412
+ norm_error = torch.abs(angle_norm - 1.0)
413
+ norm_error = norm_error.permute(
414
+ *range(len(norm_error.shape))[1:-2], 0, -2, -1
415
+ )
416
+ angle_norm_loss = masked_mean(
417
+ protein_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
418
+ )
419
+
420
+ loss = loss + angle_norm_weight * angle_norm_loss
421
+
422
+ # Average over the batch dimension
423
+ loss = torch.mean(loss)
424
+
425
+ return loss
426
+
427
+
428
+ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
429
+ num_bins = logits.shape[-1]
430
+ bin_width = 1.0 / num_bins
431
+ bounds = torch.arange(
432
+ start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
433
+ )
434
+ probs = torch.nn.functional.softmax(logits, dim=-1)
435
+ pred_lddt_ca = torch.sum(
436
+ probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
437
+ dim=-1,
438
+ )
439
+ return pred_lddt_ca * 100
440
+
441
+
442
+ def lddt(
443
+ all_atom_pred_pos: torch.Tensor,
444
+ all_atom_positions: torch.Tensor,
445
+ all_atom_mask: torch.Tensor,
446
+ cutoff: float = 15.0,
447
+ eps: float = 1e-10,
448
+ per_residue: bool = True,
449
+ ) -> torch.Tensor:
450
+ n = all_atom_mask.shape[-2]
451
+ dmat_true = torch.sqrt(
452
+ eps
453
+ + torch.sum(
454
+ (
455
+ all_atom_positions[..., None, :]
456
+ - all_atom_positions[..., None, :, :]
457
+ )
458
+ ** 2,
459
+ dim=-1,
460
+ )
461
+ )
462
+
463
+ dmat_pred = torch.sqrt(
464
+ eps
465
+ + torch.sum(
466
+ (
467
+ all_atom_pred_pos[..., None, :]
468
+ - all_atom_pred_pos[..., None, :, :]
469
+ )
470
+ ** 2,
471
+ dim=-1,
472
+ )
473
+ )
474
+ dists_to_score = (
475
+ (dmat_true < cutoff)
476
+ * all_atom_mask
477
+ * permute_final_dims(all_atom_mask, (1, 0))
478
+ * (1.0 - torch.eye(n, device=all_atom_mask.device))
479
+ )
480
+
481
+ dist_l1 = torch.abs(dmat_true - dmat_pred)
482
+
483
+ score = (
484
+ (dist_l1 < 0.5).type(dist_l1.dtype)
485
+ + (dist_l1 < 1.0).type(dist_l1.dtype)
486
+ + (dist_l1 < 2.0).type(dist_l1.dtype)
487
+ + (dist_l1 < 4.0).type(dist_l1.dtype)
488
+ )
489
+ score = score * 0.25
490
+
491
+ dims = (-1,) if per_residue else (-2, -1)
492
+ norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
493
+ score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
494
+
495
+ return score
496
+
497
+
498
+ def lddt_ca(
499
+ all_atom_pred_pos: torch.Tensor,
500
+ all_atom_positions: torch.Tensor,
501
+ all_atom_mask: torch.Tensor,
502
+ cutoff: float = 15.0,
503
+ eps: float = 1e-10,
504
+ per_residue: bool = True,
505
+ ) -> torch.Tensor:
506
+ ca_pos = residue_constants.atom_order["CA"]
507
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
508
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
509
+ all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
510
+
511
+ return lddt(
512
+ all_atom_pred_pos,
513
+ all_atom_positions,
514
+ all_atom_mask,
515
+ cutoff=cutoff,
516
+ eps=eps,
517
+ per_residue=per_residue,
518
+ )
519
+
520
+
521
+ def lddt_loss(
522
+ logits: torch.Tensor,
523
+ all_atom_pred_pos: torch.Tensor,
524
+ atom37_gt_positions: torch.Tensor,
525
+ atom37_atom_exists_in_gt: torch.Tensor,
526
+ resolution: torch.Tensor,
527
+ cutoff: float = 15.0,
528
+ no_bins: int = 50,
529
+ min_resolution: float = 0.1,
530
+ max_resolution: float = 3.0,
531
+ eps: float = 1e-10,
532
+ **kwargs,
533
+ ) -> torch.Tensor:
534
+ # remove ligand
535
+ logits = logits[:, :atom37_atom_exists_in_gt.shape[1], :]
536
+
537
+ ca_pos = residue_constants.atom_order["CA"]
538
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
539
+ atom37_gt_positions = atom37_gt_positions[..., ca_pos, :]
540
+ atom37_atom_exists_in_gt = atom37_atom_exists_in_gt[..., ca_pos: (ca_pos + 1)] # keep dim
541
+
542
+ score = lddt(
543
+ all_atom_pred_pos,
544
+ atom37_gt_positions,
545
+ atom37_atom_exists_in_gt,
546
+ cutoff=cutoff,
547
+ eps=eps
548
+ )
549
+
550
+ # TODO: Remove after initial pipeline testing
551
+ score = torch.nan_to_num(score, nan=torch.nanmean(score))
552
+ score[score < 0] = 0
553
+
554
+ score = score.detach()
555
+ bin_index = torch.floor(score * no_bins).long()
556
+ bin_index = torch.clamp(bin_index, max=(no_bins - 1))
557
+ lddt_ca_one_hot = torch.nn.functional.one_hot(
558
+ bin_index, num_classes=no_bins
559
+ )
560
+
561
+ errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
562
+ atom37_atom_exists_in_gt = atom37_atom_exists_in_gt.squeeze(-1)
563
+ loss = torch.sum(errors * atom37_atom_exists_in_gt, dim=-1) / (
564
+ eps + torch.sum(atom37_atom_exists_in_gt, dim=-1)
565
+ )
566
+
567
+ loss = loss * (
568
+ (resolution >= min_resolution) & (resolution <= max_resolution)
569
+ )
570
+
571
+ # Average over the batch dimension
572
+ loss = torch.mean(loss)
573
+
574
+ return loss
575
+
576
+
577
+ def distogram_loss(
578
+ logits,
579
+ gt_pseudo_beta_with_lig,
580
+ gt_pseudo_beta_with_lig_mask,
581
+ min_bin=2.3125,
582
+ max_bin=21.6875,
583
+ no_bins=64,
584
+ eps=1e-6,
585
+ **kwargs,
586
+ ):
587
+ boundaries = torch.linspace(
588
+ min_bin,
589
+ max_bin,
590
+ no_bins - 1,
591
+ device=logits.device,
592
+ )
593
+ boundaries = boundaries ** 2
594
+
595
+ dists = torch.sum(
596
+ (gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2,
597
+ dim=-1,
598
+ keepdims=True,
599
+ )
600
+
601
+ true_bins = torch.sum(dists > boundaries, dim=-1)
602
+ errors = softmax_cross_entropy(
603
+ logits,
604
+ torch.nn.functional.one_hot(true_bins, no_bins),
605
+ )
606
+
607
+ square_mask = gt_pseudo_beta_with_lig_mask[..., None] * gt_pseudo_beta_with_lig_mask[..., None, :]
608
+
609
+ # FP16-friendly sum. Equivalent to:
610
+ # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
611
+ # (eps + torch.sum(square_mask, dim=(-1, -2))))
612
+ denom = eps + torch.sum(square_mask, dim=(-1, -2))
613
+ mean = errors * square_mask
614
+ mean = torch.sum(mean, dim=-1)
615
+ mean = mean / denom[..., None]
616
+ mean = torch.sum(mean, dim=-1)
617
+
618
+ # Average over the batch dimensions
619
+ mean = torch.mean(mean)
620
+
621
+ return mean
622
+
623
+
624
+ def inter_contact_loss(
625
+ logits: torch.Tensor,
626
+ gt_inter_contacts: torch.Tensor,
627
+ inter_pair_mask: torch.Tensor,
628
+ pos_class_weight: float = 200.0,
629
+ contact_distance: float = 5.0,
630
+ **kwargs,
631
+ ):
632
+ logits = logits.squeeze(-1)
633
+ bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, gt_inter_contacts, reduction='none',
634
+ pos_weight=logits.new_tensor([pos_class_weight]))
635
+ masked_loss = bce_loss * inter_pair_mask
636
+ final_loss = masked_loss.sum() / inter_pair_mask.sum()
637
+
638
+ return final_loss
639
+
640
+
641
+ def affinity_loss(
642
+ logits,
643
+ affinity,
644
+ affinity_loss_factor,
645
+ min_bin=0,
646
+ max_bin=15,
647
+ no_bins=32,
648
+ **kwargs,
649
+ ):
650
+ boundaries = torch.linspace(
651
+ min_bin,
652
+ max_bin,
653
+ no_bins - 1,
654
+ device=logits.device,
655
+ )
656
+
657
+ true_bins = torch.sum(affinity > boundaries, dim=-1)
658
+ errors = softmax_cross_entropy(
659
+ logits,
660
+ torch.nn.functional.one_hot(true_bins, no_bins),
661
+ )
662
+
663
+ # print("errors dim", errors.shape, affinity_loss_factor.shape, errors)
664
+ after_factor = errors * affinity_loss_factor.squeeze()
665
+ if affinity_loss_factor.sum() > 0.1:
666
+ mean_val = after_factor.sum() / affinity_loss_factor.sum()
667
+ else:
668
+ # If no affinity in batch - get a very small loss. the factor also makes the loss small
669
+ mean_val = after_factor.sum() * 1e-3
670
+ # print("after factor", after_factor.shape, after_factor, affinity_loss_factor.sum(), mean_val)
671
+ return mean_val
672
+
673
+
674
+ def positions_inter_distogram_loss(
675
+ out,
676
+ aatype: torch.Tensor,
677
+ inter_pair_mask: torch.Tensor,
678
+ gt_pseudo_beta_with_lig: torch.Tensor,
679
+ max_dist=20.,
680
+ length_scale=10.,
681
+ eps: float = 1e-10,
682
+ **kwargs,
683
+ ):
684
+
685
+ predicted_atoms = pseudo_beta_fn(aatype, out["final_atom_positions"], None)
686
+ pred_dists = torch.sum(
687
+ (predicted_atoms[..., None, :] - predicted_atoms[..., None, :, :]) ** 2,
688
+ dim=-1,
689
+ keepdims=True,
690
+ )
691
+
692
+ gt_dists = torch.sum(
693
+ (gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2,
694
+ dim=-1,
695
+ keepdims=True,
696
+ )
697
+
698
+ pred_dists = pred_dists.clamp(max=max_dist ** 2)
699
+ gt_dists = gt_dists.clamp(max=max_dist ** 2)
700
+
701
+ dists_diff = torch.abs(pred_dists - gt_dists) / (length_scale ** 2)
702
+ dists_diff = dists_diff * inter_pair_mask.unsqueeze(-1)
703
+
704
+ dists_diff_sum_per_batch = torch.sum(torch.sqrt(eps + dists_diff), dim=(-1, -2, -3))
705
+ mask_size_per_batch = torch.sum(inter_pair_mask, dim=(-1, -2))
706
+ inter_loss = torch.mean(dists_diff_sum_per_batch / (eps + mask_size_per_batch))
707
+
708
+ return inter_loss
709
+
710
+
711
+ def positions_intra_ligand_distogram_loss(
712
+ out,
713
+ aatype: torch.Tensor,
714
+ ligand_mask: torch.Tensor,
715
+ gt_pseudo_beta_with_lig: torch.Tensor,
716
+ max_dist=20.,
717
+ length_scale=4., # similar to RosettaFoldAA
718
+ eps=1e-10,
719
+ **kwargs,
720
+ ):
721
+ intra_ligand_pair_mask = ligand_mask[..., None] * ligand_mask[..., None, :]
722
+ predicted_atoms = pseudo_beta_fn(aatype, out["final_atom_positions"], None)
723
+ pred_dists = torch.sum(
724
+ (predicted_atoms[..., None, :] - predicted_atoms[..., None, :, :]) ** 2,
725
+ dim=-1,
726
+ keepdims=True,
727
+ )
728
+
729
+ gt_dists = torch.sum(
730
+ (gt_pseudo_beta_with_lig[..., None, :] - gt_pseudo_beta_with_lig[..., None, :, :]) ** 2,
731
+ dim=-1,
732
+ keepdims=True,
733
+ )
734
+
735
+ pred_dists = torch.sqrt(eps + pred_dists.clamp(max=max_dist ** 2)) / length_scale
736
+ gt_dists = torch.sqrt(eps + gt_dists.clamp(max=max_dist ** 2)) / length_scale
737
+
738
+ # Apply L2 loss
739
+ dists_diff = (pred_dists - gt_dists) ** 2
740
+
741
+ dists_diff = dists_diff * intra_ligand_pair_mask.unsqueeze(-1)
742
+
743
+ dists_diff_sum_per_batch = torch.sum(dists_diff, dim=(-1, -2, -3))
744
+ mask_size_per_batch = torch.sum(intra_ligand_pair_mask, dim=(-1, -2))
745
+ intra_ligand_loss = torch.mean(dists_diff_sum_per_batch / (eps + mask_size_per_batch))
746
+
747
+ return intra_ligand_loss
748
+
749
+
750
+ def _calculate_bin_centers(boundaries: torch.Tensor):
751
+ step = boundaries[1] - boundaries[0]
752
+ bin_centers = boundaries + step / 2
753
+ bin_centers = torch.cat(
754
+ [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
755
+ )
756
+ return bin_centers
757
+
758
+
759
+ def _calculate_expected_aligned_error(
760
+ alignment_confidence_breaks: torch.Tensor,
761
+ aligned_distance_error_probs: torch.Tensor,
762
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
763
+ bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
764
+ return (
765
+ torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
766
+ bin_centers[-1],
767
+ )
768
+
769
+
770
+ def compute_predicted_aligned_error(
771
+ logits: torch.Tensor,
772
+ max_bin: int = 31,
773
+ no_bins: int = 64,
774
+ **kwargs,
775
+ ) -> Dict[str, torch.Tensor]:
776
+ """Computes aligned confidence metrics from logits.
777
+
778
+ Args:
779
+ logits: [*, num_res, num_res, num_bins] the logits output from
780
+ PredictedAlignedErrorHead.
781
+ max_bin: Maximum bin value
782
+ no_bins: Number of bins
783
+ Returns:
784
+ aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
785
+ aligned error probabilities over bins for each residue pair.
786
+ predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
787
+ error for each pair of residues.
788
+ max_predicted_aligned_error: [*] the maximum predicted error possible.
789
+ """
790
+ boundaries = torch.linspace(
791
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
792
+ )
793
+
794
+ aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
795
+ (
796
+ predicted_aligned_error,
797
+ max_predicted_aligned_error,
798
+ ) = _calculate_expected_aligned_error(
799
+ alignment_confidence_breaks=boundaries,
800
+ aligned_distance_error_probs=aligned_confidence_probs,
801
+ )
802
+
803
+ return {
804
+ "aligned_confidence_probs": aligned_confidence_probs,
805
+ "predicted_aligned_error": predicted_aligned_error,
806
+ "max_predicted_aligned_error": max_predicted_aligned_error,
807
+ }
808
+
809
+
810
+ def compute_tm(
811
+ logits: torch.Tensor,
812
+ residue_weights: Optional[torch.Tensor] = None,
813
+ asym_id: Optional[torch.Tensor] = None,
814
+ interface: bool = False,
815
+ max_bin: int = 31,
816
+ no_bins: int = 64,
817
+ eps: float = 1e-8,
818
+ **kwargs,
819
+ ) -> torch.Tensor:
820
+ if residue_weights is None:
821
+ residue_weights = logits.new_ones(logits.shape[-2])
822
+
823
+ boundaries = torch.linspace(
824
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
825
+ )
826
+
827
+ bin_centers = _calculate_bin_centers(boundaries)
828
+ clipped_n = max(torch.sum(residue_weights), 19)
829
+
830
+ d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
831
+
832
+ probs = torch.nn.functional.softmax(logits, dim=-1)
833
+
834
+ tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
835
+ predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
836
+
837
+ n = residue_weights.shape[-1]
838
+ pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
839
+ if interface and (asym_id is not None):
840
+ if len(asym_id.shape) > 1:
841
+ assert len(asym_id.shape) <= 2
842
+ batch_size = asym_id.shape[0]
843
+ pair_mask = residue_weights.new_ones((batch_size, n, n), dtype=torch.int32)
844
+ pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
845
+
846
+ predicted_tm_term *= pair_mask
847
+
848
+ pair_residue_weights = pair_mask * (
849
+ residue_weights[..., None, :] * residue_weights[..., :, None]
850
+ )
851
+ denom = eps + torch.sum(pair_residue_weights, dim=-1, keepdims=True)
852
+ normed_residue_mask = pair_residue_weights / denom
853
+ per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
854
+
855
+ weighted = per_alignment * residue_weights
856
+
857
+ argmax = (weighted == torch.max(weighted)).nonzero()[0]
858
+ return per_alignment[tuple(argmax)]
859
+
860
+
861
+ def compute_renamed_ground_truth(
862
+ batch: Dict[str, torch.Tensor],
863
+ atom14_pred_positions: torch.Tensor,
864
+ eps=1e-10,
865
+ ) -> Dict[str, torch.Tensor]:
866
+ """
867
+ Find optimal renaming of ground truth based on the predicted positions.
868
+
869
+ Alg. 26 "renameSymmetricGroundTruthAtoms"
870
+
871
+ This renamed ground truth is then used for all losses,
872
+ such that each loss moves the atoms in the same direction.
873
+
874
+ Args:
875
+ batch: Dictionary containing:
876
+ * atom14_gt_positions: Ground truth positions.
877
+ * atom14_alt_gt_positions: Ground truth positions with renaming swaps.
878
+ * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
879
+ renaming swaps.
880
+ * atom14_gt_exists: Mask for which atoms exist in ground truth.
881
+ * atom14_alt_gt_exists: Mask for which atoms exist in ground truth
882
+ after renaming.
883
+ * atom14_atom_exists: Mask for whether each atom is part of the given
884
+ amino acid type.
885
+ atom14_pred_positions: Array of atom positions in global frame with shape
886
+ Returns:
887
+ Dictionary containing:
888
+ alt_naming_is_better: Array with 1.0 where alternative swap is better.
889
+ renamed_atom14_gt_positions: Array of optimal ground truth positions
890
+ after renaming swaps are performed.
891
+ renamed_atom14_gt_exists: Mask after renaming swap is performed.
892
+ """
893
+
894
+ pred_dists = torch.sqrt(
895
+ eps
896
+ + torch.sum(
897
+ (
898
+ atom14_pred_positions[..., None, :, None, :]
899
+ - atom14_pred_positions[..., None, :, None, :, :]
900
+ )
901
+ ** 2,
902
+ dim=-1,
903
+ )
904
+ )
905
+
906
+ atom14_gt_positions = batch["atom14_gt_positions"]
907
+ gt_dists = torch.sqrt(
908
+ eps
909
+ + torch.sum(
910
+ (
911
+ atom14_gt_positions[..., None, :, None, :]
912
+ - atom14_gt_positions[..., None, :, None, :, :]
913
+ )
914
+ ** 2,
915
+ dim=-1,
916
+ )
917
+ )
918
+
919
+ atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
920
+ alt_gt_dists = torch.sqrt(
921
+ eps
922
+ + torch.sum(
923
+ (
924
+ atom14_alt_gt_positions[..., None, :, None, :]
925
+ - atom14_alt_gt_positions[..., None, :, None, :, :]
926
+ )
927
+ ** 2,
928
+ dim=-1,
929
+ )
930
+ )
931
+
932
+ lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
933
+ alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
934
+
935
+ atom14_gt_exists = batch["atom14_atom_exists_in_gt"]
936
+ atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
937
+ mask = (
938
+ atom14_gt_exists[..., None, :, None]
939
+ * atom14_atom_is_ambiguous[..., None, :, None]
940
+ * atom14_gt_exists[..., None, :, None, :]
941
+ * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
942
+ )
943
+
944
+ per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
945
+ alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
946
+
947
+ fp_type = atom14_pred_positions.dtype
948
+ alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
949
+
950
+ renamed_atom14_gt_positions = (
951
+ 1.0 - alt_naming_is_better[..., None, None]
952
+ ) * atom14_gt_positions + alt_naming_is_better[
953
+ ..., None, None
954
+ ] * atom14_alt_gt_positions
955
+
956
+ renamed_atom14_gt_mask = (
957
+ 1.0 - alt_naming_is_better[..., None]
958
+ ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
959
+ "atom14_alt_gt_exists"
960
+ ]
961
+
962
+ return {
963
+ "alt_naming_is_better": alt_naming_is_better,
964
+ "renamed_atom14_gt_positions": renamed_atom14_gt_positions,
965
+ "renamed_atom14_gt_exists": renamed_atom14_gt_mask,
966
+ }
967
+
968
+
969
+ def binding_site_loss(
970
+ logits: torch.Tensor,
971
+ binding_site_mask: torch.Tensor,
972
+ protein_mask: torch.Tensor,
973
+ pos_class_weight: float,
974
+ **kwargs,
975
+ ) -> torch.Tensor:
976
+ logits = logits.squeeze(-1)
977
+ bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, binding_site_mask, reduction='none',
978
+ pos_weight=logits.new_tensor([pos_class_weight]))
979
+ masked_loss = bce_loss * protein_mask
980
+ final_loss = masked_loss.sum() / protein_mask.sum()
981
+
982
+ return final_loss
983
+
984
+
985
+ def chain_center_of_mass_loss(
986
+ all_atom_pred_pos: torch.Tensor,
987
+ all_atom_positions: torch.Tensor,
988
+ all_atom_mask: torch.Tensor,
989
+ asym_id: torch.Tensor,
990
+ clamp_distance: float = -4.0,
991
+ weight: float = 0.05,
992
+ eps: float = 1e-10, **kwargs
993
+ ) -> torch.Tensor:
994
+ """
995
+ Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
996
+
997
+ Args:
998
+ all_atom_pred_pos:
999
+ [*, N_pts, 37, 3] All-atom predicted atom positions
1000
+ all_atom_positions:
1001
+ [*, N_pts, 37, 3] Ground truth all-atom positions
1002
+ all_atom_mask:
1003
+ [*, N_pts, 37] All-atom positions mask
1004
+ asym_id:
1005
+ [*, N_pts] Chain asym IDs
1006
+ clamp_distance:
1007
+ Cutoff above which distance errors are disregarded
1008
+ weight:
1009
+ Weight for loss
1010
+ eps:
1011
+ Small value used to regularize denominators
1012
+ Returns:
1013
+ [*] loss tensor
1014
+ """
1015
+ ca_pos = residue_constants.atom_order["CA"]
1016
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
1017
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
1018
+ all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
1019
+
1020
+ one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
1021
+ one_hot = one_hot * all_atom_mask
1022
+ chain_pos_mask = one_hot.transpose(-2, -1)
1023
+ chain_exists = torch.any(chain_pos_mask, dim=-1).to(dtype=all_atom_positions.dtype)
1024
+
1025
+ def get_chain_center_of_mass(pos):
1026
+ center_sum = (chain_pos_mask[..., None] * pos[..., None, :, :]).sum(dim=-2)
1027
+ centers = center_sum / (torch.sum(chain_pos_mask, dim=-1, keepdim=True) + eps)
1028
+ return Vec3Array.from_array(centers)
1029
+
1030
+ pred_centers = get_chain_center_of_mass(all_atom_pred_pos) # [B, NC, 3]
1031
+ true_centers = get_chain_center_of_mass(all_atom_positions) # [B, NC, 3]
1032
+
1033
+ pred_dists = euclidean_distance(pred_centers[..., None, :], pred_centers[..., :, None], epsilon=eps)
1034
+ true_dists = euclidean_distance(true_centers[..., None, :], true_centers[..., :, None], epsilon=eps)
1035
+ losses = torch.clamp((weight * (pred_dists - true_dists - clamp_distance)), max=0) ** 2
1036
+ loss_mask = chain_exists[..., :, None] * chain_exists[..., None, :]
1037
+
1038
+ loss = masked_mean(loss_mask, losses, dim=(-1, -2))
1039
+ return loss
1040
+
1041
+
1042
+ class AlphaFoldLoss(nn.Module):
1043
+ """Aggregation of the various losses described in the supplement"""
1044
+
1045
+ def __init__(self, config):
1046
+ super(AlphaFoldLoss, self).__init__()
1047
+ self.config = config
1048
+
1049
+ def loss(self, out, batch, _return_breakdown=False):
1050
+ """
1051
+ Rename previous forward() as loss()
1052
+ so that can be reused in the subclass
1053
+ """
1054
+ if "renamed_atom14_gt_positions" not in out.keys():
1055
+ batch.update(
1056
+ compute_renamed_ground_truth(
1057
+ batch,
1058
+ out["sm"]["positions"][-1],
1059
+ )
1060
+ )
1061
+
1062
+ loss_fns = {
1063
+ "distogram": lambda: distogram_loss(
1064
+ logits=out["distogram_logits"],
1065
+ **{**batch, **self.config.distogram},
1066
+ ),
1067
+ "positions_inter_distogram": lambda: positions_inter_distogram_loss(
1068
+ out,
1069
+ **{**batch, **self.config.positions_inter_distogram},
1070
+ ),
1071
+ "positions_intra_distogram": lambda: positions_intra_ligand_distogram_loss(
1072
+ out,
1073
+ **{**batch, **self.config.positions_intra_distogram},
1074
+ ),
1075
+
1076
+ "affinity1d": lambda: affinity_loss(
1077
+ logits=out["affinity_1d_logits"],
1078
+ **{**batch, **self.config.affinity1d},
1079
+ ),
1080
+ "affinity2d": lambda: affinity_loss(
1081
+ logits=out["affinity_2d_logits"],
1082
+ **{**batch, **self.config.affinity2d},
1083
+ ),
1084
+ "affinity_cls": lambda: affinity_loss(
1085
+ logits=out["affinity_cls_logits"],
1086
+ **{**batch, **self.config.affinity_cls},
1087
+ ),
1088
+ "binding_site": lambda: binding_site_loss(
1089
+ logits=out["binding_site_logits"],
1090
+ **{**batch, **self.config.binding_site},
1091
+ ),
1092
+ "inter_contact": lambda: inter_contact_loss(
1093
+ logits=out["inter_contact_logits"],
1094
+ **{**batch, **self.config.inter_contact},
1095
+ ),
1096
+ # backbone is based on frames so only works on protein
1097
+ "fape_backbone": lambda: fape_bb(
1098
+ out,
1099
+ batch,
1100
+ self.config.fape_backbone,
1101
+ ),
1102
+ "fape_sidechain": lambda: fape_sidechain(
1103
+ out,
1104
+ batch,
1105
+ self.config.fape_sidechain,
1106
+ ),
1107
+ "fape_interface": lambda: fape_interface(
1108
+ out,
1109
+ batch,
1110
+ self.config.fape_interface,
1111
+ ),
1112
+ "plddt_loss": lambda: lddt_loss(
1113
+ logits=out["lddt_logits"],
1114
+ all_atom_pred_pos=out["final_atom_positions"],
1115
+ **{**batch, **self.config.plddt_loss},
1116
+ ),
1117
+ "supervised_chi": lambda: supervised_chi_loss(
1118
+ out["sm"]["angles"],
1119
+ out["sm"]["unnormalized_angles"],
1120
+ **{**batch, **self.config.supervised_chi},
1121
+ ),
1122
+ }
1123
+
1124
+ if self.config.chain_center_of_mass.enabled:
1125
+ loss_fns["chain_center_of_mass"] = lambda: chain_center_of_mass_loss(
1126
+ all_atom_pred_pos=out["final_atom_positions"],
1127
+ **{**batch, **self.config.chain_center_of_mass},
1128
+ )
1129
+
1130
+ cum_loss = 0.
1131
+ losses = {}
1132
+ loss_time_took = {}
1133
+ for loss_name, loss_fn in loss_fns.items():
1134
+ start_time = time.time()
1135
+ weight = self.config[loss_name].weight
1136
+ loss = loss_fn()
1137
+ if torch.isnan(loss) or torch.isinf(loss):
1138
+ # for k,v in batch.items():
1139
+ # if torch.any(torch.isnan(v)) or torch.any(torch.isinf(v)):
1140
+ # logging.warning(f"{k}: is nan")
1141
+ # logging.warning(f"{loss_name}: {loss}")
1142
+ logging.warning(f"{loss_name} loss is NaN. Skipping...")
1143
+ loss = loss.new_tensor(0., requires_grad=True)
1144
+ # else:
1145
+ cum_loss = cum_loss + weight * loss
1146
+ losses[loss_name] = loss.detach().clone()
1147
+ loss_time_took[loss_name] = time.time() - start_time
1148
+ losses["unscaled_loss"] = cum_loss.detach().clone()
1149
+ # print("loss took: ", round(time.time() % 10000, 3),
1150
+ # sorted(loss_time_took.items(), key=lambda x: x[1], reverse=True))
1151
+
1152
+ # Scale the loss by the square root of the minimum of the crop size and
1153
+ # the (average) sequence length. See subsection 1.9.
1154
+ seq_len = torch.mean(batch["seq_length"].float())
1155
+ crop_len = batch["aatype"].shape[-1]
1156
+ cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
1157
+
1158
+ losses["loss"] = cum_loss.detach().clone()
1159
+
1160
+ if not _return_breakdown:
1161
+ return cum_loss
1162
+
1163
+ return cum_loss, losses
1164
+
1165
+ def forward(self, out, batch, _return_breakdown=False):
1166
+ if not _return_breakdown:
1167
+ cum_loss = self.loss(out, batch, _return_breakdown)
1168
+ return cum_loss
1169
+ else:
1170
+ cum_loss, losses = self.loss(out, batch, _return_breakdown)
1171
+ return cum_loss, losses
dockformer/utils/lr_schedulers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
5
+ """ Implements the learning rate schedule defined in the AlphaFold 2
6
+ supplement. A linear warmup is followed by a plateau at the maximum
7
+ learning rate and then exponential decay.
8
+
9
+ Note that the initial learning rate of the optimizer in question is
10
+ ignored; use this class' base_lr parameter to specify the starting
11
+ point of the warmup.
12
+ """
13
+ def __init__(self,
14
+ optimizer,
15
+ last_epoch: int = -1,
16
+ verbose: bool = False,
17
+ base_lr: float = 0.,
18
+ max_lr: float = 0.001,
19
+ warmup_no_steps: int = 1000,
20
+ start_decay_after_n_steps: int = 50000,
21
+ decay_every_n_steps: int = 50000,
22
+ decay_factor: float = 0.95,
23
+ ):
24
+ step_counts = {
25
+ "warmup_no_steps": warmup_no_steps,
26
+ "start_decay_after_n_steps": start_decay_after_n_steps,
27
+ }
28
+
29
+ for k,v in step_counts.items():
30
+ if(v < 0):
31
+ raise ValueError(f"{k} must be nonnegative")
32
+
33
+ if(warmup_no_steps > start_decay_after_n_steps):
34
+ raise ValueError(
35
+ "warmup_no_steps must not exceed start_decay_after_n_steps"
36
+ )
37
+
38
+ self.optimizer = optimizer
39
+ self.last_epoch = last_epoch
40
+ self.verbose = verbose
41
+ self.base_lr = base_lr
42
+ self.max_lr = max_lr
43
+ self.warmup_no_steps = warmup_no_steps
44
+ self.start_decay_after_n_steps = start_decay_after_n_steps
45
+ self.decay_every_n_steps = decay_every_n_steps
46
+ self.decay_factor = decay_factor
47
+
48
+ super(AlphaFoldLRScheduler, self).__init__(
49
+ optimizer,
50
+ last_epoch=last_epoch,
51
+ verbose=verbose,
52
+ )
53
+
54
+ def state_dict(self):
55
+ state_dict = {
56
+ k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
57
+ }
58
+
59
+ return state_dict
60
+
61
+ def load_state_dict(self, state_dict):
62
+ self.__dict__.update(state_dict)
63
+
64
+ def get_lr(self):
65
+ if(not self._get_lr_called_within_step):
66
+ raise RuntimeError(
67
+ "To get the last learning rate computed by the scheduler, use "
68
+ "get_last_lr()"
69
+ )
70
+
71
+ step_no = self.last_epoch
72
+
73
+ if(step_no <= self.warmup_no_steps):
74
+ lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
75
+ elif(step_no > self.start_decay_after_n_steps):
76
+ steps_since_decay = step_no - self.start_decay_after_n_steps
77
+ exp = (steps_since_decay // self.decay_every_n_steps) + 1
78
+ lr = self.max_lr * (self.decay_factor ** exp)
79
+ else: # plateau
80
+ lr = self.max_lr
81
+
82
+ return [lr for group in self.optimizer.param_groups]
dockformer/utils/precision_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 AlQuraishi Laboratory
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+
16
+ import torch
17
+
18
+ def is_fp16_enabled():
19
+ # Autocast world
20
+ fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
21
+ fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
22
+
23
+ return fp16_enabled