jgpeters commited on
Commit
1ab0116
1 Parent(s): bfdf7ef

Upload 6 files

Browse files
Files changed (6) hide show
  1. metl/__init__.py +2 -0
  2. metl/encode.py +58 -0
  3. metl/main.py +139 -0
  4. metl/models.py +1064 -0
  5. metl/relative_attention.py +586 -0
  6. metl/structure.py +184 -0
metl/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .main import *
2
+ __version__ = "0.1"
metl/encode.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Encodes data in different formats """
2
+ from enum import Enum, auto
3
+
4
+ import numpy as np
5
+
6
+
7
+ class Encoding(Enum):
8
+ INT_SEQS = auto()
9
+ ONE_HOT = auto()
10
+
11
+
12
+ class DataEncoder:
13
+ chars = ["*", "A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
14
+ num_chars = len(chars)
15
+ mapping = {c: i for i, c in enumerate(chars)}
16
+
17
+ def __init__(self, encoding: Encoding = Encoding.INT_SEQS):
18
+ self.encoding = encoding
19
+
20
+ def _encode_from_int_seqs(self, seq_ints):
21
+ if self.encoding == Encoding.INT_SEQS:
22
+ return seq_ints
23
+ elif self.encoding == Encoding.ONE_HOT:
24
+ one_hot = np.eye(self.num_chars)[seq_ints]
25
+ return one_hot.astype(np.float32)
26
+
27
+ def encode_sequences(self, char_seqs):
28
+ seq_ints = []
29
+ for char_seq in char_seqs:
30
+ int_seq = [self.mapping[c] for c in char_seq]
31
+ seq_ints.append(int_seq)
32
+ seq_ints = np.array(seq_ints).astype(int)
33
+ return self._encode_from_int_seqs(seq_ints)
34
+
35
+ def encode_variants(self, wt, variants):
36
+ # convert wild type seq to integer encoding
37
+ wt_int = np.zeros(len(wt), dtype=np.uint8)
38
+ for i, c in enumerate(wt):
39
+ wt_int[i] = self.mapping[c]
40
+
41
+ # tile the wild-type seq
42
+ seq_ints = np.tile(wt_int, (len(variants), 1))
43
+
44
+ for i, variant in enumerate(variants):
45
+ # special handling if we want to encode the wild-type seq (it's already correct!)
46
+ if variant == "_wt":
47
+ continue
48
+
49
+ # variants are a list of mutations [mutation1, mutation2, ....]
50
+ variant = variant.split(",")
51
+ for mutation in variant:
52
+ # mutations are in the form <original char><position><replacement char>
53
+ position = int(mutation[1:-1])
54
+ replacement = self.mapping[mutation[-1]]
55
+ seq_ints[i, position] = replacement
56
+
57
+ seq_ints = seq_ints.astype(int)
58
+ return self._encode_from_int_seqs(seq_ints)
metl/main.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.hub
3
+
4
+ import metl.models as models
5
+ from metl.encode import DataEncoder, Encoding
6
+
7
+ UUID_URL_MAP = {
8
+ # global source models
9
+ "D72M9aEp": "https://zenodo.org/records/11051645/files/METL-G-20M-1D-D72M9aEp.pt?download=1",
10
+ "Nr9zCKpR": "https://zenodo.org/records/11051645/files/METL-G-20M-3D-Nr9zCKpR.pt?download=1",
11
+ "auKdzzwX": "https://zenodo.org/records/11051645/files/METL-G-50M-1D-auKdzzwX.pt?download=1",
12
+ "6PSAzdfv": "https://zenodo.org/records/11051645/files/METL-G-50M-3D-6PSAzdfv.pt?download=1",
13
+
14
+ # local source models
15
+ "8gMPQJy4": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GFP-8gMPQJy4.pt?download=1",
16
+ "Hr4GNHws": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GFP-Hr4GNHws.pt?download=1",
17
+ "8iFoiYw2": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-DLG4_2022-8iFoiYw2.pt?download=1",
18
+ "kt5DdWTa": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-DLG4_2022-kt5DdWTa.pt?download=1",
19
+ "DMfkjVzT": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GB1-DMfkjVzT.pt?download=1",
20
+ "epegcFiH": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GB1-epegcFiH.pt?download=1",
21
+ "kS3rUS7h": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GRB2-kS3rUS7h.pt?download=1",
22
+ "X7w83g6S": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GRB2-X7w83g6S.pt?download=1",
23
+ "UKebCQGz": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Pab1-UKebCQGz.pt?download=1",
24
+ "2rr8V4th": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Pab1-2rr8V4th.pt?download=1",
25
+ "PREhfC22": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-TEM-1-PREhfC22.pt?download=1",
26
+ "9ASvszux": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-TEM-1-9ASvszux.pt?download=1",
27
+ "HscFFkAb": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Ube4b-HscFFkAb.pt?download=1",
28
+ "H48oiNZN": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Ube4b-H48oiNZN.pt?download=1",
29
+
30
+ # metl bind source models
31
+ "K6mw24Rg": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-STANDARD-K6mw24Rg.pt?download=1",
32
+ "Bo5wn2SG": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-BINDING-Bo5wn2SG.pt?download=1",
33
+
34
+ # finetuned models from GFP design experiment
35
+ "YoQkzoLD": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-1D-GFP-YoQkzoLD.pt?download=1",
36
+ "PEkeRuxb": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-3D-GFP-PEkeRuxb.pt?download=1",
37
+
38
+ }
39
+
40
+ IDENT_UUID_MAP = {
41
+ # the keys should be all lowercase
42
+ "metl-g-20m-1d": "D72M9aEp",
43
+ "metl-g-20m-3d": "Nr9zCKpR",
44
+ "metl-g-50m-1d": "auKdzzwX",
45
+ "metl-g-50m-3d": "6PSAzdfv",
46
+
47
+ # GFP local source models
48
+ "metl-l-2m-1d-gfp": "8gMPQJy4",
49
+ "metl-l-2m-3d-gfp": "Hr4GNHws",
50
+
51
+ # DLG4 local source models
52
+ "metl-l-2m-1d-dlg4": "8iFoiYw2",
53
+ "metl-l-2m-3d-dlg4": "kt5DdWTa",
54
+
55
+ # GB1 local source models
56
+ "metl-l-2m-1d-gb1": "DMfkjVzT",
57
+ "metl-l-2m-3d-gb1": "epegcFiH",
58
+
59
+ # GRB2 local source models
60
+ "metl-l-2m-1d-grb2": "kS3rUS7h",
61
+ "metl-l-2m-3d-grb2": "X7w83g6S",
62
+
63
+ # Pab1 local source models
64
+ "metl-l-2m-1d-pab1": "UKebCQGz",
65
+ "metl-l-2m-3d-pab1": "2rr8V4th",
66
+
67
+ # TEM-1 local source models
68
+ "metl-l-2m-1d-tem-1": "PREhfC22",
69
+ "metl-l-2m-3d-tem-1": "9ASvszux",
70
+
71
+ # Ube4b local source models
72
+ "metl-l-2m-1d-ube4b": "HscFFkAb",
73
+ "metl-l-2m-3d-ube4b": "H48oiNZN",
74
+
75
+ # METL-Bind for GB1
76
+ "metl-bind-2m-3d-gb1-standard": "K6mw24Rg",
77
+ "metl-bind-2m-3d-gb1-binding": "Bo5wn2SG",
78
+
79
+ # GFP design models, giving them an ident
80
+ "metl-l-2m-1d-gfp-ft-design": "YoQkzoLD",
81
+ "metl-l-2m-3d-gfp-ft-design": "PEkeRuxb",
82
+
83
+ }
84
+
85
+
86
+ def download_checkpoint(uuid):
87
+ ckpt = torch.hub.load_state_dict_from_url(UUID_URL_MAP[uuid],
88
+ map_location="cpu", file_name=f"{uuid}.pt")
89
+ state_dict = ckpt["state_dict"]
90
+ hyper_parameters = ckpt["hyper_parameters"]
91
+
92
+ return state_dict, hyper_parameters
93
+
94
+
95
+ def _get_data_encoding(hparams):
96
+ if "encoding" in hparams and hparams["encoding"] == "int_seqs":
97
+ encoding = Encoding.INT_SEQS
98
+ elif "encoding" in hparams and hparams["encoding"] == "one_hot":
99
+ encoding = Encoding.ONE_HOT
100
+ elif (("encoding" in hparams and hparams["encoding"] == "auto") or "encoding" not in hparams) and \
101
+ hparams["model_name"] in ["transformer_encoder"]:
102
+ encoding = Encoding.INT_SEQS
103
+ else:
104
+ raise ValueError("Detected unsupported encoding in hyperparameters")
105
+
106
+ return encoding
107
+
108
+
109
+ def load_model_and_data_encoder(state_dict, hparams):
110
+ model = models.Model[hparams["model_name"]].cls(**hparams)
111
+ model.load_state_dict(state_dict)
112
+
113
+ data_encoder = DataEncoder(_get_data_encoding(hparams))
114
+
115
+ return model, data_encoder
116
+
117
+
118
+ def get_from_uuid(uuid):
119
+ if uuid in UUID_URL_MAP:
120
+ state_dict, hparams = download_checkpoint(uuid)
121
+ return load_model_and_data_encoder(state_dict, hparams)
122
+ else:
123
+ raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP")
124
+
125
+
126
+ def get_from_ident(ident):
127
+ ident = ident.lower()
128
+ if ident in IDENT_UUID_MAP:
129
+ state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident])
130
+ return load_model_and_data_encoder(state_dict, hparams)
131
+ else:
132
+ raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP")
133
+
134
+
135
+ def get_from_checkpoint(ckpt_fn):
136
+ ckpt = torch.load(ckpt_fn, map_location="cpu")
137
+ state_dict = ckpt["state_dict"]
138
+ hyper_parameters = ckpt["hyper_parameters"]
139
+ return load_model_and_data_encoder(state_dict, hyper_parameters)
metl/models.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ from argparse import ArgumentParser
4
+ import enum
5
+ from os.path import isfile
6
+ from typing import List, Tuple, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch import Tensor
12
+
13
+ import metl.relative_attention as ra
14
+
15
+
16
+ def reset_parameters_helper(m: nn.Module):
17
+ """ helper function for resetting model parameters, meant to be used with model.apply() """
18
+
19
+ # the PyTorch MultiHeadAttention has a private function _reset_parameters()
20
+ # other layers have a public reset_parameters()... go figure
21
+ reset_parameters = getattr(m, "reset_parameters", None)
22
+ reset_parameters_private = getattr(m, "_reset_parameters", None)
23
+
24
+ if callable(reset_parameters) and callable(reset_parameters_private):
25
+ raise RuntimeError("Module has both public and private methods for resetting parameters. "
26
+ "This is unexpected... probably should just call the public one.")
27
+
28
+ if callable(reset_parameters):
29
+ m.reset_parameters()
30
+
31
+ if callable(reset_parameters_private):
32
+ m._reset_parameters()
33
+
34
+
35
+ class SequentialWithArgs(nn.Sequential):
36
+ def forward(self, x, **kwargs):
37
+ for module in self:
38
+ if isinstance(module, ra.RelativeTransformerEncoder) or isinstance(module, SequentialWithArgs):
39
+ # for relative transformer encoders, pass in kwargs (pdb_fn)
40
+ x = module(x, **kwargs)
41
+ else:
42
+ # for all modules, don't pass in kwargs
43
+ x = module(x)
44
+ return x
45
+
46
+
47
+ class PositionalEncoding(nn.Module):
48
+ # originally from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
49
+ # they have since updated their implementation, but it is functionally equivalent
50
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
51
+ super(PositionalEncoding, self).__init__()
52
+ self.dropout = nn.Dropout(p=dropout)
53
+
54
+ pe = torch.zeros(max_len, d_model)
55
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
56
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
57
+ pe[:, 0::2] = torch.sin(position * div_term)
58
+ pe[:, 1::2] = torch.cos(position * div_term)
59
+ # note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim]
60
+ # however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first)
61
+ # fixed by changing pe = pe.unsqueeze(0).transpose(0, 1) to pe = pe.unsqueeze(0)
62
+ # also down below, changing our indexing into the position encoding to reflect new dimensions
63
+ # pe = pe.unsqueeze(0).transpose(0, 1)
64
+ pe = pe.unsqueeze(0)
65
+ self.register_buffer('pe', pe)
66
+
67
+ def forward(self, x, **kwargs):
68
+ # note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim]
69
+ # however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first)
70
+ # fixed by changing x = x + self.pe[:x.size(0)] to x = x + self.pe[:, :x.size(1), :]
71
+ # x = x + self.pe[:x.size(0), :]
72
+ x = x + self.pe[:, :x.size(1), :]
73
+ return self.dropout(x)
74
+
75
+
76
+ class ScaledEmbedding(nn.Module):
77
+ # https://pytorch.org/tutorials/beginner/translation_transformer.html
78
+ # a helper function for embedding that scales by sqrt(d_model) in the forward()
79
+ # makes it, so we don't have to do the scaling in the main AttnModel forward()
80
+
81
+ # todo: be aware of embedding scaling factor
82
+ # regarding the scaling factor, it's unclear exactly what the purpose is and whether it is needed
83
+ # there are several theories on why it is used, and it shows up in all the transformer reference implementations
84
+ # https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod
85
+ # 1. Has something to do with weight sharing between the embedding and the decoder output
86
+ # 2. Scales up the embeddings so the signal doesn't get overwhelmed when adding the absolute positional encoding
87
+ # 3. It cancels out with the scaling factor in scaled dot product attention, and helps make the model robust
88
+ # to the choice of embedding_len
89
+ # 4. It's not actually needed
90
+
91
+ # Regarding #1, not really sure about this. In section 3.4 of attention is all you need,
92
+ # that's where they state they multiply the embedding weights by sqrt(d_model), and the context is that they
93
+ # are sharing the same weight matrix between the two embedding layers and the pre-softmax linear transformation.
94
+ # there may be a reason that we want those weights scaled differently for the embedding layers vs. the linear
95
+ # transformation. It might have something to do with the scale at which embedding weights are initialized
96
+ # is more appropriate for the decoder linear transform vs how they are used in the attention function. Might have
97
+ # something to do with computing the correct next-token probabilities. Overall, I'm really not sure about this,
98
+ # but we aren't using a decoder anyway. So if this is the reason, then we don't need to perform the multiply.
99
+
100
+ # Regarding #2, it seems like in one implementation of transformers (fairseq), the sinusoidal positional encoding
101
+ # has a range of (-1.0, 1.0), but the word embedding are initialized with mean 0 and s.d embedding_dim ** -0.5,
102
+ # which for embedding_dim=512, is a range closer to (-0.10, 0.10). Thus, the positional embedding would overwhelm
103
+ # the word embeddings when they are added together. The scaling factor increases the signal of the word embeddings.
104
+ # for embedding_dim=512, it scales word embeddings by 22, increasing range of the word embeddings to (-2.2, 2.2).
105
+ # link to fairseq implementation, search for nn.init to see them do the initialization
106
+ # https://fairseq.readthedocs.io/en/v0.7.1/_modules/fairseq/models/transformer.html
107
+ #
108
+ # For PyTorch, PyTorch initializes nn.Embedding with a standard normal distribution mean 0, variance 1: N(0,1).
109
+ # this puts the range for the word embeddings around (-3, 3). the pytorch implementation for positional encoding
110
+ # also has a range of (-1.0, 1.0). So already, these are much closer in scale, and it doesn't seem like we need
111
+ # to increase the scale of the word embeddings. However, PyTorch example still multiply by the scaling factor
112
+ # unclear whether this is just a carryover that is not actually needed, or if there is a different reason
113
+ #
114
+ # EDIT! I just realized that even though nn.Embedding defaults to a range of around (-3, 3), the PyTorch
115
+ # transformer example actually re-initializes them using a uniform distribution in the range of (-0.1, 0.1)
116
+ # that makes it very similar to the fairseq implementation, so the scaling factor that PyTorch uses actually would
117
+ # bring the word embedding and positional encodings much closer in scale. So this could be the reason why pytorch
118
+ # does it
119
+
120
+ # Regarding #3, I don't think so. Firstly, does it actually cancel there? Secondly, the purpose of the scaling
121
+ # factor in scaled dot product attention, according to attention is all you need, is to counteract dot products
122
+ # that are very high in magnitude due to choice of large mbedding length (aka d_k). The problem with high magnitude
123
+ # dot products is that potentially, the softmax is pushed into regions where it has extremely small gradients,
124
+ # making learning difficult. If the scaling factor in the embedding was meant to counteract the scaling factor in
125
+ # scaled dot product attention, then what would be the point of doing all that?
126
+
127
+ # Regarding #4, I don't think the scaling will have any effects in practice, it's probably not needed
128
+
129
+ # Overall, I think #2 is the most likely reason why this scaling is performed. In theory, I think
130
+ # even if the scaling wasn't performed, the network might learn to up-scale the word embedding weights to increase
131
+ # word embedding signal vs. the position signal on its own. Another question I have is why not just initialize
132
+ # the embedding weights to have higher initial values? Why put it in the range (-0.1, 0.1)?
133
+ #
134
+ # The fact that most implementations have this scaling concerns me, makes me think I might be missing something.
135
+ # For our purposes, we can train a couple models to see if scaling has any positive or negative effect.
136
+ # Still need to think about potential effects of this scaling on relative position embeddings.
137
+
138
+ def __init__(self, num_embeddings: int, embedding_dim: int, scale: bool):
139
+ super(ScaledEmbedding, self).__init__()
140
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
141
+ self.emb_size = embedding_dim
142
+ self.embed_scale = math.sqrt(self.emb_size)
143
+
144
+ self.scale = scale
145
+
146
+ self.init_weights()
147
+
148
+ def init_weights(self):
149
+ # todo: not sure why PyTorch example initializes weights like this
150
+ # might have something to do with word embedding scaling factor (see above)
151
+ # could also just try the default weight initialization for nn.Embedding()
152
+ init_range = 0.1
153
+ self.embedding.weight.data.uniform_(-init_range, init_range)
154
+
155
+ def forward(self, tokens: Tensor, **kwargs):
156
+ if self.scale:
157
+ return self.embedding(tokens.long()) * self.embed_scale
158
+ else:
159
+ return self.embedding(tokens.long())
160
+
161
+
162
+ class FCBlock(nn.Module):
163
+ """ a fully connected block with options for batchnorm and dropout
164
+ can extend in the future with option for different activation, etc """
165
+
166
+ def __init__(self,
167
+ in_features: int,
168
+ num_hidden_nodes: int = 64,
169
+ use_batchnorm: bool = False,
170
+ use_layernorm: bool = False,
171
+ norm_before_activation: bool = False,
172
+ use_dropout: bool = False,
173
+ dropout_rate: float = 0.2,
174
+ activation: str = "relu"):
175
+
176
+ super().__init__()
177
+
178
+ if use_batchnorm and use_layernorm:
179
+ raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
180
+
181
+ self.use_batchnorm = use_batchnorm
182
+ self.use_dropout = use_dropout
183
+ self.use_layernorm = use_layernorm
184
+ self.norm_before_activation = norm_before_activation
185
+
186
+ self.fc = nn.Linear(in_features=in_features, out_features=num_hidden_nodes)
187
+
188
+ self.activation = get_activation_fn(activation, functional=False)
189
+
190
+ if use_batchnorm:
191
+ self.norm = nn.BatchNorm1d(num_hidden_nodes)
192
+
193
+ if use_layernorm:
194
+ self.norm = nn.LayerNorm(num_hidden_nodes)
195
+
196
+ if use_dropout:
197
+ self.dropout = nn.Dropout(p=dropout_rate)
198
+
199
+ def forward(self, x, **kwargs):
200
+ x = self.fc(x)
201
+
202
+ # norm can be before or after activation, using flag
203
+ if (self.use_batchnorm or self.use_layernorm) and self.norm_before_activation:
204
+ x = self.norm(x)
205
+
206
+ x = self.activation(x)
207
+
208
+ # batchnorm being applied after activation, there is some discussion on this online
209
+ if (self.use_batchnorm or self.use_layernorm) and not self.norm_before_activation:
210
+ x = self.norm(x)
211
+
212
+ # dropout being applied last
213
+ if self.use_dropout:
214
+ x = self.dropout(x)
215
+
216
+ return x
217
+
218
+
219
+ class TaskSpecificPredictionLayers(nn.Module):
220
+ """ Constructs num_tasks [dense(num_hidden_nodes)+relu+dense(1)] layers, each independently transforming input
221
+ into a single output node. All num_tasks outputs are then concatenated into a single tensor. """
222
+
223
+ # todo: the independent layers are run in sequence rather than in parallel, causing a slowdown that
224
+ # scales with the number of tasks. might be able to run in parallel by hacking convolution operation
225
+ # https://stackoverflow.com/questions/58374980/run-multiple-models-of-an-ensemble-in-parallel-with-pytorch
226
+ # https://github.com/pytorch/pytorch/issues/54147
227
+ # https://github.com/pytorch/pytorch/issues/36459
228
+
229
+ def __init__(self,
230
+ num_tasks: int,
231
+ in_features: int,
232
+ num_hidden_nodes: int = 64,
233
+ use_batchnorm: bool = False,
234
+ use_dropout: bool = False,
235
+ dropout_rate: float = 0.2,
236
+ activation: str = "relu"):
237
+
238
+ super().__init__()
239
+
240
+ # each task-specific layer outputs a single node,
241
+ # which can be combined with torch.cat into prediction vector
242
+ self.task_specific_pred_layers = nn.ModuleList()
243
+ for i in range(num_tasks):
244
+ layers = [FCBlock(in_features=in_features,
245
+ num_hidden_nodes=num_hidden_nodes,
246
+ use_batchnorm=use_batchnorm,
247
+ use_dropout=use_dropout,
248
+ dropout_rate=dropout_rate,
249
+ activation=activation),
250
+ nn.Linear(in_features=num_hidden_nodes, out_features=1)]
251
+ self.task_specific_pred_layers.append(nn.Sequential(*layers))
252
+
253
+ def forward(self, x, **kwargs):
254
+ # run each task-specific layer and concatenate outputs into a single output vector
255
+ task_specific_outputs = []
256
+ for layer in self.task_specific_pred_layers:
257
+ task_specific_outputs.append(layer(x))
258
+
259
+ output = torch.cat(task_specific_outputs, dim=1)
260
+ return output
261
+
262
+
263
+ class GlobalAveragePooling(nn.Module):
264
+ """ helper class for global average pooling """
265
+
266
+ def __init__(self, dim=1):
267
+ super().__init__()
268
+ # our data is in [batch_size, sequence_length, embedding_length]
269
+ # with global pooling, we want to pool over the sequence dimension (dim=1)
270
+ self.dim = dim
271
+
272
+ def forward(self, x, **kwargs):
273
+ return torch.mean(x, dim=self.dim)
274
+
275
+
276
+ class CLSPooling(nn.Module):
277
+ """ helper class for CLS token extraction """
278
+
279
+ def __init__(self, cls_position=0):
280
+ super().__init__()
281
+
282
+ # the position of the CLS token in the sequence dimension
283
+ # currently, the CLS token is in the first position, but may move it to the last position
284
+ self.cls_position = cls_position
285
+
286
+ def forward(self, x, **kwargs):
287
+ # assumes input is in [batch_size, sequence_len, embedding_len]
288
+ # thus sequence dimension is dimension 1
289
+ return x[:, self.cls_position, :]
290
+
291
+
292
+ class TransformerEncoderWrapper(nn.TransformerEncoder):
293
+ """ wrapper around PyTorch's TransformerEncoder that re-initializes layer parameters,
294
+ so each transformer encoder layer has a different initialization """
295
+
296
+ # todo: PyTorch is changing its transformer API... check up on and see if there is a better way
297
+ def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True):
298
+ super().__init__(encoder_layer, num_layers, norm)
299
+ if reset_params:
300
+ self.apply(reset_parameters_helper)
301
+
302
+
303
+ class AttnModel(nn.Module):
304
+ # https://pytorch.org/tutorials/beginner/transformer_tutorial.html
305
+
306
+ @staticmethod
307
+ def add_model_specific_args(parent_parser):
308
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
309
+
310
+ parser.add_argument('--pos_encoding', type=str, default="absolute",
311
+ choices=["none", "absolute", "relative", "relative_3D"],
312
+ help="what type of positional encoding to use")
313
+ parser.add_argument('--pos_encoding_dropout', type=float, default=0.1,
314
+ help="out much dropout to use in positional encoding, for pos_encoding==absolute")
315
+ parser.add_argument('--clipping_threshold', type=int, default=3,
316
+ help="clipping threshold for relative position embedding, for relative and relative_3D")
317
+ parser.add_argument('--contact_threshold', type=int, default=7,
318
+ help="threshold, in angstroms, for contact map, for relative_3D")
319
+ parser.add_argument('--embedding_len', type=int, default=128)
320
+ parser.add_argument('--num_heads', type=int, default=2)
321
+ parser.add_argument('--num_hidden', type=int, default=64)
322
+ parser.add_argument('--num_enc_layers', type=int, default=2)
323
+ parser.add_argument('--enc_layer_dropout', type=float, default=0.1)
324
+ parser.add_argument('--use_final_encoder_norm', action="store_true", default=False)
325
+
326
+ parser.add_argument('--global_average_pooling', action="store_true", default=False)
327
+ parser.add_argument('--cls_pooling', action="store_true", default=False)
328
+
329
+ parser.add_argument('--use_task_specific_layers', action="store_true", default=False,
330
+ help="exclusive with use_final_hidden_layer; takes priority over use_final_hidden_layer"
331
+ " if both flags are set")
332
+ parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
333
+ parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
334
+ parser.add_argument('--final_hidden_size', type=int, default=64)
335
+ parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
336
+ parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
337
+ parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
338
+ parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
339
+
340
+ parser.add_argument('--activation', type=str, default="relu",
341
+ help="activation function used for all activations in the network")
342
+ return parser
343
+
344
+ def __init__(self,
345
+ # data args
346
+ num_tasks: int,
347
+ aa_seq_len: int,
348
+ num_tokens: int,
349
+ # transformer encoder model args
350
+ pos_encoding: str = "absolute",
351
+ pos_encoding_dropout: float = 0.1,
352
+ clipping_threshold: int = 3,
353
+ contact_threshold: int = 7,
354
+ pdb_fns: List[str] = None,
355
+ embedding_len: int = 64,
356
+ num_heads: int = 2,
357
+ num_hidden: int = 64,
358
+ num_enc_layers: int = 2,
359
+ enc_layer_dropout: float = 0.1,
360
+ use_final_encoder_norm: bool = False,
361
+ # pooling to fixed-length representation
362
+ global_average_pooling: bool = True,
363
+ cls_pooling: bool = False,
364
+ # prediction layers
365
+ use_task_specific_layers: bool = False,
366
+ task_specific_hidden_nodes: int = 64,
367
+ use_final_hidden_layer: bool = False,
368
+ final_hidden_size: int = 64,
369
+ use_final_hidden_layer_norm: bool = False,
370
+ final_hidden_layer_norm_before_activation: bool = False,
371
+ use_final_hidden_layer_dropout: bool = False,
372
+ final_hidden_layer_dropout_rate: float = 0.2,
373
+ # activation function
374
+ activation: str = "relu",
375
+ *args, **kwargs):
376
+
377
+ super().__init__()
378
+
379
+ # store embedding length for use in the forward function
380
+ self.embedding_len = embedding_len
381
+ self.aa_seq_len = aa_seq_len
382
+
383
+ # build up layers
384
+ layers = collections.OrderedDict()
385
+
386
+ # amino acid embedding
387
+ layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=True)
388
+
389
+ # absolute positional encoding
390
+ if pos_encoding == "absolute":
391
+ layers["pos_encoder"] = PositionalEncoding(embedding_len, dropout=pos_encoding_dropout, max_len=512)
392
+
393
+ # transformer encoder layer for none or absolute positional encoding
394
+ if pos_encoding in ["none", "absolute"]:
395
+ encoder_layer = torch.nn.TransformerEncoderLayer(d_model=embedding_len,
396
+ nhead=num_heads,
397
+ dim_feedforward=num_hidden,
398
+ dropout=enc_layer_dropout,
399
+ activation=get_activation_fn(activation),
400
+ norm_first=True,
401
+ batch_first=True)
402
+
403
+ # layer norm that is used after the transformer encoder layers
404
+ # if the norm_first is False, this is *redundant* and not needed
405
+ # but if norm_first is True, this can be used to normalize outputs from
406
+ # the transformer encoder before inputting to the final fully connected layer
407
+ encoder_norm = None
408
+ if use_final_encoder_norm:
409
+ encoder_norm = nn.LayerNorm(embedding_len)
410
+
411
+ layers["tr_encoder"] = TransformerEncoderWrapper(encoder_layer=encoder_layer,
412
+ num_layers=num_enc_layers,
413
+ norm=encoder_norm)
414
+
415
+ # transformer encoder layer for relative position encoding
416
+ elif pos_encoding in ["relative", "relative_3D"]:
417
+ relative_encoder_layer = ra.RelativeTransformerEncoderLayer(d_model=embedding_len,
418
+ nhead=num_heads,
419
+ pos_encoding=pos_encoding,
420
+ clipping_threshold=clipping_threshold,
421
+ contact_threshold=contact_threshold,
422
+ pdb_fns=pdb_fns,
423
+ dim_feedforward=num_hidden,
424
+ dropout=enc_layer_dropout,
425
+ activation=get_activation_fn(activation),
426
+ norm_first=True)
427
+
428
+ encoder_norm = None
429
+ if use_final_encoder_norm:
430
+ encoder_norm = nn.LayerNorm(embedding_len)
431
+
432
+ layers["tr_encoder"] = ra.RelativeTransformerEncoder(encoder_layer=relative_encoder_layer,
433
+ num_layers=num_enc_layers,
434
+ norm=encoder_norm)
435
+
436
+ # GLOBAL AVERAGE POOLING OR CLS TOKEN
437
+ # set up the layers and output shapes (i.e. input shapes for the pred layer)
438
+ if global_average_pooling:
439
+ # pool over the sequence dimension
440
+ layers["avg_pooling"] = GlobalAveragePooling(dim=1)
441
+ pred_layer_input_features = embedding_len
442
+ elif cls_pooling:
443
+ layers["cls_pooling"] = CLSPooling(cls_position=0)
444
+ pred_layer_input_features = embedding_len
445
+ else:
446
+ # no global average pooling or CLS token
447
+ # sequence dimension is still there, just flattened
448
+ layers["flatten"] = nn.Flatten()
449
+ pred_layer_input_features = embedding_len * aa_seq_len
450
+
451
+ # PREDICTION
452
+ if use_task_specific_layers:
453
+ # task specific prediction layers (nonlinear transform for each task)
454
+ layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
455
+ in_features=pred_layer_input_features,
456
+ num_hidden_nodes=task_specific_hidden_nodes,
457
+ activation=activation)
458
+ elif use_final_hidden_layer:
459
+ # combined prediction linear (linear transform for each task)
460
+ layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
461
+ num_hidden_nodes=final_hidden_size,
462
+ use_batchnorm=False,
463
+ use_layernorm=use_final_hidden_layer_norm,
464
+ norm_before_activation=final_hidden_layer_norm_before_activation,
465
+ use_dropout=use_final_hidden_layer_dropout,
466
+ dropout_rate=final_hidden_layer_dropout_rate,
467
+ activation=activation)
468
+
469
+ layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
470
+ else:
471
+ layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
472
+
473
+ # FINAL MODEL
474
+ self.model = SequentialWithArgs(layers)
475
+
476
+ def forward(self, x, **kwargs):
477
+ return self.model(x, **kwargs)
478
+
479
+
480
+ class Transpose(nn.Module):
481
+ """ helper layer to swap data from (batch, seq, channels) to (batch, channels, seq)
482
+ used as a helper in the convolutional network which pytorch defaults to channels-first """
483
+
484
+ def __init__(self, dims: Tuple[int, ...] = (1, 2)):
485
+ super().__init__()
486
+ self.dims = dims
487
+
488
+ def forward(self, x, **kwargs):
489
+ x = x.transpose(*self.dims).contiguous()
490
+ return x
491
+
492
+
493
+ def conv1d_out_shape(seq_len, kernel_size, stride=1, pad=0, dilation=1):
494
+ return (seq_len + (2 * pad) - (dilation * (kernel_size - 1)) - 1 // stride) + 1
495
+
496
+
497
+ class ConvBlock(nn.Module):
498
+ def __init__(self,
499
+ in_channels: int,
500
+ out_channels: int,
501
+ kernel_size: int,
502
+ dilation: int = 1,
503
+ padding: str = "same",
504
+ use_batchnorm: bool = False,
505
+ use_layernorm: bool = False,
506
+ norm_before_activation: bool = False,
507
+ use_dropout: bool = False,
508
+ dropout_rate: float = 0.2,
509
+ activation: str = "relu"):
510
+
511
+ super().__init__()
512
+
513
+ if use_batchnorm and use_layernorm:
514
+ raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
515
+
516
+ self.use_batchnorm = use_batchnorm
517
+ self.use_layernorm = use_layernorm
518
+ self.norm_before_activation = norm_before_activation
519
+ self.use_dropout = use_dropout
520
+
521
+ self.conv = nn.Conv1d(in_channels=in_channels,
522
+ out_channels=out_channels,
523
+ kernel_size=kernel_size,
524
+ padding=padding,
525
+ dilation=dilation)
526
+
527
+ self.activation = get_activation_fn(activation, functional=False)
528
+
529
+ if use_batchnorm:
530
+ self.norm = nn.BatchNorm1d(out_channels)
531
+
532
+ if use_layernorm:
533
+ self.norm = nn.LayerNorm(out_channels)
534
+
535
+ if use_dropout:
536
+ self.dropout = nn.Dropout(p=dropout_rate)
537
+
538
+ def forward(self, x, **kwargs):
539
+ x = self.conv(x)
540
+
541
+ # norm can be before or after activation, using flag
542
+ if self.use_batchnorm and self.norm_before_activation:
543
+ x = self.norm(x)
544
+ elif self.use_layernorm and self.norm_before_activation:
545
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
546
+
547
+ x = self.activation(x)
548
+
549
+ # batchnorm being applied after activation, there is some discussion on this online
550
+ if self.use_batchnorm and not self.norm_before_activation:
551
+ x = self.norm(x)
552
+ elif self.use_layernorm and not self.norm_before_activation:
553
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
554
+
555
+ # dropout being applied after batchnorm, there is some discussion on this online
556
+ if self.use_dropout:
557
+ x = self.dropout(x)
558
+
559
+ return x
560
+
561
+
562
+ class ConvModel2(nn.Module):
563
+ """ convolutional source model that supports padded inputs, pooling, etc """
564
+
565
+ @staticmethod
566
+ def add_model_specific_args(parent_parser):
567
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
568
+ parser.add_argument('--use_embedding', action="store_true", default=False)
569
+ parser.add_argument('--embedding_len', type=int, default=128)
570
+
571
+ parser.add_argument('--num_conv_layers', type=int, default=1)
572
+ parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
573
+ parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
574
+ parser.add_argument('--dilations', type=int, nargs="+", default=[1])
575
+ parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
576
+ parser.add_argument('--use_conv_layer_norm', action="store_true", default=False)
577
+ parser.add_argument('--conv_layer_norm_before_activation', action="store_true", default=False)
578
+ parser.add_argument('--use_conv_layer_dropout', action="store_true", default=False)
579
+ parser.add_argument('--conv_layer_dropout_rate', type=float, default=0.2)
580
+
581
+ parser.add_argument('--global_average_pooling', action="store_true", default=False)
582
+
583
+ parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
584
+ parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
585
+ parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
586
+ parser.add_argument('--final_hidden_size', type=int, default=64)
587
+ parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
588
+ parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
589
+ parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
590
+ parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
591
+
592
+ parser.add_argument('--activation', type=str, default="relu",
593
+ help="activation function used for all activations in the network")
594
+
595
+ return parser
596
+
597
+ def __init__(self,
598
+ # data
599
+ num_tasks: int,
600
+ aa_seq_len: int,
601
+ aa_encoding_len: int,
602
+ num_tokens: int,
603
+ # convolutional model args
604
+ use_embedding: bool = False,
605
+ embedding_len: int = 64,
606
+ num_conv_layers: int = 1,
607
+ kernel_sizes: List[int] = (7,),
608
+ out_channels: List[int] = (128,),
609
+ dilations: List[int] = (1,),
610
+ padding: str = "valid",
611
+ use_conv_layer_norm: bool = False,
612
+ conv_layer_norm_before_activation: bool = False,
613
+ use_conv_layer_dropout: bool = False,
614
+ conv_layer_dropout_rate: float = 0.2,
615
+ # pooling
616
+ global_average_pooling: bool = True,
617
+ # prediction layers
618
+ use_task_specific_layers: bool = False,
619
+ task_specific_hidden_nodes: int = 64,
620
+ use_final_hidden_layer: bool = False,
621
+ final_hidden_size: int = 64,
622
+ use_final_hidden_layer_norm: bool = False,
623
+ final_hidden_layer_norm_before_activation: bool = False,
624
+ use_final_hidden_layer_dropout: bool = False,
625
+ final_hidden_layer_dropout_rate: float = 0.2,
626
+ # activation function
627
+ activation: str = "relu",
628
+ *args, **kwargs):
629
+
630
+ super(ConvModel2, self).__init__()
631
+
632
+ # build up the layers
633
+ layers = collections.OrderedDict()
634
+
635
+ # amino acid embedding
636
+ if use_embedding:
637
+ layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=False)
638
+
639
+ # transpose the input to match PyTorch's expected format
640
+ layers["transpose"] = Transpose(dims=(1, 2))
641
+
642
+ # build up the convolutional layers
643
+ for layer_num in range(num_conv_layers):
644
+ # determine the number of input channels for the first convolutional layer
645
+ if layer_num == 0 and use_embedding:
646
+ # for the first convolutional layer, the in_channels is the embedding_len
647
+ in_channels = embedding_len
648
+ elif layer_num == 0 and not use_embedding:
649
+ # for the first convolutional layer, the in_channels is the aa_encoding_len
650
+ in_channels = aa_encoding_len
651
+ else:
652
+ in_channels = out_channels[layer_num - 1]
653
+
654
+ layers[f"conv{layer_num}"] = ConvBlock(in_channels=in_channels,
655
+ out_channels=out_channels[layer_num],
656
+ kernel_size=kernel_sizes[layer_num],
657
+ dilation=dilations[layer_num],
658
+ padding=padding,
659
+ use_batchnorm=False,
660
+ use_layernorm=use_conv_layer_norm,
661
+ norm_before_activation=conv_layer_norm_before_activation,
662
+ use_dropout=use_conv_layer_dropout,
663
+ dropout_rate=conv_layer_dropout_rate,
664
+ activation=activation)
665
+
666
+ # handle transition from convolutional layers to fully connected layer
667
+ # either use global average pooling or flatten
668
+ # take into consideration whether we are using valid or same padding
669
+ if global_average_pooling:
670
+ # global average pooling (mean across the seq len dimension)
671
+ # the seq len dimensions is the last dimension (batch_size, num_filters, seq_len)
672
+ layers["avg_pooling"] = GlobalAveragePooling(dim=-1)
673
+ # the prediction layers will take num_filters input features
674
+ pred_layer_input_features = out_channels[-1]
675
+
676
+ else:
677
+ # no global average pooling. flatten instead.
678
+ layers["flatten"] = nn.Flatten()
679
+ # calculate the final output len of the convolutional layers
680
+ # and the number of input features for the prediction layers
681
+ if padding == "valid":
682
+ # valid padding (aka no padding) results in shrinking length in progressive layers
683
+ conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0], dilation=dilations[0])
684
+ for layer_num in range(1, num_conv_layers):
685
+ conv_out_len = conv1d_out_shape(conv_out_len,
686
+ kernel_size=kernel_sizes[layer_num],
687
+ dilation=dilations[layer_num])
688
+ pred_layer_input_features = conv_out_len * out_channels[-1]
689
+ else:
690
+ # padding == "same"
691
+ pred_layer_input_features = aa_seq_len * out_channels[-1]
692
+
693
+ # prediction layer
694
+ if use_task_specific_layers:
695
+ layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
696
+ in_features=pred_layer_input_features,
697
+ num_hidden_nodes=task_specific_hidden_nodes,
698
+ activation=activation)
699
+
700
+ # final hidden layer (with potential additional dropout)
701
+ elif use_final_hidden_layer:
702
+ layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
703
+ num_hidden_nodes=final_hidden_size,
704
+ use_batchnorm=False,
705
+ use_layernorm=use_final_hidden_layer_norm,
706
+ norm_before_activation=final_hidden_layer_norm_before_activation,
707
+ use_dropout=use_final_hidden_layer_dropout,
708
+ dropout_rate=final_hidden_layer_dropout_rate,
709
+ activation=activation)
710
+ layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
711
+
712
+ else:
713
+ layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
714
+
715
+ self.model = nn.Sequential(layers)
716
+
717
+ def forward(self, x, **kwargs):
718
+ output = self.model(x)
719
+ return output
720
+
721
+
722
+ class ConvModel(nn.Module):
723
+ """ a convolutional network with convolutional layers followed by a fully connected layer """
724
+
725
+ @staticmethod
726
+ def add_model_specific_args(parent_parser):
727
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
728
+ parser.add_argument('--num_conv_layers', type=int, default=1)
729
+ parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
730
+ parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
731
+ parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
732
+ parser.add_argument('--use_final_hidden_layer', action="store_true",
733
+ help="whether to use a final hidden layer")
734
+ parser.add_argument('--final_hidden_size', type=int, default=128,
735
+ help="number of nodes in the final hidden layer")
736
+ parser.add_argument('--use_dropout', action="store_true",
737
+ help="whether to use dropout in the final hidden layer")
738
+ parser.add_argument('--dropout_rate', type=float, default=0.2,
739
+ help="dropout rate in the final hidden layer")
740
+ parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
741
+ parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
742
+ return parser
743
+
744
+ def __init__(self,
745
+ num_tasks: int,
746
+ aa_seq_len: int,
747
+ aa_encoding_len: int,
748
+ num_conv_layers: int = 1,
749
+ kernel_sizes: List[int] = (7,),
750
+ out_channels: List[int] = (128,),
751
+ padding: str = "valid",
752
+ use_final_hidden_layer: bool = True,
753
+ final_hidden_size: int = 128,
754
+ use_dropout: bool = False,
755
+ dropout_rate: float = 0.2,
756
+ use_task_specific_layers: bool = False,
757
+ task_specific_hidden_nodes: int = 64,
758
+ *args, **kwargs):
759
+
760
+ super(ConvModel, self).__init__()
761
+
762
+ # set up the model as a Sequential block (less to do in forward())
763
+ layers = collections.OrderedDict()
764
+
765
+ layers["transpose"] = Transpose(dims=(1, 2))
766
+
767
+ for layer_num in range(num_conv_layers):
768
+ # for the first convolutional layer, the in_channels is the feature_len
769
+ in_channels = aa_encoding_len if layer_num == 0 else out_channels[layer_num - 1]
770
+
771
+ layers["conv{}".format(layer_num)] = nn.Sequential(
772
+ nn.Conv1d(in_channels=in_channels,
773
+ out_channels=out_channels[layer_num],
774
+ kernel_size=kernel_sizes[layer_num],
775
+ padding=padding),
776
+ nn.ReLU()
777
+ )
778
+
779
+ layers["flatten"] = nn.Flatten()
780
+
781
+ # calculate the final output len of the convolutional layers
782
+ # and the number of input features for the prediction layers
783
+ if padding == "valid":
784
+ # valid padding (aka no padding) results in shrinking length in progressive layers
785
+ conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0])
786
+ for layer_num in range(1, num_conv_layers):
787
+ conv_out_len = conv1d_out_shape(conv_out_len, kernel_size=kernel_sizes[layer_num])
788
+ next_dim = conv_out_len * out_channels[-1]
789
+ elif padding == "same":
790
+ next_dim = aa_seq_len * out_channels[-1]
791
+ else:
792
+ raise ValueError("unexpected value for padding: {}".format(padding))
793
+
794
+ # final hidden layer (with potential additional dropout)
795
+ if use_final_hidden_layer:
796
+ layers["fc1"] = FCBlock(in_features=next_dim,
797
+ num_hidden_nodes=final_hidden_size,
798
+ use_batchnorm=False,
799
+ use_dropout=use_dropout,
800
+ dropout_rate=dropout_rate)
801
+ next_dim = final_hidden_size
802
+
803
+ # final prediction layer
804
+ # either task specific nonlinear layers or a single linear layer
805
+ if use_task_specific_layers:
806
+ layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
807
+ in_features=next_dim,
808
+ num_hidden_nodes=task_specific_hidden_nodes)
809
+ else:
810
+ layers["prediction"] = nn.Linear(in_features=next_dim, out_features=num_tasks)
811
+
812
+ self.model = nn.Sequential(layers)
813
+
814
+ def forward(self, x, **kwargs):
815
+ output = self.model(x)
816
+ return output
817
+
818
+
819
+ class FCModel(nn.Module):
820
+
821
+ @staticmethod
822
+ def add_model_specific_args(parent_parser):
823
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
824
+ parser.add_argument('--num_layers', type=int, default=1)
825
+ parser.add_argument('--num_hidden', nargs="+", type=int, default=[128])
826
+ parser.add_argument('--use_batchnorm', action="store_true", default=False)
827
+ parser.add_argument('--use_layernorm', action="store_true", default=False)
828
+ parser.add_argument('--norm_before_activation', action="store_true", default=False)
829
+ parser.add_argument('--use_dropout', action="store_true", default=False)
830
+ parser.add_argument('--dropout_rate', type=float, default=0.2)
831
+ return parser
832
+
833
+ def __init__(self,
834
+ num_tasks: int,
835
+ seq_encoding_len: int,
836
+ num_layers: int = 1,
837
+ num_hidden: List[int] = (128,),
838
+ use_batchnorm: bool = False,
839
+ use_layernorm: bool = False,
840
+ norm_before_activation: bool = False,
841
+ use_dropout: bool = False,
842
+ dropout_rate: float = 0.2,
843
+ activation: str = "relu",
844
+ *args, **kwargs):
845
+ super().__init__()
846
+
847
+ # set up the model as a Sequential block (less to do in forward())
848
+ layers = collections.OrderedDict()
849
+
850
+ # flatten inputs as this is all fully connected
851
+ layers["flatten"] = nn.Flatten()
852
+
853
+ # build up the variable number of hidden layers (fully connected + ReLU + dropout (if set))
854
+ for layer_num in range(num_layers):
855
+ # for the first layer (layer_num == 0), in_features is determined by given input
856
+ # for subsequent layers, the in_features is the previous layer's num_hidden
857
+ in_features = seq_encoding_len if layer_num == 0 else num_hidden[layer_num - 1]
858
+
859
+ layers["fc{}".format(layer_num)] = FCBlock(in_features=in_features,
860
+ num_hidden_nodes=num_hidden[layer_num],
861
+ use_batchnorm=use_batchnorm,
862
+ use_layernorm=use_layernorm,
863
+ norm_before_activation=norm_before_activation,
864
+ use_dropout=use_dropout,
865
+ dropout_rate=dropout_rate,
866
+ activation=activation)
867
+
868
+ # finally, the linear output layer
869
+ in_features = num_hidden[-1] if num_layers > 0 else seq_encoding_len
870
+ layers["output"] = nn.Linear(in_features=in_features, out_features=num_tasks)
871
+
872
+ self.model = nn.Sequential(layers)
873
+
874
+ def forward(self, x, **kwargs):
875
+ output = self.model(x)
876
+ return output
877
+
878
+
879
+ class LRModel(nn.Module):
880
+ """ a simple linear model """
881
+
882
+ def __init__(self, num_tasks, seq_encoding_len, *args, **kwargs):
883
+ super().__init__()
884
+
885
+ self.model = nn.Sequential(
886
+ nn.Flatten(),
887
+ nn.Linear(seq_encoding_len, out_features=num_tasks))
888
+
889
+ def forward(self, x, **kwargs):
890
+ output = self.model(x)
891
+ return output
892
+
893
+
894
+ class TransferModel(nn.Module):
895
+ """ transfer learning model """
896
+
897
+ @staticmethod
898
+ def add_model_specific_args(parent_parser):
899
+
900
+ def none_or_int(value: str):
901
+ return None if value.lower() == "none" else int(value)
902
+
903
+ p = ArgumentParser(parents=[parent_parser], add_help=False)
904
+
905
+ # for model set up
906
+ p.add_argument('--pretrained_ckpt_path', type=str, default=None)
907
+
908
+ # where to cut off the backbone
909
+ p.add_argument("--backbone_cutoff", type=none_or_int, default=-1,
910
+ help="where to cut off the backbone. can be a negative int, indexing back from "
911
+ "pretrained_model.model.model. a value of -1 would chop off the backbone prediction head. "
912
+ "a value of -2 chops the prediction head and FC layer. a value of -3 chops"
913
+ "the above, as well as the global average pooling layer. all depends on architecture.")
914
+
915
+ p.add_argument("--pred_layer_input_features", type=int, default=None,
916
+ help="if None, number of features will be determined based on backbone_cutoff and standard "
917
+ "architecture. otherwise, specify the number of input features for the prediction layer")
918
+
919
+ # top net args
920
+ p.add_argument("--top_net_type", type=str, default="linear", choices=["linear", "nonlinear", "sklearn"])
921
+ p.add_argument("--top_net_hidden_nodes", type=int, default=256)
922
+ p.add_argument("--top_net_use_batchnorm", action="store_true")
923
+ p.add_argument("--top_net_use_dropout", action="store_true")
924
+ p.add_argument("--top_net_dropout_rate", type=float, default=0.1)
925
+
926
+ return p
927
+
928
+ def __init__(self,
929
+ # pretrained model
930
+ pretrained_ckpt_path: Optional[str] = None,
931
+ pretrained_hparams: Optional[dict] = None,
932
+ backbone_cutoff: Optional[int] = -1,
933
+ # top net
934
+ pred_layer_input_features: Optional[int] = None,
935
+ top_net_type: str = "linear",
936
+ top_net_hidden_nodes: int = 256,
937
+ top_net_use_batchnorm: bool = False,
938
+ top_net_use_dropout: bool = False,
939
+ top_net_dropout_rate: float = 0.1,
940
+ *args, **kwargs):
941
+
942
+ super().__init__()
943
+
944
+ # error checking: if pretrained_ckpt_path is None, then pretrained_hparams must be specified
945
+ if pretrained_ckpt_path is None and pretrained_hparams is None:
946
+ raise ValueError("Either pretrained_ckpt_path or pretrained_hparams must be specified")
947
+
948
+ # note: pdb_fns is loaded from transfer model arguments rather than original source model hparams
949
+ # if pdb_fns is specified as a kwarg, pass it on for structure-based RPE
950
+ # otherwise, can just set pdb_fns to None, and structure-based RPE will handle new PDBs on the fly
951
+ pdb_fns = kwargs["pdb_fns"] if "pdb_fns" in kwargs else None
952
+
953
+ # generate a fresh backbone using pretrained_hparams if specified
954
+ # otherwise load the backbone from the pretrained checkpoint
955
+ # we prioritize pretrained_hparams over pretrained_ckpt_path because
956
+ # pretrained_hparams will only really be specified if we are loading from a DMSTask checkpoint
957
+ # meaning the TransferModel has already been fine-tuned on DMS data, and we are likely loading
958
+ # weights from that finetuning (including weights for the backbone)
959
+ # whereas if pretrained_hparams is not specified but pretrained_ckpt_path is, then we are
960
+ # likely finetuning the TransferModel for the first time, and we need the pretrained weights for the
961
+ # backbone from the RosettaTask checkpoint
962
+ if pretrained_hparams is not None:
963
+ # pretrained_hparams will only be specified if we are loading from a DMSTask checkpoint
964
+ pretrained_hparams["pdb_fns"] = pdb_fns
965
+ pretrained_model = Model[pretrained_hparams["model_name"]].cls(**pretrained_hparams)
966
+ self.pretrained_hparams = pretrained_hparams
967
+ else:
968
+ # not supported in metl-pretrained
969
+ raise NotImplementedError("Loading pretrained weights from RosettaTask checkpoint not supported")
970
+
971
+ layers = collections.OrderedDict()
972
+
973
+ # set the backbone to all layers except the last layer (the pre-trained prediction layer)
974
+ if backbone_cutoff is None:
975
+ layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children()))
976
+ else:
977
+ layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children())[0:backbone_cutoff])
978
+
979
+ if top_net_type == "sklearn":
980
+ # sklearn top not doesn't require any more layers, just return model for the repr layer
981
+ self.model = SequentialWithArgs(layers)
982
+ return
983
+
984
+ # figure out dimensions of input into the prediction layer
985
+ if pred_layer_input_features is None:
986
+ # todo: can make this more robust by checking if the pretrained_mode.hparams for use_final_hidden_layer,
987
+ # global_average_pooling, etc. then can determine what the layer will be based on backbone_cutoff.
988
+ # currently, assumes that pretrained_model uses global average pooling and a final_hidden_layer
989
+ if backbone_cutoff is None:
990
+ # no backbone cutoff... use the full network (including tasks) as the backbone
991
+ pred_layer_input_features = self.pretrained_hparams["num_tasks"]
992
+ elif backbone_cutoff == -1:
993
+ pred_layer_input_features = self.pretrained_hparams["final_hidden_size"]
994
+ elif backbone_cutoff == -2:
995
+ pred_layer_input_features = self.pretrained_hparams["embedding_len"]
996
+ elif backbone_cutoff == -3:
997
+ pred_layer_input_features = self.pretrained_hparams["embedding_len"] * kwargs["aa_seq_len"]
998
+ else:
999
+ raise ValueError("can't automatically determine pred_layer_input_features for given backbone_cutoff")
1000
+
1001
+ layers["flatten"] = nn.Flatten(start_dim=1)
1002
+
1003
+ # create a new prediction layer on top of the backbone
1004
+ if top_net_type == "linear":
1005
+ # linear layer for prediction
1006
+ layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=1)
1007
+ elif top_net_type == "nonlinear":
1008
+ # fully connected with hidden layer
1009
+ fc_block = FCBlock(in_features=pred_layer_input_features,
1010
+ num_hidden_nodes=top_net_hidden_nodes,
1011
+ use_batchnorm=top_net_use_batchnorm,
1012
+ use_dropout=top_net_use_dropout,
1013
+ dropout_rate=top_net_dropout_rate)
1014
+
1015
+ pred_layer = nn.Linear(in_features=top_net_hidden_nodes, out_features=1)
1016
+
1017
+ layers["prediction"] = SequentialWithArgs(fc_block, pred_layer)
1018
+ else:
1019
+ raise ValueError("Unexpected type of top net layer: {}".format(top_net_type))
1020
+
1021
+ self.model = SequentialWithArgs(layers)
1022
+
1023
+ def forward(self, x, **kwargs):
1024
+ return self.model(x, **kwargs)
1025
+
1026
+
1027
+ def get_activation_fn(activation, functional=True):
1028
+ if activation == "relu":
1029
+ return F.relu if functional else nn.ReLU()
1030
+ elif activation == "gelu":
1031
+ return F.gelu if functional else nn.GELU()
1032
+ elif activation == "silo" or activation == "swish":
1033
+ return F.silu if functional else nn.SiLU()
1034
+ elif activation == "leaky_relu" or activation == "lrelu":
1035
+ return F.leaky_relu if functional else nn.LeakyReLU()
1036
+ else:
1037
+ raise RuntimeError("unknown activation: {}".format(activation))
1038
+
1039
+
1040
+ class Model(enum.Enum):
1041
+ def __new__(cls, *args, **kwds):
1042
+ value = len(cls.__members__) + 1
1043
+ obj = object.__new__(cls)
1044
+ obj._value_ = value
1045
+ return obj
1046
+
1047
+ def __init__(self, cls, transfer_model):
1048
+ self.cls = cls
1049
+ self.transfer_model = transfer_model
1050
+
1051
+ linear = LRModel, False
1052
+ fully_connected = FCModel, False
1053
+ cnn = ConvModel, False
1054
+ cnn2 = ConvModel2, False
1055
+ transformer_encoder = AttnModel, False
1056
+ transfer_model = TransferModel, True
1057
+
1058
+
1059
+ def main():
1060
+ pass
1061
+
1062
+
1063
+ if __name__ == "__main__":
1064
+ main()
metl/relative_attention.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ implementation of transformer encoder with relative attention
2
+ references:
3
+ - https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a
4
+ - https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
5
+ - https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py
6
+ - https://github.com/jiezouguihuafu/ClassicalModelreproduced/blob/main/Transformer/transfor_rpe.py
7
+ """
8
+
9
+ import copy
10
+ from os.path import basename, dirname, join, isfile
11
+ from typing import Optional, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch import Tensor
17
+ from torch.nn import Linear, Dropout, LayerNorm
18
+ import time
19
+ import networkx as nx
20
+
21
+ import metl.structure as structure
22
+ import metl.models as models
23
+
24
+
25
+ class RelativePosition3D(nn.Module):
26
+ """ Contact map-based relative position embeddings """
27
+
28
+ # need to compute a bucket_mtx for each structure
29
+ # need to know which bucket_mtx to use when grabbing the embeddings in forward()
30
+ # - on init, get a list of all PDB files we will be using
31
+ # - use a dictionary to store PDB files --> bucket_mtxs
32
+ # - forward() gets a new arg: the pdb file, which indexes into the dictionary to grab the right bucket_mtx
33
+ def __init__(self,
34
+ embedding_len: int,
35
+ contact_threshold: int,
36
+ clipping_threshold: int,
37
+ pdb_fns: Optional[Union[str, list, tuple]] = None,
38
+ default_pdb_dir: str = "data/pdb_files"):
39
+
40
+ # preferably, pdb_fns contains full paths to the PDBs, but if just the PDB filename is given
41
+ # then it defaults to the path data/pdb_files/<pdb_fn>
42
+ super().__init__()
43
+ self.embedding_len = embedding_len
44
+ self.clipping_threshold = clipping_threshold
45
+ self.contact_threshold = contact_threshold
46
+ self.default_pdb_dir = default_pdb_dir
47
+
48
+ # dummy buffer for getting correct device for on-the-fly bucket matrix generation
49
+ self.register_buffer("dummy_buffer", torch.empty(0), persistent=False)
50
+
51
+ # for 3D-based positions, the number of embeddings is generally the number of buckets
52
+ # for contact map-based distances, that is clipping_threshold + 1
53
+ num_embeddings = clipping_threshold + 1
54
+
55
+ # this is the embedding lookup table E_r
56
+ self.embeddings_table = nn.Embedding(num_embeddings, embedding_len)
57
+
58
+ # set up pdb_fns that were passed in on init (can also be set up during runtime in forward())
59
+ # todo: i'm using a hacky workaround to move the bucket_mtxs to the correct device
60
+ # i tried to make it more efficient by registering bucket matrices as buffers, but i was
61
+ # having problems with DDP syncing the buffers across processes
62
+ self.bucket_mtxs = {}
63
+ self.bucket_mtxs_device = self.dummy_buffer.device
64
+ self._init_pdbs(pdb_fns)
65
+
66
+ def forward(self, pdb_fn):
67
+ # compute matrix R by grabbing the embeddings from the embeddings lookup table
68
+ embeddings = self.embeddings_table(self._get_bucket_mtx(pdb_fn))
69
+ return embeddings
70
+
71
+ # def _get_bucket_mtx(self, pdb_fn):
72
+ # """ retrieve a bucket matrix given the pdb_fn.
73
+ # if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be
74
+ # retrieved from the object buffer. if the bucket matrix has not been computed yet, it will be here """
75
+ # pdb_attr = self._pdb_key(pdb_fn)
76
+ # if hasattr(self, pdb_attr):
77
+ # return getattr(self, pdb_attr)
78
+ # else:
79
+ # # encountering a new PDB at runtime... process it
80
+ # # todo: if there's a new PDB at runtime, it will be initialized separately in each instance
81
+ # # of RelativePosition3D, for each layer. It would be more efficient to have a global
82
+ # # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through
83
+ # self._init_pdb(pdb_fn)
84
+ # return getattr(self, pdb_attr)
85
+
86
+ def _move_bucket_mtxs(self, device):
87
+ for k, v in self.bucket_mtxs.items():
88
+ self.bucket_mtxs[k] = v.to(device)
89
+ self.bucket_mtxs_device = device
90
+
91
+ def _get_bucket_mtx(self, pdb_fn):
92
+ """ retrieve a bucket matrix given the pdb_fn.
93
+ if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be
94
+ retrieved from the bucket_mtxs dictionary. else, it will be computed now on-the-fly """
95
+
96
+ # ensure that all the bucket matrices are on the same device as the nn.Embedding
97
+ if self.bucket_mtxs_device != self.dummy_buffer.device:
98
+ self._move_bucket_mtxs(self.dummy_buffer.device)
99
+
100
+ pdb_attr = self._pdb_key(pdb_fn)
101
+ if pdb_attr in self.bucket_mtxs:
102
+ return self.bucket_mtxs[pdb_attr]
103
+ else:
104
+ # encountering a new PDB at runtime... process it
105
+ # todo: if there's a new PDB at runtime, it will be initialized separately in each instance
106
+ # of RelativePosition3D, for each layer. It would be more efficient to have a global
107
+ # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through
108
+ self._init_pdb(pdb_fn)
109
+ return self.bucket_mtxs[pdb_attr]
110
+
111
+ # def _set_bucket_mtx(self, pdb_fn, bucket_mtx):
112
+ # """ store a bucket matrix as a buffer """
113
+ # # if PyTorch ever implements a BufferDict, we could use it here efficiently
114
+ # # there is also BufferDict from https://botorch.org/api/_modules/botorch/utils/torch.html
115
+ # # would just need to modify it to have an option for persistent=False
116
+ # bucket_mtx = bucket_mtx.to(self.dummy_buffer.device)
117
+ #
118
+ # self.register_buffer(self._pdb_key(pdb_fn), bucket_mtx, persistent=False)
119
+
120
+ def _set_bucket_mtx(self, pdb_fn, bucket_mtx):
121
+ """ store a bucket matrix in the bucket dict """
122
+
123
+ # move the bucket_mtx to the same device that the other bucket matrices are on
124
+ bucket_mtx = bucket_mtx.to(self.bucket_mtxs_device)
125
+
126
+ self.bucket_mtxs[self._pdb_key(pdb_fn)] = bucket_mtx
127
+
128
+ @staticmethod
129
+ def _pdb_key(pdb_fn):
130
+ """ return a unique key for the given pdb_fn, used to map unique PDBs """
131
+ # note this key does NOT currently support PDBs with the same basename but different paths
132
+ # assumes every PDB is in the format <pdb_name>.pdb
133
+ # should be a compatible with being a class attribute, as it is used as a pytorch buffer name
134
+ return f"pdb_{basename(pdb_fn).split('.')[0]}"
135
+
136
+ def _init_pdbs(self, pdb_fns):
137
+ start = time.time()
138
+
139
+ if pdb_fns is None:
140
+ # nothing to initialize if pdb_fns is None
141
+ return
142
+
143
+ # make sure pdb_fns is a list
144
+ if not isinstance(pdb_fns, list) and not isinstance(pdb_fns, tuple):
145
+ pdb_fns = [pdb_fns]
146
+
147
+ # init each pdb fn in the list
148
+ for pdb_fn in pdb_fns:
149
+ self._init_pdb(pdb_fn)
150
+
151
+ print("Initialized PDB bucket matrices in: {:.3f}".format(time.time() - start))
152
+
153
+ def _init_pdb(self, pdb_fn):
154
+ """ process a pdb file for use with structure-based relative attention """
155
+ # if pdb_fn is not a full path, default to the path data/pdb_files/<pdb_fn>
156
+ if dirname(pdb_fn) == "":
157
+ # handle the case where the pdb file is in the current working directory
158
+ # if there is a PDB file in the cwd.... then just use it as is. otherwise, append the default.
159
+ if not isfile(pdb_fn):
160
+ pdb_fn = join(self.default_pdb_dir, pdb_fn)
161
+
162
+ # create a structure graph from the pdb_fn and contact threshold
163
+ cbeta_mtx = structure.cbeta_distance_matrix(pdb_fn)
164
+ structure_graph = structure.dist_thresh_graph(cbeta_mtx, self.contact_threshold)
165
+
166
+ # bucket_mtx indexes into the embedding lookup table to create the final distance matrix
167
+ bucket_mtx = self._compute_bucket_mtx(structure_graph)
168
+
169
+ self._set_bucket_mtx(pdb_fn, bucket_mtx)
170
+
171
+ def _compute_bucketed_neighbors(self, structure_graph, source_node):
172
+ """ gets the bucketed neighbors from the given source node and structure graph"""
173
+ if self.clipping_threshold < 0:
174
+ raise ValueError("Clipping threshold must be >= 0")
175
+
176
+ sspl = _inv_dict(nx.single_source_shortest_path_length(structure_graph, source_node))
177
+
178
+ if self.clipping_threshold is not None:
179
+ num_buckets = 1 + self.clipping_threshold
180
+ sspl = _combine_d(sspl, self.clipping_threshold, num_buckets - 1)
181
+
182
+ return sspl
183
+
184
+ def _compute_bucket_mtx(self, structure_graph):
185
+ """ get the bucket_mtx for the given structure_graph
186
+ calls _get_bucketed_neighbors for every node in the structure_graph """
187
+ num_residues = len(list(structure_graph))
188
+
189
+ # index into the embedding lookup table to create the final distance matrix
190
+ bucket_mtx = torch.zeros(num_residues, num_residues, dtype=torch.long)
191
+
192
+ for node_num in sorted(list(structure_graph)):
193
+ bucketed_neighbors = self._compute_bucketed_neighbors(structure_graph, node_num)
194
+
195
+ for bucket_num, neighbors in bucketed_neighbors.items():
196
+ bucket_mtx[node_num, neighbors] = bucket_num
197
+
198
+ return bucket_mtx
199
+
200
+
201
+ class RelativePosition(nn.Module):
202
+ """ creates the embedding lookup table E_r and computes R
203
+ note this inherits from pl.LightningModule instead of nn.Module
204
+ makes it easier to access the device with `self.device`
205
+ might be able to keep it as an nn.Module using the hacky dummy_param or commented out .device property """
206
+
207
+ def __init__(self, embedding_len: int, clipping_threshold: int):
208
+ """
209
+ embedding_len: the length of the embedding, may be d_model, or d_model // num_heads for multihead
210
+ clipping_threshold: the maximum relative position, referred to as k by Shaw et al.
211
+ """
212
+ super().__init__()
213
+ self.embedding_len = embedding_len
214
+ self.clipping_threshold = clipping_threshold
215
+ # for sequence-based distances, the number of embeddings is 2*k+1, where k is the clipping threshold
216
+ num_embeddings = 2 * clipping_threshold + 1
217
+
218
+ # this is the embedding lookup table E_r
219
+ self.embeddings_table = nn.Embedding(num_embeddings, embedding_len)
220
+
221
+ # for getting the correct device for range vectors in forward
222
+ self.register_buffer("dummy_buffer", torch.empty(0), persistent=False)
223
+
224
+ def forward(self, length_q, length_k):
225
+ # supports different length sequences, but in self-attention length_q and length_k are the same
226
+ range_vec_q = torch.arange(length_q, device=self.dummy_buffer.device)
227
+ range_vec_k = torch.arange(length_k, device=self.dummy_buffer.device)
228
+
229
+ # this sets up the standard sequence-based distance matrix for relative positions
230
+ # the current position is 0, positions to the right are +1, +2, etc, and to the left -1, -2, etc
231
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
232
+ distance_mat_clipped = torch.clamp(distance_mat, -self.clipping_threshold, self.clipping_threshold)
233
+
234
+ # convert to indices, indexing into the embedding table
235
+ final_mat = (distance_mat_clipped + self.clipping_threshold).long()
236
+
237
+ # compute matrix R by grabbing the embeddings from the embedding lookup table
238
+ embeddings = self.embeddings_table(final_mat)
239
+
240
+ return embeddings
241
+
242
+
243
+ class RelativeMultiHeadAttention(nn.Module):
244
+ def __init__(self, embed_dim, num_heads, dropout, pos_encoding, clipping_threshold, contact_threshold, pdb_fns):
245
+ """
246
+ Multi-head attention with relative position embeddings. Input data should be in batch_first format.
247
+ :param embed_dim: aka d_model, aka hid_dim
248
+ :param num_heads: number of heads
249
+ :param dropout: how much dropout for scaled dot product attention
250
+
251
+ :param pos_encoding: what type of positional encoding to use, relative or relative3D
252
+ :param clipping_threshold: clipping threshold for relative position embedding
253
+ :param contact_threshold: for relative_3D, the threshold in angstroms for the contact map
254
+ :param pdb_fns: pdb file(s) to set up the relative position object
255
+
256
+ """
257
+ super().__init__()
258
+
259
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
260
+
261
+ # model dimensions
262
+ self.embed_dim = embed_dim
263
+ self.num_heads = num_heads
264
+ self.head_dim = embed_dim // num_heads
265
+
266
+ # pos encoding stuff
267
+ self.pos_encoding = pos_encoding
268
+ self.clipping_threshold = clipping_threshold
269
+ self.contact_threshold = contact_threshold
270
+ if pdb_fns is not None and not isinstance(pdb_fns, list):
271
+ pdb_fns = [pdb_fns]
272
+ self.pdb_fns = pdb_fns
273
+
274
+ # relative position embeddings for use with keys and values
275
+ # Shaw et al. uses relative position information for both keys and values
276
+ # Huang et al. only uses it for the keys, which is probably enough
277
+ if pos_encoding == "relative":
278
+ self.relative_position_k = RelativePosition(self.head_dim, self.clipping_threshold)
279
+ self.relative_position_v = RelativePosition(self.head_dim, self.clipping_threshold)
280
+ elif pos_encoding == "relative_3D":
281
+ self.relative_position_k = RelativePosition3D(self.head_dim, self.contact_threshold,
282
+ self.clipping_threshold, self.pdb_fns)
283
+ self.relative_position_v = RelativePosition3D(self.head_dim, self.contact_threshold,
284
+ self.clipping_threshold, self.pdb_fns)
285
+ else:
286
+ raise ValueError("unrecognized pos_encoding: {}".format(pos_encoding))
287
+
288
+ # WQ, WK, and WV from attention is all you need
289
+ # note these default to bias=True, same as PyTorch implementation
290
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
291
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
292
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
293
+
294
+ # WO from attention is all you need
295
+ # used for the final projection when computing multi-head attention
296
+ # PyTorch uses NonDynamicallyQuantizableLinear instead of Linear to avoid triggering an obscure
297
+ # error quantizing the model https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L122
298
+ # todo: if quantizing the model, explore if the above is a concern for us
299
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
300
+
301
+ # dropout for scaled dot product attention
302
+ self.dropout = nn.Dropout(dropout)
303
+
304
+ # scaling factor for scaled dot product attention
305
+ scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
306
+ # persistent=False if you don't want to save it inside state_dict
307
+ self.register_buffer('scale', scale)
308
+
309
+ # toggles meant to be set directly by user
310
+ self.need_weights = False
311
+ self.average_attn_weights = True
312
+
313
+ def _compute_attn_weights(self, query, key, len_q, len_k, batch_size, mask, pdb_fn):
314
+ """ computes the attention weights (a "compatability function" of queries with corresponding keys) """
315
+
316
+ # calculate the first term in the numerator attn1, which is Q*K
317
+ # todo: pytorch reshapes q,k and v to 3 dimensions (similar to how r_q2 is below)
318
+ # is that functionally equivalent to what we're doing? is their way faster?
319
+ # r_q1 = [batch_size, num_heads, len_q, head_dim]
320
+ r_q1 = query.view(batch_size, len_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
321
+ # todo: we could directly permute r_k1 to [batch_size, num_heads, head_dim, len_k]
322
+ # to make it compatible for matrix multiplication with r_q1, instead of 2-step approach
323
+ # r_k1 = [batch_size, num_heads, len_k, head_dim]
324
+ r_k1 = key.view(batch_size, len_k, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
325
+ # attn1 = [batch_size, num_heads, len_q, len_k]
326
+ attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))
327
+
328
+ # calculate the second term in the numerator attn2, which is Q*R
329
+ # r_q2 = [query_len, batch_size * num_heads, head_dim]
330
+ r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size * self.num_heads, self.head_dim)
331
+
332
+ # todo: support multiple different PDB base structures per batch
333
+ # one option:
334
+ # - require batches to be all the same protein
335
+ # - add argument to forward() to accept the PDB file for the protein in the batch
336
+ # - then we just pass in the PDB file to relative position's forward()
337
+ # to support multiple different structures per batch:
338
+ # - add argument to forward() to accept PDB files, one for each item in batch
339
+ # - make corresponding changing in relative_position object to return R for each structure
340
+ # - note: if there are a lot of of different structures, and the sequence lengths are long,
341
+ # this could be memory prohibitive because R (rel_pos_k) can take up a lot of mem for long seqs
342
+ # - adjust the attn2 calculation to factor in the multiple different R matrices.
343
+ # the way to do this might have to be to do multiple matmuls, one for each each structure.
344
+ # basically, would split up r_q2 into several matrices grouped by structure, and then
345
+ # multiply with corresponding R, then combine back into the exact same order of the original r_q2
346
+ # note: this may be computationally intensive (splitting, more matrix muliplies, joining)
347
+ # another option would be to create views(?), repeating the different Rs so we can do a
348
+ # a matris multiply directly with r_q2
349
+ # - would shapes be affected if there was padding in the queries, keys, values?
350
+
351
+ if self.pos_encoding == "relative":
352
+ # rel_pos_k = [len_q, len_k, head_dim]
353
+ rel_pos_k = self.relative_position_k(len_q, len_k)
354
+ elif self.pos_encoding == "relative_3D":
355
+ # rel_pos_k = [sequence length (from PDB structure), head_dim]
356
+ rel_pos_k = self.relative_position_k(pdb_fn)
357
+ else:
358
+ raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding))
359
+
360
+ # the matmul basically computes the dot product between each input position’s query vector and
361
+ # its corresponding relative position embeddings across all input sequences in the heads and batch
362
+ # attn2 = [batch_size * num_heads, len_q, len_k]
363
+ attn2 = torch.matmul(r_q2, rel_pos_k.transpose(1, 2)).transpose(0, 1)
364
+ # attn2 = [batch_size, num_heads, len_q, len_k]
365
+ attn2 = attn2.contiguous().view(batch_size, self.num_heads, len_q, len_k)
366
+
367
+ # calculate attention weights
368
+ attn_weights = (attn1 + attn2) / self.scale
369
+
370
+ # apply mask if given
371
+ if mask is not None:
372
+ # todo: pytorch uses float("-inf") instead of -1e10
373
+ attn_weights = attn_weights.masked_fill(mask == 0, -1e10)
374
+
375
+ # softmax gives us attn_weights weights
376
+ attn_weights = torch.softmax(attn_weights, dim=-1)
377
+ # attn_weights = [batch_size, num_heads, len_q, len_k]
378
+ attn_weights = self.dropout(attn_weights)
379
+
380
+ return attn_weights
381
+
382
+ def _compute_avg_val(self, value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn):
383
+ # todo: add option to not factor in relative position embeddings in value calculation
384
+ # calculate the first term, the attn*values
385
+ # r_v1 = [batch_size, num_heads, len_v, head_dim]
386
+ r_v1 = value.view(batch_size, len_v, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
387
+ # avg1 = [batch_size, num_heads, len_q, head_dim]
388
+ avg1 = torch.matmul(attn_weights, r_v1)
389
+
390
+ # calculate the second term, the attn*R
391
+ # similar to how relative embeddings are factored in the attention weights calculation
392
+ if self.pos_encoding == "relative":
393
+ # rel_pos_v = [query_len, value_len, head_dim]
394
+ rel_pos_v = self.relative_position_v(len_q, len_v)
395
+ elif self.pos_encoding == "relative_3D":
396
+ # rel_pos_v = [sequence length (from PDB structure), head_dim]
397
+ rel_pos_v = self.relative_position_v(pdb_fn)
398
+ else:
399
+ raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding))
400
+
401
+ # r_attn_weights = [len_q, batch_size * num_heads, len_v]
402
+ r_attn_weights = attn_weights.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size * self.num_heads, len_k)
403
+ avg2 = torch.matmul(r_attn_weights, rel_pos_v)
404
+ # avg2 = [batch_size, num_heads, len_q, head_dim]
405
+ avg2 = avg2.transpose(0, 1).contiguous().view(batch_size, self.num_heads, len_q, self.head_dim)
406
+
407
+ # calculate avg value
408
+ x = avg1 + avg2 # [batch_size, num_heads, len_q, head_dim]
409
+ x = x.permute(0, 2, 1, 3).contiguous() # [batch_size, len_q, num_heads, head_dim]
410
+ # x = [batch_size, len_q, embed_dim]
411
+ x = x.view(batch_size, len_q, self.embed_dim)
412
+
413
+ return x
414
+
415
+ def forward(self, query, key, value, pdb_fn=None, mask=None):
416
+ # query = [batch_size, q_len, embed_dim]
417
+ # key = [batch_size, k_len, embed_dim]
418
+ # value = [batch_size, v_en, embed_dim]
419
+ batch_size = query.shape[0]
420
+ len_k, len_q, len_v = (key.shape[1], query.shape[1], value.shape[1])
421
+
422
+ # in projection (multiply inputs by WQ, WK, WV)
423
+ query = self.q_proj(query)
424
+ key = self.k_proj(key)
425
+ value = self.v_proj(value)
426
+
427
+ # first compute the attention weights, then multiply with values
428
+ # attn = [batch size, num_heads, len_q, len_k]
429
+ attn_weights = self._compute_attn_weights(query, key, len_q, len_k, batch_size, mask, pdb_fn)
430
+
431
+ # take weighted average of values (weighted by attention weights)
432
+ attn_output = self._compute_avg_val(value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn)
433
+
434
+ # output projection
435
+ # attn_output = [batch_size, len_q, embed_dim]
436
+ attn_output = self.out_proj(attn_output)
437
+
438
+ if self.need_weights:
439
+ # return attention weights in addition to attention
440
+ # average the weights over the heads (to get overall attention)
441
+ # attn_weights = [batch_size, len_q, len_k]
442
+ if self.average_attn_weights:
443
+ attn_weights = attn_weights.sum(dim=1) / self.num_heads
444
+ return {"attn_output": attn_output, "attn_weights": attn_weights}
445
+ else:
446
+ return attn_output
447
+
448
+
449
+ class RelativeTransformerEncoderLayer(nn.Module):
450
+ """
451
+ d_model: the number of expected features in the input (required).
452
+ nhead: the number of heads in the MultiHeadAttention models (required).
453
+ clipping_threshold: the clipping threshold for relative position embeddings
454
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
455
+ dropout: the dropout value (default=0.1).
456
+ activation: the activation function of the intermediate layer, can be a string
457
+ ("relu" or "gelu") or a unary callable. Default: relu
458
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
459
+ norm_first: if ``True``, layer norm is done prior to attention and feedforward
460
+ operations, respectively. Otherwise, it's done after. Default: ``False`` (after).
461
+ """
462
+
463
+ # this is some kind of torch jit compiling helper... will also ensure these values don't change
464
+ __constants__ = ['batch_first', 'norm_first']
465
+
466
+ def __init__(self,
467
+ d_model,
468
+ nhead,
469
+ pos_encoding="relative",
470
+ clipping_threshold=3,
471
+ contact_threshold=7,
472
+ pdb_fns=None,
473
+ dim_feedforward=2048,
474
+ dropout=0.1,
475
+ activation=F.relu,
476
+ layer_norm_eps=1e-5,
477
+ norm_first=False) -> None:
478
+
479
+ self.batch_first = True
480
+
481
+ super(RelativeTransformerEncoderLayer, self).__init__()
482
+
483
+ self.self_attn = RelativeMultiHeadAttention(d_model, nhead, dropout,
484
+ pos_encoding, clipping_threshold, contact_threshold, pdb_fns)
485
+
486
+ # feed forward model
487
+ self.linear1 = Linear(d_model, dim_feedforward)
488
+ self.dropout = Dropout(dropout)
489
+ self.linear2 = Linear(dim_feedforward, d_model)
490
+
491
+ self.norm_first = norm_first
492
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
493
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
494
+ self.dropout1 = Dropout(dropout)
495
+ self.dropout2 = Dropout(dropout)
496
+
497
+ # Legacy string support for activation function.
498
+ if isinstance(activation, str):
499
+ self.activation = models.get_activation_fn(activation)
500
+ else:
501
+ self.activation = activation
502
+
503
+ def forward(self, src: Tensor, pdb_fn=None) -> Tensor:
504
+ x = src
505
+ if self.norm_first:
506
+ x = x + self._sa_block(self.norm1(x), pdb_fn=pdb_fn)
507
+ x = x + self._ff_block(self.norm2(x))
508
+ else:
509
+ x = self.norm1(x + self._sa_block(x))
510
+ x = self.norm2(x + self._ff_block(x))
511
+
512
+ return x
513
+
514
+ # self-attention block
515
+ def _sa_block(self, x: Tensor, pdb_fn=None) -> Tensor:
516
+ x = self.self_attn(x, x, x, pdb_fn=pdb_fn)
517
+ if isinstance(x, dict):
518
+ # handle the case where we are returning attention weights
519
+ x = x["attn_output"]
520
+ return self.dropout1(x)
521
+
522
+ # feed forward block
523
+ def _ff_block(self, x: Tensor) -> Tensor:
524
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
525
+ return self.dropout2(x)
526
+
527
+
528
+ class RelativeTransformerEncoder(nn.Module):
529
+ def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True):
530
+ super(RelativeTransformerEncoder, self).__init__()
531
+ # using get_clones means all layers have the same initialization
532
+ # this is also a problem in PyTorch's TransformerEncoder implementation, which this is based on
533
+ # todo: PyTorch is changing its transformer API... check up on and see if there is a better way
534
+ self.layers = _get_clones(encoder_layer, num_layers)
535
+ self.num_layers = num_layers
536
+ self.norm = norm
537
+
538
+ # important because get_clones means all layers have same initialization
539
+ # should recursively reset parameters for all submodules
540
+ if reset_params:
541
+ self.apply(models.reset_parameters_helper)
542
+
543
+ def forward(self, src: Tensor, pdb_fn=None) -> Tensor:
544
+ output = src
545
+
546
+ for mod in self.layers:
547
+ output = mod(output, pdb_fn=pdb_fn)
548
+
549
+ if self.norm is not None:
550
+ output = self.norm(output)
551
+
552
+ return output
553
+
554
+
555
+ def _get_clones(module, num_clones):
556
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clones)])
557
+
558
+
559
+ def _inv_dict(d):
560
+ """ helper function for contact map-based position embeddings """
561
+ inv = dict()
562
+ for k, v in d.items():
563
+ # collect dict keys into lists based on value
564
+ inv.setdefault(v, list()).append(k)
565
+ for k, v in inv.items():
566
+ # put in sorted order
567
+ inv[k] = sorted(v)
568
+ return inv
569
+
570
+
571
+ def _combine_d(d, threshold, combined_key):
572
+ """ helper function for contact map-based position embeddings
573
+ d is a dictionary with ints as keys and lists as values.
574
+ for all keys >= threshold, this function combines the values of those keys into a single list """
575
+ out_d = {}
576
+ for k, v in d.items():
577
+ if k < threshold:
578
+ out_d[k] = v
579
+ elif k >= threshold:
580
+ if combined_key not in out_d:
581
+ out_d[combined_key] = v
582
+ else:
583
+ out_d[combined_key] += v
584
+ if combined_key in out_d:
585
+ out_d[combined_key] = sorted(out_d[combined_key])
586
+ return out_d
metl/structure.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import isfile
3
+ from enum import Enum, auto
4
+
5
+ import numpy as np
6
+ from scipy.spatial.distance import cdist
7
+ import networkx as nx
8
+ from biopandas.pdb import PandasPdb
9
+
10
+
11
+ class GraphType(Enum):
12
+ LINEAR = auto()
13
+ COMPLETE = auto()
14
+ DISCONNECTED = auto()
15
+ DIST_THRESH = auto()
16
+ DIST_THRESH_SHUFFLED = auto()
17
+
18
+
19
+ def save_graph(g, fn):
20
+ """ Saves graph to file """
21
+ nx.write_gexf(g, fn)
22
+
23
+
24
+ def load_graph(fn):
25
+ """ Loads graph from file """
26
+ g = nx.read_gexf(fn, node_type=int)
27
+ return g
28
+
29
+
30
+ def shuffle_nodes(g, seed=7):
31
+ """ Shuffles the nodes of the given graph and returns a copy of the shuffled graph """
32
+ # get the list of nodes in this graph
33
+ nodes = g.nodes()
34
+
35
+ # create a permuted list of nodes
36
+ np.random.seed(seed)
37
+ nodes_shuffled = np.random.permutation(nodes)
38
+
39
+ # create a dictionary mapping from old node label to new node label
40
+ mapping = {n: ns for n, ns in zip(nodes, nodes_shuffled)}
41
+
42
+ g_shuffled = nx.relabel_nodes(g, mapping, copy=True)
43
+
44
+ return g_shuffled
45
+
46
+
47
+ def linear_graph(num_residues):
48
+ """ Creates a linear graph where each node is connected to its sequence neighbor in order """
49
+ g = nx.Graph()
50
+ g.add_nodes_from(np.arange(0, num_residues))
51
+ for i in range(num_residues-1):
52
+ g.add_edge(i, i+1)
53
+ return g
54
+
55
+
56
+ def complete_graph(num_residues):
57
+ """ Creates a graph where each node is connected to all other nodes"""
58
+ g = nx.complete_graph(num_residues)
59
+ return g
60
+
61
+
62
+ def disconnected_graph(num_residues):
63
+ g = nx.Graph()
64
+ g.add_nodes_from(np.arange(0, num_residues))
65
+ return g
66
+
67
+
68
+ def dist_thresh_graph(dist_mtx, threshold):
69
+ """ Creates undirected graph based on a distance threshold """
70
+ g = nx.Graph()
71
+ g.add_nodes_from(np.arange(0, dist_mtx.shape[0]))
72
+
73
+ # loop through each residue
74
+ for rn1 in range(len(dist_mtx)):
75
+ # find all residues that are within threshold distance of current
76
+ rns_within_threshold = np.where(dist_mtx[rn1] < threshold)[0]
77
+
78
+ # add edges from current residue to those that are within threshold
79
+ for rn2 in rns_within_threshold:
80
+ # don't add self edges
81
+ if rn1 != rn2:
82
+ g.add_edge(rn1, rn2)
83
+ return g
84
+
85
+
86
+ def ordered_adjacency_matrix(g):
87
+ """ returns the adjacency matrix ordered by node label in increasing order as a numpy array """
88
+ node_order = sorted(g.nodes())
89
+ adj_mtx = nx.to_numpy_matrix(g, nodelist=node_order)
90
+ return np.asarray(adj_mtx).astype(np.float32)
91
+
92
+
93
+ def cbeta_distance_matrix(pdb_fn, start=0, end=None):
94
+ # note that start and end are not going by residue number
95
+ # they are going by whatever the listing in the pdb file is
96
+
97
+ # read the pdb file into a biopandas object
98
+ ppdb = PandasPdb().read_pdb(pdb_fn)
99
+
100
+ # group by residue number
101
+ # important to specify sort=True so that group keys (residue number) are in order
102
+ # the reason is we loop through group keys below, and assume that residues are in order
103
+ # the pandas function has sort=True by default, but we specify it anyway because it is important
104
+ grouped = ppdb.df["ATOM"].groupby("residue_number", sort=True)
105
+
106
+ # a list of coords for the cbeta or calpha of each residue
107
+ coords = []
108
+
109
+ # loop through each residue and find the coordinates of cbeta
110
+ for i, (residue_number, values) in enumerate(grouped):
111
+
112
+ # skip residues not in the range
113
+ end_index = (len(grouped) if end is None else end)
114
+ if i not in range(start, end_index):
115
+ continue
116
+
117
+ residue_group = grouped.get_group(residue_number)
118
+
119
+ atom_names = residue_group["atom_name"]
120
+ if "CB" in atom_names.values:
121
+ # print("Using CB...")
122
+ atom_name = "CB"
123
+ elif "CA" in atom_names.values:
124
+ # print("Using CA...")
125
+ atom_name = "CA"
126
+ else:
127
+ raise ValueError("Couldn't find CB or CA for residue {}".format(residue_number))
128
+
129
+ # get the coordinates of cbeta (or calpha)
130
+ coords.append(
131
+ residue_group[residue_group["atom_name"] == atom_name][["x_coord", "y_coord", "z_coord"]].values[0])
132
+
133
+ # stack the coords into a numpy array where each row has the x,y,z coords for a different residue
134
+ coords = np.stack(coords)
135
+
136
+ # compute pairwise euclidean distance between all cbetas
137
+ dist_mtx = cdist(coords, coords, metric="euclidean")
138
+
139
+ return dist_mtx
140
+
141
+
142
+ def get_neighbors(g, nodes):
143
+ """ returns a list (set) of neighbors of all given nodes """
144
+ neighbors = set()
145
+ for n in nodes:
146
+ neighbors.update(g.neighbors(n))
147
+ return sorted(list(neighbors))
148
+
149
+
150
+ def gen_graph(graph_type, res_dist_mtx, dist_thresh=7, shuffle_seed=7, graph_save_dir=None, save=False):
151
+ """ generate the specified structure graph using the specified residue distance matrix """
152
+ if graph_type is GraphType.LINEAR:
153
+ g = linear_graph(len(res_dist_mtx))
154
+ save_fn = None if not save else os.path.join(graph_save_dir, "linear.graph")
155
+
156
+ elif graph_type is GraphType.COMPLETE:
157
+ g = complete_graph(len(res_dist_mtx))
158
+ save_fn = None if not save else os.path.join(graph_save_dir, "complete.graph")
159
+
160
+ elif graph_type is GraphType.DISCONNECTED:
161
+ g = disconnected_graph(len(res_dist_mtx))
162
+ save_fn = None if not save else os.path.join(graph_save_dir, "disconnected.graph")
163
+
164
+ elif graph_type is GraphType.DIST_THRESH:
165
+ g = dist_thresh_graph(res_dist_mtx, dist_thresh)
166
+ save_fn = None if not save else os.path.join(graph_save_dir, "dist_thresh_{}.graph".format(dist_thresh))
167
+
168
+ elif graph_type is GraphType.DIST_THRESH_SHUFFLED:
169
+ g = dist_thresh_graph(res_dist_mtx, dist_thresh)
170
+ g = shuffle_nodes(g, seed=shuffle_seed)
171
+ save_fn = None if not save else \
172
+ os.path.join(graph_save_dir, "dist_thresh_{}_shuffled_r{}.graph".format(dist_thresh, shuffle_seed))
173
+
174
+ else:
175
+ raise ValueError("Graph type {} is not implemented".format(graph_type))
176
+
177
+ if save:
178
+ if isfile(save_fn):
179
+ print("err: graph already exists: {}. to overwrite, delete the existing file first".format(save_fn))
180
+ else:
181
+ os.makedirs(graph_save_dir, exist_ok=True)
182
+ save_graph(g, save_fn)
183
+
184
+ return g