bshor commited on
Commit
446e400
1 Parent(s): bca3a49

add script used to generate dataset

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Ignore __pycache__ folders
2
+ __pycache__/
3
+ .idea/
4
+
5
+ # Ignore .DS_Store files
6
+ .DS_Store
env_consts.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  TEST_INPUT_DIR = None
4
  TEST_OUTPUT_DIR = None
5
  THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__))
6
- CKPT_PATH = os.path.join(THIS_FILE_DIR, "resources", "only_weights_87-172000.ckpt")
7
  RUN_CONFIG_PATH = os.path.join(THIS_FILE_DIR, "resources", "run_config.json")
8
 
9
  OUTPUT_PROT_PATH = os.path.join(THIS_FILE_DIR, "predicted_protein_out.pdb")
 
3
  TEST_INPUT_DIR = None
4
  TEST_OUTPUT_DIR = None
5
  THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__))
6
+ CKPT_PATH = os.path.join(THIS_FILE_DIR, "resources", "only_weights_107-187000.ckpt")
7
  RUN_CONFIG_PATH = os.path.join(THIS_FILE_DIR, "resources", "run_config.json")
8
 
9
  OUTPUT_PROT_PATH = os.path.join(THIS_FILE_DIR, "predicted_protein_out.pdb")
