Upload 6 files
Browse files- metl/__init__.py +2 -0
- metl/encode.py +58 -0
- metl/main.py +139 -0
- metl/models.py +1064 -0
- metl/relative_attention.py +586 -0
- metl/structure.py +184 -0
metl/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .main import *
|
2 |
+
__version__ = "0.1"
|
metl/encode.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Encodes data in different formats """
|
2 |
+
from enum import Enum, auto
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class Encoding(Enum):
|
8 |
+
INT_SEQS = auto()
|
9 |
+
ONE_HOT = auto()
|
10 |
+
|
11 |
+
|
12 |
+
class DataEncoder:
|
13 |
+
chars = ["*", "A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
|
14 |
+
num_chars = len(chars)
|
15 |
+
mapping = {c: i for i, c in enumerate(chars)}
|
16 |
+
|
17 |
+
def __init__(self, encoding: Encoding = Encoding.INT_SEQS):
|
18 |
+
self.encoding = encoding
|
19 |
+
|
20 |
+
def _encode_from_int_seqs(self, seq_ints):
|
21 |
+
if self.encoding == Encoding.INT_SEQS:
|
22 |
+
return seq_ints
|
23 |
+
elif self.encoding == Encoding.ONE_HOT:
|
24 |
+
one_hot = np.eye(self.num_chars)[seq_ints]
|
25 |
+
return one_hot.astype(np.float32)
|
26 |
+
|
27 |
+
def encode_sequences(self, char_seqs):
|
28 |
+
seq_ints = []
|
29 |
+
for char_seq in char_seqs:
|
30 |
+
int_seq = [self.mapping[c] for c in char_seq]
|
31 |
+
seq_ints.append(int_seq)
|
32 |
+
seq_ints = np.array(seq_ints).astype(int)
|
33 |
+
return self._encode_from_int_seqs(seq_ints)
|
34 |
+
|
35 |
+
def encode_variants(self, wt, variants):
|
36 |
+
# convert wild type seq to integer encoding
|
37 |
+
wt_int = np.zeros(len(wt), dtype=np.uint8)
|
38 |
+
for i, c in enumerate(wt):
|
39 |
+
wt_int[i] = self.mapping[c]
|
40 |
+
|
41 |
+
# tile the wild-type seq
|
42 |
+
seq_ints = np.tile(wt_int, (len(variants), 1))
|
43 |
+
|
44 |
+
for i, variant in enumerate(variants):
|
45 |
+
# special handling if we want to encode the wild-type seq (it's already correct!)
|
46 |
+
if variant == "_wt":
|
47 |
+
continue
|
48 |
+
|
49 |
+
# variants are a list of mutations [mutation1, mutation2, ....]
|
50 |
+
variant = variant.split(",")
|
51 |
+
for mutation in variant:
|
52 |
+
# mutations are in the form <original char><position><replacement char>
|
53 |
+
position = int(mutation[1:-1])
|
54 |
+
replacement = self.mapping[mutation[-1]]
|
55 |
+
seq_ints[i, position] = replacement
|
56 |
+
|
57 |
+
seq_ints = seq_ints.astype(int)
|
58 |
+
return self._encode_from_int_seqs(seq_ints)
|
metl/main.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.hub
|
3 |
+
|
4 |
+
import metl.models as models
|
5 |
+
from metl.encode import DataEncoder, Encoding
|
6 |
+
|
7 |
+
UUID_URL_MAP = {
|
8 |
+
# global source models
|
9 |
+
"D72M9aEp": "https://zenodo.org/records/11051645/files/METL-G-20M-1D-D72M9aEp.pt?download=1",
|
10 |
+
"Nr9zCKpR": "https://zenodo.org/records/11051645/files/METL-G-20M-3D-Nr9zCKpR.pt?download=1",
|
11 |
+
"auKdzzwX": "https://zenodo.org/records/11051645/files/METL-G-50M-1D-auKdzzwX.pt?download=1",
|
12 |
+
"6PSAzdfv": "https://zenodo.org/records/11051645/files/METL-G-50M-3D-6PSAzdfv.pt?download=1",
|
13 |
+
|
14 |
+
# local source models
|
15 |
+
"8gMPQJy4": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GFP-8gMPQJy4.pt?download=1",
|
16 |
+
"Hr4GNHws": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GFP-Hr4GNHws.pt?download=1",
|
17 |
+
"8iFoiYw2": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-DLG4_2022-8iFoiYw2.pt?download=1",
|
18 |
+
"kt5DdWTa": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-DLG4_2022-kt5DdWTa.pt?download=1",
|
19 |
+
"DMfkjVzT": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GB1-DMfkjVzT.pt?download=1",
|
20 |
+
"epegcFiH": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GB1-epegcFiH.pt?download=1",
|
21 |
+
"kS3rUS7h": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-GRB2-kS3rUS7h.pt?download=1",
|
22 |
+
"X7w83g6S": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-GRB2-X7w83g6S.pt?download=1",
|
23 |
+
"UKebCQGz": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Pab1-UKebCQGz.pt?download=1",
|
24 |
+
"2rr8V4th": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Pab1-2rr8V4th.pt?download=1",
|
25 |
+
"PREhfC22": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-TEM-1-PREhfC22.pt?download=1",
|
26 |
+
"9ASvszux": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-TEM-1-9ASvszux.pt?download=1",
|
27 |
+
"HscFFkAb": "https://zenodo.org/records/11051645/files/METL-L-2M-1D-Ube4b-HscFFkAb.pt?download=1",
|
28 |
+
"H48oiNZN": "https://zenodo.org/records/11051645/files/METL-L-2M-3D-Ube4b-H48oiNZN.pt?download=1",
|
29 |
+
|
30 |
+
# metl bind source models
|
31 |
+
"K6mw24Rg": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-STANDARD-K6mw24Rg.pt?download=1",
|
32 |
+
"Bo5wn2SG": "https://zenodo.org/records/11051645/files/METL-BIND-2M-3D-GB1-BINDING-Bo5wn2SG.pt?download=1",
|
33 |
+
|
34 |
+
# finetuned models from GFP design experiment
|
35 |
+
"YoQkzoLD": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-1D-GFP-YoQkzoLD.pt?download=1",
|
36 |
+
"PEkeRuxb": "https://zenodo.org/records/11051645/files/FT-METL-L-2M-3D-GFP-PEkeRuxb.pt?download=1",
|
37 |
+
|
38 |
+
}
|
39 |
+
|
40 |
+
IDENT_UUID_MAP = {
|
41 |
+
# the keys should be all lowercase
|
42 |
+
"metl-g-20m-1d": "D72M9aEp",
|
43 |
+
"metl-g-20m-3d": "Nr9zCKpR",
|
44 |
+
"metl-g-50m-1d": "auKdzzwX",
|
45 |
+
"metl-g-50m-3d": "6PSAzdfv",
|
46 |
+
|
47 |
+
# GFP local source models
|
48 |
+
"metl-l-2m-1d-gfp": "8gMPQJy4",
|
49 |
+
"metl-l-2m-3d-gfp": "Hr4GNHws",
|
50 |
+
|
51 |
+
# DLG4 local source models
|
52 |
+
"metl-l-2m-1d-dlg4": "8iFoiYw2",
|
53 |
+
"metl-l-2m-3d-dlg4": "kt5DdWTa",
|
54 |
+
|
55 |
+
# GB1 local source models
|
56 |
+
"metl-l-2m-1d-gb1": "DMfkjVzT",
|
57 |
+
"metl-l-2m-3d-gb1": "epegcFiH",
|
58 |
+
|
59 |
+
# GRB2 local source models
|
60 |
+
"metl-l-2m-1d-grb2": "kS3rUS7h",
|
61 |
+
"metl-l-2m-3d-grb2": "X7w83g6S",
|
62 |
+
|
63 |
+
# Pab1 local source models
|
64 |
+
"metl-l-2m-1d-pab1": "UKebCQGz",
|
65 |
+
"metl-l-2m-3d-pab1": "2rr8V4th",
|
66 |
+
|
67 |
+
# TEM-1 local source models
|
68 |
+
"metl-l-2m-1d-tem-1": "PREhfC22",
|
69 |
+
"metl-l-2m-3d-tem-1": "9ASvszux",
|
70 |
+
|
71 |
+
# Ube4b local source models
|
72 |
+
"metl-l-2m-1d-ube4b": "HscFFkAb",
|
73 |
+
"metl-l-2m-3d-ube4b": "H48oiNZN",
|
74 |
+
|
75 |
+
# METL-Bind for GB1
|
76 |
+
"metl-bind-2m-3d-gb1-standard": "K6mw24Rg",
|
77 |
+
"metl-bind-2m-3d-gb1-binding": "Bo5wn2SG",
|
78 |
+
|
79 |
+
# GFP design models, giving them an ident
|
80 |
+
"metl-l-2m-1d-gfp-ft-design": "YoQkzoLD",
|
81 |
+
"metl-l-2m-3d-gfp-ft-design": "PEkeRuxb",
|
82 |
+
|
83 |
+
}
|
84 |
+
|
85 |
+
|
86 |
+
def download_checkpoint(uuid):
|
87 |
+
ckpt = torch.hub.load_state_dict_from_url(UUID_URL_MAP[uuid],
|
88 |
+
map_location="cpu", file_name=f"{uuid}.pt")
|
89 |
+
state_dict = ckpt["state_dict"]
|
90 |
+
hyper_parameters = ckpt["hyper_parameters"]
|
91 |
+
|
92 |
+
return state_dict, hyper_parameters
|
93 |
+
|
94 |
+
|
95 |
+
def _get_data_encoding(hparams):
|
96 |
+
if "encoding" in hparams and hparams["encoding"] == "int_seqs":
|
97 |
+
encoding = Encoding.INT_SEQS
|
98 |
+
elif "encoding" in hparams and hparams["encoding"] == "one_hot":
|
99 |
+
encoding = Encoding.ONE_HOT
|
100 |
+
elif (("encoding" in hparams and hparams["encoding"] == "auto") or "encoding" not in hparams) and \
|
101 |
+
hparams["model_name"] in ["transformer_encoder"]:
|
102 |
+
encoding = Encoding.INT_SEQS
|
103 |
+
else:
|
104 |
+
raise ValueError("Detected unsupported encoding in hyperparameters")
|
105 |
+
|
106 |
+
return encoding
|
107 |
+
|
108 |
+
|
109 |
+
def load_model_and_data_encoder(state_dict, hparams):
|
110 |
+
model = models.Model[hparams["model_name"]].cls(**hparams)
|
111 |
+
model.load_state_dict(state_dict)
|
112 |
+
|
113 |
+
data_encoder = DataEncoder(_get_data_encoding(hparams))
|
114 |
+
|
115 |
+
return model, data_encoder
|
116 |
+
|
117 |
+
|
118 |
+
def get_from_uuid(uuid):
|
119 |
+
if uuid in UUID_URL_MAP:
|
120 |
+
state_dict, hparams = download_checkpoint(uuid)
|
121 |
+
return load_model_and_data_encoder(state_dict, hparams)
|
122 |
+
else:
|
123 |
+
raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP")
|
124 |
+
|
125 |
+
|
126 |
+
def get_from_ident(ident):
|
127 |
+
ident = ident.lower()
|
128 |
+
if ident in IDENT_UUID_MAP:
|
129 |
+
state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident])
|
130 |
+
return load_model_and_data_encoder(state_dict, hparams)
|
131 |
+
else:
|
132 |
+
raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP")
|
133 |
+
|
134 |
+
|
135 |
+
def get_from_checkpoint(ckpt_fn):
|
136 |
+
ckpt = torch.load(ckpt_fn, map_location="cpu")
|
137 |
+
state_dict = ckpt["state_dict"]
|
138 |
+
hyper_parameters = ckpt["hyper_parameters"]
|
139 |
+
return load_model_and_data_encoder(state_dict, hyper_parameters)
|
metl/models.py
ADDED
@@ -0,0 +1,1064 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import math
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
import enum
|
5 |
+
from os.path import isfile
|
6 |
+
from typing import List, Tuple, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import Tensor
|
12 |
+
|
13 |
+
import metl.relative_attention as ra
|
14 |
+
|
15 |
+
|
16 |
+
def reset_parameters_helper(m: nn.Module):
|
17 |
+
""" helper function for resetting model parameters, meant to be used with model.apply() """
|
18 |
+
|
19 |
+
# the PyTorch MultiHeadAttention has a private function _reset_parameters()
|
20 |
+
# other layers have a public reset_parameters()... go figure
|
21 |
+
reset_parameters = getattr(m, "reset_parameters", None)
|
22 |
+
reset_parameters_private = getattr(m, "_reset_parameters", None)
|
23 |
+
|
24 |
+
if callable(reset_parameters) and callable(reset_parameters_private):
|
25 |
+
raise RuntimeError("Module has both public and private methods for resetting parameters. "
|
26 |
+
"This is unexpected... probably should just call the public one.")
|
27 |
+
|
28 |
+
if callable(reset_parameters):
|
29 |
+
m.reset_parameters()
|
30 |
+
|
31 |
+
if callable(reset_parameters_private):
|
32 |
+
m._reset_parameters()
|
33 |
+
|
34 |
+
|
35 |
+
class SequentialWithArgs(nn.Sequential):
|
36 |
+
def forward(self, x, **kwargs):
|
37 |
+
for module in self:
|
38 |
+
if isinstance(module, ra.RelativeTransformerEncoder) or isinstance(module, SequentialWithArgs):
|
39 |
+
# for relative transformer encoders, pass in kwargs (pdb_fn)
|
40 |
+
x = module(x, **kwargs)
|
41 |
+
else:
|
42 |
+
# for all modules, don't pass in kwargs
|
43 |
+
x = module(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class PositionalEncoding(nn.Module):
|
48 |
+
# originally from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
49 |
+
# they have since updated their implementation, but it is functionally equivalent
|
50 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
51 |
+
super(PositionalEncoding, self).__init__()
|
52 |
+
self.dropout = nn.Dropout(p=dropout)
|
53 |
+
|
54 |
+
pe = torch.zeros(max_len, d_model)
|
55 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
56 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
57 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
58 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
59 |
+
# note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim]
|
60 |
+
# however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first)
|
61 |
+
# fixed by changing pe = pe.unsqueeze(0).transpose(0, 1) to pe = pe.unsqueeze(0)
|
62 |
+
# also down below, changing our indexing into the position encoding to reflect new dimensions
|
63 |
+
# pe = pe.unsqueeze(0).transpose(0, 1)
|
64 |
+
pe = pe.unsqueeze(0)
|
65 |
+
self.register_buffer('pe', pe)
|
66 |
+
|
67 |
+
def forward(self, x, **kwargs):
|
68 |
+
# note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim]
|
69 |
+
# however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first)
|
70 |
+
# fixed by changing x = x + self.pe[:x.size(0)] to x = x + self.pe[:, :x.size(1), :]
|
71 |
+
# x = x + self.pe[:x.size(0), :]
|
72 |
+
x = x + self.pe[:, :x.size(1), :]
|
73 |
+
return self.dropout(x)
|
74 |
+
|
75 |
+
|
76 |
+
class ScaledEmbedding(nn.Module):
|
77 |
+
# https://pytorch.org/tutorials/beginner/translation_transformer.html
|
78 |
+
# a helper function for embedding that scales by sqrt(d_model) in the forward()
|
79 |
+
# makes it, so we don't have to do the scaling in the main AttnModel forward()
|
80 |
+
|
81 |
+
# todo: be aware of embedding scaling factor
|
82 |
+
# regarding the scaling factor, it's unclear exactly what the purpose is and whether it is needed
|
83 |
+
# there are several theories on why it is used, and it shows up in all the transformer reference implementations
|
84 |
+
# https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod
|
85 |
+
# 1. Has something to do with weight sharing between the embedding and the decoder output
|
86 |
+
# 2. Scales up the embeddings so the signal doesn't get overwhelmed when adding the absolute positional encoding
|
87 |
+
# 3. It cancels out with the scaling factor in scaled dot product attention, and helps make the model robust
|
88 |
+
# to the choice of embedding_len
|
89 |
+
# 4. It's not actually needed
|
90 |
+
|
91 |
+
# Regarding #1, not really sure about this. In section 3.4 of attention is all you need,
|
92 |
+
# that's where they state they multiply the embedding weights by sqrt(d_model), and the context is that they
|
93 |
+
# are sharing the same weight matrix between the two embedding layers and the pre-softmax linear transformation.
|
94 |
+
# there may be a reason that we want those weights scaled differently for the embedding layers vs. the linear
|
95 |
+
# transformation. It might have something to do with the scale at which embedding weights are initialized
|
96 |
+
# is more appropriate for the decoder linear transform vs how they are used in the attention function. Might have
|
97 |
+
# something to do with computing the correct next-token probabilities. Overall, I'm really not sure about this,
|
98 |
+
# but we aren't using a decoder anyway. So if this is the reason, then we don't need to perform the multiply.
|
99 |
+
|
100 |
+
# Regarding #2, it seems like in one implementation of transformers (fairseq), the sinusoidal positional encoding
|
101 |
+
# has a range of (-1.0, 1.0), but the word embedding are initialized with mean 0 and s.d embedding_dim ** -0.5,
|
102 |
+
# which for embedding_dim=512, is a range closer to (-0.10, 0.10). Thus, the positional embedding would overwhelm
|
103 |
+
# the word embeddings when they are added together. The scaling factor increases the signal of the word embeddings.
|
104 |
+
# for embedding_dim=512, it scales word embeddings by 22, increasing range of the word embeddings to (-2.2, 2.2).
|
105 |
+
# link to fairseq implementation, search for nn.init to see them do the initialization
|
106 |
+
# https://fairseq.readthedocs.io/en/v0.7.1/_modules/fairseq/models/transformer.html
|
107 |
+
#
|
108 |
+
# For PyTorch, PyTorch initializes nn.Embedding with a standard normal distribution mean 0, variance 1: N(0,1).
|
109 |
+
# this puts the range for the word embeddings around (-3, 3). the pytorch implementation for positional encoding
|
110 |
+
# also has a range of (-1.0, 1.0). So already, these are much closer in scale, and it doesn't seem like we need
|
111 |
+
# to increase the scale of the word embeddings. However, PyTorch example still multiply by the scaling factor
|
112 |
+
# unclear whether this is just a carryover that is not actually needed, or if there is a different reason
|
113 |
+
#
|
114 |
+
# EDIT! I just realized that even though nn.Embedding defaults to a range of around (-3, 3), the PyTorch
|
115 |
+
# transformer example actually re-initializes them using a uniform distribution in the range of (-0.1, 0.1)
|
116 |
+
# that makes it very similar to the fairseq implementation, so the scaling factor that PyTorch uses actually would
|
117 |
+
# bring the word embedding and positional encodings much closer in scale. So this could be the reason why pytorch
|
118 |
+
# does it
|
119 |
+
|
120 |
+
# Regarding #3, I don't think so. Firstly, does it actually cancel there? Secondly, the purpose of the scaling
|
121 |
+
# factor in scaled dot product attention, according to attention is all you need, is to counteract dot products
|
122 |
+
# that are very high in magnitude due to choice of large mbedding length (aka d_k). The problem with high magnitude
|
123 |
+
# dot products is that potentially, the softmax is pushed into regions where it has extremely small gradients,
|
124 |
+
# making learning difficult. If the scaling factor in the embedding was meant to counteract the scaling factor in
|
125 |
+
# scaled dot product attention, then what would be the point of doing all that?
|
126 |
+
|
127 |
+
# Regarding #4, I don't think the scaling will have any effects in practice, it's probably not needed
|
128 |
+
|
129 |
+
# Overall, I think #2 is the most likely reason why this scaling is performed. In theory, I think
|
130 |
+
# even if the scaling wasn't performed, the network might learn to up-scale the word embedding weights to increase
|
131 |
+
# word embedding signal vs. the position signal on its own. Another question I have is why not just initialize
|
132 |
+
# the embedding weights to have higher initial values? Why put it in the range (-0.1, 0.1)?
|
133 |
+
#
|
134 |
+
# The fact that most implementations have this scaling concerns me, makes me think I might be missing something.
|
135 |
+
# For our purposes, we can train a couple models to see if scaling has any positive or negative effect.
|
136 |
+
# Still need to think about potential effects of this scaling on relative position embeddings.
|
137 |
+
|
138 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, scale: bool):
|
139 |
+
super(ScaledEmbedding, self).__init__()
|
140 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
141 |
+
self.emb_size = embedding_dim
|
142 |
+
self.embed_scale = math.sqrt(self.emb_size)
|
143 |
+
|
144 |
+
self.scale = scale
|
145 |
+
|
146 |
+
self.init_weights()
|
147 |
+
|
148 |
+
def init_weights(self):
|
149 |
+
# todo: not sure why PyTorch example initializes weights like this
|
150 |
+
# might have something to do with word embedding scaling factor (see above)
|
151 |
+
# could also just try the default weight initialization for nn.Embedding()
|
152 |
+
init_range = 0.1
|
153 |
+
self.embedding.weight.data.uniform_(-init_range, init_range)
|
154 |
+
|
155 |
+
def forward(self, tokens: Tensor, **kwargs):
|
156 |
+
if self.scale:
|
157 |
+
return self.embedding(tokens.long()) * self.embed_scale
|
158 |
+
else:
|
159 |
+
return self.embedding(tokens.long())
|
160 |
+
|
161 |
+
|
162 |
+
class FCBlock(nn.Module):
|
163 |
+
""" a fully connected block with options for batchnorm and dropout
|
164 |
+
can extend in the future with option for different activation, etc """
|
165 |
+
|
166 |
+
def __init__(self,
|
167 |
+
in_features: int,
|
168 |
+
num_hidden_nodes: int = 64,
|
169 |
+
use_batchnorm: bool = False,
|
170 |
+
use_layernorm: bool = False,
|
171 |
+
norm_before_activation: bool = False,
|
172 |
+
use_dropout: bool = False,
|
173 |
+
dropout_rate: float = 0.2,
|
174 |
+
activation: str = "relu"):
|
175 |
+
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
if use_batchnorm and use_layernorm:
|
179 |
+
raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
|
180 |
+
|
181 |
+
self.use_batchnorm = use_batchnorm
|
182 |
+
self.use_dropout = use_dropout
|
183 |
+
self.use_layernorm = use_layernorm
|
184 |
+
self.norm_before_activation = norm_before_activation
|
185 |
+
|
186 |
+
self.fc = nn.Linear(in_features=in_features, out_features=num_hidden_nodes)
|
187 |
+
|
188 |
+
self.activation = get_activation_fn(activation, functional=False)
|
189 |
+
|
190 |
+
if use_batchnorm:
|
191 |
+
self.norm = nn.BatchNorm1d(num_hidden_nodes)
|
192 |
+
|
193 |
+
if use_layernorm:
|
194 |
+
self.norm = nn.LayerNorm(num_hidden_nodes)
|
195 |
+
|
196 |
+
if use_dropout:
|
197 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
198 |
+
|
199 |
+
def forward(self, x, **kwargs):
|
200 |
+
x = self.fc(x)
|
201 |
+
|
202 |
+
# norm can be before or after activation, using flag
|
203 |
+
if (self.use_batchnorm or self.use_layernorm) and self.norm_before_activation:
|
204 |
+
x = self.norm(x)
|
205 |
+
|
206 |
+
x = self.activation(x)
|
207 |
+
|
208 |
+
# batchnorm being applied after activation, there is some discussion on this online
|
209 |
+
if (self.use_batchnorm or self.use_layernorm) and not self.norm_before_activation:
|
210 |
+
x = self.norm(x)
|
211 |
+
|
212 |
+
# dropout being applied last
|
213 |
+
if self.use_dropout:
|
214 |
+
x = self.dropout(x)
|
215 |
+
|
216 |
+
return x
|
217 |
+
|
218 |
+
|
219 |
+
class TaskSpecificPredictionLayers(nn.Module):
|
220 |
+
""" Constructs num_tasks [dense(num_hidden_nodes)+relu+dense(1)] layers, each independently transforming input
|
221 |
+
into a single output node. All num_tasks outputs are then concatenated into a single tensor. """
|
222 |
+
|
223 |
+
# todo: the independent layers are run in sequence rather than in parallel, causing a slowdown that
|
224 |
+
# scales with the number of tasks. might be able to run in parallel by hacking convolution operation
|
225 |
+
# https://stackoverflow.com/questions/58374980/run-multiple-models-of-an-ensemble-in-parallel-with-pytorch
|
226 |
+
# https://github.com/pytorch/pytorch/issues/54147
|
227 |
+
# https://github.com/pytorch/pytorch/issues/36459
|
228 |
+
|
229 |
+
def __init__(self,
|
230 |
+
num_tasks: int,
|
231 |
+
in_features: int,
|
232 |
+
num_hidden_nodes: int = 64,
|
233 |
+
use_batchnorm: bool = False,
|
234 |
+
use_dropout: bool = False,
|
235 |
+
dropout_rate: float = 0.2,
|
236 |
+
activation: str = "relu"):
|
237 |
+
|
238 |
+
super().__init__()
|
239 |
+
|
240 |
+
# each task-specific layer outputs a single node,
|
241 |
+
# which can be combined with torch.cat into prediction vector
|
242 |
+
self.task_specific_pred_layers = nn.ModuleList()
|
243 |
+
for i in range(num_tasks):
|
244 |
+
layers = [FCBlock(in_features=in_features,
|
245 |
+
num_hidden_nodes=num_hidden_nodes,
|
246 |
+
use_batchnorm=use_batchnorm,
|
247 |
+
use_dropout=use_dropout,
|
248 |
+
dropout_rate=dropout_rate,
|
249 |
+
activation=activation),
|
250 |
+
nn.Linear(in_features=num_hidden_nodes, out_features=1)]
|
251 |
+
self.task_specific_pred_layers.append(nn.Sequential(*layers))
|
252 |
+
|
253 |
+
def forward(self, x, **kwargs):
|
254 |
+
# run each task-specific layer and concatenate outputs into a single output vector
|
255 |
+
task_specific_outputs = []
|
256 |
+
for layer in self.task_specific_pred_layers:
|
257 |
+
task_specific_outputs.append(layer(x))
|
258 |
+
|
259 |
+
output = torch.cat(task_specific_outputs, dim=1)
|
260 |
+
return output
|
261 |
+
|
262 |
+
|
263 |
+
class GlobalAveragePooling(nn.Module):
|
264 |
+
""" helper class for global average pooling """
|
265 |
+
|
266 |
+
def __init__(self, dim=1):
|
267 |
+
super().__init__()
|
268 |
+
# our data is in [batch_size, sequence_length, embedding_length]
|
269 |
+
# with global pooling, we want to pool over the sequence dimension (dim=1)
|
270 |
+
self.dim = dim
|
271 |
+
|
272 |
+
def forward(self, x, **kwargs):
|
273 |
+
return torch.mean(x, dim=self.dim)
|
274 |
+
|
275 |
+
|
276 |
+
class CLSPooling(nn.Module):
|
277 |
+
""" helper class for CLS token extraction """
|
278 |
+
|
279 |
+
def __init__(self, cls_position=0):
|
280 |
+
super().__init__()
|
281 |
+
|
282 |
+
# the position of the CLS token in the sequence dimension
|
283 |
+
# currently, the CLS token is in the first position, but may move it to the last position
|
284 |
+
self.cls_position = cls_position
|
285 |
+
|
286 |
+
def forward(self, x, **kwargs):
|
287 |
+
# assumes input is in [batch_size, sequence_len, embedding_len]
|
288 |
+
# thus sequence dimension is dimension 1
|
289 |
+
return x[:, self.cls_position, :]
|
290 |
+
|
291 |
+
|
292 |
+
class TransformerEncoderWrapper(nn.TransformerEncoder):
|
293 |
+
""" wrapper around PyTorch's TransformerEncoder that re-initializes layer parameters,
|
294 |
+
so each transformer encoder layer has a different initialization """
|
295 |
+
|
296 |
+
# todo: PyTorch is changing its transformer API... check up on and see if there is a better way
|
297 |
+
def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True):
|
298 |
+
super().__init__(encoder_layer, num_layers, norm)
|
299 |
+
if reset_params:
|
300 |
+
self.apply(reset_parameters_helper)
|
301 |
+
|
302 |
+
|
303 |
+
class AttnModel(nn.Module):
|
304 |
+
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
305 |
+
|
306 |
+
@staticmethod
|
307 |
+
def add_model_specific_args(parent_parser):
|
308 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
309 |
+
|
310 |
+
parser.add_argument('--pos_encoding', type=str, default="absolute",
|
311 |
+
choices=["none", "absolute", "relative", "relative_3D"],
|
312 |
+
help="what type of positional encoding to use")
|
313 |
+
parser.add_argument('--pos_encoding_dropout', type=float, default=0.1,
|
314 |
+
help="out much dropout to use in positional encoding, for pos_encoding==absolute")
|
315 |
+
parser.add_argument('--clipping_threshold', type=int, default=3,
|
316 |
+
help="clipping threshold for relative position embedding, for relative and relative_3D")
|
317 |
+
parser.add_argument('--contact_threshold', type=int, default=7,
|
318 |
+
help="threshold, in angstroms, for contact map, for relative_3D")
|
319 |
+
parser.add_argument('--embedding_len', type=int, default=128)
|
320 |
+
parser.add_argument('--num_heads', type=int, default=2)
|
321 |
+
parser.add_argument('--num_hidden', type=int, default=64)
|
322 |
+
parser.add_argument('--num_enc_layers', type=int, default=2)
|
323 |
+
parser.add_argument('--enc_layer_dropout', type=float, default=0.1)
|
324 |
+
parser.add_argument('--use_final_encoder_norm', action="store_true", default=False)
|
325 |
+
|
326 |
+
parser.add_argument('--global_average_pooling', action="store_true", default=False)
|
327 |
+
parser.add_argument('--cls_pooling', action="store_true", default=False)
|
328 |
+
|
329 |
+
parser.add_argument('--use_task_specific_layers', action="store_true", default=False,
|
330 |
+
help="exclusive with use_final_hidden_layer; takes priority over use_final_hidden_layer"
|
331 |
+
" if both flags are set")
|
332 |
+
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
333 |
+
parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
|
334 |
+
parser.add_argument('--final_hidden_size', type=int, default=64)
|
335 |
+
parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
|
336 |
+
parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
|
337 |
+
parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
|
338 |
+
parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
|
339 |
+
|
340 |
+
parser.add_argument('--activation', type=str, default="relu",
|
341 |
+
help="activation function used for all activations in the network")
|
342 |
+
return parser
|
343 |
+
|
344 |
+
def __init__(self,
|
345 |
+
# data args
|
346 |
+
num_tasks: int,
|
347 |
+
aa_seq_len: int,
|
348 |
+
num_tokens: int,
|
349 |
+
# transformer encoder model args
|
350 |
+
pos_encoding: str = "absolute",
|
351 |
+
pos_encoding_dropout: float = 0.1,
|
352 |
+
clipping_threshold: int = 3,
|
353 |
+
contact_threshold: int = 7,
|
354 |
+
pdb_fns: List[str] = None,
|
355 |
+
embedding_len: int = 64,
|
356 |
+
num_heads: int = 2,
|
357 |
+
num_hidden: int = 64,
|
358 |
+
num_enc_layers: int = 2,
|
359 |
+
enc_layer_dropout: float = 0.1,
|
360 |
+
use_final_encoder_norm: bool = False,
|
361 |
+
# pooling to fixed-length representation
|
362 |
+
global_average_pooling: bool = True,
|
363 |
+
cls_pooling: bool = False,
|
364 |
+
# prediction layers
|
365 |
+
use_task_specific_layers: bool = False,
|
366 |
+
task_specific_hidden_nodes: int = 64,
|
367 |
+
use_final_hidden_layer: bool = False,
|
368 |
+
final_hidden_size: int = 64,
|
369 |
+
use_final_hidden_layer_norm: bool = False,
|
370 |
+
final_hidden_layer_norm_before_activation: bool = False,
|
371 |
+
use_final_hidden_layer_dropout: bool = False,
|
372 |
+
final_hidden_layer_dropout_rate: float = 0.2,
|
373 |
+
# activation function
|
374 |
+
activation: str = "relu",
|
375 |
+
*args, **kwargs):
|
376 |
+
|
377 |
+
super().__init__()
|
378 |
+
|
379 |
+
# store embedding length for use in the forward function
|
380 |
+
self.embedding_len = embedding_len
|
381 |
+
self.aa_seq_len = aa_seq_len
|
382 |
+
|
383 |
+
# build up layers
|
384 |
+
layers = collections.OrderedDict()
|
385 |
+
|
386 |
+
# amino acid embedding
|
387 |
+
layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=True)
|
388 |
+
|
389 |
+
# absolute positional encoding
|
390 |
+
if pos_encoding == "absolute":
|
391 |
+
layers["pos_encoder"] = PositionalEncoding(embedding_len, dropout=pos_encoding_dropout, max_len=512)
|
392 |
+
|
393 |
+
# transformer encoder layer for none or absolute positional encoding
|
394 |
+
if pos_encoding in ["none", "absolute"]:
|
395 |
+
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=embedding_len,
|
396 |
+
nhead=num_heads,
|
397 |
+
dim_feedforward=num_hidden,
|
398 |
+
dropout=enc_layer_dropout,
|
399 |
+
activation=get_activation_fn(activation),
|
400 |
+
norm_first=True,
|
401 |
+
batch_first=True)
|
402 |
+
|
403 |
+
# layer norm that is used after the transformer encoder layers
|
404 |
+
# if the norm_first is False, this is *redundant* and not needed
|
405 |
+
# but if norm_first is True, this can be used to normalize outputs from
|
406 |
+
# the transformer encoder before inputting to the final fully connected layer
|
407 |
+
encoder_norm = None
|
408 |
+
if use_final_encoder_norm:
|
409 |
+
encoder_norm = nn.LayerNorm(embedding_len)
|
410 |
+
|
411 |
+
layers["tr_encoder"] = TransformerEncoderWrapper(encoder_layer=encoder_layer,
|
412 |
+
num_layers=num_enc_layers,
|
413 |
+
norm=encoder_norm)
|
414 |
+
|
415 |
+
# transformer encoder layer for relative position encoding
|
416 |
+
elif pos_encoding in ["relative", "relative_3D"]:
|
417 |
+
relative_encoder_layer = ra.RelativeTransformerEncoderLayer(d_model=embedding_len,
|
418 |
+
nhead=num_heads,
|
419 |
+
pos_encoding=pos_encoding,
|
420 |
+
clipping_threshold=clipping_threshold,
|
421 |
+
contact_threshold=contact_threshold,
|
422 |
+
pdb_fns=pdb_fns,
|
423 |
+
dim_feedforward=num_hidden,
|
424 |
+
dropout=enc_layer_dropout,
|
425 |
+
activation=get_activation_fn(activation),
|
426 |
+
norm_first=True)
|
427 |
+
|
428 |
+
encoder_norm = None
|
429 |
+
if use_final_encoder_norm:
|
430 |
+
encoder_norm = nn.LayerNorm(embedding_len)
|
431 |
+
|
432 |
+
layers["tr_encoder"] = ra.RelativeTransformerEncoder(encoder_layer=relative_encoder_layer,
|
433 |
+
num_layers=num_enc_layers,
|
434 |
+
norm=encoder_norm)
|
435 |
+
|
436 |
+
# GLOBAL AVERAGE POOLING OR CLS TOKEN
|
437 |
+
# set up the layers and output shapes (i.e. input shapes for the pred layer)
|
438 |
+
if global_average_pooling:
|
439 |
+
# pool over the sequence dimension
|
440 |
+
layers["avg_pooling"] = GlobalAveragePooling(dim=1)
|
441 |
+
pred_layer_input_features = embedding_len
|
442 |
+
elif cls_pooling:
|
443 |
+
layers["cls_pooling"] = CLSPooling(cls_position=0)
|
444 |
+
pred_layer_input_features = embedding_len
|
445 |
+
else:
|
446 |
+
# no global average pooling or CLS token
|
447 |
+
# sequence dimension is still there, just flattened
|
448 |
+
layers["flatten"] = nn.Flatten()
|
449 |
+
pred_layer_input_features = embedding_len * aa_seq_len
|
450 |
+
|
451 |
+
# PREDICTION
|
452 |
+
if use_task_specific_layers:
|
453 |
+
# task specific prediction layers (nonlinear transform for each task)
|
454 |
+
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
455 |
+
in_features=pred_layer_input_features,
|
456 |
+
num_hidden_nodes=task_specific_hidden_nodes,
|
457 |
+
activation=activation)
|
458 |
+
elif use_final_hidden_layer:
|
459 |
+
# combined prediction linear (linear transform for each task)
|
460 |
+
layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
|
461 |
+
num_hidden_nodes=final_hidden_size,
|
462 |
+
use_batchnorm=False,
|
463 |
+
use_layernorm=use_final_hidden_layer_norm,
|
464 |
+
norm_before_activation=final_hidden_layer_norm_before_activation,
|
465 |
+
use_dropout=use_final_hidden_layer_dropout,
|
466 |
+
dropout_rate=final_hidden_layer_dropout_rate,
|
467 |
+
activation=activation)
|
468 |
+
|
469 |
+
layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
|
470 |
+
else:
|
471 |
+
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
|
472 |
+
|
473 |
+
# FINAL MODEL
|
474 |
+
self.model = SequentialWithArgs(layers)
|
475 |
+
|
476 |
+
def forward(self, x, **kwargs):
|
477 |
+
return self.model(x, **kwargs)
|
478 |
+
|
479 |
+
|
480 |
+
class Transpose(nn.Module):
|
481 |
+
""" helper layer to swap data from (batch, seq, channels) to (batch, channels, seq)
|
482 |
+
used as a helper in the convolutional network which pytorch defaults to channels-first """
|
483 |
+
|
484 |
+
def __init__(self, dims: Tuple[int, ...] = (1, 2)):
|
485 |
+
super().__init__()
|
486 |
+
self.dims = dims
|
487 |
+
|
488 |
+
def forward(self, x, **kwargs):
|
489 |
+
x = x.transpose(*self.dims).contiguous()
|
490 |
+
return x
|
491 |
+
|
492 |
+
|
493 |
+
def conv1d_out_shape(seq_len, kernel_size, stride=1, pad=0, dilation=1):
|
494 |
+
return (seq_len + (2 * pad) - (dilation * (kernel_size - 1)) - 1 // stride) + 1
|
495 |
+
|
496 |
+
|
497 |
+
class ConvBlock(nn.Module):
|
498 |
+
def __init__(self,
|
499 |
+
in_channels: int,
|
500 |
+
out_channels: int,
|
501 |
+
kernel_size: int,
|
502 |
+
dilation: int = 1,
|
503 |
+
padding: str = "same",
|
504 |
+
use_batchnorm: bool = False,
|
505 |
+
use_layernorm: bool = False,
|
506 |
+
norm_before_activation: bool = False,
|
507 |
+
use_dropout: bool = False,
|
508 |
+
dropout_rate: float = 0.2,
|
509 |
+
activation: str = "relu"):
|
510 |
+
|
511 |
+
super().__init__()
|
512 |
+
|
513 |
+
if use_batchnorm and use_layernorm:
|
514 |
+
raise ValueError("Only one of use_batchnorm or use_layernorm can be set to True")
|
515 |
+
|
516 |
+
self.use_batchnorm = use_batchnorm
|
517 |
+
self.use_layernorm = use_layernorm
|
518 |
+
self.norm_before_activation = norm_before_activation
|
519 |
+
self.use_dropout = use_dropout
|
520 |
+
|
521 |
+
self.conv = nn.Conv1d(in_channels=in_channels,
|
522 |
+
out_channels=out_channels,
|
523 |
+
kernel_size=kernel_size,
|
524 |
+
padding=padding,
|
525 |
+
dilation=dilation)
|
526 |
+
|
527 |
+
self.activation = get_activation_fn(activation, functional=False)
|
528 |
+
|
529 |
+
if use_batchnorm:
|
530 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
531 |
+
|
532 |
+
if use_layernorm:
|
533 |
+
self.norm = nn.LayerNorm(out_channels)
|
534 |
+
|
535 |
+
if use_dropout:
|
536 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
537 |
+
|
538 |
+
def forward(self, x, **kwargs):
|
539 |
+
x = self.conv(x)
|
540 |
+
|
541 |
+
# norm can be before or after activation, using flag
|
542 |
+
if self.use_batchnorm and self.norm_before_activation:
|
543 |
+
x = self.norm(x)
|
544 |
+
elif self.use_layernorm and self.norm_before_activation:
|
545 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
546 |
+
|
547 |
+
x = self.activation(x)
|
548 |
+
|
549 |
+
# batchnorm being applied after activation, there is some discussion on this online
|
550 |
+
if self.use_batchnorm and not self.norm_before_activation:
|
551 |
+
x = self.norm(x)
|
552 |
+
elif self.use_layernorm and not self.norm_before_activation:
|
553 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
554 |
+
|
555 |
+
# dropout being applied after batchnorm, there is some discussion on this online
|
556 |
+
if self.use_dropout:
|
557 |
+
x = self.dropout(x)
|
558 |
+
|
559 |
+
return x
|
560 |
+
|
561 |
+
|
562 |
+
class ConvModel2(nn.Module):
|
563 |
+
""" convolutional source model that supports padded inputs, pooling, etc """
|
564 |
+
|
565 |
+
@staticmethod
|
566 |
+
def add_model_specific_args(parent_parser):
|
567 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
568 |
+
parser.add_argument('--use_embedding', action="store_true", default=False)
|
569 |
+
parser.add_argument('--embedding_len', type=int, default=128)
|
570 |
+
|
571 |
+
parser.add_argument('--num_conv_layers', type=int, default=1)
|
572 |
+
parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
|
573 |
+
parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
|
574 |
+
parser.add_argument('--dilations', type=int, nargs="+", default=[1])
|
575 |
+
parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
|
576 |
+
parser.add_argument('--use_conv_layer_norm', action="store_true", default=False)
|
577 |
+
parser.add_argument('--conv_layer_norm_before_activation', action="store_true", default=False)
|
578 |
+
parser.add_argument('--use_conv_layer_dropout', action="store_true", default=False)
|
579 |
+
parser.add_argument('--conv_layer_dropout_rate', type=float, default=0.2)
|
580 |
+
|
581 |
+
parser.add_argument('--global_average_pooling', action="store_true", default=False)
|
582 |
+
|
583 |
+
parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
|
584 |
+
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
585 |
+
parser.add_argument('--use_final_hidden_layer', action="store_true", default=False)
|
586 |
+
parser.add_argument('--final_hidden_size', type=int, default=64)
|
587 |
+
parser.add_argument('--use_final_hidden_layer_norm', action="store_true", default=False)
|
588 |
+
parser.add_argument('--final_hidden_layer_norm_before_activation', action="store_true", default=False)
|
589 |
+
parser.add_argument('--use_final_hidden_layer_dropout', action="store_true", default=False)
|
590 |
+
parser.add_argument('--final_hidden_layer_dropout_rate', type=float, default=0.2)
|
591 |
+
|
592 |
+
parser.add_argument('--activation', type=str, default="relu",
|
593 |
+
help="activation function used for all activations in the network")
|
594 |
+
|
595 |
+
return parser
|
596 |
+
|
597 |
+
def __init__(self,
|
598 |
+
# data
|
599 |
+
num_tasks: int,
|
600 |
+
aa_seq_len: int,
|
601 |
+
aa_encoding_len: int,
|
602 |
+
num_tokens: int,
|
603 |
+
# convolutional model args
|
604 |
+
use_embedding: bool = False,
|
605 |
+
embedding_len: int = 64,
|
606 |
+
num_conv_layers: int = 1,
|
607 |
+
kernel_sizes: List[int] = (7,),
|
608 |
+
out_channels: List[int] = (128,),
|
609 |
+
dilations: List[int] = (1,),
|
610 |
+
padding: str = "valid",
|
611 |
+
use_conv_layer_norm: bool = False,
|
612 |
+
conv_layer_norm_before_activation: bool = False,
|
613 |
+
use_conv_layer_dropout: bool = False,
|
614 |
+
conv_layer_dropout_rate: float = 0.2,
|
615 |
+
# pooling
|
616 |
+
global_average_pooling: bool = True,
|
617 |
+
# prediction layers
|
618 |
+
use_task_specific_layers: bool = False,
|
619 |
+
task_specific_hidden_nodes: int = 64,
|
620 |
+
use_final_hidden_layer: bool = False,
|
621 |
+
final_hidden_size: int = 64,
|
622 |
+
use_final_hidden_layer_norm: bool = False,
|
623 |
+
final_hidden_layer_norm_before_activation: bool = False,
|
624 |
+
use_final_hidden_layer_dropout: bool = False,
|
625 |
+
final_hidden_layer_dropout_rate: float = 0.2,
|
626 |
+
# activation function
|
627 |
+
activation: str = "relu",
|
628 |
+
*args, **kwargs):
|
629 |
+
|
630 |
+
super(ConvModel2, self).__init__()
|
631 |
+
|
632 |
+
# build up the layers
|
633 |
+
layers = collections.OrderedDict()
|
634 |
+
|
635 |
+
# amino acid embedding
|
636 |
+
if use_embedding:
|
637 |
+
layers["embedder"] = ScaledEmbedding(num_embeddings=num_tokens, embedding_dim=embedding_len, scale=False)
|
638 |
+
|
639 |
+
# transpose the input to match PyTorch's expected format
|
640 |
+
layers["transpose"] = Transpose(dims=(1, 2))
|
641 |
+
|
642 |
+
# build up the convolutional layers
|
643 |
+
for layer_num in range(num_conv_layers):
|
644 |
+
# determine the number of input channels for the first convolutional layer
|
645 |
+
if layer_num == 0 and use_embedding:
|
646 |
+
# for the first convolutional layer, the in_channels is the embedding_len
|
647 |
+
in_channels = embedding_len
|
648 |
+
elif layer_num == 0 and not use_embedding:
|
649 |
+
# for the first convolutional layer, the in_channels is the aa_encoding_len
|
650 |
+
in_channels = aa_encoding_len
|
651 |
+
else:
|
652 |
+
in_channels = out_channels[layer_num - 1]
|
653 |
+
|
654 |
+
layers[f"conv{layer_num}"] = ConvBlock(in_channels=in_channels,
|
655 |
+
out_channels=out_channels[layer_num],
|
656 |
+
kernel_size=kernel_sizes[layer_num],
|
657 |
+
dilation=dilations[layer_num],
|
658 |
+
padding=padding,
|
659 |
+
use_batchnorm=False,
|
660 |
+
use_layernorm=use_conv_layer_norm,
|
661 |
+
norm_before_activation=conv_layer_norm_before_activation,
|
662 |
+
use_dropout=use_conv_layer_dropout,
|
663 |
+
dropout_rate=conv_layer_dropout_rate,
|
664 |
+
activation=activation)
|
665 |
+
|
666 |
+
# handle transition from convolutional layers to fully connected layer
|
667 |
+
# either use global average pooling or flatten
|
668 |
+
# take into consideration whether we are using valid or same padding
|
669 |
+
if global_average_pooling:
|
670 |
+
# global average pooling (mean across the seq len dimension)
|
671 |
+
# the seq len dimensions is the last dimension (batch_size, num_filters, seq_len)
|
672 |
+
layers["avg_pooling"] = GlobalAveragePooling(dim=-1)
|
673 |
+
# the prediction layers will take num_filters input features
|
674 |
+
pred_layer_input_features = out_channels[-1]
|
675 |
+
|
676 |
+
else:
|
677 |
+
# no global average pooling. flatten instead.
|
678 |
+
layers["flatten"] = nn.Flatten()
|
679 |
+
# calculate the final output len of the convolutional layers
|
680 |
+
# and the number of input features for the prediction layers
|
681 |
+
if padding == "valid":
|
682 |
+
# valid padding (aka no padding) results in shrinking length in progressive layers
|
683 |
+
conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0], dilation=dilations[0])
|
684 |
+
for layer_num in range(1, num_conv_layers):
|
685 |
+
conv_out_len = conv1d_out_shape(conv_out_len,
|
686 |
+
kernel_size=kernel_sizes[layer_num],
|
687 |
+
dilation=dilations[layer_num])
|
688 |
+
pred_layer_input_features = conv_out_len * out_channels[-1]
|
689 |
+
else:
|
690 |
+
# padding == "same"
|
691 |
+
pred_layer_input_features = aa_seq_len * out_channels[-1]
|
692 |
+
|
693 |
+
# prediction layer
|
694 |
+
if use_task_specific_layers:
|
695 |
+
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
696 |
+
in_features=pred_layer_input_features,
|
697 |
+
num_hidden_nodes=task_specific_hidden_nodes,
|
698 |
+
activation=activation)
|
699 |
+
|
700 |
+
# final hidden layer (with potential additional dropout)
|
701 |
+
elif use_final_hidden_layer:
|
702 |
+
layers["fc1"] = FCBlock(in_features=pred_layer_input_features,
|
703 |
+
num_hidden_nodes=final_hidden_size,
|
704 |
+
use_batchnorm=False,
|
705 |
+
use_layernorm=use_final_hidden_layer_norm,
|
706 |
+
norm_before_activation=final_hidden_layer_norm_before_activation,
|
707 |
+
use_dropout=use_final_hidden_layer_dropout,
|
708 |
+
dropout_rate=final_hidden_layer_dropout_rate,
|
709 |
+
activation=activation)
|
710 |
+
layers["prediction"] = nn.Linear(in_features=final_hidden_size, out_features=num_tasks)
|
711 |
+
|
712 |
+
else:
|
713 |
+
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=num_tasks)
|
714 |
+
|
715 |
+
self.model = nn.Sequential(layers)
|
716 |
+
|
717 |
+
def forward(self, x, **kwargs):
|
718 |
+
output = self.model(x)
|
719 |
+
return output
|
720 |
+
|
721 |
+
|
722 |
+
class ConvModel(nn.Module):
|
723 |
+
""" a convolutional network with convolutional layers followed by a fully connected layer """
|
724 |
+
|
725 |
+
@staticmethod
|
726 |
+
def add_model_specific_args(parent_parser):
|
727 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
728 |
+
parser.add_argument('--num_conv_layers', type=int, default=1)
|
729 |
+
parser.add_argument('--kernel_sizes', type=int, nargs="+", default=[7])
|
730 |
+
parser.add_argument('--out_channels', type=int, nargs="+", default=[128])
|
731 |
+
parser.add_argument('--padding', type=str, default="valid", choices=["valid", "same"])
|
732 |
+
parser.add_argument('--use_final_hidden_layer', action="store_true",
|
733 |
+
help="whether to use a final hidden layer")
|
734 |
+
parser.add_argument('--final_hidden_size', type=int, default=128,
|
735 |
+
help="number of nodes in the final hidden layer")
|
736 |
+
parser.add_argument('--use_dropout', action="store_true",
|
737 |
+
help="whether to use dropout in the final hidden layer")
|
738 |
+
parser.add_argument('--dropout_rate', type=float, default=0.2,
|
739 |
+
help="dropout rate in the final hidden layer")
|
740 |
+
parser.add_argument('--use_task_specific_layers', action="store_true", default=False)
|
741 |
+
parser.add_argument('--task_specific_hidden_nodes', type=int, default=64)
|
742 |
+
return parser
|
743 |
+
|
744 |
+
def __init__(self,
|
745 |
+
num_tasks: int,
|
746 |
+
aa_seq_len: int,
|
747 |
+
aa_encoding_len: int,
|
748 |
+
num_conv_layers: int = 1,
|
749 |
+
kernel_sizes: List[int] = (7,),
|
750 |
+
out_channels: List[int] = (128,),
|
751 |
+
padding: str = "valid",
|
752 |
+
use_final_hidden_layer: bool = True,
|
753 |
+
final_hidden_size: int = 128,
|
754 |
+
use_dropout: bool = False,
|
755 |
+
dropout_rate: float = 0.2,
|
756 |
+
use_task_specific_layers: bool = False,
|
757 |
+
task_specific_hidden_nodes: int = 64,
|
758 |
+
*args, **kwargs):
|
759 |
+
|
760 |
+
super(ConvModel, self).__init__()
|
761 |
+
|
762 |
+
# set up the model as a Sequential block (less to do in forward())
|
763 |
+
layers = collections.OrderedDict()
|
764 |
+
|
765 |
+
layers["transpose"] = Transpose(dims=(1, 2))
|
766 |
+
|
767 |
+
for layer_num in range(num_conv_layers):
|
768 |
+
# for the first convolutional layer, the in_channels is the feature_len
|
769 |
+
in_channels = aa_encoding_len if layer_num == 0 else out_channels[layer_num - 1]
|
770 |
+
|
771 |
+
layers["conv{}".format(layer_num)] = nn.Sequential(
|
772 |
+
nn.Conv1d(in_channels=in_channels,
|
773 |
+
out_channels=out_channels[layer_num],
|
774 |
+
kernel_size=kernel_sizes[layer_num],
|
775 |
+
padding=padding),
|
776 |
+
nn.ReLU()
|
777 |
+
)
|
778 |
+
|
779 |
+
layers["flatten"] = nn.Flatten()
|
780 |
+
|
781 |
+
# calculate the final output len of the convolutional layers
|
782 |
+
# and the number of input features for the prediction layers
|
783 |
+
if padding == "valid":
|
784 |
+
# valid padding (aka no padding) results in shrinking length in progressive layers
|
785 |
+
conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0])
|
786 |
+
for layer_num in range(1, num_conv_layers):
|
787 |
+
conv_out_len = conv1d_out_shape(conv_out_len, kernel_size=kernel_sizes[layer_num])
|
788 |
+
next_dim = conv_out_len * out_channels[-1]
|
789 |
+
elif padding == "same":
|
790 |
+
next_dim = aa_seq_len * out_channels[-1]
|
791 |
+
else:
|
792 |
+
raise ValueError("unexpected value for padding: {}".format(padding))
|
793 |
+
|
794 |
+
# final hidden layer (with potential additional dropout)
|
795 |
+
if use_final_hidden_layer:
|
796 |
+
layers["fc1"] = FCBlock(in_features=next_dim,
|
797 |
+
num_hidden_nodes=final_hidden_size,
|
798 |
+
use_batchnorm=False,
|
799 |
+
use_dropout=use_dropout,
|
800 |
+
dropout_rate=dropout_rate)
|
801 |
+
next_dim = final_hidden_size
|
802 |
+
|
803 |
+
# final prediction layer
|
804 |
+
# either task specific nonlinear layers or a single linear layer
|
805 |
+
if use_task_specific_layers:
|
806 |
+
layers["prediction"] = TaskSpecificPredictionLayers(num_tasks=num_tasks,
|
807 |
+
in_features=next_dim,
|
808 |
+
num_hidden_nodes=task_specific_hidden_nodes)
|
809 |
+
else:
|
810 |
+
layers["prediction"] = nn.Linear(in_features=next_dim, out_features=num_tasks)
|
811 |
+
|
812 |
+
self.model = nn.Sequential(layers)
|
813 |
+
|
814 |
+
def forward(self, x, **kwargs):
|
815 |
+
output = self.model(x)
|
816 |
+
return output
|
817 |
+
|
818 |
+
|
819 |
+
class FCModel(nn.Module):
|
820 |
+
|
821 |
+
@staticmethod
|
822 |
+
def add_model_specific_args(parent_parser):
|
823 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
824 |
+
parser.add_argument('--num_layers', type=int, default=1)
|
825 |
+
parser.add_argument('--num_hidden', nargs="+", type=int, default=[128])
|
826 |
+
parser.add_argument('--use_batchnorm', action="store_true", default=False)
|
827 |
+
parser.add_argument('--use_layernorm', action="store_true", default=False)
|
828 |
+
parser.add_argument('--norm_before_activation', action="store_true", default=False)
|
829 |
+
parser.add_argument('--use_dropout', action="store_true", default=False)
|
830 |
+
parser.add_argument('--dropout_rate', type=float, default=0.2)
|
831 |
+
return parser
|
832 |
+
|
833 |
+
def __init__(self,
|
834 |
+
num_tasks: int,
|
835 |
+
seq_encoding_len: int,
|
836 |
+
num_layers: int = 1,
|
837 |
+
num_hidden: List[int] = (128,),
|
838 |
+
use_batchnorm: bool = False,
|
839 |
+
use_layernorm: bool = False,
|
840 |
+
norm_before_activation: bool = False,
|
841 |
+
use_dropout: bool = False,
|
842 |
+
dropout_rate: float = 0.2,
|
843 |
+
activation: str = "relu",
|
844 |
+
*args, **kwargs):
|
845 |
+
super().__init__()
|
846 |
+
|
847 |
+
# set up the model as a Sequential block (less to do in forward())
|
848 |
+
layers = collections.OrderedDict()
|
849 |
+
|
850 |
+
# flatten inputs as this is all fully connected
|
851 |
+
layers["flatten"] = nn.Flatten()
|
852 |
+
|
853 |
+
# build up the variable number of hidden layers (fully connected + ReLU + dropout (if set))
|
854 |
+
for layer_num in range(num_layers):
|
855 |
+
# for the first layer (layer_num == 0), in_features is determined by given input
|
856 |
+
# for subsequent layers, the in_features is the previous layer's num_hidden
|
857 |
+
in_features = seq_encoding_len if layer_num == 0 else num_hidden[layer_num - 1]
|
858 |
+
|
859 |
+
layers["fc{}".format(layer_num)] = FCBlock(in_features=in_features,
|
860 |
+
num_hidden_nodes=num_hidden[layer_num],
|
861 |
+
use_batchnorm=use_batchnorm,
|
862 |
+
use_layernorm=use_layernorm,
|
863 |
+
norm_before_activation=norm_before_activation,
|
864 |
+
use_dropout=use_dropout,
|
865 |
+
dropout_rate=dropout_rate,
|
866 |
+
activation=activation)
|
867 |
+
|
868 |
+
# finally, the linear output layer
|
869 |
+
in_features = num_hidden[-1] if num_layers > 0 else seq_encoding_len
|
870 |
+
layers["output"] = nn.Linear(in_features=in_features, out_features=num_tasks)
|
871 |
+
|
872 |
+
self.model = nn.Sequential(layers)
|
873 |
+
|
874 |
+
def forward(self, x, **kwargs):
|
875 |
+
output = self.model(x)
|
876 |
+
return output
|
877 |
+
|
878 |
+
|
879 |
+
class LRModel(nn.Module):
|
880 |
+
""" a simple linear model """
|
881 |
+
|
882 |
+
def __init__(self, num_tasks, seq_encoding_len, *args, **kwargs):
|
883 |
+
super().__init__()
|
884 |
+
|
885 |
+
self.model = nn.Sequential(
|
886 |
+
nn.Flatten(),
|
887 |
+
nn.Linear(seq_encoding_len, out_features=num_tasks))
|
888 |
+
|
889 |
+
def forward(self, x, **kwargs):
|
890 |
+
output = self.model(x)
|
891 |
+
return output
|
892 |
+
|
893 |
+
|
894 |
+
class TransferModel(nn.Module):
|
895 |
+
""" transfer learning model """
|
896 |
+
|
897 |
+
@staticmethod
|
898 |
+
def add_model_specific_args(parent_parser):
|
899 |
+
|
900 |
+
def none_or_int(value: str):
|
901 |
+
return None if value.lower() == "none" else int(value)
|
902 |
+
|
903 |
+
p = ArgumentParser(parents=[parent_parser], add_help=False)
|
904 |
+
|
905 |
+
# for model set up
|
906 |
+
p.add_argument('--pretrained_ckpt_path', type=str, default=None)
|
907 |
+
|
908 |
+
# where to cut off the backbone
|
909 |
+
p.add_argument("--backbone_cutoff", type=none_or_int, default=-1,
|
910 |
+
help="where to cut off the backbone. can be a negative int, indexing back from "
|
911 |
+
"pretrained_model.model.model. a value of -1 would chop off the backbone prediction head. "
|
912 |
+
"a value of -2 chops the prediction head and FC layer. a value of -3 chops"
|
913 |
+
"the above, as well as the global average pooling layer. all depends on architecture.")
|
914 |
+
|
915 |
+
p.add_argument("--pred_layer_input_features", type=int, default=None,
|
916 |
+
help="if None, number of features will be determined based on backbone_cutoff and standard "
|
917 |
+
"architecture. otherwise, specify the number of input features for the prediction layer")
|
918 |
+
|
919 |
+
# top net args
|
920 |
+
p.add_argument("--top_net_type", type=str, default="linear", choices=["linear", "nonlinear", "sklearn"])
|
921 |
+
p.add_argument("--top_net_hidden_nodes", type=int, default=256)
|
922 |
+
p.add_argument("--top_net_use_batchnorm", action="store_true")
|
923 |
+
p.add_argument("--top_net_use_dropout", action="store_true")
|
924 |
+
p.add_argument("--top_net_dropout_rate", type=float, default=0.1)
|
925 |
+
|
926 |
+
return p
|
927 |
+
|
928 |
+
def __init__(self,
|
929 |
+
# pretrained model
|
930 |
+
pretrained_ckpt_path: Optional[str] = None,
|
931 |
+
pretrained_hparams: Optional[dict] = None,
|
932 |
+
backbone_cutoff: Optional[int] = -1,
|
933 |
+
# top net
|
934 |
+
pred_layer_input_features: Optional[int] = None,
|
935 |
+
top_net_type: str = "linear",
|
936 |
+
top_net_hidden_nodes: int = 256,
|
937 |
+
top_net_use_batchnorm: bool = False,
|
938 |
+
top_net_use_dropout: bool = False,
|
939 |
+
top_net_dropout_rate: float = 0.1,
|
940 |
+
*args, **kwargs):
|
941 |
+
|
942 |
+
super().__init__()
|
943 |
+
|
944 |
+
# error checking: if pretrained_ckpt_path is None, then pretrained_hparams must be specified
|
945 |
+
if pretrained_ckpt_path is None and pretrained_hparams is None:
|
946 |
+
raise ValueError("Either pretrained_ckpt_path or pretrained_hparams must be specified")
|
947 |
+
|
948 |
+
# note: pdb_fns is loaded from transfer model arguments rather than original source model hparams
|
949 |
+
# if pdb_fns is specified as a kwarg, pass it on for structure-based RPE
|
950 |
+
# otherwise, can just set pdb_fns to None, and structure-based RPE will handle new PDBs on the fly
|
951 |
+
pdb_fns = kwargs["pdb_fns"] if "pdb_fns" in kwargs else None
|
952 |
+
|
953 |
+
# generate a fresh backbone using pretrained_hparams if specified
|
954 |
+
# otherwise load the backbone from the pretrained checkpoint
|
955 |
+
# we prioritize pretrained_hparams over pretrained_ckpt_path because
|
956 |
+
# pretrained_hparams will only really be specified if we are loading from a DMSTask checkpoint
|
957 |
+
# meaning the TransferModel has already been fine-tuned on DMS data, and we are likely loading
|
958 |
+
# weights from that finetuning (including weights for the backbone)
|
959 |
+
# whereas if pretrained_hparams is not specified but pretrained_ckpt_path is, then we are
|
960 |
+
# likely finetuning the TransferModel for the first time, and we need the pretrained weights for the
|
961 |
+
# backbone from the RosettaTask checkpoint
|
962 |
+
if pretrained_hparams is not None:
|
963 |
+
# pretrained_hparams will only be specified if we are loading from a DMSTask checkpoint
|
964 |
+
pretrained_hparams["pdb_fns"] = pdb_fns
|
965 |
+
pretrained_model = Model[pretrained_hparams["model_name"]].cls(**pretrained_hparams)
|
966 |
+
self.pretrained_hparams = pretrained_hparams
|
967 |
+
else:
|
968 |
+
# not supported in metl-pretrained
|
969 |
+
raise NotImplementedError("Loading pretrained weights from RosettaTask checkpoint not supported")
|
970 |
+
|
971 |
+
layers = collections.OrderedDict()
|
972 |
+
|
973 |
+
# set the backbone to all layers except the last layer (the pre-trained prediction layer)
|
974 |
+
if backbone_cutoff is None:
|
975 |
+
layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children()))
|
976 |
+
else:
|
977 |
+
layers["backbone"] = SequentialWithArgs(*list(pretrained_model.model.children())[0:backbone_cutoff])
|
978 |
+
|
979 |
+
if top_net_type == "sklearn":
|
980 |
+
# sklearn top not doesn't require any more layers, just return model for the repr layer
|
981 |
+
self.model = SequentialWithArgs(layers)
|
982 |
+
return
|
983 |
+
|
984 |
+
# figure out dimensions of input into the prediction layer
|
985 |
+
if pred_layer_input_features is None:
|
986 |
+
# todo: can make this more robust by checking if the pretrained_mode.hparams for use_final_hidden_layer,
|
987 |
+
# global_average_pooling, etc. then can determine what the layer will be based on backbone_cutoff.
|
988 |
+
# currently, assumes that pretrained_model uses global average pooling and a final_hidden_layer
|
989 |
+
if backbone_cutoff is None:
|
990 |
+
# no backbone cutoff... use the full network (including tasks) as the backbone
|
991 |
+
pred_layer_input_features = self.pretrained_hparams["num_tasks"]
|
992 |
+
elif backbone_cutoff == -1:
|
993 |
+
pred_layer_input_features = self.pretrained_hparams["final_hidden_size"]
|
994 |
+
elif backbone_cutoff == -2:
|
995 |
+
pred_layer_input_features = self.pretrained_hparams["embedding_len"]
|
996 |
+
elif backbone_cutoff == -3:
|
997 |
+
pred_layer_input_features = self.pretrained_hparams["embedding_len"] * kwargs["aa_seq_len"]
|
998 |
+
else:
|
999 |
+
raise ValueError("can't automatically determine pred_layer_input_features for given backbone_cutoff")
|
1000 |
+
|
1001 |
+
layers["flatten"] = nn.Flatten(start_dim=1)
|
1002 |
+
|
1003 |
+
# create a new prediction layer on top of the backbone
|
1004 |
+
if top_net_type == "linear":
|
1005 |
+
# linear layer for prediction
|
1006 |
+
layers["prediction"] = nn.Linear(in_features=pred_layer_input_features, out_features=1)
|
1007 |
+
elif top_net_type == "nonlinear":
|
1008 |
+
# fully connected with hidden layer
|
1009 |
+
fc_block = FCBlock(in_features=pred_layer_input_features,
|
1010 |
+
num_hidden_nodes=top_net_hidden_nodes,
|
1011 |
+
use_batchnorm=top_net_use_batchnorm,
|
1012 |
+
use_dropout=top_net_use_dropout,
|
1013 |
+
dropout_rate=top_net_dropout_rate)
|
1014 |
+
|
1015 |
+
pred_layer = nn.Linear(in_features=top_net_hidden_nodes, out_features=1)
|
1016 |
+
|
1017 |
+
layers["prediction"] = SequentialWithArgs(fc_block, pred_layer)
|
1018 |
+
else:
|
1019 |
+
raise ValueError("Unexpected type of top net layer: {}".format(top_net_type))
|
1020 |
+
|
1021 |
+
self.model = SequentialWithArgs(layers)
|
1022 |
+
|
1023 |
+
def forward(self, x, **kwargs):
|
1024 |
+
return self.model(x, **kwargs)
|
1025 |
+
|
1026 |
+
|
1027 |
+
def get_activation_fn(activation, functional=True):
|
1028 |
+
if activation == "relu":
|
1029 |
+
return F.relu if functional else nn.ReLU()
|
1030 |
+
elif activation == "gelu":
|
1031 |
+
return F.gelu if functional else nn.GELU()
|
1032 |
+
elif activation == "silo" or activation == "swish":
|
1033 |
+
return F.silu if functional else nn.SiLU()
|
1034 |
+
elif activation == "leaky_relu" or activation == "lrelu":
|
1035 |
+
return F.leaky_relu if functional else nn.LeakyReLU()
|
1036 |
+
else:
|
1037 |
+
raise RuntimeError("unknown activation: {}".format(activation))
|
1038 |
+
|
1039 |
+
|
1040 |
+
class Model(enum.Enum):
|
1041 |
+
def __new__(cls, *args, **kwds):
|
1042 |
+
value = len(cls.__members__) + 1
|
1043 |
+
obj = object.__new__(cls)
|
1044 |
+
obj._value_ = value
|
1045 |
+
return obj
|
1046 |
+
|
1047 |
+
def __init__(self, cls, transfer_model):
|
1048 |
+
self.cls = cls
|
1049 |
+
self.transfer_model = transfer_model
|
1050 |
+
|
1051 |
+
linear = LRModel, False
|
1052 |
+
fully_connected = FCModel, False
|
1053 |
+
cnn = ConvModel, False
|
1054 |
+
cnn2 = ConvModel2, False
|
1055 |
+
transformer_encoder = AttnModel, False
|
1056 |
+
transfer_model = TransferModel, True
|
1057 |
+
|
1058 |
+
|
1059 |
+
def main():
|
1060 |
+
pass
|
1061 |
+
|
1062 |
+
|
1063 |
+
if __name__ == "__main__":
|
1064 |
+
main()
|
metl/relative_attention.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" implementation of transformer encoder with relative attention
|
2 |
+
references:
|
3 |
+
- https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a
|
4 |
+
- https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
|
5 |
+
- https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py
|
6 |
+
- https://github.com/jiezouguihuafu/ClassicalModelreproduced/blob/main/Transformer/transfor_rpe.py
|
7 |
+
"""
|
8 |
+
|
9 |
+
import copy
|
10 |
+
from os.path import basename, dirname, join, isfile
|
11 |
+
from typing import Optional, Union
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch import Tensor
|
17 |
+
from torch.nn import Linear, Dropout, LayerNorm
|
18 |
+
import time
|
19 |
+
import networkx as nx
|
20 |
+
|
21 |
+
import metl.structure as structure
|
22 |
+
import metl.models as models
|
23 |
+
|
24 |
+
|
25 |
+
class RelativePosition3D(nn.Module):
|
26 |
+
""" Contact map-based relative position embeddings """
|
27 |
+
|
28 |
+
# need to compute a bucket_mtx for each structure
|
29 |
+
# need to know which bucket_mtx to use when grabbing the embeddings in forward()
|
30 |
+
# - on init, get a list of all PDB files we will be using
|
31 |
+
# - use a dictionary to store PDB files --> bucket_mtxs
|
32 |
+
# - forward() gets a new arg: the pdb file, which indexes into the dictionary to grab the right bucket_mtx
|
33 |
+
def __init__(self,
|
34 |
+
embedding_len: int,
|
35 |
+
contact_threshold: int,
|
36 |
+
clipping_threshold: int,
|
37 |
+
pdb_fns: Optional[Union[str, list, tuple]] = None,
|
38 |
+
default_pdb_dir: str = "data/pdb_files"):
|
39 |
+
|
40 |
+
# preferably, pdb_fns contains full paths to the PDBs, but if just the PDB filename is given
|
41 |
+
# then it defaults to the path data/pdb_files/<pdb_fn>
|
42 |
+
super().__init__()
|
43 |
+
self.embedding_len = embedding_len
|
44 |
+
self.clipping_threshold = clipping_threshold
|
45 |
+
self.contact_threshold = contact_threshold
|
46 |
+
self.default_pdb_dir = default_pdb_dir
|
47 |
+
|
48 |
+
# dummy buffer for getting correct device for on-the-fly bucket matrix generation
|
49 |
+
self.register_buffer("dummy_buffer", torch.empty(0), persistent=False)
|
50 |
+
|
51 |
+
# for 3D-based positions, the number of embeddings is generally the number of buckets
|
52 |
+
# for contact map-based distances, that is clipping_threshold + 1
|
53 |
+
num_embeddings = clipping_threshold + 1
|
54 |
+
|
55 |
+
# this is the embedding lookup table E_r
|
56 |
+
self.embeddings_table = nn.Embedding(num_embeddings, embedding_len)
|
57 |
+
|
58 |
+
# set up pdb_fns that were passed in on init (can also be set up during runtime in forward())
|
59 |
+
# todo: i'm using a hacky workaround to move the bucket_mtxs to the correct device
|
60 |
+
# i tried to make it more efficient by registering bucket matrices as buffers, but i was
|
61 |
+
# having problems with DDP syncing the buffers across processes
|
62 |
+
self.bucket_mtxs = {}
|
63 |
+
self.bucket_mtxs_device = self.dummy_buffer.device
|
64 |
+
self._init_pdbs(pdb_fns)
|
65 |
+
|
66 |
+
def forward(self, pdb_fn):
|
67 |
+
# compute matrix R by grabbing the embeddings from the embeddings lookup table
|
68 |
+
embeddings = self.embeddings_table(self._get_bucket_mtx(pdb_fn))
|
69 |
+
return embeddings
|
70 |
+
|
71 |
+
# def _get_bucket_mtx(self, pdb_fn):
|
72 |
+
# """ retrieve a bucket matrix given the pdb_fn.
|
73 |
+
# if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be
|
74 |
+
# retrieved from the object buffer. if the bucket matrix has not been computed yet, it will be here """
|
75 |
+
# pdb_attr = self._pdb_key(pdb_fn)
|
76 |
+
# if hasattr(self, pdb_attr):
|
77 |
+
# return getattr(self, pdb_attr)
|
78 |
+
# else:
|
79 |
+
# # encountering a new PDB at runtime... process it
|
80 |
+
# # todo: if there's a new PDB at runtime, it will be initialized separately in each instance
|
81 |
+
# # of RelativePosition3D, for each layer. It would be more efficient to have a global
|
82 |
+
# # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through
|
83 |
+
# self._init_pdb(pdb_fn)
|
84 |
+
# return getattr(self, pdb_attr)
|
85 |
+
|
86 |
+
def _move_bucket_mtxs(self, device):
|
87 |
+
for k, v in self.bucket_mtxs.items():
|
88 |
+
self.bucket_mtxs[k] = v.to(device)
|
89 |
+
self.bucket_mtxs_device = device
|
90 |
+
|
91 |
+
def _get_bucket_mtx(self, pdb_fn):
|
92 |
+
""" retrieve a bucket matrix given the pdb_fn.
|
93 |
+
if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be
|
94 |
+
retrieved from the bucket_mtxs dictionary. else, it will be computed now on-the-fly """
|
95 |
+
|
96 |
+
# ensure that all the bucket matrices are on the same device as the nn.Embedding
|
97 |
+
if self.bucket_mtxs_device != self.dummy_buffer.device:
|
98 |
+
self._move_bucket_mtxs(self.dummy_buffer.device)
|
99 |
+
|
100 |
+
pdb_attr = self._pdb_key(pdb_fn)
|
101 |
+
if pdb_attr in self.bucket_mtxs:
|
102 |
+
return self.bucket_mtxs[pdb_attr]
|
103 |
+
else:
|
104 |
+
# encountering a new PDB at runtime... process it
|
105 |
+
# todo: if there's a new PDB at runtime, it will be initialized separately in each instance
|
106 |
+
# of RelativePosition3D, for each layer. It would be more efficient to have a global
|
107 |
+
# bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through
|
108 |
+
self._init_pdb(pdb_fn)
|
109 |
+
return self.bucket_mtxs[pdb_attr]
|
110 |
+
|
111 |
+
# def _set_bucket_mtx(self, pdb_fn, bucket_mtx):
|
112 |
+
# """ store a bucket matrix as a buffer """
|
113 |
+
# # if PyTorch ever implements a BufferDict, we could use it here efficiently
|
114 |
+
# # there is also BufferDict from https://botorch.org/api/_modules/botorch/utils/torch.html
|
115 |
+
# # would just need to modify it to have an option for persistent=False
|
116 |
+
# bucket_mtx = bucket_mtx.to(self.dummy_buffer.device)
|
117 |
+
#
|
118 |
+
# self.register_buffer(self._pdb_key(pdb_fn), bucket_mtx, persistent=False)
|
119 |
+
|
120 |
+
def _set_bucket_mtx(self, pdb_fn, bucket_mtx):
|
121 |
+
""" store a bucket matrix in the bucket dict """
|
122 |
+
|
123 |
+
# move the bucket_mtx to the same device that the other bucket matrices are on
|
124 |
+
bucket_mtx = bucket_mtx.to(self.bucket_mtxs_device)
|
125 |
+
|
126 |
+
self.bucket_mtxs[self._pdb_key(pdb_fn)] = bucket_mtx
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def _pdb_key(pdb_fn):
|
130 |
+
""" return a unique key for the given pdb_fn, used to map unique PDBs """
|
131 |
+
# note this key does NOT currently support PDBs with the same basename but different paths
|
132 |
+
# assumes every PDB is in the format <pdb_name>.pdb
|
133 |
+
# should be a compatible with being a class attribute, as it is used as a pytorch buffer name
|
134 |
+
return f"pdb_{basename(pdb_fn).split('.')[0]}"
|
135 |
+
|
136 |
+
def _init_pdbs(self, pdb_fns):
|
137 |
+
start = time.time()
|
138 |
+
|
139 |
+
if pdb_fns is None:
|
140 |
+
# nothing to initialize if pdb_fns is None
|
141 |
+
return
|
142 |
+
|
143 |
+
# make sure pdb_fns is a list
|
144 |
+
if not isinstance(pdb_fns, list) and not isinstance(pdb_fns, tuple):
|
145 |
+
pdb_fns = [pdb_fns]
|
146 |
+
|
147 |
+
# init each pdb fn in the list
|
148 |
+
for pdb_fn in pdb_fns:
|
149 |
+
self._init_pdb(pdb_fn)
|
150 |
+
|
151 |
+
print("Initialized PDB bucket matrices in: {:.3f}".format(time.time() - start))
|
152 |
+
|
153 |
+
def _init_pdb(self, pdb_fn):
|
154 |
+
""" process a pdb file for use with structure-based relative attention """
|
155 |
+
# if pdb_fn is not a full path, default to the path data/pdb_files/<pdb_fn>
|
156 |
+
if dirname(pdb_fn) == "":
|
157 |
+
# handle the case where the pdb file is in the current working directory
|
158 |
+
# if there is a PDB file in the cwd.... then just use it as is. otherwise, append the default.
|
159 |
+
if not isfile(pdb_fn):
|
160 |
+
pdb_fn = join(self.default_pdb_dir, pdb_fn)
|
161 |
+
|
162 |
+
# create a structure graph from the pdb_fn and contact threshold
|
163 |
+
cbeta_mtx = structure.cbeta_distance_matrix(pdb_fn)
|
164 |
+
structure_graph = structure.dist_thresh_graph(cbeta_mtx, self.contact_threshold)
|
165 |
+
|
166 |
+
# bucket_mtx indexes into the embedding lookup table to create the final distance matrix
|
167 |
+
bucket_mtx = self._compute_bucket_mtx(structure_graph)
|
168 |
+
|
169 |
+
self._set_bucket_mtx(pdb_fn, bucket_mtx)
|
170 |
+
|
171 |
+
def _compute_bucketed_neighbors(self, structure_graph, source_node):
|
172 |
+
""" gets the bucketed neighbors from the given source node and structure graph"""
|
173 |
+
if self.clipping_threshold < 0:
|
174 |
+
raise ValueError("Clipping threshold must be >= 0")
|
175 |
+
|
176 |
+
sspl = _inv_dict(nx.single_source_shortest_path_length(structure_graph, source_node))
|
177 |
+
|
178 |
+
if self.clipping_threshold is not None:
|
179 |
+
num_buckets = 1 + self.clipping_threshold
|
180 |
+
sspl = _combine_d(sspl, self.clipping_threshold, num_buckets - 1)
|
181 |
+
|
182 |
+
return sspl
|
183 |
+
|
184 |
+
def _compute_bucket_mtx(self, structure_graph):
|
185 |
+
""" get the bucket_mtx for the given structure_graph
|
186 |
+
calls _get_bucketed_neighbors for every node in the structure_graph """
|
187 |
+
num_residues = len(list(structure_graph))
|
188 |
+
|
189 |
+
# index into the embedding lookup table to create the final distance matrix
|
190 |
+
bucket_mtx = torch.zeros(num_residues, num_residues, dtype=torch.long)
|
191 |
+
|
192 |
+
for node_num in sorted(list(structure_graph)):
|
193 |
+
bucketed_neighbors = self._compute_bucketed_neighbors(structure_graph, node_num)
|
194 |
+
|
195 |
+
for bucket_num, neighbors in bucketed_neighbors.items():
|
196 |
+
bucket_mtx[node_num, neighbors] = bucket_num
|
197 |
+
|
198 |
+
return bucket_mtx
|
199 |
+
|
200 |
+
|
201 |
+
class RelativePosition(nn.Module):
|
202 |
+
""" creates the embedding lookup table E_r and computes R
|
203 |
+
note this inherits from pl.LightningModule instead of nn.Module
|
204 |
+
makes it easier to access the device with `self.device`
|
205 |
+
might be able to keep it as an nn.Module using the hacky dummy_param or commented out .device property """
|
206 |
+
|
207 |
+
def __init__(self, embedding_len: int, clipping_threshold: int):
|
208 |
+
"""
|
209 |
+
embedding_len: the length of the embedding, may be d_model, or d_model // num_heads for multihead
|
210 |
+
clipping_threshold: the maximum relative position, referred to as k by Shaw et al.
|
211 |
+
"""
|
212 |
+
super().__init__()
|
213 |
+
self.embedding_len = embedding_len
|
214 |
+
self.clipping_threshold = clipping_threshold
|
215 |
+
# for sequence-based distances, the number of embeddings is 2*k+1, where k is the clipping threshold
|
216 |
+
num_embeddings = 2 * clipping_threshold + 1
|
217 |
+
|
218 |
+
# this is the embedding lookup table E_r
|
219 |
+
self.embeddings_table = nn.Embedding(num_embeddings, embedding_len)
|
220 |
+
|
221 |
+
# for getting the correct device for range vectors in forward
|
222 |
+
self.register_buffer("dummy_buffer", torch.empty(0), persistent=False)
|
223 |
+
|
224 |
+
def forward(self, length_q, length_k):
|
225 |
+
# supports different length sequences, but in self-attention length_q and length_k are the same
|
226 |
+
range_vec_q = torch.arange(length_q, device=self.dummy_buffer.device)
|
227 |
+
range_vec_k = torch.arange(length_k, device=self.dummy_buffer.device)
|
228 |
+
|
229 |
+
# this sets up the standard sequence-based distance matrix for relative positions
|
230 |
+
# the current position is 0, positions to the right are +1, +2, etc, and to the left -1, -2, etc
|
231 |
+
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
|
232 |
+
distance_mat_clipped = torch.clamp(distance_mat, -self.clipping_threshold, self.clipping_threshold)
|
233 |
+
|
234 |
+
# convert to indices, indexing into the embedding table
|
235 |
+
final_mat = (distance_mat_clipped + self.clipping_threshold).long()
|
236 |
+
|
237 |
+
# compute matrix R by grabbing the embeddings from the embedding lookup table
|
238 |
+
embeddings = self.embeddings_table(final_mat)
|
239 |
+
|
240 |
+
return embeddings
|
241 |
+
|
242 |
+
|
243 |
+
class RelativeMultiHeadAttention(nn.Module):
|
244 |
+
def __init__(self, embed_dim, num_heads, dropout, pos_encoding, clipping_threshold, contact_threshold, pdb_fns):
|
245 |
+
"""
|
246 |
+
Multi-head attention with relative position embeddings. Input data should be in batch_first format.
|
247 |
+
:param embed_dim: aka d_model, aka hid_dim
|
248 |
+
:param num_heads: number of heads
|
249 |
+
:param dropout: how much dropout for scaled dot product attention
|
250 |
+
|
251 |
+
:param pos_encoding: what type of positional encoding to use, relative or relative3D
|
252 |
+
:param clipping_threshold: clipping threshold for relative position embedding
|
253 |
+
:param contact_threshold: for relative_3D, the threshold in angstroms for the contact map
|
254 |
+
:param pdb_fns: pdb file(s) to set up the relative position object
|
255 |
+
|
256 |
+
"""
|
257 |
+
super().__init__()
|
258 |
+
|
259 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
260 |
+
|
261 |
+
# model dimensions
|
262 |
+
self.embed_dim = embed_dim
|
263 |
+
self.num_heads = num_heads
|
264 |
+
self.head_dim = embed_dim // num_heads
|
265 |
+
|
266 |
+
# pos encoding stuff
|
267 |
+
self.pos_encoding = pos_encoding
|
268 |
+
self.clipping_threshold = clipping_threshold
|
269 |
+
self.contact_threshold = contact_threshold
|
270 |
+
if pdb_fns is not None and not isinstance(pdb_fns, list):
|
271 |
+
pdb_fns = [pdb_fns]
|
272 |
+
self.pdb_fns = pdb_fns
|
273 |
+
|
274 |
+
# relative position embeddings for use with keys and values
|
275 |
+
# Shaw et al. uses relative position information for both keys and values
|
276 |
+
# Huang et al. only uses it for the keys, which is probably enough
|
277 |
+
if pos_encoding == "relative":
|
278 |
+
self.relative_position_k = RelativePosition(self.head_dim, self.clipping_threshold)
|
279 |
+
self.relative_position_v = RelativePosition(self.head_dim, self.clipping_threshold)
|
280 |
+
elif pos_encoding == "relative_3D":
|
281 |
+
self.relative_position_k = RelativePosition3D(self.head_dim, self.contact_threshold,
|
282 |
+
self.clipping_threshold, self.pdb_fns)
|
283 |
+
self.relative_position_v = RelativePosition3D(self.head_dim, self.contact_threshold,
|
284 |
+
self.clipping_threshold, self.pdb_fns)
|
285 |
+
else:
|
286 |
+
raise ValueError("unrecognized pos_encoding: {}".format(pos_encoding))
|
287 |
+
|
288 |
+
# WQ, WK, and WV from attention is all you need
|
289 |
+
# note these default to bias=True, same as PyTorch implementation
|
290 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
291 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
292 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
293 |
+
|
294 |
+
# WO from attention is all you need
|
295 |
+
# used for the final projection when computing multi-head attention
|
296 |
+
# PyTorch uses NonDynamicallyQuantizableLinear instead of Linear to avoid triggering an obscure
|
297 |
+
# error quantizing the model https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L122
|
298 |
+
# todo: if quantizing the model, explore if the above is a concern for us
|
299 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
300 |
+
|
301 |
+
# dropout for scaled dot product attention
|
302 |
+
self.dropout = nn.Dropout(dropout)
|
303 |
+
|
304 |
+
# scaling factor for scaled dot product attention
|
305 |
+
scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
|
306 |
+
# persistent=False if you don't want to save it inside state_dict
|
307 |
+
self.register_buffer('scale', scale)
|
308 |
+
|
309 |
+
# toggles meant to be set directly by user
|
310 |
+
self.need_weights = False
|
311 |
+
self.average_attn_weights = True
|
312 |
+
|
313 |
+
def _compute_attn_weights(self, query, key, len_q, len_k, batch_size, mask, pdb_fn):
|
314 |
+
""" computes the attention weights (a "compatability function" of queries with corresponding keys) """
|
315 |
+
|
316 |
+
# calculate the first term in the numerator attn1, which is Q*K
|
317 |
+
# todo: pytorch reshapes q,k and v to 3 dimensions (similar to how r_q2 is below)
|
318 |
+
# is that functionally equivalent to what we're doing? is their way faster?
|
319 |
+
# r_q1 = [batch_size, num_heads, len_q, head_dim]
|
320 |
+
r_q1 = query.view(batch_size, len_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
321 |
+
# todo: we could directly permute r_k1 to [batch_size, num_heads, head_dim, len_k]
|
322 |
+
# to make it compatible for matrix multiplication with r_q1, instead of 2-step approach
|
323 |
+
# r_k1 = [batch_size, num_heads, len_k, head_dim]
|
324 |
+
r_k1 = key.view(batch_size, len_k, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
325 |
+
# attn1 = [batch_size, num_heads, len_q, len_k]
|
326 |
+
attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))
|
327 |
+
|
328 |
+
# calculate the second term in the numerator attn2, which is Q*R
|
329 |
+
# r_q2 = [query_len, batch_size * num_heads, head_dim]
|
330 |
+
r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size * self.num_heads, self.head_dim)
|
331 |
+
|
332 |
+
# todo: support multiple different PDB base structures per batch
|
333 |
+
# one option:
|
334 |
+
# - require batches to be all the same protein
|
335 |
+
# - add argument to forward() to accept the PDB file for the protein in the batch
|
336 |
+
# - then we just pass in the PDB file to relative position's forward()
|
337 |
+
# to support multiple different structures per batch:
|
338 |
+
# - add argument to forward() to accept PDB files, one for each item in batch
|
339 |
+
# - make corresponding changing in relative_position object to return R for each structure
|
340 |
+
# - note: if there are a lot of of different structures, and the sequence lengths are long,
|
341 |
+
# this could be memory prohibitive because R (rel_pos_k) can take up a lot of mem for long seqs
|
342 |
+
# - adjust the attn2 calculation to factor in the multiple different R matrices.
|
343 |
+
# the way to do this might have to be to do multiple matmuls, one for each each structure.
|
344 |
+
# basically, would split up r_q2 into several matrices grouped by structure, and then
|
345 |
+
# multiply with corresponding R, then combine back into the exact same order of the original r_q2
|
346 |
+
# note: this may be computationally intensive (splitting, more matrix muliplies, joining)
|
347 |
+
# another option would be to create views(?), repeating the different Rs so we can do a
|
348 |
+
# a matris multiply directly with r_q2
|
349 |
+
# - would shapes be affected if there was padding in the queries, keys, values?
|
350 |
+
|
351 |
+
if self.pos_encoding == "relative":
|
352 |
+
# rel_pos_k = [len_q, len_k, head_dim]
|
353 |
+
rel_pos_k = self.relative_position_k(len_q, len_k)
|
354 |
+
elif self.pos_encoding == "relative_3D":
|
355 |
+
# rel_pos_k = [sequence length (from PDB structure), head_dim]
|
356 |
+
rel_pos_k = self.relative_position_k(pdb_fn)
|
357 |
+
else:
|
358 |
+
raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding))
|
359 |
+
|
360 |
+
# the matmul basically computes the dot product between each input position’s query vector and
|
361 |
+
# its corresponding relative position embeddings across all input sequences in the heads and batch
|
362 |
+
# attn2 = [batch_size * num_heads, len_q, len_k]
|
363 |
+
attn2 = torch.matmul(r_q2, rel_pos_k.transpose(1, 2)).transpose(0, 1)
|
364 |
+
# attn2 = [batch_size, num_heads, len_q, len_k]
|
365 |
+
attn2 = attn2.contiguous().view(batch_size, self.num_heads, len_q, len_k)
|
366 |
+
|
367 |
+
# calculate attention weights
|
368 |
+
attn_weights = (attn1 + attn2) / self.scale
|
369 |
+
|
370 |
+
# apply mask if given
|
371 |
+
if mask is not None:
|
372 |
+
# todo: pytorch uses float("-inf") instead of -1e10
|
373 |
+
attn_weights = attn_weights.masked_fill(mask == 0, -1e10)
|
374 |
+
|
375 |
+
# softmax gives us attn_weights weights
|
376 |
+
attn_weights = torch.softmax(attn_weights, dim=-1)
|
377 |
+
# attn_weights = [batch_size, num_heads, len_q, len_k]
|
378 |
+
attn_weights = self.dropout(attn_weights)
|
379 |
+
|
380 |
+
return attn_weights
|
381 |
+
|
382 |
+
def _compute_avg_val(self, value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn):
|
383 |
+
# todo: add option to not factor in relative position embeddings in value calculation
|
384 |
+
# calculate the first term, the attn*values
|
385 |
+
# r_v1 = [batch_size, num_heads, len_v, head_dim]
|
386 |
+
r_v1 = value.view(batch_size, len_v, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
387 |
+
# avg1 = [batch_size, num_heads, len_q, head_dim]
|
388 |
+
avg1 = torch.matmul(attn_weights, r_v1)
|
389 |
+
|
390 |
+
# calculate the second term, the attn*R
|
391 |
+
# similar to how relative embeddings are factored in the attention weights calculation
|
392 |
+
if self.pos_encoding == "relative":
|
393 |
+
# rel_pos_v = [query_len, value_len, head_dim]
|
394 |
+
rel_pos_v = self.relative_position_v(len_q, len_v)
|
395 |
+
elif self.pos_encoding == "relative_3D":
|
396 |
+
# rel_pos_v = [sequence length (from PDB structure), head_dim]
|
397 |
+
rel_pos_v = self.relative_position_v(pdb_fn)
|
398 |
+
else:
|
399 |
+
raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding))
|
400 |
+
|
401 |
+
# r_attn_weights = [len_q, batch_size * num_heads, len_v]
|
402 |
+
r_attn_weights = attn_weights.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size * self.num_heads, len_k)
|
403 |
+
avg2 = torch.matmul(r_attn_weights, rel_pos_v)
|
404 |
+
# avg2 = [batch_size, num_heads, len_q, head_dim]
|
405 |
+
avg2 = avg2.transpose(0, 1).contiguous().view(batch_size, self.num_heads, len_q, self.head_dim)
|
406 |
+
|
407 |
+
# calculate avg value
|
408 |
+
x = avg1 + avg2 # [batch_size, num_heads, len_q, head_dim]
|
409 |
+
x = x.permute(0, 2, 1, 3).contiguous() # [batch_size, len_q, num_heads, head_dim]
|
410 |
+
# x = [batch_size, len_q, embed_dim]
|
411 |
+
x = x.view(batch_size, len_q, self.embed_dim)
|
412 |
+
|
413 |
+
return x
|
414 |
+
|
415 |
+
def forward(self, query, key, value, pdb_fn=None, mask=None):
|
416 |
+
# query = [batch_size, q_len, embed_dim]
|
417 |
+
# key = [batch_size, k_len, embed_dim]
|
418 |
+
# value = [batch_size, v_en, embed_dim]
|
419 |
+
batch_size = query.shape[0]
|
420 |
+
len_k, len_q, len_v = (key.shape[1], query.shape[1], value.shape[1])
|
421 |
+
|
422 |
+
# in projection (multiply inputs by WQ, WK, WV)
|
423 |
+
query = self.q_proj(query)
|
424 |
+
key = self.k_proj(key)
|
425 |
+
value = self.v_proj(value)
|
426 |
+
|
427 |
+
# first compute the attention weights, then multiply with values
|
428 |
+
# attn = [batch size, num_heads, len_q, len_k]
|
429 |
+
attn_weights = self._compute_attn_weights(query, key, len_q, len_k, batch_size, mask, pdb_fn)
|
430 |
+
|
431 |
+
# take weighted average of values (weighted by attention weights)
|
432 |
+
attn_output = self._compute_avg_val(value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn)
|
433 |
+
|
434 |
+
# output projection
|
435 |
+
# attn_output = [batch_size, len_q, embed_dim]
|
436 |
+
attn_output = self.out_proj(attn_output)
|
437 |
+
|
438 |
+
if self.need_weights:
|
439 |
+
# return attention weights in addition to attention
|
440 |
+
# average the weights over the heads (to get overall attention)
|
441 |
+
# attn_weights = [batch_size, len_q, len_k]
|
442 |
+
if self.average_attn_weights:
|
443 |
+
attn_weights = attn_weights.sum(dim=1) / self.num_heads
|
444 |
+
return {"attn_output": attn_output, "attn_weights": attn_weights}
|
445 |
+
else:
|
446 |
+
return attn_output
|
447 |
+
|
448 |
+
|
449 |
+
class RelativeTransformerEncoderLayer(nn.Module):
|
450 |
+
"""
|
451 |
+
d_model: the number of expected features in the input (required).
|
452 |
+
nhead: the number of heads in the MultiHeadAttention models (required).
|
453 |
+
clipping_threshold: the clipping threshold for relative position embeddings
|
454 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
455 |
+
dropout: the dropout value (default=0.1).
|
456 |
+
activation: the activation function of the intermediate layer, can be a string
|
457 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
458 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
459 |
+
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
460 |
+
operations, respectively. Otherwise, it's done after. Default: ``False`` (after).
|
461 |
+
"""
|
462 |
+
|
463 |
+
# this is some kind of torch jit compiling helper... will also ensure these values don't change
|
464 |
+
__constants__ = ['batch_first', 'norm_first']
|
465 |
+
|
466 |
+
def __init__(self,
|
467 |
+
d_model,
|
468 |
+
nhead,
|
469 |
+
pos_encoding="relative",
|
470 |
+
clipping_threshold=3,
|
471 |
+
contact_threshold=7,
|
472 |
+
pdb_fns=None,
|
473 |
+
dim_feedforward=2048,
|
474 |
+
dropout=0.1,
|
475 |
+
activation=F.relu,
|
476 |
+
layer_norm_eps=1e-5,
|
477 |
+
norm_first=False) -> None:
|
478 |
+
|
479 |
+
self.batch_first = True
|
480 |
+
|
481 |
+
super(RelativeTransformerEncoderLayer, self).__init__()
|
482 |
+
|
483 |
+
self.self_attn = RelativeMultiHeadAttention(d_model, nhead, dropout,
|
484 |
+
pos_encoding, clipping_threshold, contact_threshold, pdb_fns)
|
485 |
+
|
486 |
+
# feed forward model
|
487 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
488 |
+
self.dropout = Dropout(dropout)
|
489 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
490 |
+
|
491 |
+
self.norm_first = norm_first
|
492 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
|
493 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
|
494 |
+
self.dropout1 = Dropout(dropout)
|
495 |
+
self.dropout2 = Dropout(dropout)
|
496 |
+
|
497 |
+
# Legacy string support for activation function.
|
498 |
+
if isinstance(activation, str):
|
499 |
+
self.activation = models.get_activation_fn(activation)
|
500 |
+
else:
|
501 |
+
self.activation = activation
|
502 |
+
|
503 |
+
def forward(self, src: Tensor, pdb_fn=None) -> Tensor:
|
504 |
+
x = src
|
505 |
+
if self.norm_first:
|
506 |
+
x = x + self._sa_block(self.norm1(x), pdb_fn=pdb_fn)
|
507 |
+
x = x + self._ff_block(self.norm2(x))
|
508 |
+
else:
|
509 |
+
x = self.norm1(x + self._sa_block(x))
|
510 |
+
x = self.norm2(x + self._ff_block(x))
|
511 |
+
|
512 |
+
return x
|
513 |
+
|
514 |
+
# self-attention block
|
515 |
+
def _sa_block(self, x: Tensor, pdb_fn=None) -> Tensor:
|
516 |
+
x = self.self_attn(x, x, x, pdb_fn=pdb_fn)
|
517 |
+
if isinstance(x, dict):
|
518 |
+
# handle the case where we are returning attention weights
|
519 |
+
x = x["attn_output"]
|
520 |
+
return self.dropout1(x)
|
521 |
+
|
522 |
+
# feed forward block
|
523 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
524 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
525 |
+
return self.dropout2(x)
|
526 |
+
|
527 |
+
|
528 |
+
class RelativeTransformerEncoder(nn.Module):
|
529 |
+
def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True):
|
530 |
+
super(RelativeTransformerEncoder, self).__init__()
|
531 |
+
# using get_clones means all layers have the same initialization
|
532 |
+
# this is also a problem in PyTorch's TransformerEncoder implementation, which this is based on
|
533 |
+
# todo: PyTorch is changing its transformer API... check up on and see if there is a better way
|
534 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
535 |
+
self.num_layers = num_layers
|
536 |
+
self.norm = norm
|
537 |
+
|
538 |
+
# important because get_clones means all layers have same initialization
|
539 |
+
# should recursively reset parameters for all submodules
|
540 |
+
if reset_params:
|
541 |
+
self.apply(models.reset_parameters_helper)
|
542 |
+
|
543 |
+
def forward(self, src: Tensor, pdb_fn=None) -> Tensor:
|
544 |
+
output = src
|
545 |
+
|
546 |
+
for mod in self.layers:
|
547 |
+
output = mod(output, pdb_fn=pdb_fn)
|
548 |
+
|
549 |
+
if self.norm is not None:
|
550 |
+
output = self.norm(output)
|
551 |
+
|
552 |
+
return output
|
553 |
+
|
554 |
+
|
555 |
+
def _get_clones(module, num_clones):
|
556 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clones)])
|
557 |
+
|
558 |
+
|
559 |
+
def _inv_dict(d):
|
560 |
+
""" helper function for contact map-based position embeddings """
|
561 |
+
inv = dict()
|
562 |
+
for k, v in d.items():
|
563 |
+
# collect dict keys into lists based on value
|
564 |
+
inv.setdefault(v, list()).append(k)
|
565 |
+
for k, v in inv.items():
|
566 |
+
# put in sorted order
|
567 |
+
inv[k] = sorted(v)
|
568 |
+
return inv
|
569 |
+
|
570 |
+
|
571 |
+
def _combine_d(d, threshold, combined_key):
|
572 |
+
""" helper function for contact map-based position embeddings
|
573 |
+
d is a dictionary with ints as keys and lists as values.
|
574 |
+
for all keys >= threshold, this function combines the values of those keys into a single list """
|
575 |
+
out_d = {}
|
576 |
+
for k, v in d.items():
|
577 |
+
if k < threshold:
|
578 |
+
out_d[k] = v
|
579 |
+
elif k >= threshold:
|
580 |
+
if combined_key not in out_d:
|
581 |
+
out_d[combined_key] = v
|
582 |
+
else:
|
583 |
+
out_d[combined_key] += v
|
584 |
+
if combined_key in out_d:
|
585 |
+
out_d[combined_key] = sorted(out_d[combined_key])
|
586 |
+
return out_d
|
metl/structure.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import isfile
|
3 |
+
from enum import Enum, auto
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from scipy.spatial.distance import cdist
|
7 |
+
import networkx as nx
|
8 |
+
from biopandas.pdb import PandasPdb
|
9 |
+
|
10 |
+
|
11 |
+
class GraphType(Enum):
|
12 |
+
LINEAR = auto()
|
13 |
+
COMPLETE = auto()
|
14 |
+
DISCONNECTED = auto()
|
15 |
+
DIST_THRESH = auto()
|
16 |
+
DIST_THRESH_SHUFFLED = auto()
|
17 |
+
|
18 |
+
|
19 |
+
def save_graph(g, fn):
|
20 |
+
""" Saves graph to file """
|
21 |
+
nx.write_gexf(g, fn)
|
22 |
+
|
23 |
+
|
24 |
+
def load_graph(fn):
|
25 |
+
""" Loads graph from file """
|
26 |
+
g = nx.read_gexf(fn, node_type=int)
|
27 |
+
return g
|
28 |
+
|
29 |
+
|
30 |
+
def shuffle_nodes(g, seed=7):
|
31 |
+
""" Shuffles the nodes of the given graph and returns a copy of the shuffled graph """
|
32 |
+
# get the list of nodes in this graph
|
33 |
+
nodes = g.nodes()
|
34 |
+
|
35 |
+
# create a permuted list of nodes
|
36 |
+
np.random.seed(seed)
|
37 |
+
nodes_shuffled = np.random.permutation(nodes)
|
38 |
+
|
39 |
+
# create a dictionary mapping from old node label to new node label
|
40 |
+
mapping = {n: ns for n, ns in zip(nodes, nodes_shuffled)}
|
41 |
+
|
42 |
+
g_shuffled = nx.relabel_nodes(g, mapping, copy=True)
|
43 |
+
|
44 |
+
return g_shuffled
|
45 |
+
|
46 |
+
|
47 |
+
def linear_graph(num_residues):
|
48 |
+
""" Creates a linear graph where each node is connected to its sequence neighbor in order """
|
49 |
+
g = nx.Graph()
|
50 |
+
g.add_nodes_from(np.arange(0, num_residues))
|
51 |
+
for i in range(num_residues-1):
|
52 |
+
g.add_edge(i, i+1)
|
53 |
+
return g
|
54 |
+
|
55 |
+
|
56 |
+
def complete_graph(num_residues):
|
57 |
+
""" Creates a graph where each node is connected to all other nodes"""
|
58 |
+
g = nx.complete_graph(num_residues)
|
59 |
+
return g
|
60 |
+
|
61 |
+
|
62 |
+
def disconnected_graph(num_residues):
|
63 |
+
g = nx.Graph()
|
64 |
+
g.add_nodes_from(np.arange(0, num_residues))
|
65 |
+
return g
|
66 |
+
|
67 |
+
|
68 |
+
def dist_thresh_graph(dist_mtx, threshold):
|
69 |
+
""" Creates undirected graph based on a distance threshold """
|
70 |
+
g = nx.Graph()
|
71 |
+
g.add_nodes_from(np.arange(0, dist_mtx.shape[0]))
|
72 |
+
|
73 |
+
# loop through each residue
|
74 |
+
for rn1 in range(len(dist_mtx)):
|
75 |
+
# find all residues that are within threshold distance of current
|
76 |
+
rns_within_threshold = np.where(dist_mtx[rn1] < threshold)[0]
|
77 |
+
|
78 |
+
# add edges from current residue to those that are within threshold
|
79 |
+
for rn2 in rns_within_threshold:
|
80 |
+
# don't add self edges
|
81 |
+
if rn1 != rn2:
|
82 |
+
g.add_edge(rn1, rn2)
|
83 |
+
return g
|
84 |
+
|
85 |
+
|
86 |
+
def ordered_adjacency_matrix(g):
|
87 |
+
""" returns the adjacency matrix ordered by node label in increasing order as a numpy array """
|
88 |
+
node_order = sorted(g.nodes())
|
89 |
+
adj_mtx = nx.to_numpy_matrix(g, nodelist=node_order)
|
90 |
+
return np.asarray(adj_mtx).astype(np.float32)
|
91 |
+
|
92 |
+
|
93 |
+
def cbeta_distance_matrix(pdb_fn, start=0, end=None):
|
94 |
+
# note that start and end are not going by residue number
|
95 |
+
# they are going by whatever the listing in the pdb file is
|
96 |
+
|
97 |
+
# read the pdb file into a biopandas object
|
98 |
+
ppdb = PandasPdb().read_pdb(pdb_fn)
|
99 |
+
|
100 |
+
# group by residue number
|
101 |
+
# important to specify sort=True so that group keys (residue number) are in order
|
102 |
+
# the reason is we loop through group keys below, and assume that residues are in order
|
103 |
+
# the pandas function has sort=True by default, but we specify it anyway because it is important
|
104 |
+
grouped = ppdb.df["ATOM"].groupby("residue_number", sort=True)
|
105 |
+
|
106 |
+
# a list of coords for the cbeta or calpha of each residue
|
107 |
+
coords = []
|
108 |
+
|
109 |
+
# loop through each residue and find the coordinates of cbeta
|
110 |
+
for i, (residue_number, values) in enumerate(grouped):
|
111 |
+
|
112 |
+
# skip residues not in the range
|
113 |
+
end_index = (len(grouped) if end is None else end)
|
114 |
+
if i not in range(start, end_index):
|
115 |
+
continue
|
116 |
+
|
117 |
+
residue_group = grouped.get_group(residue_number)
|
118 |
+
|
119 |
+
atom_names = residue_group["atom_name"]
|
120 |
+
if "CB" in atom_names.values:
|
121 |
+
# print("Using CB...")
|
122 |
+
atom_name = "CB"
|
123 |
+
elif "CA" in atom_names.values:
|
124 |
+
# print("Using CA...")
|
125 |
+
atom_name = "CA"
|
126 |
+
else:
|
127 |
+
raise ValueError("Couldn't find CB or CA for residue {}".format(residue_number))
|
128 |
+
|
129 |
+
# get the coordinates of cbeta (or calpha)
|
130 |
+
coords.append(
|
131 |
+
residue_group[residue_group["atom_name"] == atom_name][["x_coord", "y_coord", "z_coord"]].values[0])
|
132 |
+
|
133 |
+
# stack the coords into a numpy array where each row has the x,y,z coords for a different residue
|
134 |
+
coords = np.stack(coords)
|
135 |
+
|
136 |
+
# compute pairwise euclidean distance between all cbetas
|
137 |
+
dist_mtx = cdist(coords, coords, metric="euclidean")
|
138 |
+
|
139 |
+
return dist_mtx
|
140 |
+
|
141 |
+
|
142 |
+
def get_neighbors(g, nodes):
|
143 |
+
""" returns a list (set) of neighbors of all given nodes """
|
144 |
+
neighbors = set()
|
145 |
+
for n in nodes:
|
146 |
+
neighbors.update(g.neighbors(n))
|
147 |
+
return sorted(list(neighbors))
|
148 |
+
|
149 |
+
|
150 |
+
def gen_graph(graph_type, res_dist_mtx, dist_thresh=7, shuffle_seed=7, graph_save_dir=None, save=False):
|
151 |
+
""" generate the specified structure graph using the specified residue distance matrix """
|
152 |
+
if graph_type is GraphType.LINEAR:
|
153 |
+
g = linear_graph(len(res_dist_mtx))
|
154 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "linear.graph")
|
155 |
+
|
156 |
+
elif graph_type is GraphType.COMPLETE:
|
157 |
+
g = complete_graph(len(res_dist_mtx))
|
158 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "complete.graph")
|
159 |
+
|
160 |
+
elif graph_type is GraphType.DISCONNECTED:
|
161 |
+
g = disconnected_graph(len(res_dist_mtx))
|
162 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "disconnected.graph")
|
163 |
+
|
164 |
+
elif graph_type is GraphType.DIST_THRESH:
|
165 |
+
g = dist_thresh_graph(res_dist_mtx, dist_thresh)
|
166 |
+
save_fn = None if not save else os.path.join(graph_save_dir, "dist_thresh_{}.graph".format(dist_thresh))
|
167 |
+
|
168 |
+
elif graph_type is GraphType.DIST_THRESH_SHUFFLED:
|
169 |
+
g = dist_thresh_graph(res_dist_mtx, dist_thresh)
|
170 |
+
g = shuffle_nodes(g, seed=shuffle_seed)
|
171 |
+
save_fn = None if not save else \
|
172 |
+
os.path.join(graph_save_dir, "dist_thresh_{}_shuffled_r{}.graph".format(dist_thresh, shuffle_seed))
|
173 |
+
|
174 |
+
else:
|
175 |
+
raise ValueError("Graph type {} is not implemented".format(graph_type))
|
176 |
+
|
177 |
+
if save:
|
178 |
+
if isfile(save_fn):
|
179 |
+
print("err: graph already exists: {}. to overwrite, delete the existing file first".format(save_fn))
|
180 |
+
else:
|
181 |
+
os.makedirs(graph_save_dir, exist_ok=True)
|
182 |
+
save_graph(g, save_fn)
|
183 |
+
|
184 |
+
return g
|