jie1 commited on
Commit
95d4ac0
1 Parent(s): ccbe866

Upload 2 files

Browse files
ProteinMPNN-main/protein_mpnn_run.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path
3
+
4
+
5
+ def p_m_r(ca_only, path_to_model_weights, model_name, seed, save_score, save_probs, score_only, conditional_probs_only,
6
+ conditional_probs_only_backbone
7
+ , unconditional_probs_only, backbone_noise, num_seq_per_target, batch_size, max_length, sampling_temp, out_folder,
8
+ pdb_path, pdb_path_chains, jsonl_path, chain_id_jsonl, fixed_positions_jsonl, omit_AAs, bias_AA_jsonl, bias_by_res_jsonl
9
+ , omit_AA_jsonl, pssm_jsonl, pssm_multi, pssm_threshold, pssm_log_odds_flag, pssm_bias_flag, tied_positions_jsonl):
10
+ seed = int(seed)
11
+ save_score = int(save_score)
12
+ save_probs = int(save_probs)
13
+ score_only = int(score_only)
14
+ conditional_probs_only = int(conditional_probs_only)
15
+ conditional_probs_only_backbone = int(conditional_probs_only_backbone)
16
+ unconditional_probs_only = int(unconditional_probs_only)
17
+ num_seq_per_target = int(num_seq_per_target)
18
+ batch_size = int(batch_size)
19
+ max_length = int(max_length)
20
+ pssm_log_odds_flag = int(pssm_log_odds_flag)
21
+ pssm_bias_flag = int(pssm_bias_flag)
22
+ import json, time, os, sys, glob
23
+ import shutil
24
+ import warnings
25
+ import numpy as np
26
+ import torch
27
+ from torch import optim
28
+ from torch.utils.data import DataLoader
29
+ from torch.utils.data.dataset import random_split, Subset
30
+ import copy
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ import random
34
+ import os.path
35
+ import subprocess
36
+
37
+ from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, \
38
+ cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
39
+ from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN
40
+
41
+ if seed:
42
+ seed = seed
43
+ else:
44
+ seed = int(np.random.randint(0, high=999, size=1, dtype=int)[0])
45
+
46
+ torch.manual_seed(seed)
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+
50
+ hidden_dim = 128
51
+ num_layers = 3
52
+
53
+ if path_to_model_weights:
54
+ model_folder_path = path_to_model_weights
55
+ if model_folder_path[-1] != '/':
56
+ model_folder_path = model_folder_path + '/'
57
+ else:
58
+ file_path = os.path.realpath(__file__)
59
+ # 改
60
+ k = file_path.rfind("\\")
61
+ if ca_only:
62
+ model_folder_path = file_path[:k] + '/ca_model_weights/'
63
+ else:
64
+ model_folder_path = file_path[:k] + '/vanilla_model_weights/'
65
+
66
+ checkpoint_path = model_folder_path + f'{model_name}.pt'
67
+ folder_for_outputs = out_folder
68
+
69
+ NUM_BATCHES = num_seq_per_target // batch_size
70
+ BATCH_COPIES = batch_size
71
+ temperatures = [float(item) for item in sampling_temp.split()]
72
+ omit_AAs_list = omit_AAs
73
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
74
+
75
+ omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
76
+ device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
77
+ # os.path.isfile():判断某一对象(需提供绝对路径)是否为文件
78
+ # 改
79
+ if chain_id_jsonl:
80
+ if os.path.isfile(chain_id_jsonl.name):
81
+ with open(chain_id_jsonl.name, 'r') as json_file:
82
+ json_list = list(json_file)
83
+ for json_str in json_list:
84
+ chain_id_dict = json.loads(json_str)
85
+ else:
86
+ chain_id_dict = None
87
+ print(40 * '-')
88
+ print('chain_id_jsonl is NOT loaded')
89
+ if fixed_positions_jsonl:
90
+ if os.path.isfile(fixed_positions_jsonl.name):
91
+ with open(fixed_positions_jsonl.name, 'r') as json_file:
92
+ json_list = list(json_file)
93
+ for json_str in json_list:
94
+ fixed_positions_dict = json.loads(json_str)
95
+ else:
96
+ print(40 * '-')
97
+ print('fixed_positions_jsonl is NOT loaded')
98
+ fixed_positions_dict = None
99
+
100
+ if os.path.isfile(pssm_jsonl):
101
+ with open(pssm_jsonl, 'r') as json_file:
102
+ json_list = list(json_file)
103
+ pssm_dict = {}
104
+ for json_str in json_list:
105
+ pssm_dict.update(json.loads(json_str))
106
+ else:
107
+ print(40 * '-')
108
+ print('pssm_jsonl is NOT loaded')
109
+ pssm_dict = None
110
+
111
+ if os.path.isfile(omit_AA_jsonl):
112
+ with open(omit_AA_jsonl, 'r') as json_file:
113
+ json_list = list(json_file)
114
+ for json_str in json_list:
115
+ omit_AA_dict = json.loads(json_str)
116
+ else:
117
+ print(40 * '-')
118
+ print('omit_AA_jsonl is NOT loaded')
119
+ omit_AA_dict = None
120
+ if bias_AA_jsonl:
121
+ if os.path.isfile(bias_AA_jsonl.name):
122
+ with open(bias_AA_jsonl.name, 'r') as json_file:
123
+ json_list = list(json_file)
124
+ for json_str in json_list:
125
+ bias_AA_dict = json.loads(json_str)
126
+ else:
127
+ print(40 * '-')
128
+ print('bias_AA_jsonl is NOT loaded')
129
+ bias_AA_dict = None
130
+ if tied_positions_jsonl:
131
+ if os.path.isfile(tied_positions_jsonl.name):
132
+ with open(tied_positions_jsonl.name, 'r') as json_file:
133
+ json_list = list(json_file)
134
+ for json_str in json_list:
135
+ tied_positions_dict = json.loads(json_str)
136
+ else:
137
+ print(40 * '-')
138
+ print('tied_positions_jsonl is NOT loaded')
139
+ tied_positions_dict = None
140
+
141
+ if os.path.isfile(bias_by_res_jsonl):
142
+ with open(bias_by_res_jsonl, 'r') as json_file:
143
+ json_list = list(json_file)
144
+
145
+ for json_str in json_list:
146
+ bias_by_res_dict = json.loads(json_str)
147
+ print('bias by residue dictionary is loaded')
148
+ else:
149
+ print(40 * '-')
150
+ print('bias by residue dictionary is not loaded, or not provided')
151
+ bias_by_res_dict = None
152
+
153
+ print(40 * '-')
154
+ bias_AAs_np = np.zeros(len(alphabet))
155
+ if bias_AA_dict:
156
+ for n, AA in enumerate(alphabet):
157
+ if AA in list(bias_AA_dict.keys()):
158
+ bias_AAs_np[n] = bias_AA_dict[AA]
159
+
160
+ # 改
161
+ if pdb_path:
162
+
163
+ pdb_dict_list = parse_PDB(pdb_path.name, ca_only=ca_only)
164
+ dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
165
+ all_chain_list = [item[-1:] for item in list(pdb_dict_list[0]) if item[:9] == 'seq_chain'] # ['A','B', 'C',...]
166
+ if pdb_path_chains:
167
+ designed_chain_list = [str(item) for item in pdb_path_chains.split()]
168
+ else:
169
+ designed_chain_list = all_chain_list
170
+ fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list]
171
+ chain_id_dict = {}
172
+ chain_id_dict[pdb_dict_list[0]['name']] = (designed_chain_list, fixed_chain_list)
173
+ else:
174
+ dataset_valid = StructureDataset(jsonl_path.name, truncate=None, max_length=max_length)
175
+
176
+ print(40 * '-')
177
+ checkpoint = torch.load(checkpoint_path, map_location=device)
178
+ print('Number of edges:', checkpoint['num_edges'])
179
+ noise_level_print = checkpoint['noise_level']
180
+ print(f'Training noise level: {noise_level_print}A')
181
+ model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim,
182
+ hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers,
183
+ augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
184
+ model.to(device)
185
+ model.load_state_dict(checkpoint['model_state_dict'])
186
+ model.eval()
187
+
188
+ # Build paths for experiment
189
+ base_folder = folder_for_outputs
190
+ if base_folder[-1] != '/':
191
+ base_folder = base_folder + '/'
192
+ if not os.path.exists(base_folder):
193
+ os.makedirs(base_folder)
194
+
195
+ if not os.path.exists(base_folder + 'seqs'):
196
+ os.makedirs(base_folder + 'seqs')
197
+
198
+ if save_score:
199
+ if not os.path.exists(base_folder + 'scores'):
200
+ os.makedirs(base_folder + 'scores')
201
+
202
+ if score_only:
203
+ if not os.path.exists(base_folder + 'score_only'):
204
+ os.makedirs(base_folder + 'score_only')
205
+
206
+ if conditional_probs_only:
207
+ if not os.path.exists(base_folder + 'conditional_probs_only'):
208
+ os.makedirs(base_folder + 'conditional_probs_only')
209
+
210
+ if unconditional_probs_only:
211
+ if not os.path.exists(base_folder + 'unconditional_probs_only'):
212
+ os.makedirs(base_folder + 'unconditional_probs_only')
213
+
214
+ if save_probs:
215
+ if not os.path.exists(base_folder + 'probs'):
216
+ os.makedirs(base_folder + 'probs')
217
+
218
+ # Timing
219
+ start_time = time.time()
220
+ total_residues = 0
221
+ protein_list = []
222
+ total_step = 0
223
+ # Validation epoch
224
+ with torch.no_grad():
225
+ test_sum, test_weights = 0., 0.
226
+ # print('Generating sequences...')
227
+ # 改
228
+ results = []
229
+ # enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
230
+ for ix, protein in enumerate(dataset_valid):
231
+ score_list = []
232
+ global_score_list = []
233
+ all_probs_list = []
234
+ all_log_probs_list = []
235
+ S_sample_list = []
236
+ # deepcopy复制
237
+ batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
238
+ X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(
239
+ batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict,
240
+ bias_by_res_dict, ca_only=ca_only)
241
+ pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() # 1.0 for true, 0.0 for false
242
+ name_ = batch_clones[0]['name']
243
+ if score_only:
244
+ structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + '.npz'
245
+ native_score_list = []
246
+ global_native_score_list = []
247
+ for j in range(NUM_BATCHES):
248
+ randn_1 = torch.randn(chain_M.shape, device=X.device)
249
+ log_probs = model(X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1)
250
+ mask_for_loss = mask * chain_M * chain_M_pos
251
+ scores = _scores(S, log_probs, mask_for_loss)
252
+ native_score = scores.cpu().data.numpy()
253
+ native_score_list.append(native_score)
254
+ global_scores = _scores(S, log_probs, mask)
255
+ global_native_score = global_scores.cpu().data.numpy()
256
+ global_native_score_list.append(global_native_score)
257
+ native_score = np.concatenate(native_score_list, 0)
258
+ global_native_score = np.concatenate(global_native_score_list, 0)
259
+ ns_mean = native_score.mean()
260
+ ns_mean_print = np.format_float_positional(np.float32(ns_mean), unique=False, precision=4)
261
+ ns_std = native_score.std()
262
+ ns_std_print = np.format_float_positional(np.float32(ns_std), unique=False, precision=4)
263
+
264
+ global_ns_mean = global_native_score.mean()
265
+ global_ns_mean_print = np.format_float_positional(np.float32(global_ns_mean), unique=False, precision=4)
266
+ global_ns_std = global_native_score.std()
267
+ global_ns_std_print = np.format_float_positional(np.float32(global_ns_std), unique=False, precision=4)
268
+
269
+ ns_sample_size = native_score.shape[0]
270
+ np.savez(structure_sequence_score_file, score=native_score, global_score=global_native_score)
271
+ print(
272
+ f'Score for {name_}, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size}, Global Score for {name_}, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
273
+ results.append(structure_sequence_score_file)
274
+ elif conditional_probs_only:
275
+ print(f'Calculating conditional probabilities for {name_}')
276
+ conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
277
+ log_conditional_probs_list = []
278
+ for j in range(NUM_BATCHES):
279
+ randn_1 = torch.randn(chain_M.shape, device=X.device)
280
+ log_conditional_probs = model.conditional_probs(X, S, mask, chain_M * chain_M_pos, residue_idx,
281
+ chain_encoding_all, randn_1,
282
+ conditional_probs_only_backbone)
283
+ log_conditional_probs_list.append(log_conditional_probs.cpu().numpy())
284
+ concat_log_p = np.concatenate(log_conditional_probs_list, 0) # [B, L, 21]
285
+ mask_out = (chain_M * chain_M_pos * mask)[0,].cpu().numpy()
286
+ np.savez(conditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(),
287
+ mask=mask[0,].cpu().numpy(), design_mask=mask_out)
288
+ elif unconditional_probs_only:
289
+ print(f'Calculating sequence unconditional probabilities for {name_}')
290
+ # 改
291
+ unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name'] + '.npz'
292
+ log_unconditional_probs_list = []
293
+ for j in range(NUM_BATCHES):
294
+ log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
295
+ log_unconditional_probs_list.append(log_unconditional_probs.cpu().numpy())
296
+ concat_log_p = np.concatenate(log_unconditional_probs_list, 0) # [B, L, 21]
297
+ mask_out = (chain_M * chain_M_pos * mask)[0,].cpu().numpy()
298
+ np.savez(unconditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(),
299
+ mask=mask[0,].cpu().numpy(), design_mask=mask_out)
300
+ results.append(unconditional_probs_only_file)
301
+ else:
302
+ randn_1 = torch.randn(chain_M.shape, device=X.device)
303
+ log_probs = model(X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1)
304
+ mask_for_loss = mask * chain_M * chain_M_pos
305
+ scores = _scores(S, log_probs, mask_for_loss) # score only the redesigned part
306
+ native_score = scores.cpu().data.numpy()
307
+ global_scores = _scores(S, log_probs, mask) # score the whole structure-sequence
308
+ global_native_score = global_scores.cpu().data.numpy()
309
+ # Generate some sequences
310
+ ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
311
+ score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
312
+ probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
313
+ print(f'Generating sequences for: {name_}')
314
+ t0 = time.time()
315
+ with open(ali_file, 'w') as f:
316
+ for temp in temperatures:
317
+ for j in range(NUM_BATCHES):
318
+ randn_2 = torch.randn(chain_M.shape, device=X.device)
319
+ if tied_positions_dict == None:
320
+ sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
321
+ mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
322
+ bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
323
+ omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
324
+ pssm_bias=pssm_bias, pssm_multi=pssm_multi,
325
+ pssm_log_odds_flag=bool(pssm_log_odds_flag),
326
+ pssm_log_odds_mask=pssm_log_odds_mask,
327
+ pssm_bias_flag=bool(pssm_bias_flag),
328
+ bias_by_res=bias_by_res_all)
329
+ S_sample = sample_dict["S"]
330
+ else:
331
+ sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx,
332
+ mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np,
333
+ bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos,
334
+ omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef,
335
+ pssm_bias=pssm_bias, pssm_multi=pssm_multi,
336
+ pssm_log_odds_flag=bool(pssm_log_odds_flag),
337
+ pssm_log_odds_mask=pssm_log_odds_mask,
338
+ pssm_bias_flag=bool(pssm_bias_flag),
339
+ tied_pos=tied_pos_list_of_lists_list[0],
340
+ tied_beta=tied_beta, bias_by_res=bias_by_res_all)
341
+ # Compute scores
342
+ S_sample = sample_dict["S"]
343
+ log_probs = model(X, S_sample, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all,
344
+ randn_2, use_input_decoding_order=True,
345
+ decoding_order=sample_dict["decoding_order"])
346
+ mask_for_loss = mask * chain_M * chain_M_pos
347
+ scores = _scores(S_sample, log_probs, mask_for_loss)
348
+ scores = scores.cpu().data.numpy()
349
+
350
+ global_scores = _scores(S_sample, log_probs, mask) # score the whole structure-sequence
351
+ global_scores = global_scores.cpu().data.numpy()
352
+
353
+ all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
354
+ all_log_probs_list.append(log_probs.cpu().data.numpy())
355
+ S_sample_list.append(S_sample.cpu().data.numpy())
356
+ for b_ix in range(BATCH_COPIES):
357
+ masked_chain_length_list = masked_chain_length_list_list[b_ix]
358
+ masked_list = masked_list_list[b_ix]
359
+ seq_recovery_rate = torch.sum(torch.sum(
360
+ torch.nn.functional.one_hot(S[b_ix], 21) * torch.nn.functional.one_hot(
361
+ S_sample[b_ix], 21), axis=-1) * mask_for_loss[b_ix]) / torch.sum(
362
+ mask_for_loss[b_ix])
363
+ seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
364
+ score = scores[b_ix]
365
+ score_list.append(score)
366
+ global_score = global_scores[b_ix]
367
+ global_score_list.append(global_score)
368
+ native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
369
+ if b_ix == 0 and j == 0 and temp == temperatures[0]:
370
+ start = 0
371
+ end = 0
372
+ list_of_AAs = []
373
+ for mask_l in masked_chain_length_list:
374
+ end += mask_l
375
+ list_of_AAs.append(native_seq[start:end])
376
+ start = end
377
+ native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
378
+ l0 = 0
379
+ for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[
380
+ :-1]:
381
+ l0 += mc_length
382
+ native_seq = native_seq[:l0] + '/' + native_seq[l0:]
383
+ l0 += 1
384
+ sorted_masked_chain_letters = np.argsort(masked_list_list[0])
385
+ print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
386
+ sorted_visible_chain_letters = np.argsort(visible_list_list[0])
387
+ print_visible_chains = [visible_list_list[0][i] for i in
388
+ sorted_visible_chain_letters]
389
+ native_score_print = np.format_float_positional(np.float32(native_score.mean()),
390
+ unique=False, precision=4)
391
+ global_native_score_print = np.format_float_positional(
392
+ np.float32(global_native_score.mean()), unique=False, precision=4)
393
+ script_dir = os.path.dirname(os.path.realpath(__file__))
394
+ try:
395
+ commit_str = subprocess.check_output(
396
+ f'git --git-dir {script_dir}/.git rev-parse HEAD',
397
+ shell=True).decode().strip()
398
+ except subprocess.CalledProcessError:
399
+ commit_str = 'unknown'
400
+ if ca_only:
401
+ print_model_name = 'CA_model_name'
402
+ else:
403
+ print_model_name = 'model_name'
404
+ f.write(
405
+ '>{}, score={}, global_score={}, fixed_chains={}, designed_chains={}, {}={}, git_hash={}, seed={}\n{}\n'.format(
406
+ name_, native_score_print, global_native_score_print, print_visible_chains,
407
+ print_masked_chains, print_model_name, model_name, commit_str, seed,
408
+ native_seq)) # write the native sequence
409
+ start = 0
410
+ end = 0
411
+ list_of_AAs = []
412
+ for mask_l in masked_chain_length_list:
413
+ end += mask_l
414
+ list_of_AAs.append(seq[start:end])
415
+ start = end
416
+
417
+ seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
418
+ l0 = 0
419
+ for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
420
+ l0 += mc_length
421
+ seq = seq[:l0] + '/' + seq[l0:]
422
+ l0 += 1
423
+ score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
424
+ global_score_print = np.format_float_positional(np.float32(global_score), unique=False,
425
+ precision=4)
426
+ seq_rec_print = np.format_float_positional(
427
+ np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
428
+ sample_number = j * BATCH_COPIES + b_ix + 1
429
+ f.write(
430
+ '>T={}, sample={}, score={}, global_score={}, seq_recovery={}\n{}\n'.format(temp,
431
+ sample_number,
432
+ score_print,
433
+ global_score_print,
434
+ seq_rec_print,
435
+ seq)) # write generated sequence
436
+ results.append(ali_file)
437
+ if save_score:
438
+ np.savez(score_file, score=np.array(score_list, np.float32),
439
+ global_score=np.array(global_score_list, np.float32))
440
+ if save_probs:
441
+ all_probs_concat = np.concatenate(all_probs_list)
442
+ all_log_probs_concat = np.concatenate(all_log_probs_list)
443
+ S_sample_concat = np.concatenate(S_sample_list)
444
+ np.savez(probs_file, probs=np.array(all_probs_concat, np.float32),
445
+ log_probs=np.array(all_log_probs_concat, np.float32),
446
+ S=np.array(S_sample_concat, np.int32), mask=mask_for_loss.cpu().data.numpy(),
447
+ chain_order=chain_list_list)
448
+ t1 = time.time()
449
+ dt = round(float(t1 - t0), 4)
450
+ num_seqs = len(temperatures) * NUM_BATCHES * BATCH_COPIES
451
+ total_length = X.shape[1]
452
+ print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
453
+
454
+ return results
455
+
456
+
457
+ # if __name__ == "__main__":
458
+ # argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
459
+ #
460
+ # argparser.add_argument("--ca_only", action="store_true", default=False,
461
+ # help="Parse CA-only structures and use CA-only models (default: false)")
462
+ # argparser.add_argument("--path_to_model_weights", type=str, default="", help="Path to model weights folder;")
463
+ # argparser.add_argument("--model_name", type=str, default="v_48_020",
464
+ # help="ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030; v_48_010=version with 48 edges 0.10A noise")
465
+ #
466
+ # argparser.add_argument("--seed", type=int, default=0, help="If set to 0 then a random seed will be picked;")
467
+ #
468
+ # argparser.add_argument("--save_score", type=int, default=0,
469
+ # help="0 for False, 1 for True; save score=-log_prob to npy files")
470
+ # argparser.add_argument("--save_probs", type=int, default=0,
471
+ # help="0 for False, 1 for True; save MPNN predicted probabilites per position")
472
+ #
473
+ # argparser.add_argument("--score_only", type=int, default=0,
474
+ # help="0 for False, 1 for True; score input backbone-sequence pairs")
475
+ #
476
+ # argparser.add_argument("--conditional_probs_only", type=int, default=0,
477
+ # help="0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)")
478
+ # argparser.add_argument("--conditional_probs_only_backbone", type=int, default=0,
479
+ # help="0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)")
480
+ # argparser.add_argument("--unconditional_probs_only", type=int, default=0,
481
+ # help="0 for False, 1 for True; output unconditional probabilities p(s_i given backbone) in one forward pass")
482
+ #
483
+ # argparser.add_argument("--backbone_noise", type=float, default=0.00,
484
+ # help="Standard deviation of Gaussian noise to add to backbone atoms")
485
+ # argparser.add_argument("--num_seq_per_target", type=int, default=1,
486
+ # help="Number of sequences to generate per target")
487
+ # argparser.add_argument("--batch_size", type=int, default=1,
488
+ # help="Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory")
489
+ # argparser.add_argument("--max_length", type=int, default=200000, help="Max sequence length")
490
+ # argparser.add_argument("--sampling_temp", type=str, default="0.1",
491
+ # help="A string of temperatures, 0.2 0.25 0.5. Sampling temperature for amino acids. Suggested values 0.1, 0.15, 0.2, 0.25, 0.3. Higher values will lead to more diversity.")
492
+ #
493
+ # argparser.add_argument("--out_folder", type=str, help="Path to a folder to output sequences, e.g. /home/out/")
494
+ # argparser.add_argument("--pdb_path", type=str, default='', help="Path to a single PDB to be designed")
495
+ # argparser.add_argument("--pdb_path_chains", type=str, default='',
496
+ # help="Define which chains need to be designed for a single PDB ")
497
+ # argparser.add_argument("--jsonl_path", type=str, help="Path to a folder with parsed pdb into jsonl")
498
+ # argparser.add_argument("--chain_id_jsonl", type=str, default='',
499
+ # help="Path to a dictionary specifying which chains need to be designed and which ones are fixed, if not specied all chains will be designed.")
500
+ # argparser.add_argument("--fixed_positions_jsonl", type=str, default='',
501
+ # help="Path to a dictionary with fixed positions")
502
+ # argparser.add_argument("--omit_AAs", type=list, default='X',
503
+ # help="Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.")
504
+ # argparser.add_argument("--bias_AA_jsonl", type=str, default='',
505
+ # help="Path to a dictionary which specifies AA composion bias if neededi, e.g. {A: -1.1, F: 0.7} would make A less likely and F more likely.")
506
+ #
507
+ # argparser.add_argument("--bias_by_res_jsonl", default='', help="Path to dictionary with per position bias.")
508
+ # argparser.add_argument("--omit_AA_jsonl", type=str, default='',
509
+ # help="Path to a dictionary which specifies which amino acids need to be omited from design at specific chain indices")
510
+ # argparser.add_argument("--pssm_jsonl", type=str, default='', help="Path to a dictionary with pssm")
511
+ # argparser.add_argument("--pssm_multi", type=float, default=0.0,
512
+ # help="A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions")
513
+ # argparser.add_argument("--pssm_threshold", type=float, default=0.0,
514
+ # help="A value between -inf + inf to restric per position AAs")
515
+ # argparser.add_argument("--pssm_log_odds_flag", type=int, default=0, help="0 for False, 1 for True")
516
+ # argparser.add_argument("--pssm_bias_flag", type=int, default=0, help="0 for False, 1 for True")
517
+ #
518
+ # argparser.add_argument("--tied_positions_jsonl", type=str, default='',
519
+ # help="Path to a dictionary with tied positions")
520
+ #
521
+ # args = argparser.parse_args()
522
+ # main(args)
ProteinMPNN-main/protein_mpnn_utils.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import json, time, os, sys, glob
3
+ import shutil
4
+ import numpy as np
5
+ import torch
6
+ from torch import optim
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data.dataset import random_split, Subset
9
+
10
+ import copy
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import random
14
+ import itertools
15
+
16
+ #A number of functions/classes are adopted from: https://github.com/jingraham/neurips19-graph-protein-design
17
+
18
+ def _scores(S, log_probs, mask):
19
+ """ Negative log probabilities """
20
+ criterion = torch.nn.NLLLoss(reduction='none')
21
+ loss = criterion(
22
+ log_probs.contiguous().view(-1,log_probs.size(-1)),
23
+ S.contiguous().view(-1)
24
+ ).view(S.size())
25
+ scores = torch.sum(loss * mask, dim=-1) / torch.sum(mask, dim=-1)
26
+ return scores
27
+
28
+ def _S_to_seq(S, mask):
29
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
30
+ seq = ''.join([alphabet[c] for c, m in zip(S.tolist(), mask.tolist()) if m > 0])
31
+ return seq
32
+
33
+ def parse_PDB_biounits(x, atoms=['N','CA','C'], chain=None):
34
+ '''
35
+ input: x = PDB filename
36
+ atoms = atoms to extract (optional)
37
+ output: (length, atoms, coords=(x,y,z)), sequence
38
+ '''
39
+
40
+ alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
41
+ states = len(alpha_1)
42
+ alpha_3 = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE',
43
+ 'LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL','GAP']
44
+
45
+ aa_1_N = {a:n for n,a in enumerate(alpha_1)}
46
+ aa_3_N = {a:n for n,a in enumerate(alpha_3)}
47
+ aa_N_1 = {n:a for n,a in enumerate(alpha_1)}
48
+ aa_1_3 = {a:b for a,b in zip(alpha_1,alpha_3)}
49
+ aa_3_1 = {b:a for a,b in zip(alpha_1,alpha_3)}
50
+
51
+ def AA_to_N(x):
52
+ # ["ARND"] -> [[0,1,2,3]]
53
+ x = np.array(x);
54
+ if x.ndim == 0: x = x[None]
55
+ return [[aa_1_N.get(a, states-1) for a in y] for y in x]
56
+
57
+ def N_to_AA(x):
58
+ # [[0,1,2,3]] -> ["ARND"]
59
+ x = np.array(x);
60
+ if x.ndim == 1: x = x[None]
61
+ return ["".join([aa_N_1.get(a,"-") for a in y]) for y in x]
62
+
63
+ xyz,seq,min_resn,max_resn = {},{},1e6,-1e6
64
+ for line in open(x,"rb"):
65
+ line = line.decode("utf-8","ignore").rstrip()
66
+
67
+ if line[:6] == "HETATM" and line[17:17+3] == "MSE":
68
+ line = line.replace("HETATM","ATOM ")
69
+ line = line.replace("MSE","MET")
70
+
71
+ if line[:4] == "ATOM":
72
+ ch = line[21:22]
73
+ if ch == chain or chain is None:
74
+ atom = line[12:12+4].strip()
75
+ resi = line[17:17+3]
76
+ resn = line[22:22+5].strip()
77
+ x,y,z = [float(line[i:(i+8)]) for i in [30,38,46]]
78
+
79
+ if resn[-1].isalpha():
80
+ resa,resn = resn[-1],int(resn[:-1])-1
81
+ else:
82
+ resa,resn = "",int(resn)-1
83
+ # resn = int(resn)
84
+ if resn < min_resn:
85
+ min_resn = resn
86
+ if resn > max_resn:
87
+ max_resn = resn
88
+ if resn not in xyz:
89
+ xyz[resn] = {}
90
+ if resa not in xyz[resn]:
91
+ xyz[resn][resa] = {}
92
+ if resn not in seq:
93
+ seq[resn] = {}
94
+ if resa not in seq[resn]:
95
+ seq[resn][resa] = resi
96
+
97
+ if atom not in xyz[resn][resa]:
98
+ xyz[resn][resa][atom] = np.array([x,y,z])
99
+
100
+ # convert to numpy arrays, fill in missing values
101
+ seq_,xyz_ = [],[]
102
+ try:
103
+ for resn in range(min_resn,max_resn+1):
104
+ if resn in seq:
105
+ for k in sorted(seq[resn]): seq_.append(aa_3_N.get(seq[resn][k],20))
106
+ else: seq_.append(20)
107
+ if resn in xyz:
108
+ for k in sorted(xyz[resn]):
109
+ for atom in atoms:
110
+ if atom in xyz[resn][k]: xyz_.append(xyz[resn][k][atom])
111
+ else: xyz_.append(np.full(3,np.nan))
112
+ else:
113
+ for atom in atoms: xyz_.append(np.full(3,np.nan))
114
+ return np.array(xyz_).reshape(-1,len(atoms),3), N_to_AA(np.array(seq_))
115
+ except TypeError:
116
+ return 'no_chain', 'no_chain'
117
+
118
+ def parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False):
119
+ c=0
120
+ pdb_dict_list = []
121
+ init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
122
+ extra_alphabet = [str(item) for item in list(np.arange(300))]
123
+ chain_alphabet = init_alphabet + extra_alphabet
124
+
125
+ if input_chain_list:
126
+ chain_alphabet = input_chain_list
127
+
128
+
129
+ biounit_names = [path_to_pdb]
130
+ for biounit in biounit_names:
131
+ my_dict = {}
132
+ s = 0
133
+ concat_seq = ''
134
+ concat_N = []
135
+ concat_CA = []
136
+ concat_C = []
137
+ concat_O = []
138
+ concat_mask = []
139
+ coords_dict = {}
140
+ for letter in chain_alphabet:
141
+ if ca_only:
142
+ sidechain_atoms = ['CA']
143
+ else:
144
+ sidechain_atoms = ['N', 'CA', 'C', 'O']
145
+ xyz, seq = parse_PDB_biounits(biounit, atoms=sidechain_atoms, chain=letter)
146
+ if type(xyz) != str:
147
+ concat_seq += seq[0]
148
+ my_dict['seq_chain_'+letter]=seq[0]
149
+ coords_dict_chain = {}
150
+ if ca_only:
151
+ coords_dict_chain['CA_chain_'+letter]=xyz.tolist()
152
+ else:
153
+ coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist()
154
+ coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist()
155
+ coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist()
156
+ coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
157
+ my_dict['coords_chain_'+letter]=coords_dict_chain
158
+ s += 1
159
+ # g改
160
+ fi = biounit.rfind("\\")
161
+ my_dict['name']=biounit[(fi+1):(fi+5)]
162
+ my_dict['num_of_chains'] = s
163
+ my_dict['seq'] = concat_seq
164
+ if s <= len(chain_alphabet):
165
+ pdb_dict_list.append(my_dict)
166
+ c+=1
167
+ return pdb_dict_list
168
+
169
+
170
+
171
+ def tied_featurize(batch, device, chain_dict, fixed_position_dict=None, omit_AA_dict=None, tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None, ca_only=False):
172
+ """ Pack and pad batch into torch tensors """
173
+ alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
174
+ B = len(batch)
175
+ lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
176
+ L_max = max([len(b['seq']) for b in batch])
177
+ if ca_only:
178
+ X = np.zeros([B, L_max, 1, 3])
179
+ else:
180
+ X = np.zeros([B, L_max, 4, 3])
181
+ residue_idx = -100*np.ones([B, L_max], dtype=np.int32)
182
+ chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
183
+ pssm_coef_all = np.zeros([B, L_max], dtype=np.float32) #1.0 for the bits that need to be predicted
184
+ pssm_bias_all = np.zeros([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
185
+ pssm_log_odds_all = 10000.0*np.ones([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted
186
+ chain_M_pos = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
187
+ bias_by_res_all = np.zeros([B, L_max, 21], dtype=np.float32)
188
+ chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted
189
+ S = np.zeros([B, L_max], dtype=np.int32)
190
+ omit_AA_mask = np.zeros([B, L_max, len(alphabet)], dtype=np.int32)
191
+ # Build the batch
192
+ letter_list_list = []
193
+ visible_list_list = []
194
+ masked_list_list = []
195
+ masked_chain_length_list_list = []
196
+ tied_pos_list_of_lists_list = []
197
+ #shuffle all chains before the main loop
198
+ for i, b in enumerate(batch):
199
+ if chain_dict != None:
200
+ masked_chains, visible_chains = chain_dict[b['name']] #masked_chains a list of chain letters to predict [A, D, F]
201
+ else:
202
+ masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_']
203
+ visible_chains = []
204
+ num_chains = b['num_of_chains']
205
+ all_chains = masked_chains + visible_chains
206
+ #random.shuffle(all_chains)
207
+ for i, b in enumerate(batch):
208
+ mask_dict = {}
209
+ a = 0
210
+ x_chain_list = []
211
+ chain_mask_list = []
212
+ chain_seq_list = []
213
+ chain_encoding_list = []
214
+ c = 1
215
+ letter_list = []
216
+ global_idx_start_list = [0]
217
+ visible_list = []
218
+ masked_list = []
219
+ masked_chain_length_list = []
220
+ fixed_position_mask_list = []
221
+ omit_AA_mask_list = []
222
+ pssm_coef_list = []
223
+ pssm_bias_list = []
224
+ pssm_log_odds_list = []
225
+ bias_by_res_list = []
226
+ l0 = 0
227
+ l1 = 0
228
+ for step, letter in enumerate(all_chains):
229
+ if letter in visible_chains:
230
+ letter_list.append(letter)
231
+ visible_list.append(letter)
232
+ chain_seq = b[f'seq_chain_{letter}']
233
+ chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
234
+ chain_length = len(chain_seq)
235
+ global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
236
+ chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
237
+ chain_mask = np.zeros(chain_length) #0.0 for visible chains
238
+ if ca_only:
239
+ x_chain = np.array(chain_coords[f'CA_chain_{letter}']) #[chain_lenght,1,3] #CA_diff
240
+ if len(x_chain.shape) == 2:
241
+ x_chain = x_chain[:,None,:]
242
+ else:
243
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
244
+ x_chain_list.append(x_chain)
245
+ chain_mask_list.append(chain_mask)
246
+ chain_seq_list.append(chain_seq)
247
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
248
+ l1 += chain_length
249
+ residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
250
+ l0 += chain_length
251
+ c+=1
252
+ fixed_position_mask = np.ones(chain_length)
253
+ fixed_position_mask_list.append(fixed_position_mask)
254
+ omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
255
+ omit_AA_mask_list.append(omit_AA_mask_temp)
256
+ pssm_coef = np.zeros(chain_length)
257
+ pssm_bias = np.zeros([chain_length, 21])
258
+ pssm_log_odds = 10000.0*np.ones([chain_length, 21])
259
+ pssm_coef_list.append(pssm_coef)
260
+ pssm_bias_list.append(pssm_bias)
261
+ pssm_log_odds_list.append(pssm_log_odds)
262
+ bias_by_res_list.append(np.zeros([chain_length, 21]))
263
+ if letter in masked_chains:
264
+ masked_list.append(letter)
265
+ letter_list.append(letter)
266
+ chain_seq = b[f'seq_chain_{letter}']
267
+ chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq])
268
+ chain_length = len(chain_seq)
269
+ global_idx_start_list.append(global_idx_start_list[-1]+chain_length)
270
+ masked_chain_length_list.append(chain_length)
271
+ chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
272
+ chain_mask = np.ones(chain_length) #1.0 for masked
273
+ if ca_only:
274
+ x_chain = np.array(chain_coords[f'CA_chain_{letter}']) #[chain_lenght,1,3] #CA_diff
275
+ if len(x_chain.shape) == 2:
276
+ x_chain = x_chain[:,None,:]
277
+ else:
278
+ x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
279
+ x_chain_list.append(x_chain)
280
+ chain_mask_list.append(chain_mask)
281
+ chain_seq_list.append(chain_seq)
282
+ chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
283
+ l1 += chain_length
284
+ residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
285
+ l0 += chain_length
286
+ c+=1
287
+ fixed_position_mask = np.ones(chain_length)
288
+ if fixed_position_dict!=None:
289
+ fixed_pos_list = fixed_position_dict[b['name']][letter]
290
+ if fixed_pos_list:
291
+ fixed_position_mask[np.array(fixed_pos_list)-1] = 0.0
292
+ fixed_position_mask_list.append(fixed_position_mask)
293
+ omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32)
294
+ if omit_AA_dict!=None:
295
+ for item in omit_AA_dict[b['name']][letter]:
296
+ idx_AA = np.array(item[0])-1
297
+ AA_idx = np.array([np.argwhere(np.array(list(alphabet))== AA)[0][0] for AA in item[1]]).repeat(idx_AA.shape[0])
298
+ idx_ = np.array([[a, b] for a in idx_AA for b in AA_idx])
299
+ omit_AA_mask_temp[idx_[:,0], idx_[:,1]] = 1
300
+ omit_AA_mask_list.append(omit_AA_mask_temp)
301
+ pssm_coef = np.zeros(chain_length)
302
+ pssm_bias = np.zeros([chain_length, 21])
303
+ pssm_log_odds = 10000.0*np.ones([chain_length, 21])
304
+ if pssm_dict:
305
+ if pssm_dict[b['name']][letter]:
306
+ pssm_coef = pssm_dict[b['name']][letter]['pssm_coef']
307
+ pssm_bias = pssm_dict[b['name']][letter]['pssm_bias']
308
+ pssm_log_odds = pssm_dict[b['name']][letter]['pssm_log_odds']
309
+ pssm_coef_list.append(pssm_coef)
310
+ pssm_bias_list.append(pssm_bias)
311
+ pssm_log_odds_list.append(pssm_log_odds)
312
+ if bias_by_res_dict:
313
+ bias_by_res_list.append(bias_by_res_dict[b['name']][letter])
314
+ else:
315
+ bias_by_res_list.append(np.zeros([chain_length, 21]))
316
+
317
+
318
+ letter_list_np = np.array(letter_list)
319
+ tied_pos_list_of_lists = []
320
+ tied_beta = np.ones(L_max)
321
+ if tied_positions_dict!=None:
322
+ tied_pos_list = tied_positions_dict[b['name']]
323
+ if tied_pos_list:
324
+ set_chains_tied = set(list(itertools.chain(*[list(item) for item in tied_pos_list])))
325
+ for tied_item in tied_pos_list:
326
+ one_list = []
327
+ for k, v in tied_item.items():
328
+ start_idx = global_idx_start_list[np.argwhere(letter_list_np == k)[0][0]]
329
+ if isinstance(v[0], list):
330
+ for v_count in range(len(v[0])):
331
+ one_list.append(start_idx+v[0][v_count]-1)#make 0 to be the first
332
+ tied_beta[start_idx+v[0][v_count]-1] = v[1][v_count]
333
+ else:
334
+ for v_ in v:
335
+ one_list.append(start_idx+v_-1)#make 0 to be the first
336
+ tied_pos_list_of_lists.append(one_list)
337
+ tied_pos_list_of_lists_list.append(tied_pos_list_of_lists)
338
+
339
+
340
+
341
+ x = np.concatenate(x_chain_list,0) #[L, 4, 3]
342
+ all_sequence = "".join(chain_seq_list)
343
+ m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
344
+ chain_encoding = np.concatenate(chain_encoding_list,0)
345
+ m_pos = np.concatenate(fixed_position_mask_list,0) #[L,], 1.0 for places that need to be predicted
346
+
347
+ pssm_coef_ = np.concatenate(pssm_coef_list,0) #[L,], 1.0 for places that need to be predicted
348
+ pssm_bias_ = np.concatenate(pssm_bias_list,0) #[L,], 1.0 for places that need to be predicted
349
+ pssm_log_odds_ = np.concatenate(pssm_log_odds_list,0) #[L,], 1.0 for places that need to be predicted
350
+
351
+ bias_by_res_ = np.concatenate(bias_by_res_list, 0) #[L,21], 0.0 for places where AA frequencies don't need to be tweaked
352
+
353
+ l = len(all_sequence)
354
+ x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
355
+ X[i,:,:,:] = x_pad
356
+
357
+ m_pad = np.pad(m, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
358
+ m_pos_pad = np.pad(m_pos, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
359
+ omit_AA_mask_pad = np.pad(np.concatenate(omit_AA_mask_list,0), [[0,L_max-l]], 'constant', constant_values=(0.0, ))
360
+ chain_M[i,:] = m_pad
361
+ chain_M_pos[i,:] = m_pos_pad
362
+ omit_AA_mask[i,] = omit_AA_mask_pad
363
+
364
+ chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
365
+ chain_encoding_all[i,:] = chain_encoding_pad
366
+
367
+ pssm_coef_pad = np.pad(pssm_coef_, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
368
+ pssm_bias_pad = np.pad(pssm_bias_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
369
+ pssm_log_odds_pad = np.pad(pssm_log_odds_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
370
+
371
+ pssm_coef_all[i,:] = pssm_coef_pad
372
+ pssm_bias_all[i,:] = pssm_bias_pad
373
+ pssm_log_odds_all[i,:] = pssm_log_odds_pad
374
+
375
+ bias_by_res_pad = np.pad(bias_by_res_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, ))
376
+ bias_by_res_all[i,:] = bias_by_res_pad
377
+
378
+ # Convert to labels
379
+ indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)
380
+ S[i, :l] = indices
381
+ letter_list_list.append(letter_list)
382
+ visible_list_list.append(visible_list)
383
+ masked_list_list.append(masked_list)
384
+ masked_chain_length_list_list.append(masked_chain_length_list)
385
+
386
+
387
+ isnan = np.isnan(X)
388
+ mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
389
+ X[isnan] = 0.
390
+
391
+ # Conversion
392
+ pssm_coef_all = torch.from_numpy(pssm_coef_all).to(dtype=torch.float32, device=device)
393
+ pssm_bias_all = torch.from_numpy(pssm_bias_all).to(dtype=torch.float32, device=device)
394
+ pssm_log_odds_all = torch.from_numpy(pssm_log_odds_all).to(dtype=torch.float32, device=device)
395
+
396
+ tied_beta = torch.from_numpy(tied_beta).to(dtype=torch.float32, device=device)
397
+
398
+ jumps = ((residue_idx[:,1:]-residue_idx[:,:-1])==1).astype(np.float32)
399
+ bias_by_res_all = torch.from_numpy(bias_by_res_all).to(dtype=torch.float32, device=device)
400
+ phi_mask = np.pad(jumps, [[0,0],[1,0]])
401
+ psi_mask = np.pad(jumps, [[0,0],[0,1]])
402
+ omega_mask = np.pad(jumps, [[0,0],[0,1]])
403
+ dihedral_mask = np.concatenate([phi_mask[:,:,None], psi_mask[:,:,None], omega_mask[:,:,None]], -1) #[B,L,3]
404
+ dihedral_mask = torch.from_numpy(dihedral_mask).to(dtype=torch.float32, device=device)
405
+ residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long,device=device)
406
+ S = torch.from_numpy(S).to(dtype=torch.long,device=device)
407
+ X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
408
+ mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
409
+ chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)
410
+ chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32, device=device)
411
+ omit_AA_mask = torch.from_numpy(omit_AA_mask).to(dtype=torch.float32, device=device)
412
+ chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)
413
+ if ca_only:
414
+ X_out = X[:,:,0]
415
+ else:
416
+ X_out = X
417
+ return X_out, S, mask, lengths, chain_M, chain_encoding_all, letter_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef_all, pssm_bias_all, pssm_log_odds_all, bias_by_res_all, tied_beta
418
+
419
+
420
+
421
+ def loss_nll(S, log_probs, mask):
422
+ """ Negative log probabilities """
423
+ criterion = torch.nn.NLLLoss(reduction='none')
424
+ loss = criterion(
425
+ log_probs.contiguous().view(-1, log_probs.size(-1)), S.contiguous().view(-1)
426
+ ).view(S.size())
427
+ loss_av = torch.sum(loss * mask) / torch.sum(mask)
428
+ return loss, loss_av
429
+
430
+
431
+ def loss_smoothed(S, log_probs, mask, weight=0.1):
432
+ """ Negative log probabilities """
433
+ S_onehot = torch.nn.functional.one_hot(S, 21).float()
434
+
435
+ # Label smoothing
436
+ S_onehot = S_onehot + weight / float(S_onehot.size(-1))
437
+ S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)
438
+
439
+ loss = -(S_onehot * log_probs).sum(-1)
440
+ loss_av = torch.sum(loss * mask) / torch.sum(mask)
441
+ return loss, loss_av
442
+
443
+ class StructureDataset():
444
+ def __init__(self, jsonl_file, verbose=True, truncate=None, max_length=100,
445
+ alphabet='ACDEFGHIKLMNPQRSTVWYX-'):
446
+ alphabet_set = set([a for a in alphabet])
447
+ discard_count = {
448
+ 'bad_chars': 0,
449
+ 'too_long': 0,
450
+ 'bad_seq_length': 0
451
+ }
452
+
453
+ with open(jsonl_file) as f:
454
+ self.data = []
455
+
456
+ lines = f.readlines()
457
+ start = time.time()
458
+ for i, line in enumerate(lines):
459
+ entry = json.loads(line)
460
+ seq = entry['seq']
461
+ name = entry['name']
462
+
463
+ # Convert raw coords to np arrays
464
+ #for key, val in entry['coords'].items():
465
+ # entry['coords'][key] = np.asarray(val)
466
+
467
+ # Check if in alphabet
468
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
469
+ if len(bad_chars) == 0:
470
+ if len(entry['seq']) <= max_length:
471
+ if True:
472
+ self.data.append(entry)
473
+ else:
474
+ discard_count['bad_seq_length'] += 1
475
+ else:
476
+ discard_count['too_long'] += 1
477
+ else:
478
+ print(name, bad_chars, entry['seq'])
479
+ discard_count['bad_chars'] += 1
480
+
481
+ # Truncate early
482
+ if truncate is not None and len(self.data) == truncate:
483
+ return
484
+
485
+ if verbose and (i + 1) % 1000 == 0:
486
+ elapsed = time.time() - start
487
+ print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed))
488
+
489
+ print('discarded', discard_count)
490
+ def __len__(self):
491
+ return len(self.data)
492
+
493
+ def __getitem__(self, idx):
494
+ return self.data[idx]
495
+
496
+
497
+ class StructureDatasetPDB():
498
+ def __init__(self, pdb_dict_list, verbose=True, truncate=None, max_length=100,
499
+ alphabet='ACDEFGHIKLMNPQRSTVWYX-'):
500
+ alphabet_set = set([a for a in alphabet])
501
+ discard_count = {
502
+ 'bad_chars': 0,
503
+ 'too_long': 0,
504
+ 'bad_seq_length': 0
505
+ }
506
+
507
+ self.data = []
508
+
509
+ start = time.time()
510
+ for i, entry in enumerate(pdb_dict_list):
511
+ seq = entry['seq']
512
+ name = entry['name']
513
+
514
+ bad_chars = set([s for s in seq]).difference(alphabet_set)
515
+ if len(bad_chars) == 0:
516
+ if len(entry['seq']) <= max_length:
517
+ self.data.append(entry)
518
+ else:
519
+ discard_count['too_long'] += 1
520
+ else:
521
+ discard_count['bad_chars'] += 1
522
+
523
+ # Truncate early
524
+ if truncate is not None and len(self.data) == truncate:
525
+ return
526
+
527
+ if verbose and (i + 1) % 1000 == 0:
528
+ elapsed = time.time() - start
529
+
530
+ #print('Discarded', discard_count)
531
+ def __len__(self):
532
+ return len(self.data)
533
+
534
+ def __getitem__(self, idx):
535
+ return self.data[idx]
536
+
537
+
538
+
539
+ class StructureLoader():
540
+ def __init__(self, dataset, batch_size=100, shuffle=True,
541
+ collate_fn=lambda x:x, drop_last=False):
542
+ self.dataset = dataset
543
+ self.size = len(dataset)
544
+ self.lengths = [len(dataset[i]['seq']) for i in range(self.size)]
545
+ self.batch_size = batch_size
546
+ sorted_ix = np.argsort(self.lengths)
547
+
548
+ # Cluster into batches of similar sizes
549
+ clusters, batch = [], []
550
+ batch_max = 0
551
+ for ix in sorted_ix:
552
+ size = self.lengths[ix]
553
+ if size * (len(batch) + 1) <= self.batch_size:
554
+ batch.append(ix)
555
+ batch_max = size
556
+ else:
557
+ clusters.append(batch)
558
+ batch, batch_max = [], 0
559
+ if len(batch) > 0:
560
+ clusters.append(batch)
561
+ self.clusters = clusters
562
+
563
+ def __len__(self):
564
+ return len(self.clusters)
565
+
566
+ def __iter__(self):
567
+ np.random.shuffle(self.clusters)
568
+ for b_idx in self.clusters:
569
+ batch = [self.dataset[i] for i in b_idx]
570
+ yield batch
571
+
572
+
573
+
574
+ # The following gather functions
575
+ def gather_edges(edges, neighbor_idx):
576
+ # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
577
+ neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
578
+ edge_features = torch.gather(edges, 2, neighbors)
579
+ return edge_features
580
+
581
+ def gather_nodes(nodes, neighbor_idx):
582
+ # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
583
+ # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
584
+ neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1))
585
+ neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
586
+ # Gather and re-pack
587
+ neighbor_features = torch.gather(nodes, 1, neighbors_flat)
588
+ neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
589
+ return neighbor_features
590
+
591
+ def gather_nodes_t(nodes, neighbor_idx):
592
+ # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
593
+ idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2))
594
+ neighbor_features = torch.gather(nodes, 1, idx_flat)
595
+ return neighbor_features
596
+
597
+ def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
598
+ h_nodes = gather_nodes(h_nodes, E_idx)
599
+ h_nn = torch.cat([h_neighbors, h_nodes], -1)
600
+ return h_nn
601
+
602
+
603
+ class EncLayer(nn.Module):
604
+ def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
605
+ super(EncLayer, self).__init__()
606
+ self.num_hidden = num_hidden
607
+ self.num_in = num_in
608
+ self.scale = scale
609
+ self.dropout1 = nn.Dropout(dropout)
610
+ self.dropout2 = nn.Dropout(dropout)
611
+ self.dropout3 = nn.Dropout(dropout)
612
+ self.norm1 = nn.LayerNorm(num_hidden)
613
+ self.norm2 = nn.LayerNorm(num_hidden)
614
+ self.norm3 = nn.LayerNorm(num_hidden)
615
+
616
+ self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
617
+ self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
618
+ self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
619
+ self.W11 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
620
+ self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
621
+ self.W13 = nn.Linear(num_hidden, num_hidden, bias=True)
622
+ self.act = torch.nn.GELU()
623
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
624
+
625
+ def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None):
626
+ """ Parallel computation of full transformer layer """
627
+
628
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
629
+ h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
630
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
631
+ h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
632
+ if mask_attend is not None:
633
+ h_message = mask_attend.unsqueeze(-1) * h_message
634
+ dh = torch.sum(h_message, -2) / self.scale
635
+ h_V = self.norm1(h_V + self.dropout1(dh))
636
+
637
+ dh = self.dense(h_V)
638
+ h_V = self.norm2(h_V + self.dropout2(dh))
639
+ if mask_V is not None:
640
+ mask_V = mask_V.unsqueeze(-1)
641
+ h_V = mask_V * h_V
642
+
643
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
644
+ h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
645
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
646
+ h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV)))))
647
+ h_E = self.norm3(h_E + self.dropout3(h_message))
648
+ return h_V, h_E
649
+
650
+
651
+ class DecLayer(nn.Module):
652
+ def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
653
+ super(DecLayer, self).__init__()
654
+ self.num_hidden = num_hidden
655
+ self.num_in = num_in
656
+ self.scale = scale
657
+ self.dropout1 = nn.Dropout(dropout)
658
+ self.dropout2 = nn.Dropout(dropout)
659
+ self.norm1 = nn.LayerNorm(num_hidden)
660
+ self.norm2 = nn.LayerNorm(num_hidden)
661
+
662
+ self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
663
+ self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
664
+ self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
665
+ self.act = torch.nn.GELU()
666
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
667
+
668
+ def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
669
+ """ Parallel computation of full transformer layer """
670
+
671
+ # Concatenate h_V_i to h_E_ij
672
+ h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1)
673
+ h_EV = torch.cat([h_V_expand, h_E], -1)
674
+
675
+ h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
676
+ if mask_attend is not None:
677
+ h_message = mask_attend.unsqueeze(-1) * h_message
678
+ dh = torch.sum(h_message, -2) / self.scale
679
+
680
+ h_V = self.norm1(h_V + self.dropout1(dh))
681
+
682
+ # Position-wise feedforward
683
+ dh = self.dense(h_V)
684
+ h_V = self.norm2(h_V + self.dropout2(dh))
685
+
686
+ if mask_V is not None:
687
+ mask_V = mask_V.unsqueeze(-1)
688
+ h_V = mask_V * h_V
689
+ return h_V
690
+
691
+
692
+
693
+ class PositionWiseFeedForward(nn.Module):
694
+ def __init__(self, num_hidden, num_ff):
695
+ super(PositionWiseFeedForward, self).__init__()
696
+ self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
697
+ self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
698
+ self.act = torch.nn.GELU()
699
+ def forward(self, h_V):
700
+ h = self.act(self.W_in(h_V))
701
+ h = self.W_out(h)
702
+ return h
703
+
704
+ class PositionalEncodings(nn.Module):
705
+ def __init__(self, num_embeddings, max_relative_feature=32):
706
+ super(PositionalEncodings, self).__init__()
707
+ self.num_embeddings = num_embeddings
708
+ self.max_relative_feature = max_relative_feature
709
+ self.linear = nn.Linear(2*max_relative_feature+1+1, num_embeddings)
710
+
711
+ def forward(self, offset, mask):
712
+ d = torch.clip(offset + self.max_relative_feature, 0, 2*self.max_relative_feature)*mask + (1-mask)*(2*self.max_relative_feature+1)
713
+ d_onehot = torch.nn.functional.one_hot(d, 2*self.max_relative_feature+1+1)
714
+ E = self.linear(d_onehot.float())
715
+ return E
716
+
717
+
718
+
719
+ class CA_ProteinFeatures(nn.Module):
720
+ def __init__(self, edge_features, node_features, num_positional_embeddings=16,
721
+ num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16):
722
+ """ Extract protein features """
723
+ super(CA_ProteinFeatures, self).__init__()
724
+ self.edge_features = edge_features
725
+ self.node_features = node_features
726
+ self.top_k = top_k
727
+ self.augment_eps = augment_eps
728
+ self.num_rbf = num_rbf
729
+ self.num_positional_embeddings = num_positional_embeddings
730
+
731
+ # Positional encoding
732
+ self.embeddings = PositionalEncodings(num_positional_embeddings)
733
+ # Normalization and embedding
734
+ node_in, edge_in = 3, num_positional_embeddings + num_rbf*9 + 7
735
+ self.node_embedding = nn.Linear(node_in, node_features, bias=False) #NOT USED
736
+ self.edge_embedding = nn.Linear(edge_in, edge_features, bias=False)
737
+ self.norm_nodes = nn.LayerNorm(node_features)
738
+ self.norm_edges = nn.LayerNorm(edge_features)
739
+
740
+
741
+ def _quaternions(self, R):
742
+ """ Convert a batch of 3D rotations [R] to quaternions [Q]
743
+ R [...,3,3]
744
+ Q [...,4]
745
+ """
746
+ # Simple Wikipedia version
747
+ # en.wikipedia.org/wiki/Rotation_matrix#Quaternion
748
+ # For other options see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
749
+ diag = torch.diagonal(R, dim1=-2, dim2=-1)
750
+ Rxx, Ryy, Rzz = diag.unbind(-1)
751
+ magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
752
+ Rxx - Ryy - Rzz,
753
+ - Rxx + Ryy - Rzz,
754
+ - Rxx - Ryy + Rzz
755
+ ], -1)))
756
+ _R = lambda i,j: R[:,:,:,i,j]
757
+ signs = torch.sign(torch.stack([
758
+ _R(2,1) - _R(1,2),
759
+ _R(0,2) - _R(2,0),
760
+ _R(1,0) - _R(0,1)
761
+ ], -1))
762
+ xyz = signs * magnitudes
763
+ # The relu enforces a non-negative trace
764
+ w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
765
+ Q = torch.cat((xyz, w), -1)
766
+ Q = F.normalize(Q, dim=-1)
767
+ return Q
768
+
769
+ def _orientations_coarse(self, X, E_idx, eps=1e-6):
770
+ dX = X[:,1:,:] - X[:,:-1,:]
771
+ dX_norm = torch.norm(dX,dim=-1)
772
+ dX_mask = (3.6<dX_norm) & (dX_norm<4.0) #exclude CA-CA jumps
773
+ dX = dX*dX_mask[:,:,None]
774
+ U = F.normalize(dX, dim=-1)
775
+ u_2 = U[:,:-2,:]
776
+ u_1 = U[:,1:-1,:]
777
+ u_0 = U[:,2:,:]
778
+ # Backbone normals
779
+ n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
780
+ n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
781
+
782
+ # Bond angle calculation
783
+ cosA = -(u_1 * u_0).sum(-1)
784
+ cosA = torch.clamp(cosA, -1+eps, 1-eps)
785
+ A = torch.acos(cosA)
786
+ # Angle between normals
787
+ cosD = (n_2 * n_1).sum(-1)
788
+ cosD = torch.clamp(cosD, -1+eps, 1-eps)
789
+ D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
790
+ # Backbone features
791
+ AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2)
792
+ AD_features = F.pad(AD_features, (0,0,1,2), 'constant', 0)
793
+
794
+ # Build relative orientations
795
+ o_1 = F.normalize(u_2 - u_1, dim=-1)
796
+ O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2)
797
+ O = O.view(list(O.shape[:2]) + [9])
798
+ O = F.pad(O, (0,0,1,2), 'constant', 0)
799
+ O_neighbors = gather_nodes(O, E_idx)
800
+ X_neighbors = gather_nodes(X, E_idx)
801
+
802
+ # Re-view as rotation matrices
803
+ O = O.view(list(O.shape[:2]) + [3,3])
804
+ O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3])
805
+
806
+ # Rotate into local reference frames
807
+ dX = X_neighbors - X.unsqueeze(-2)
808
+ dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1)
809
+ dU = F.normalize(dU, dim=-1)
810
+ R = torch.matmul(O.unsqueeze(2).transpose(-1,-2), O_neighbors)
811
+ Q = self._quaternions(R)
812
+
813
+ # Orientation features
814
+ O_features = torch.cat((dU,Q), dim=-1)
815
+ return AD_features, O_features
816
+
817
+
818
+
819
+ def _dist(self, X, mask, eps=1E-6):
820
+ """ Pairwise euclidean distances """
821
+ # Convolutional network on NCHW
822
+ mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
823
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
824
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
825
+
826
+ # Identify k nearest neighbors (including self)
827
+ D_max, _ = torch.max(D, -1, keepdim=True)
828
+ D_adjust = D + (1. - mask_2D) * D_max
829
+ D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)
830
+ mask_neighbors = gather_edges(mask_2D.unsqueeze(-1), E_idx)
831
+ return D_neighbors, E_idx, mask_neighbors
832
+
833
+ def _rbf(self, D):
834
+ # Distance radial basis function
835
+ device = D.device
836
+ D_min, D_max, D_count = 2., 22., self.num_rbf
837
+ D_mu = torch.linspace(D_min, D_max, D_count).to(device)
838
+ D_mu = D_mu.view([1,1,1,-1])
839
+ D_sigma = (D_max - D_min) / D_count
840
+ D_expand = torch.unsqueeze(D, -1)
841
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
842
+ return RBF
843
+
844
+ def _get_rbf(self, A, B, E_idx):
845
+ D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
846
+ D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
847
+ RBF_A_B = self._rbf(D_A_B_neighbors)
848
+ return RBF_A_B
849
+
850
+ def forward(self, Ca, mask, residue_idx, chain_labels):
851
+ """ Featurize coordinates as an attributed graph """
852
+ if self.augment_eps > 0:
853
+ Ca = Ca + self.augment_eps * torch.randn_like(Ca)
854
+
855
+ D_neighbors, E_idx, mask_neighbors = self._dist(Ca, mask)
856
+
857
+ Ca_0 = torch.zeros(Ca.shape, device=Ca.device)
858
+ Ca_2 = torch.zeros(Ca.shape, device=Ca.device)
859
+ Ca_0[:,1:,:] = Ca[:,:-1,:]
860
+ Ca_1 = Ca
861
+ Ca_2[:,:-1,:] = Ca[:,1:,:]
862
+
863
+ V, O_features = self._orientations_coarse(Ca, E_idx)
864
+
865
+ RBF_all = []
866
+ RBF_all.append(self._rbf(D_neighbors)) #Ca_1-Ca_1
867
+ RBF_all.append(self._get_rbf(Ca_0, Ca_0, E_idx))
868
+ RBF_all.append(self._get_rbf(Ca_2, Ca_2, E_idx))
869
+
870
+ RBF_all.append(self._get_rbf(Ca_0, Ca_1, E_idx))
871
+ RBF_all.append(self._get_rbf(Ca_0, Ca_2, E_idx))
872
+
873
+ RBF_all.append(self._get_rbf(Ca_1, Ca_0, E_idx))
874
+ RBF_all.append(self._get_rbf(Ca_1, Ca_2, E_idx))
875
+
876
+ RBF_all.append(self._get_rbf(Ca_2, Ca_0, E_idx))
877
+ RBF_all.append(self._get_rbf(Ca_2, Ca_1, E_idx))
878
+
879
+
880
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
881
+
882
+
883
+ offset = residue_idx[:,:,None]-residue_idx[:,None,:]
884
+ offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K]
885
+
886
+ d_chains = ((chain_labels[:, :, None] - chain_labels[:,None,:])==0).long()
887
+ E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]
888
+ E_positional = self.embeddings(offset.long(), E_chains)
889
+ E = torch.cat((E_positional, RBF_all, O_features), -1)
890
+
891
+
892
+ E = self.edge_embedding(E)
893
+ E = self.norm_edges(E)
894
+
895
+ return E, E_idx
896
+
897
+
898
+
899
+
900
+ class ProteinFeatures(nn.Module):
901
+ def __init__(self, edge_features, node_features, num_positional_embeddings=16,
902
+ num_rbf=16, top_k=30, augment_eps=0., num_chain_embeddings=16):
903
+ """ Extract protein features """
904
+ super(ProteinFeatures, self).__init__()
905
+ self.edge_features = edge_features
906
+ self.node_features = node_features
907
+ self.top_k = top_k
908
+ self.augment_eps = augment_eps
909
+ self.num_rbf = num_rbf
910
+ self.num_positional_embeddings = num_positional_embeddings
911
+
912
+ self.embeddings = PositionalEncodings(num_positional_embeddings)
913
+ node_in, edge_in = 6, num_positional_embeddings + num_rbf*25
914
+ self.edge_embedding = nn.Linear(edge_in, edge_features, bias=False)
915
+ self.norm_edges = nn.LayerNorm(edge_features)
916
+
917
+ def _dist(self, X, mask, eps=1E-6):
918
+ mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
919
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
920
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
921
+ D_max, _ = torch.max(D, -1, keepdim=True)
922
+ D_adjust = D + (1. - mask_2D) * D_max
923
+ sampled_top_k = self.top_k
924
+ D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False)
925
+ return D_neighbors, E_idx
926
+
927
+ def _rbf(self, D):
928
+ device = D.device
929
+ D_min, D_max, D_count = 2., 22., self.num_rbf
930
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
931
+ D_mu = D_mu.view([1,1,1,-1])
932
+ D_sigma = (D_max - D_min) / D_count
933
+ D_expand = torch.unsqueeze(D, -1)
934
+ RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
935
+ return RBF
936
+
937
+ def _get_rbf(self, A, B, E_idx):
938
+ D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
939
+ D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
940
+ RBF_A_B = self._rbf(D_A_B_neighbors)
941
+ return RBF_A_B
942
+
943
+ def forward(self, X, mask, residue_idx, chain_labels):
944
+ if self.augment_eps > 0:
945
+ X = X + self.augment_eps * torch.randn_like(X)
946
+
947
+ b = X[:,:,1,:] - X[:,:,0,:]
948
+ c = X[:,:,2,:] - X[:,:,1,:]
949
+ a = torch.cross(b, c, dim=-1)
950
+ Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + X[:,:,1,:]
951
+ Ca = X[:,:,1,:]
952
+ N = X[:,:,0,:]
953
+ C = X[:,:,2,:]
954
+ O = X[:,:,3,:]
955
+
956
+ D_neighbors, E_idx = self._dist(Ca, mask)
957
+
958
+ RBF_all = []
959
+ RBF_all.append(self._rbf(D_neighbors)) #Ca-Ca
960
+ RBF_all.append(self._get_rbf(N, N, E_idx)) #N-N
961
+ RBF_all.append(self._get_rbf(C, C, E_idx)) #C-C
962
+ RBF_all.append(self._get_rbf(O, O, E_idx)) #O-O
963
+ RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) #Cb-Cb
964
+ RBF_all.append(self._get_rbf(Ca, N, E_idx)) #Ca-N
965
+ RBF_all.append(self._get_rbf(Ca, C, E_idx)) #Ca-C
966
+ RBF_all.append(self._get_rbf(Ca, O, E_idx)) #Ca-O
967
+ RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) #Ca-Cb
968
+ RBF_all.append(self._get_rbf(N, C, E_idx)) #N-C
969
+ RBF_all.append(self._get_rbf(N, O, E_idx)) #N-O
970
+ RBF_all.append(self._get_rbf(N, Cb, E_idx)) #N-Cb
971
+ RBF_all.append(self._get_rbf(Cb, C, E_idx)) #Cb-C
972
+ RBF_all.append(self._get_rbf(Cb, O, E_idx)) #Cb-O
973
+ RBF_all.append(self._get_rbf(O, C, E_idx)) #O-C
974
+ RBF_all.append(self._get_rbf(N, Ca, E_idx)) #N-Ca
975
+ RBF_all.append(self._get_rbf(C, Ca, E_idx)) #C-Ca
976
+ RBF_all.append(self._get_rbf(O, Ca, E_idx)) #O-Ca
977
+ RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) #Cb-Ca
978
+ RBF_all.append(self._get_rbf(C, N, E_idx)) #C-N
979
+ RBF_all.append(self._get_rbf(O, N, E_idx)) #O-N
980
+ RBF_all.append(self._get_rbf(Cb, N, E_idx)) #Cb-N
981
+ RBF_all.append(self._get_rbf(C, Cb, E_idx)) #C-Cb
982
+ RBF_all.append(self._get_rbf(O, Cb, E_idx)) #O-Cb
983
+ RBF_all.append(self._get_rbf(C, O, E_idx)) #C-O
984
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
985
+
986
+ offset = residue_idx[:,:,None]-residue_idx[:,None,:]
987
+ offset = gather_edges(offset[:,:,:,None], E_idx)[:,:,:,0] #[B, L, K]
988
+
989
+ d_chains = ((chain_labels[:, :, None] - chain_labels[:,None,:])==0).long() #find self vs non-self interaction
990
+ E_chains = gather_edges(d_chains[:,:,:,None], E_idx)[:,:,:,0]
991
+ E_positional = self.embeddings(offset.long(), E_chains)
992
+ E = torch.cat((E_positional, RBF_all), -1)
993
+ E = self.edge_embedding(E)
994
+ E = self.norm_edges(E)
995
+ return E, E_idx
996
+
997
+
998
+
999
+ class ProteinMPNN(nn.Module):
1000
+ def __init__(self, num_letters, node_features, edge_features,
1001
+ hidden_dim, num_encoder_layers=3, num_decoder_layers=3,
1002
+ vocab=21, k_neighbors=64, augment_eps=0.05, dropout=0.1, ca_only=False):
1003
+ super(ProteinMPNN, self).__init__()
1004
+
1005
+ # Hyperparameters
1006
+ self.node_features = node_features
1007
+ self.edge_features = edge_features
1008
+ self.hidden_dim = hidden_dim
1009
+
1010
+ # Featurization layers
1011
+ if ca_only:
1012
+ self.features = CA_ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)
1013
+ self.W_v = nn.Linear(node_features, hidden_dim, bias=True)
1014
+ else:
1015
+ self.features = ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)
1016
+
1017
+ self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
1018
+ self.W_s = nn.Embedding(vocab, hidden_dim)
1019
+
1020
+ # Encoder layers
1021
+ self.encoder_layers = nn.ModuleList([
1022
+ EncLayer(hidden_dim, hidden_dim*2, dropout=dropout)
1023
+ for _ in range(num_encoder_layers)
1024
+ ])
1025
+
1026
+ # Decoder layers
1027
+ self.decoder_layers = nn.ModuleList([
1028
+ DecLayer(hidden_dim, hidden_dim*3, dropout=dropout)
1029
+ for _ in range(num_decoder_layers)
1030
+ ])
1031
+ self.W_out = nn.Linear(hidden_dim, num_letters, bias=True)
1032
+
1033
+ for p in self.parameters():
1034
+ if p.dim() > 1:
1035
+ nn.init.xavier_uniform_(p)
1036
+
1037
+ def forward(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, use_input_decoding_order=False, decoding_order=None):
1038
+ """ Graph-conditioned sequence model """
1039
+ device=X.device
1040
+ # Prepare node and edge embeddings
1041
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1042
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
1043
+ h_E = self.W_e(E)
1044
+
1045
+ # Encoder is unmasked self-attention
1046
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1047
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1048
+ for layer in self.encoder_layers:
1049
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
1050
+
1051
+ # Concatenate sequence embeddings for autoregressive decoder
1052
+ h_S = self.W_s(S)
1053
+ h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
1054
+
1055
+ # Build encoder embeddings
1056
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1057
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1058
+
1059
+
1060
+ chain_M = chain_M*mask #update chain_M to include missing regions
1061
+ if not use_input_decoding_order:
1062
+ decoding_order = torch.argsort((chain_M+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1063
+ mask_size = E_idx.shape[1]
1064
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1065
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1066
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1067
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1068
+ mask_bw = mask_1D * mask_attend
1069
+ mask_fw = mask_1D * (1. - mask_attend)
1070
+
1071
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1072
+ for layer in self.decoder_layers:
1073
+ # Masked positions attend to encoder information, unmasked see.
1074
+ h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
1075
+ h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
1076
+ h_V = layer(h_V, h_ESV, mask)
1077
+
1078
+ logits = self.W_out(h_V)
1079
+ log_probs = F.log_softmax(logits, dim=-1)
1080
+ return log_probs
1081
+
1082
+
1083
+
1084
+ def sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, bias_by_res=None):
1085
+ device = X.device
1086
+ # Prepare node and edge embeddings
1087
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1088
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
1089
+ h_E = self.W_e(E)
1090
+
1091
+ # Encoder is unmasked self-attention
1092
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1093
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1094
+ for layer in self.encoder_layers:
1095
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
1096
+
1097
+ # Decoder uses masked self-attention
1098
+ chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
1099
+ decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1100
+ mask_size = E_idx.shape[1]
1101
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1102
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1103
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1104
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1105
+ mask_bw = mask_1D * mask_attend
1106
+ mask_fw = mask_1D * (1. - mask_attend)
1107
+
1108
+ N_batch, N_nodes = X.size(0), X.size(1)
1109
+ log_probs = torch.zeros((N_batch, N_nodes, 21), device=device)
1110
+ all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32)
1111
+ h_S = torch.zeros_like(h_V, device=device)
1112
+ S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
1113
+ h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
1114
+ constant = torch.tensor(omit_AAs_np, device=device)
1115
+ constant_bias = torch.tensor(bias_AAs_np, device=device)
1116
+ #chain_mask_combined = chain_mask*chain_M_pos
1117
+ omit_AA_mask_flag = omit_AA_mask != None
1118
+
1119
+
1120
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1121
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1122
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1123
+ for t_ in range(N_nodes):
1124
+ t = decoding_order[:,t_] #[B]
1125
+ chain_mask_gathered = torch.gather(chain_mask, 1, t[:,None]) #[B]
1126
+ mask_gathered = torch.gather(mask, 1, t[:,None]) #[B]
1127
+ bias_by_res_gathered = torch.gather(bias_by_res, 1, t[:,None,None].repeat(1,1,21))[:,0,:] #[B, 21]
1128
+ if (mask_gathered==0).all(): #for padded or missing regions only
1129
+ S_t = torch.gather(S_true, 1, t[:,None])
1130
+ else:
1131
+ # Hidden layers
1132
+ E_idx_t = torch.gather(E_idx, 1, t[:,None,None].repeat(1,1,E_idx.shape[-1]))
1133
+ h_E_t = torch.gather(h_E, 1, t[:,None,None,None].repeat(1,1,h_E.shape[-2], h_E.shape[-1]))
1134
+ h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
1135
+ h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1, t[:,None,None,None].repeat(1,1,h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]))
1136
+ mask_t = torch.gather(mask, 1, t[:,None])
1137
+ for l, layer in enumerate(self.decoder_layers):
1138
+ # Updated relational features for future states
1139
+ h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
1140
+ h_V_t = torch.gather(h_V_stack[l], 1, t[:,None,None].repeat(1,1,h_V_stack[l].shape[-1]))
1141
+ h_ESV_t = torch.gather(mask_bw, 1, t[:,None,None,None].repeat(1,1,mask_bw.shape[-2], mask_bw.shape[-1])) * h_ESV_decoder_t + h_EXV_encoder_t
1142
+ h_V_stack[l+1].scatter_(1, t[:,None,None].repeat(1,1,h_V.shape[-1]), layer(h_V_t, h_ESV_t, mask_V=mask_t))
1143
+ # Sampling step
1144
+ h_V_t = torch.gather(h_V_stack[-1], 1, t[:,None,None].repeat(1,1,h_V_stack[-1].shape[-1]))[:,0]
1145
+ logits = self.W_out(h_V_t) / temperature
1146
+ probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
1147
+ if pssm_bias_flag:
1148
+ pssm_coef_gathered = torch.gather(pssm_coef, 1, t[:,None])[:,0]
1149
+ pssm_bias_gathered = torch.gather(pssm_bias, 1, t[:,None,None].repeat(1,1,pssm_bias.shape[-1]))[:,0]
1150
+ probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
1151
+ if pssm_log_odds_flag:
1152
+ pssm_log_odds_mask_gathered = torch.gather(pssm_log_odds_mask, 1, t[:,None, None].repeat(1,1,pssm_log_odds_mask.shape[-1]))[:,0] #[B, 21]
1153
+ probs_masked = probs*pssm_log_odds_mask_gathered
1154
+ probs_masked += probs * 0.001
1155
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1156
+ if omit_AA_mask_flag:
1157
+ omit_AA_mask_gathered = torch.gather(omit_AA_mask, 1, t[:,None, None].repeat(1,1,omit_AA_mask.shape[-1]))[:,0] #[B, 21]
1158
+ probs_masked = probs*(1.0-omit_AA_mask_gathered)
1159
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1160
+ S_t = torch.multinomial(probs, 1)
1161
+ all_probs.scatter_(1, t[:,None,None].repeat(1,1,21), (chain_mask_gathered[:,:,None,]*probs[:,None,:]).float())
1162
+ S_true_gathered = torch.gather(S_true, 1, t[:,None])
1163
+ S_t = (S_t*chain_mask_gathered+S_true_gathered*(1.0-chain_mask_gathered)).long()
1164
+ temp1 = self.W_s(S_t)
1165
+ h_S.scatter_(1, t[:,None,None].repeat(1,1,temp1.shape[-1]), temp1)
1166
+ S.scatter_(1, t[:,None], S_t)
1167
+ output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
1168
+ return output_dict
1169
+
1170
+
1171
+ def tied_sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, tied_pos=None, tied_beta=None, bias_by_res=None):
1172
+ device = X.device
1173
+ # Prepare node and edge embeddings
1174
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1175
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
1176
+ h_E = self.W_e(E)
1177
+ # Encoder is unmasked self-attention
1178
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1179
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1180
+ for layer in self.encoder_layers:
1181
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
1182
+
1183
+ # Decoder uses masked self-attention
1184
+ chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
1185
+ decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1186
+
1187
+ new_decoding_order = []
1188
+ for t_dec in list(decoding_order[0,].cpu().data.numpy()):
1189
+ if t_dec not in list(itertools.chain(*new_decoding_order)):
1190
+ list_a = [item for item in tied_pos if t_dec in item]
1191
+ if list_a:
1192
+ new_decoding_order.append(list_a[0])
1193
+ else:
1194
+ new_decoding_order.append([t_dec])
1195
+ decoding_order = torch.tensor(list(itertools.chain(*new_decoding_order)), device=device)[None,].repeat(X.shape[0],1)
1196
+
1197
+ mask_size = E_idx.shape[1]
1198
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1199
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1200
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1201
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1202
+ mask_bw = mask_1D * mask_attend
1203
+ mask_fw = mask_1D * (1. - mask_attend)
1204
+
1205
+ N_batch, N_nodes = X.size(0), X.size(1)
1206
+ log_probs = torch.zeros((N_batch, N_nodes, 21), device=device)
1207
+ all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32)
1208
+ h_S = torch.zeros_like(h_V, device=device)
1209
+ S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
1210
+ h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
1211
+ constant = torch.tensor(omit_AAs_np, device=device)
1212
+ constant_bias = torch.tensor(bias_AAs_np, device=device)
1213
+ omit_AA_mask_flag = omit_AA_mask != None
1214
+
1215
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1216
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1217
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1218
+ for t_list in new_decoding_order:
1219
+ logits = 0.0
1220
+ logit_list = []
1221
+ done_flag = False
1222
+ for t in t_list:
1223
+ if (mask[:,t]==0).all():
1224
+ S_t = S_true[:,t]
1225
+ for t in t_list:
1226
+ h_S[:,t,:] = self.W_s(S_t)
1227
+ S[:,t] = S_t
1228
+ done_flag = True
1229
+ break
1230
+ else:
1231
+ E_idx_t = E_idx[:,t:t+1,:]
1232
+ h_E_t = h_E[:,t:t+1,:,:]
1233
+ h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
1234
+ h_EXV_encoder_t = h_EXV_encoder_fw[:,t:t+1,:,:]
1235
+ mask_t = mask[:,t:t+1]
1236
+ for l, layer in enumerate(self.decoder_layers):
1237
+ h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
1238
+ h_V_t = h_V_stack[l][:,t:t+1,:]
1239
+ h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_EXV_encoder_t
1240
+ h_V_stack[l+1][:,t,:] = layer(h_V_t, h_ESV_t, mask_V=mask_t).squeeze(1)
1241
+ h_V_t = h_V_stack[-1][:,t,:]
1242
+ logit_list.append((self.W_out(h_V_t) / temperature)/len(t_list))
1243
+ logits += tied_beta[t]*(self.W_out(h_V_t) / temperature)/len(t_list)
1244
+ if done_flag:
1245
+ pass
1246
+ else:
1247
+ bias_by_res_gathered = bias_by_res[:,t,:] #[B, 21]
1248
+ probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
1249
+ if pssm_bias_flag:
1250
+ pssm_coef_gathered = pssm_coef[:,t]
1251
+ pssm_bias_gathered = pssm_bias[:,t]
1252
+ probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
1253
+ if pssm_log_odds_flag:
1254
+ pssm_log_odds_mask_gathered = pssm_log_odds_mask[:,t]
1255
+ probs_masked = probs*pssm_log_odds_mask_gathered
1256
+ probs_masked += probs * 0.001
1257
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1258
+ if omit_AA_mask_flag:
1259
+ omit_AA_mask_gathered = omit_AA_mask[:,t]
1260
+ probs_masked = probs*(1.0-omit_AA_mask_gathered)
1261
+ probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
1262
+ S_t_repeat = torch.multinomial(probs, 1).squeeze(-1)
1263
+ S_t_repeat = (chain_mask[:,t]*S_t_repeat + (1-chain_mask[:,t])*S_true[:,t]).long() #hard pick fixed positions
1264
+ for t in t_list:
1265
+ h_S[:,t,:] = self.W_s(S_t_repeat)
1266
+ S[:,t] = S_t_repeat
1267
+ all_probs[:,t,:] = probs.float()
1268
+ output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
1269
+ return output_dict
1270
+
1271
+
1272
+ def conditional_probs(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, backbone_only=False):
1273
+ """ Graph-conditioned sequence model """
1274
+ device=X.device
1275
+ # Prepare node and edge embeddings
1276
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1277
+ h_V_enc = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
1278
+ h_E = self.W_e(E)
1279
+
1280
+ # Encoder is unmasked self-attention
1281
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1282
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1283
+ for layer in self.encoder_layers:
1284
+ h_V_enc, h_E = layer(h_V_enc, h_E, E_idx, mask, mask_attend)
1285
+
1286
+ # Concatenate sequence embeddings for autoregressive decoder
1287
+ h_S = self.W_s(S)
1288
+ h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
1289
+
1290
+ # Build encoder embeddings
1291
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
1292
+ h_EXV_encoder = cat_neighbors_nodes(h_V_enc, h_EX_encoder, E_idx)
1293
+
1294
+
1295
+ chain_M = chain_M*mask #update chain_M to include missing regions
1296
+
1297
+ chain_M_np = chain_M.cpu().numpy()
1298
+ idx_to_loop = np.argwhere(chain_M_np[0,:]==1)[:,0]
1299
+ log_conditional_probs = torch.zeros([X.shape[0], chain_M.shape[1], 21], device=device).float()
1300
+
1301
+ for idx in idx_to_loop:
1302
+ h_V = torch.clone(h_V_enc)
1303
+ order_mask = torch.zeros(chain_M.shape[1], device=device).float()
1304
+ if backbone_only:
1305
+ order_mask = torch.ones(chain_M.shape[1], device=device).float()
1306
+ order_mask[idx] = 0.
1307
+ else:
1308
+ order_mask = torch.zeros(chain_M.shape[1], device=device).float()
1309
+ order_mask[idx] = 1.
1310
+ decoding_order = torch.argsort((order_mask[None,]+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
1311
+ mask_size = E_idx.shape[1]
1312
+ permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
1313
+ order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
1314
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1315
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1316
+ mask_bw = mask_1D * mask_attend
1317
+ mask_fw = mask_1D * (1. - mask_attend)
1318
+
1319
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1320
+ for layer in self.decoder_layers:
1321
+ # Masked positions attend to encoder information, unmasked see.
1322
+ h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
1323
+ h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
1324
+ h_V = layer(h_V, h_ESV, mask)
1325
+
1326
+ logits = self.W_out(h_V)
1327
+ log_probs = F.log_softmax(logits, dim=-1)
1328
+ log_conditional_probs[:,idx,:] = log_probs[:,idx,:]
1329
+ return log_conditional_probs
1330
+
1331
+
1332
+ def unconditional_probs(self, X, mask, residue_idx, chain_encoding_all):
1333
+ """ Graph-conditioned sequence model """
1334
+ device=X.device
1335
+ # Prepare node and edge embeddings
1336
+ E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
1337
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
1338
+ h_E = self.W_e(E)
1339
+
1340
+ # Encoder is unmasked self-attention
1341
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
1342
+ mask_attend = mask.unsqueeze(-1) * mask_attend
1343
+ for layer in self.encoder_layers:
1344
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
1345
+
1346
+ # Build encoder embeddings
1347
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx)
1348
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
1349
+
1350
+ order_mask_backward = torch.zeros([X.shape[0], X.shape[1], X.shape[1]], device=device)
1351
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
1352
+ mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
1353
+ mask_bw = mask_1D * mask_attend
1354
+ mask_fw = mask_1D * (1. - mask_attend)
1355
+
1356
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
1357
+ for layer in self.decoder_layers:
1358
+ h_V = layer(h_V, h_EXV_encoder_fw, mask)
1359
+
1360
+ logits = self.W_out(h_V)
1361
+ log_probs = F.log_softmax(logits, dim=-1)
1362
+ return log_probs
1363
+