prepare_plinder_dataset.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ import random
5
+ import sys
6
+ import time
7
+ from typing import List, Tuple, Optional
8
+
9
+ import Bio.PDB
10
+ import Bio.SeqUtils
11
+ import pandas as pd
12
+ import numpy as np
13
+ import requests
14
+ from rdkit import Chem
15
+ from rdkit.Chem import AllChem
16
+
17
+
18
+ BASE_FOLDER = "/tmp/"
19
+
20
+ OUTPUT_FOLDER = f"{BASE_FOLDER}/processed"
21
+ # https://storage.googleapis.com/plinder/2024-06/v2/index/annotation_table.parquet
22
+ PLINDER_ANNOTATIONS = f'{BASE_FOLDER}/raw_data/2024-06_v2_index_annotation_table.parquet'
23
+ # https://storage.googleapis.com/plinder/2024-06/v2/splits/split.parquet
24
+ PLINDER_SPLITS = f'{BASE_FOLDER}/raw_data/2024-06_v2_splits_split.parquet'
25
+
26
+ # https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dapo/links.parquet
27
+ PLINDER_LINKED_APO_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=apo_links.parquet"
28
+ # https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dpred/links.parquet
29
+ PLINDER_LINKED_PRED_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=pred_links.parquet"
30
+ # https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/apo.zip
31
+ PLINDER_LINKED_APO_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_apo"
32
+ # https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/pred.zip
33
+ PLINDER_LINKED_PRED_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_pred"
34
+ GSUTIL_PATH = f"{BASE_FOLDER}/google-cloud-sdk/bin/gsutil"
35
+
36
+
37
+
38
+ def get_cached_systems_to_train(recompute=False):
39
+ output_path = os.path.join(OUTPUT_FOLDER, "to_train.pickle")
40
+ if os.path.exists(output_path) and not recompute:
41
+ return pd.read_pickle(output_path)
42
+
43
+ """
44
+ full:
45
+ loaded 1357906 409726 163816 433865
46
+ loaded 990260 409726 125818 106411
47
+ joined splits 409726
48
+ Has splits 311008
49
+ unique systems 311008
50
+ split
51
+ train 309140
52
+ test 1036
53
+ val 832
54
+ Name: count, dtype: int64
55
+ Has affinity 36856
56
+ Has affinity by splits split
57
+ train 36598
58
+ test 142
59
+ val 116
60
+ Name: count, dtype: int64
61
+ Total systems before pred 311008
62
+ Total systems after pred 311008
63
+ Has pred 83487
64
+ Has apo 75127
65
+ Has both 51506
66
+ Has either 107108
67
+ columns Index(['system_id', 'entry_pdb_id', 'ligand_binding_affinity',
68
+ 'entry_release_date', 'system_pocket_UniProt',
69
+ 'system_num_protein_chains', 'system_num_ligand_chains', 'uniqueness',
70
+ 'split', 'cluster', 'cluster_for_val_split',
71
+ 'system_pass_validation_criteria', 'system_pass_statistics_criteria',
72
+ 'system_proper_num_ligand_chains', 'system_proper_pocket_num_residues',
73
+ 'system_proper_num_interactions',
74
+ 'system_proper_ligand_max_molecular_weight',
75
+ 'system_has_binding_affinity', 'system_has_apo_or_pred', '_bucket_id',
76
+ 'linked_pred_id', 'linked_apo_id'],
77
+ dtype='object')
78
+ total systems 311008
79
+ """
80
+
81
+ systems = pd.read_parquet(PLINDER_ANNOTATIONS,
82
+ columns=['system_id', 'entry_pdb_id', 'ligand_binding_affinity',
83
+ 'entry_release_date', 'system_pocket_UniProt', 'entry_resolution',
84
+ 'system_num_protein_chains', 'system_num_ligand_chains'])
85
+ splits = pd.read_parquet(PLINDER_SPLITS)
86
+ linked_pred = pd.read_parquet(PLINDER_LINKED_PRED_MAP)
87
+ linked_apo = pd.read_parquet(PLINDER_LINKED_APO_MAP)
88
+
89
+ print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo))
90
+
91
+ # remove duplicated
92
+ systems = systems.drop_duplicates(subset=['system_id'])
93
+ splits = splits.drop_duplicates(subset=['system_id'])
94
+ linked_pred = linked_pred.drop_duplicates(subset=['reference_system_id'])
95
+ linked_apo = linked_apo.drop_duplicates(subset=['reference_system_id'])
96
+ print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo))
97
+
98
+ # join splits
99
+ systems = pd.merge(systems, splits, on='system_id', how='inner')
100
+ print("joined splits", len(systems))
101
+
102
+ systems['_bucket_id'] = systems['entry_pdb_id'].str[1:3]
103
+
104
+ # leave only with train/val/test splits
105
+ systems = systems[systems['split'].isin(['train', 'val', 'test'])]
106
+
107
+ print("Has splits", len(systems))
108
+ print("unique systems", systems['system_id'].nunique())
109
+ print(systems["split"].value_counts())
110
+
111
+ print("Has affinity", len(systems[systems['ligand_binding_affinity'].notna()]))
112
+
113
+ # print has affinity by splits
114
+ print("Has affinity by splits", systems[systems['ligand_binding_affinity'].notna()]['split'].value_counts())
115
+
116
+ print("Total systems before pred", len(systems))
117
+ # join linked structures - allow to not have linked structures
118
+ systems = pd.merge(systems, linked_pred[['reference_system_id', 'id']],
119
+ left_on='system_id', right_on='reference_system_id',
120
+ how='left')
121
+ print("Total systems after pred", len(systems))
122
+
123
+ # Rename the 'id' column from linked_pred to 'linked_pred_id'
124
+ systems.rename(columns={'id': 'linked_pred_id'}, inplace=True)
125
+
126
+ # Merge the result with linked_apo on the same condition
127
+ systems = pd.merge(systems, linked_apo[['reference_system_id', 'id']],
128
+ left_on='system_id', right_on='reference_system_id',
129
+ how='left')
130
+
131
+ # Rename the 'id' column from linked_apo to 'linked_apo_id'
132
+ systems.rename(columns={'id': 'linked_apo_id'}, inplace=True)
133
+
134
+ # Drop the reference_system_id columns that were added during the merge
135
+ systems.drop(columns=['reference_system_id_x', 'reference_system_id_y'], inplace=True)
136
+
137
+ cluster_sizes = systems["cluster"].value_counts()
138
+ systems["cluster_size"] = systems["cluster"].map(cluster_sizes)
139
+ # print(systems[['system_id', 'cluster', 'cluster_size']])
140
+
141
+ print("Has pred", systems['linked_pred_id'].notna().sum())
142
+ print("Has apo", systems['linked_apo_id'].notna().sum())
143
+ print("Has both", (systems['linked_pred_id'].notna() & systems['linked_apo_id'].notna()).sum())
144
+ print("Has either", (systems['linked_pred_id'].notna() | systems['linked_apo_id'].notna()).sum())
145
+
146
+ print("columns", systems.columns)
147
+
148
+ systems.to_pickle(output_path)
149
+ return systems
150
+
151
+
152
+ def create_conformers(smiles, output_path, num_conformers=100, multiplier_samples=1):
153
+ target_mol = Chem.MolFromSmiles(smiles)
154
+ target_mol = Chem.AddHs(target_mol)
155
+
156
+ params = AllChem.ETKDGv3()
157
+ params.numThreads = 0 # Use all available threads
158
+ params.pruneRmsThresh = 0.1 # Pruning threshold for RMSD
159
+ conformer_ids = AllChem.EmbedMultipleConfs(target_mol, numConfs=num_conformers * multiplier_samples, params=params)
160
+
161
+ # Optional: Optimize each conformer using MMFF94 force field
162
+ # for conf_id in conformer_ids:
163
+ # AllChem.UFFOptimizeMolecule(target_mol, confId=conf_id)
164
+
165
+ # remove hydrogen atoms
166
+ target_mol = Chem.RemoveHs(target_mol)
167
+
168
+ # Save aligned conformers to a file (optional)
169
+ w = Chem.SDWriter(output_path)
170
+ for i, conf_id in enumerate(conformer_ids):
171
+ if i >= num_conformers:
172
+ break
173
+ w.write(target_mol, confId=conf_id)
174
+ w.close()
175
+
176
+
177
+ def do_robust_chain_object_renumber(chain: Bio.PDB.Chain.Chain, new_chain_id: str) -> Optional[Bio.PDB.Chain.Chain]:
178
+ all_residues = [res for res in chain.get_residues()
179
+ if "CA" in res and Bio.SeqUtils.seq1(res.get_resname()) not in ("X", "", " ")]
180
+ if not all_residues:
181
+ return None
182
+
183
+ res_and_res_id = [(res, res.get_id()[1]) for res in all_residues]
184
+
185
+ min_res_id = min([i[1] for i in res_and_res_id])
186
+ if min_res_id < 1:
187
+ print("Negative res id", chain, min_res_id)
188
+ factor = -1 * min_res_id + 1
189
+ res_and_res_id = [(res, res_id + factor) for res, res_id in res_and_res_id]
190
+
191
+ res_and_res_id_no_collisions = []
192
+ for res, res_id in res_and_res_id[::-1]:
193
+ if res_and_res_id_no_collisions and res_and_res_id_no_collisions[-1][1] == res_id:
194
+ # there is a collision, usually an insertion residue
195
+ res_and_res_id_no_collisions = [(i, j + 1) for i, j in res_and_res_id_no_collisions]
196
+ res_and_res_id_no_collisions.append((res, res_id))
197
+
198
+ first_res_id = min([i[1] for i in res_and_res_id_no_collisions])
199
+ factor = 1 - first_res_id # start from 1
200
+ new_chain = Bio.PDB.Chain.Chain(new_chain_id)
201
+
202
+ res_and_res_id_no_collisions.sort(key=lambda x: x[1])
203
+
204
+ for res, res_id in res_and_res_id_no_collisions:
205
+ chain.detach_child(res.id)
206
+ res.id = (" ", res_id + factor, " ")
207
+ new_chain.add(res)
208
+
209
+ return new_chain
210
+
211
+
212
+ def robust_renumber_protein(pdb_path: str, output_path: str):
213
+ if pdb_path.endswith(".pdb"):
214
+ pdb_parser = Bio.PDB.PDBParser(QUIET=True)
215
+ pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path)
216
+ elif pdb_path.endswith(".cif"):
217
+ pdb_struct = Bio.PDB.MMCIFParser().get_structure("original_pdb", pdb_path)
218
+ else:
219
+ raise ValueError("Unknown file type", pdb_path)
220
+ assert len(list(pdb_struct)) == 1, "can't extract if more than one model"
221
+ model = next(iter(pdb_struct))
222
+ chains = list(model.get_chains())
223
+ new_model = Bio.PDB.Model.Model(0)
224
+ chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
225
+ for chain, chain_id in zip(chains, chain_ids):
226
+ new_chain = do_robust_chain_object_renumber(chain, chain_id)
227
+ if new_chain is None:
228
+ continue
229
+ new_model.add(new_chain)
230
+ new_struct = Bio.PDB.Structure.Structure("renumbered_pdb")
231
+ new_struct.add(new_model)
232
+ io = Bio.PDB.PDBIO()
233
+ io.set_structure(new_struct)
234
+ io.save(output_path)
235
+
236
+
237
+ def _get_extra(extra_to_save: int, res_before: List[int], res_after: List[int]) -> set:
238
+ take_from_before = random.randint(0, extra_to_save)
239
+ take_from_after = extra_to_save - take_from_before
240
+ if take_from_before > len(res_before):
241
+ take_from_after = extra_to_save - len(res_before)
242
+ take_from_before = len(res_before)
243
+ if take_from_after > len(res_after):
244
+ take_from_before = extra_to_save - len(res_after)
245
+ take_from_after = len(res_after)
246
+
247
+ extra_to_add = set()
248
+ if take_from_before > 0:
249
+ extra_to_add.update(res_before[-take_from_before:])
250
+ extra_to_add.update(res_after[:take_from_after])
251
+
252
+ return extra_to_add
253
+
254
+
255
+ def crop_protein_cont(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int,
256
+ distance_threshold: float):
257
+ protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False)
258
+ ligand_size = ligand_pos.shape[0]
259
+
260
+ pdb_parser = Bio.PDB.PDBParser(QUIET=True)
261
+ gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path)))
262
+
263
+ all_res_ids_by_chain = {chain.id: sorted([res.id[1] for res in chain.get_residues() if "CA" in res])
264
+ for chain in gt_model.get_chains()}
265
+
266
+ protein_conf = protein.GetConformer()
267
+ protein_pos = protein_conf.GetPositions()
268
+ protein_atoms = list(protein.GetAtoms())
269
+ assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}"
270
+
271
+ inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :]
272
+ inter_dists = np.sqrt((inter_dists ** 2).sum(-1))
273
+ min_inter_dist_per_protein_atom = inter_dists.min(axis=0)
274
+
275
+ res_to_save_count = max_length - ligand_size
276
+
277
+ used_protein_idx = np.where(min_inter_dist_per_protein_atom < distance_threshold)[0]
278
+ pocket_residues_by_chain = {}
279
+ for idx in used_protein_idx:
280
+ res = protein_atoms[idx].GetPDBResidueInfo()
281
+ if res.GetIsHeteroAtom():
282
+ continue
283
+ if res.GetChainId() not in pocket_residues_by_chain:
284
+ pocket_residues_by_chain[res.GetChainId()] = set()
285
+ # get residue chain
286
+ pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber())
287
+
288
+ if not pocket_residues_by_chain:
289
+ print("No pocket residues found")
290
+ return -1
291
+
292
+ # print("pocket_residues_by_chain", pocket_residues_by_chain)
293
+
294
+ complete_pocket = []
295
+ extended_pocket_per_chain = {}
296
+ for chain_id, pocket_residues in pocket_residues_by_chain.items():
297
+ max_pocket_res = max(pocket_residues)
298
+ min_pocket_res = min(pocket_residues)
299
+
300
+ extended_pocket_per_chain[chain_id] = {res_id for res_id in all_res_ids_by_chain[chain_id]
301
+ if min_pocket_res <= res_id <= max_pocket_res}
302
+ for res_id in extended_pocket_per_chain[chain_id]:
303
+ complete_pocket.append((chain_id, res_id))
304
+
305
+ # print("extended_pocket_per_chain", pocket_residues_by_chain)
306
+
307
+ if len(complete_pocket) > res_to_save_count:
308
+ total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()])
309
+ total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()])
310
+ print(f"Too many residues all: {total_res_ids} pocket:{total_pocket_res} {len(complete_pocket)} "
311
+ f"(ligand size: {ligand_size})")
312
+ return -1
313
+
314
+ extra_to_save = res_to_save_count - len(complete_pocket)
315
+
316
+ # divide extra_to_save between chains
317
+ for chain_id, pocket_residues in extended_pocket_per_chain.items():
318
+ extra_to_save_per_chain = extra_to_save // len(extended_pocket_per_chain)
319
+ res_before = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id < min(pocket_residues)]
320
+ res_after = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id > max(pocket_residues)]
321
+ extra_to_add = _get_extra(extra_to_save_per_chain, res_before, res_after)
322
+ for res_id in extra_to_add:
323
+ complete_pocket.append((chain_id, res_id))
324
+
325
+ total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()])
326
+ total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()])
327
+ total_extended_res = sum([len(res_ids) for res_ids in extended_pocket_per_chain.values()])
328
+ print(f"Found valid pocket all: {total_res_ids} pocket:{total_pocket_res} {total_extended_res} "
329
+ f"{len(complete_pocket)} (ligand size: {ligand_size}) extra: {extra_to_save}")
330
+ # print("all_res_ids_by_chain", all_res_ids_by_chain)
331
+ # print("complete_pocket", sorted(complete_pocket))
332
+
333
+ res_to_remove = []
334
+ for res in gt_model.get_residues():
335
+ if (res.parent.id, res.id[1]) not in complete_pocket or res.id[0].strip() != "" or res.id[2].strip() != "":
336
+ res_to_remove.append(res)
337
+ for res in res_to_remove:
338
+ gt_model[res.parent.id].detach_child(res.id)
339
+
340
+ io = Bio.PDB.PDBIO()
341
+ io.set_structure(gt_model)
342
+ io.save(output_path)
343
+
344
+ return len(complete_pocket)
345
+
346
+
347
+ def crop_protein_simple(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int):
348
+ protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False)
349
+ ligand_size = ligand_pos.shape[0]
350
+ res_to_save_count = max_length - ligand_size
351
+
352
+ pdb_parser = Bio.PDB.PDBParser(QUIET=True)
353
+ gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path)))
354
+
355
+ protein_conf = protein.GetConformer()
356
+ protein_pos = protein_conf.GetPositions()
357
+ protein_atoms = list(protein.GetAtoms())
358
+ assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}"
359
+
360
+ inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :]
361
+ inter_dists = np.sqrt((inter_dists ** 2).sum(-1))
362
+ min_inter_dist_per_protein_atom = inter_dists.min(axis=0)
363
+
364
+ protein_idx_by_dist = np.argsort(min_inter_dist_per_protein_atom)
365
+ pocket_residues_by_chain = {}
366
+ total_found = 0
367
+ for idx in protein_idx_by_dist:
368
+ res = protein_atoms[idx].GetPDBResidueInfo()
369
+ if res.GetIsHeteroAtom():
370
+ continue
371
+
372
+ if res.GetChainId() not in pocket_residues_by_chain:
373
+ pocket_residues_by_chain[res.GetChainId()] = set()
374
+ # get residue chain
375
+ pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber())
376
+ total_found = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()])
377
+ if total_found >= res_to_save_count:
378
+ break
379
+ print("saved with simple", total_found)
380
+
381
+ if not pocket_residues_by_chain:
382
+ print("No pocket residues found")
383
+ return -1
384
+
385
+ res_to_remove = []
386
+ for res in gt_model.get_residues():
387
+ if res.id[1] not in pocket_residues_by_chain.get(res.parent.id, set()) \
388
+ or res.id[0].strip() != "" or res.id[2].strip() != "":
389
+ res_to_remove.append(res)
390
+ for res in res_to_remove:
391
+ gt_model[res.parent.id].detach_child(res.id)
392
+
393
+ io = Bio.PDB.PDBIO()
394
+ io.set_structure(gt_model)
395
+ io.save(output_path)
396
+
397
+ return total_found
398
+
399
+
400
+ def cif_to_pdb(cif_path: str, pdb_path: str):
401
+ protein = Bio.PDB.MMCIFParser().get_structure("s_cif", cif_path)
402
+ io = Bio.PDB.PDBIO()
403
+ io.set_structure(protein)
404
+ io.save(pdb_path)
405
+
406
+
407
+ def get_chain_object_to_seq(chain: Bio.PDB.Chain.Chain) -> str:
408
+ res_id_to_res = {res.get_id()[1]: res for res in chain.get_residues() if "CA" in res}
409
+
410
+ if len(res_id_to_res) == 0:
411
+ print("skipping empty chain", chain.get_id())
412
+ return ""
413
+ seq = ""
414
+ for i in range(1, max(res_id_to_res) + 1):
415
+ if i in res_id_to_res:
416
+ seq += Bio.SeqUtils.seq1(res_id_to_res[i].get_resname())
417
+ else:
418
+ seq += "X"
419
+ return seq
420
+
421
+
422
+ def get_sequence_from_pdb(pdb_path: str) -> Tuple[str, List[int]]:
423
+ pdb_parser = Bio.PDB.PDBParser(QUIET=True)
424
+ pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path)
425
+ # chain_to_seq = {chain.id: get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()}
426
+ all_chain_seqs = [ get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()]
427
+ chain_lengths = [len(seq) for seq in all_chain_seqs]
428
+ return ("X" * 20).join(all_chain_seqs), chain_lengths
429
+
430
+
431
+ from Bio import PDB
432
+ from Bio import pairwise2
433
+
434
+
435
+ def extract_sequence(chain):
436
+ seq = ''
437
+ residues = []
438
+ for res in chain.get_residues():
439
+ seq_res = Bio.SeqUtils.seq1(res.get_resname())
440
+ if seq_res in ('X', "", " "):
441
+ continue
442
+ seq += seq_res
443
+ residues.append(res)
444
+ return seq, residues
445
+
446
+
447
+ def map_residues(alignment, residues_gt, residues_pred):
448
+ idx_gt = 0
449
+ idx_pred = 0
450
+ mapping = []
451
+ for i in range(len(alignment.seqA)):
452
+ aa_gt = alignment.seqA[i]
453
+ aa_pred = alignment.seqB[i]
454
+ res_gt = None
455
+ res_pred = None
456
+ if aa_gt != '-':
457
+ res_gt = residues_gt[idx_gt]
458
+ idx_gt += 1
459
+ if aa_pred != '-':
460
+ res_pred = residues_pred[idx_pred]
461
+ idx_pred +=1
462
+ if res_gt and res_pred:
463
+ mapping.append((res_gt, res_pred))
464
+ return mapping
465
+
466
+
467
+ class ResidueSelect(PDB.Select):
468
+ def __init__(self, residues_to_select):
469
+ self.residues_to_select = set(residues_to_select)
470
+
471
+ def accept_residue(self, residue):
472
+ return residue in self.residues_to_select
473
+
474
+
475
+ def align_gt_and_input(gt_pdb_path, input_pdb_path, output_gt_path, output_input_path):
476
+ parser = PDB.PDBParser(QUIET=True)
477
+ gt_structure = parser.get_structure('gt', gt_pdb_path)
478
+ pred_structure = parser.get_structure('pred', input_pdb_path)
479
+ matched_residues_gt = []
480
+ matched_residues_pred = []
481
+
482
+ used_chain_pred = []
483
+ total_mapping_size = 0
484
+ for chain_gt in gt_structure.get_chains():
485
+ seq_gt, residues_gt = extract_sequence(chain_gt)
486
+ best_alignment = None
487
+ best_chain_pred = None
488
+ best_score = -1
489
+ best_residues_pred = None
490
+ # Find the best matching chain in pred
491
+ for chain_pred in pred_structure.get_chains():
492
+ print("checking", chain_pred.get_id(), chain_gt.get_id())
493
+ if chain_pred in used_chain_pred:
494
+ continue
495
+ seq_pred, residues_pred = extract_sequence(chain_pred)
496
+ print(seq_gt)
497
+ print(seq_pred)
498
+ alignments = pairwise2.align.globalxx(seq_gt, seq_pred, one_alignment_only=True)
499
+ if not alignments:
500
+ continue
501
+ print("checking2", chain_pred.get_id(), chain_gt.get_id())
502
+
503
+ alignment = alignments[0]
504
+ score = alignment.score
505
+ if score > best_score:
506
+ best_score = score
507
+ best_alignment = alignment
508
+ best_chain_pred = chain_pred
509
+ best_residues_pred = residues_pred
510
+ if best_alignment:
511
+ mapping = map_residues(best_alignment, residues_gt, best_residues_pred)
512
+ total_mapping_size += len(mapping)
513
+ used_chain_pred.append(best_chain_pred)
514
+ for res_gt, res_pred in mapping:
515
+ matched_residues_gt.append(res_gt)
516
+ matched_residues_pred.append(res_pred)
517
+ else:
518
+ print(f"No matching chain found for chain {chain_gt.get_id()}")
519
+ print(f"Total mapping size: {total_mapping_size}")
520
+
521
+ # Write new PDB files with only matched residues
522
+ io = PDB.PDBIO()
523
+ io.set_structure(gt_structure)
524
+ io.save(output_gt_path, ResidueSelect(matched_residues_gt))
525
+ io.set_structure(pred_structure)
526
+ io.save(output_input_path, ResidueSelect(matched_residues_pred))
527
+
528
+
529
+ def validate_matching_input_gt(gt_pdb_path, input_pdb_path):
530
+ gt_residues = [res for res in PDB.PDBParser().get_structure('gt', gt_pdb_path).get_residues()]
531
+ input_residues = [res for res in PDB.PDBParser().get_structure('input', input_pdb_path).get_residues()]
532
+
533
+ if len(gt_residues) != len(input_residues):
534
+ print(f"Residue count mismatch: {len(gt_residues)} vs {len(input_residues)}")
535
+ return -1
536
+
537
+ for res_gt, res_input in zip(gt_residues, input_residues):
538
+ if res_gt.get_resname() != res_input.get_resname():
539
+ print(f"Residue name mismatch: {res_gt.get_resname()} vs {res_input.get_resname()}")
540
+ return -1
541
+ return len(input_residues)
542
+
543
+
544
+ def prepare_system(row, system_folder, output_models_folder, output_jsons_folder, should_overwrite=False):
545
+ output_json_path = os.path.join(output_jsons_folder, f"{row['system_id']}.json")
546
+ if os.path.exists(output_json_path) and not should_overwrite:
547
+ return "Already exists"
548
+
549
+ plinder_gt_pdb_path = os.path.join(system_folder, f"receptor.pdb")
550
+ plinder_gt_ligand_paths = []
551
+ plinder_gt_ligands_folder = os.path.join(system_folder, "ligand_files")
552
+
553
+ gt_output_path = os.path.join(output_models_folder, f"{row['system_id']}_gt.pdb")
554
+ gt_output_relative_path = "plinder_models/" + f"{row['system_id']}_gt.pdb"
555
+
556
+ tmp_input_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_input.pdb")
557
+ protein_input_path = os.path.join(output_models_folder, f"{row['system_id']}_input.pdb")
558
+ protein_input_relative_path = "plinder_models/" + f"{row['system_id']}_input.pdb"
559
+
560
+ print("Copying ground truth files")
561
+ if not os.path.exists(plinder_gt_pdb_path):
562
+ print("no receptor", plinder_gt_pdb_path)
563
+ return "No receptor"
564
+
565
+ tmp_gt_pdb_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_gt.pdb")
566
+ robust_renumber_protein(plinder_gt_pdb_path, tmp_gt_pdb_path)
567
+
568
+ ligand_pos_list = []
569
+ for ligand_file in os.listdir(plinder_gt_ligands_folder):
570
+ if not ligand_file.endswith(".sdf"):
571
+ continue
572
+ plinder_gt_ligand_paths.append(os.path.join(plinder_gt_ligands_folder, ligand_file))
573
+ loaded_ligand = Chem.MolFromMolFile(os.path.join(plinder_gt_ligands_folder, ligand_file))
574
+ ligand_pos_list.append(loaded_ligand.GetConformer().GetPositions())
575
+ if loaded_ligand is None:
576
+ print("failed to load", plinder_gt_ligand_paths[-1])
577
+ return "Failed to load ligand"
578
+
579
+ # Crop ground truth protein, also removes insertion codes
580
+ ligand_pos = np.concatenate(ligand_pos_list, axis=0)
581
+
582
+ res_count_in_protein = crop_protein_cont(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350,
583
+ distance_threshold=5)
584
+ if res_count_in_protein == -1:
585
+ print("Failed to crop protein continously, using simple crop")
586
+ crop_protein_simple(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350)
587
+
588
+ os.remove(tmp_gt_pdb_path)
589
+
590
+ # Generate input protein structure
591
+ input_protein_source = None
592
+ if pd.notna(row["linked_apo_id"]):
593
+ apo_pdb_path = os.path.join(PLINDER_LINKED_APO_STRUCTURES, f"{row['linked_apo_id']}.cif")
594
+ try:
595
+ robust_renumber_protein(apo_pdb_path, tmp_input_path)
596
+ input_protein_source = "apo"
597
+ print("Using input apo", row['linked_apo_id'])
598
+ except Exception as e:
599
+ print("Problem with apo", e, row["linked_apo_id"], apo_pdb_path)
600
+ if not os.path.exists(tmp_input_path) and pd.notna(row["linked_pred_id"]):
601
+ pred_pdb_path = os.path.join(PLINDER_LINKED_PRED_STRUCTURES, f"{row['linked_pred_id']}.cif")
602
+ try:
603
+ # cif_to_pdb(pred_pdb_path, tmp_input_path)
604
+ robust_renumber_protein(pred_pdb_path, tmp_input_path)
605
+ input_protein_source = "pred"
606
+ print("Using input pred", row['linked_pred_id'])
607
+ except:
608
+ print("Problem with pred")
609
+ if not os.path.exists(tmp_input_path):
610
+ print("No linked structure found, running ESM")
611
+ url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
612
+ sequence, chain_lengths = get_sequence_from_pdb(gt_output_path)
613
+ if len(sequence) <= 400:
614
+ try:
615
+ response = requests.post(url, data=sequence)
616
+ response.raise_for_status()
617
+ pdb_text = response.text
618
+ with open(tmp_input_path, "w") as f:
619
+ f.write(pdb_text)
620
+
621
+ # divide to chains
622
+ if len(chain_lengths) > 1:
623
+ pdb_parser = Bio.PDB.PDBParser(QUIET=True)
624
+ pdb_struct = pdb_parser.get_structure("original_pdb", tmp_input_path)
625
+ pdb_model = next(iter(pdb_struct))
626
+ chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[:len(chain_lengths)]
627
+ start_ind = 1
628
+ esm_chain = next(pdb_model.get_chains())
629
+ new_model = Bio.PDB.Model.Model(0)
630
+ for chain_length, chain_id in zip(chain_lengths, chain_ids):
631
+ end_ind = start_ind + chain_length
632
+ new_chain = Bio.PDB.Chain.Chain(chain_id)
633
+ for res in esm_chain.get_residues():
634
+ if start_ind <= res.id[1] <= end_ind:
635
+ new_chain.add(res)
636
+ new_model.add(new_chain)
637
+ start_ind = end_ind + 20 # 20 is the gap in esm
638
+ io = Bio.PDB.PDBIO()
639
+ io.set_structure(new_model)
640
+ io.save(tmp_input_path)
641
+
642
+ input_protein_source = "esm"
643
+ print("Using input ESM")
644
+ except requests.exceptions.RequestException as e:
645
+ print(f"An error occurred in ESM: {e}")
646
+ # return "No linked structure found"
647
+ else:
648
+ print("Sequence too long for ESM")
649
+ if not os.path.exists(tmp_input_path):
650
+ print("Using input GT")
651
+ shutil.copyfile(gt_output_path, tmp_input_path)
652
+ input_protein_source = "gt"
653
+
654
+ align_gt_and_input(gt_output_path, tmp_input_path, gt_output_path, protein_input_path)
655
+ protein_size = validate_matching_input_gt(gt_output_path, protein_input_path)
656
+ assert protein_size > -1, "Failed to validate matching input and gt"
657
+ os.remove(tmp_input_path)
658
+
659
+ rel_gt_lig_paths = []
660
+ rel_ref_lig_paths = []
661
+ input_smiles = []
662
+ for i, ligand_path in enumerate(sorted(plinder_gt_ligand_paths)):
663
+ gt_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_gt_{i}.sdf")
664
+ # rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ref_ligand_{i}.sdf")
665
+ rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_gt_{i}.sdf")
666
+ shutil.copyfile(ligand_path, gt_ligand_output_path)
667
+
668
+ loaded_ligand = Chem.MolFromMolFile(gt_ligand_output_path)
669
+ input_smiles.append(Chem.MolToSmiles(loaded_ligand))
670
+
671
+ ref_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_ref_{i}.sdf")
672
+ rel_ref_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_ref_{i}.sdf")
673
+ create_conformers(input_smiles[-1], ref_ligand_output_path, num_conformers=1)
674
+ # check if file is empty
675
+ if os.path.getsize(ref_ligand_output_path) == 0:
676
+ print("Empty ref ligand, copying from gt", ref_ligand_output_path)
677
+ shutil.copyfile(gt_ligand_output_path, ref_ligand_output_path)
678
+
679
+ affinity = row["ligand_binding_affinity"]
680
+ if not pd.notna(affinity):
681
+ affinity = None
682
+
683
+ json_data = {
684
+ "input_structure": protein_input_relative_path,
685
+ "gt_structure": gt_output_relative_path,
686
+ "gt_sdf_list": rel_gt_lig_paths,
687
+ "input_smiles_list": input_smiles,
688
+ "resolution": row.fillna(99)["entry_resolution"],
689
+ "release_year": row["entry_release_date"],
690
+ "affinity": affinity,
691
+ "protein_seq_len": protein_size,
692
+ "uniprot": row["system_pocket_UniProt"],
693
+ "ligand_num_atoms": ligand_pos.shape[0],
694
+ "cluster": row["cluster"],
695
+ "cluster_size": row["cluster_size"],
696
+ "input_protein_source": input_protein_source,
697
+ "ref_sdf_list": rel_ref_lig_paths,
698
+ "pdb_id": row["system_id"],
699
+ }
700
+ open(output_json_path, "w").write(json.dumps(json_data, indent=4))
701
+
702
+ return "success"
703
+
704
+ # use linked structures
705
+ # input_structure_to_use = None
706
+ # apo_linked_structure = os.path.join(linked_structures_folder, "apo", system_id)
707
+ # pred_linked_structure = os.path.join(linked_structures_folder, "pred", system_id)
708
+ # if os.path.exists(apo_linked_structure):
709
+ # for folder in os.listdir(apo_linked_structure):
710
+ # if not os.path.isdir(os.path.join(pred_linked_structure, folder)):
711
+ # continue
712
+ # for filename in os.listdir(os.path.join(apo_linked_structure, folder)):
713
+ # if filename.endswith(".cif"):
714
+ # input_structure_to_use = os.path.join(apo_linked_structure, folder, filename)
715
+ # break
716
+ # if input_structure_to_use:
717
+ # break
718
+ # print(system_id, "found apo", input_structure_to_use)
719
+ # elif os.path.exists(pred_linked_structure):
720
+ # for folder in os.listdir(pred_linked_structure):
721
+ # if not os.path.isdir(os.path.join(pred_linked_structure, folder)):
722
+ # continue
723
+ # for filename in os.listdir(os.path.join(pred_linked_structure, folder)):
724
+ # if filename.endswith(".cif"):
725
+ # input_structure_to_use = os.path.join(pred_linked_structure, folder, filename)
726
+ # break
727
+ # if input_structure_to_use:
728
+ # break
729
+ # print(system_id, "found pred", input_structure_to_use)
730
+ # else:
731
+ # print(system_id, "no linked structure found")
732
+ # return "No linked structure found"
733
+
734
+
735
+ def main(prefix_bucket_id: str = "*"):
736
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
737
+ systems = get_cached_systems_to_train()
738
+ print("total systems", len(systems))
739
+
740
+ print("clusters", systems["cluster"].value_counts())
741
+
742
+ # systems = systems[systems["system_num_protein_chains"] > 1]
743
+ # return
744
+
745
+ print("splits", systems["split"].value_counts())
746
+ val_or_test = systems[(systems["split"] == "val") | (systems["split"] == "test")]
747
+ print("validation or test", len(val_or_test))
748
+
749
+ output_models_folder = os.path.join(OUTPUT_FOLDER, "plinder_models")
750
+ output_train_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_train")
751
+ output_val_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_val")
752
+ output_test_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_test")
753
+ output_info = os.path.join(OUTPUT_FOLDER, "plinder_generation_info.csv")
754
+ if prefix_bucket_id != "*":
755
+ output_info = os.path.join(OUTPUT_FOLDER, f"plinder_generation_info_{prefix_bucket_id}.csv")
756
+
757
+ os.makedirs(output_models_folder, exist_ok=True)
758
+ os.makedirs(output_train_jsons_folder, exist_ok=True)
759
+ os.makedirs(output_val_jsons_folder, exist_ok=True)
760
+ os.makedirs(output_test_jsons_folder, exist_ok=True)
761
+
762
+ split_to_folder = {
763
+ "train": output_train_jsons_folder,
764
+ "val": output_val_jsons_folder,
765
+ "test": output_test_jsons_folder
766
+ }
767
+
768
+ output_info_file = open(output_info, "a+")
769
+
770
+ for bucket_id, bucket_systems in systems.groupby('_bucket_id', sort=True):
771
+ if prefix_bucket_id != "*" and not str(bucket_id).startswith(prefix_bucket_id):
772
+ continue
773
+ # if bucket_id != "z2":
774
+ # continue
775
+ # systems_folder = "{BASE_FOLDER}/processed/tmp_z2/systems"
776
+
777
+ print("Starting bucket", bucket_id, len(bucket_systems))
778
+ print(len(bucket_systems), bucket_systems["system_num_ligand_chains"].value_counts())
779
+
780
+ tmp_output_models_folder = os.path.join(OUTPUT_FOLDER, f"tmp_{bucket_id}")
781
+ os.makedirs(tmp_output_models_folder, exist_ok=True)
782
+ os.system(f'{GSUTIL_PATH} -m cp -r "gs://plinder/2024-06/v2/systems/{bucket_id}.zip" {tmp_output_models_folder}')
783
+ systems_folder = os.path.join(tmp_output_models_folder, "systems")
784
+ os.system(f'unzip -o {os.path.join(tmp_output_models_folder, f"{bucket_id}.zip")} -d {systems_folder}')
785
+
786
+ for i, row in bucket_systems.iterrows():
787
+ # if not str(row['system_id']).startswith("4z22__1__1.A__1.C"):
788
+ # continue
789
+ print("doing", row['system_id'], row["system_num_protein_chains"], row["system_num_ligand_chains"])
790
+ system_folder = os.path.join(systems_folder, row['system_id'])
791
+ try:
792
+ success = prepare_system(row, system_folder, output_models_folder, split_to_folder[row["split"]])
793
+ print("done", row['system_id'], success)
794
+ output_info_file.write(f"{bucket_id},{row['system_id']},{success}\n")
795
+ except Exception as e:
796
+ print("Failed", row['system_id'], e)
797
+ output_info_file.write(f"{bucket_id},{row['system_id']},Failed\n")
798
+ output_info_file.flush()
799
+
800
+ shutil.rmtree(tmp_output_models_folder)
801
+
802
+
803
+ if __name__ == '__main__':
804
+ prefix_bucket_id = "*"
805
+ if len(sys.argv) > 1:
806
+ prefix_bucket_id = sys.argv[1]
807
+ main(prefix_bucket_id)
resources/{only_weights_87-172000.ckpt → only_weights_107-187000.ckpt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9024331d1e9f39084686cbae36524589dcd1ce26896da4780d462698e6ecb83c
3
- size 53301307
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c396fa56019277eb4a112dd2bb08f6cccc9f1a7393e4861f606b42035ca4cca9
3
+ size 53302016