Spaces:
Running
Running
Upload 11 files
Browse files- app.py +148 -0
- gradio_app.py +164 -0
- inference.py +268 -0
- layers.py +106 -0
- loss.py +36 -0
- models.py +93 -0
- new_dataloader.py +311 -0
- packages.txt +1 -0
- requirements.txt +12 -0
- training_data.py +31 -0
- utils.py +421 -0
app.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import streamlit_ext as ste
|
3 |
+
|
4 |
+
from inference import Inference
|
5 |
+
import random
|
6 |
+
from rdkit.Chem import Draw
|
7 |
+
from rdkit import Chem
|
8 |
+
from rdkit.Chem.Draw import IPythonConsole
|
9 |
+
import io
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
class DrugGENConfig:
|
13 |
+
submodel='DrugGEN'
|
14 |
+
act='relu'
|
15 |
+
max_atom=45
|
16 |
+
dim=32
|
17 |
+
depth=1
|
18 |
+
heads=8
|
19 |
+
mlp_ratio=3
|
20 |
+
dropout=0.
|
21 |
+
features=False
|
22 |
+
inference_sample_num=1000
|
23 |
+
inf_batch_size=1
|
24 |
+
protein_data_dir='data/akt'
|
25 |
+
drug_index='data/drug_smiles.index'
|
26 |
+
drug_data_dir='data/akt'
|
27 |
+
mol_data_dir='data'
|
28 |
+
log_dir='experiments/logs'
|
29 |
+
model_save_dir='experiments/models'
|
30 |
+
sample_dir='experiments/samples'
|
31 |
+
result_dir="experiments/tboard_output"
|
32 |
+
inf_dataset_file="chembl45_test.pt"
|
33 |
+
inf_drug_dataset_file='akt_test.pt'
|
34 |
+
inf_raw_file='data/chembl_test.smi'
|
35 |
+
inf_drug_raw_file="data/akt_test.smi"
|
36 |
+
inference_model="experiments/models/DrugGEN"
|
37 |
+
log_sample_step=1000
|
38 |
+
set_seed=False
|
39 |
+
seed=1
|
40 |
+
|
41 |
+
class NoTargetConfig(DrugGENConfig):
|
42 |
+
submodel="NoTarget"
|
43 |
+
dim=128
|
44 |
+
inference_model="experiments/models/NoTarget"
|
45 |
+
|
46 |
+
|
47 |
+
model_configs = {
|
48 |
+
"DrugGEN": DrugGENConfig(),
|
49 |
+
"NoTarget": NoTargetConfig()
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
with st.sidebar:
|
54 |
+
st.title("DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
|
55 |
+
st.write("[![arXiv](https://img.shields.io/badge/arXiv-2302.07868-b31b1b.svg)](https://arxiv.org/abs/2302.07868) [![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/DrugGEN)")
|
56 |
+
|
57 |
+
with st.expander("Expand to display information about models"):
|
58 |
+
st.write("""
|
59 |
+
### Model Variations
|
60 |
+
- **DrugGEN-Prot**: composed of two GANs, incorporates protein features to the transformer decoder module of GAN2 (together with the de novo molecules generated by GAN1) to direct the target centric molecule design.
|
61 |
+
- **DrugGEN-CrossLoss**: composed of one GAN, the input of the GAN1 generator is the real molecules dataset and the GAN1 discriminator compares the generated molecules with the real inhibitors of the given target.
|
62 |
+
- **DrugGEN-NoTarget**: composed of one GAN, focuses on learning the chemical properties from the ChEMBL training dataset, no target-specific generation.
|
63 |
+
|
64 |
+
""")
|
65 |
+
|
66 |
+
with st.form("model_selection_from"):
|
67 |
+
model_name = st.radio(
|
68 |
+
'Select a model to make inference (DrugGEN-Prot and DrugGEN-CrossLoss models design molecules to target the AKT1 protein)',
|
69 |
+
('DrugGEN-Prot', 'DrugGEN-CrossLoss', 'DrugGEN-NoTarget')
|
70 |
+
)
|
71 |
+
|
72 |
+
model_name = model_name.replace("DrugGEN-", "")
|
73 |
+
|
74 |
+
molecule_num_input = st.number_input('Number of molecules to generate', min_value=1, max_value=100_000, value=1000, step=1)
|
75 |
+
|
76 |
+
seed_input = st.number_input("RNG seed value (can be used for reproducibility):", min_value=0, value=42, step=1)
|
77 |
+
|
78 |
+
submitted = st.form_submit_button("Start Computing")
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
if submitted:
|
83 |
+
# if submitted or ("submitted" in st.session_state):
|
84 |
+
# st.session_state["submitted"] = True
|
85 |
+
config = model_configs[model_name]
|
86 |
+
|
87 |
+
config.inference_sample_num = molecule_num_input
|
88 |
+
config.seed = seed_input
|
89 |
+
|
90 |
+
with st.spinner(f'Creating the trainer class instance for {model_name}...'):
|
91 |
+
trainer = Trainer(config)
|
92 |
+
with st.spinner(f'Running inference function of {model_name} (this may take a while) ...'):
|
93 |
+
results = trainer.inference()
|
94 |
+
st.success(f"Inference of {model_name} took {results['runtime']:.2f} seconds.")
|
95 |
+
|
96 |
+
with st.expander("Expand to see the generation performance scores"):
|
97 |
+
st.write("### Generation performance scores (novelty is calculated in comparison to the training dataset)")
|
98 |
+
st.success(f"Validity: {results['fraction_valid']}")
|
99 |
+
st.success(f"Uniqueness: {results['uniqueness']}")
|
100 |
+
st.success(f"Novelty: {results['novelty']}")
|
101 |
+
|
102 |
+
with open(f'experiments/inference/{model_name}/inference_drugs.txt') as f:
|
103 |
+
inference_drugs = f.read()
|
104 |
+
# st.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")
|
105 |
+
ste.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")
|
106 |
+
|
107 |
+
|
108 |
+
st.write("Structures of randomly selected 12 de novo molecules from the inference set:")
|
109 |
+
# from rdkit.Chem import Draw
|
110 |
+
# img = Draw.MolsToGridImage(mol_list, molsPerRow=5, subImgSize=(250, 250), maxMols=num_mols,
|
111 |
+
# legends=None, useSVG=True)
|
112 |
+
generated_molecule_list = inference_drugs.split("\n")
|
113 |
+
|
114 |
+
selected_molecules = random.choices(generated_molecule_list,k=12)
|
115 |
+
|
116 |
+
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
|
117 |
+
# IPythonConsole.UninstallIPythonRenderer()
|
118 |
+
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
|
119 |
+
drawOptions.prepareMolsBeforeDrawing = False
|
120 |
+
drawOptions.bondLineWidth = 1.
|
121 |
+
|
122 |
+
molecule_image = Draw.MolsToGridImage(
|
123 |
+
selected_molecules,
|
124 |
+
molsPerRow=3,
|
125 |
+
subImgSize=(250, 250),
|
126 |
+
maxMols=len(selected_molecules),
|
127 |
+
# legends=None,
|
128 |
+
returnPNG=False,
|
129 |
+
# drawOptions=drawOptions,
|
130 |
+
highlightAtomLists=None,
|
131 |
+
highlightBondLists=None,
|
132 |
+
|
133 |
+
)
|
134 |
+
print(type(molecule_image))
|
135 |
+
# print(type(molecule_image._data_and_metadata()))
|
136 |
+
molecule_image.save("result_grid.png")
|
137 |
+
# png_data = io.BytesIO()
|
138 |
+
# molecule_image.save(png_data, format='PNG')
|
139 |
+
# png_data.seek(0)
|
140 |
+
|
141 |
+
# Step 2: Read the PNG image data as a PIL image
|
142 |
+
# pil_image = Image.open(png_data)
|
143 |
+
# st.image(pil_image)
|
144 |
+
st.image(molecule_image)
|
145 |
+
|
146 |
+
else:
|
147 |
+
st.warning("Please select a model to make inference")
|
148 |
+
|
gradio_app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from inference import Inference
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
import pandas as pd
|
6 |
+
import random
|
7 |
+
from rdkit import Chem
|
8 |
+
from rdkit.Chem import Draw
|
9 |
+
from rdkit.Chem.Draw import IPythonConsole
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
class DrugGENConfig:
|
13 |
+
submodel='DrugGEN'
|
14 |
+
act='relu'
|
15 |
+
max_atom=45
|
16 |
+
dim=32
|
17 |
+
depth=1
|
18 |
+
heads=8
|
19 |
+
mlp_ratio=3
|
20 |
+
dropout=0.
|
21 |
+
features=False
|
22 |
+
inference_sample_num=1000
|
23 |
+
inf_batch_size=1
|
24 |
+
protein_data_dir='data/akt'
|
25 |
+
drug_index='data/drug_smiles.index'
|
26 |
+
drug_data_dir='data/akt'
|
27 |
+
mol_data_dir='data'
|
28 |
+
log_dir='experiments/logs'
|
29 |
+
model_save_dir='experiments/models'
|
30 |
+
inference_model="experiments/models/DrugGEN"
|
31 |
+
sample_dir='experiments/samples'
|
32 |
+
result_dir="experiments/tboard_output"
|
33 |
+
dataset_file="chembl45_train.pt"
|
34 |
+
drug_dataset_file="akt_train.pt"
|
35 |
+
raw_file='data/chembl_train.smi'
|
36 |
+
drug_raw_file="data/akt_train.smi"
|
37 |
+
inf_dataset_file="chembl45_test.pt"
|
38 |
+
inf_drug_dataset_file='akt_test.pt'
|
39 |
+
inf_raw_file='data/chembl_test.smi'
|
40 |
+
inf_drug_raw_file="data/akt_test.smi"
|
41 |
+
log_sample_step=1000
|
42 |
+
set_seed=False
|
43 |
+
seed=1
|
44 |
+
|
45 |
+
|
46 |
+
class NoTargetConfig(DrugGENConfig):
|
47 |
+
submodel="NoTarget"
|
48 |
+
dim=128
|
49 |
+
inference_model="experiments/models/NoTarget"
|
50 |
+
|
51 |
+
|
52 |
+
model_configs = {
|
53 |
+
"DrugGEN": DrugGENConfig(),
|
54 |
+
"NoTarget": NoTargetConfig(),
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def function(model_name: str, mol_num: int, seed: int) -> tuple[PIL.Image, pd.DataFrame, str]:
|
60 |
+
'''
|
61 |
+
Returns:
|
62 |
+
image, score_df, file path
|
63 |
+
'''
|
64 |
+
|
65 |
+
config = model_configs[model_name]
|
66 |
+
config.inference_sample_num = mol_num
|
67 |
+
config.seed = seed
|
68 |
+
|
69 |
+
inferer = Inference(config)
|
70 |
+
scores = inferer.inference() # create scores_df out of this
|
71 |
+
|
72 |
+
score_df = pd.DataFrame(scores, index=[0])
|
73 |
+
|
74 |
+
output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt'
|
75 |
+
|
76 |
+
import os
|
77 |
+
new_path = f'{model_name}_denovo_mols.smi'
|
78 |
+
os.rename(output_file_path, new_path)
|
79 |
+
|
80 |
+
with open(new_path) as f:
|
81 |
+
inference_drugs = f.read()
|
82 |
+
|
83 |
+
generated_molecule_list = inference_drugs.split("\n")
|
84 |
+
|
85 |
+
rng = random.Random(seed)
|
86 |
+
selected_molecules = rng.choices(generated_molecule_list,k=12)
|
87 |
+
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
|
88 |
+
|
89 |
+
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
|
90 |
+
drawOptions.prepareMolsBeforeDrawing = False
|
91 |
+
drawOptions.bondLineWidth = 0.5
|
92 |
+
|
93 |
+
molecule_image = Draw.MolsToGridImage(
|
94 |
+
selected_molecules,
|
95 |
+
molsPerRow=3,
|
96 |
+
subImgSize=(400, 400),
|
97 |
+
maxMols=len(selected_molecules),
|
98 |
+
# legends=None,
|
99 |
+
returnPNG=False,
|
100 |
+
drawOptions=drawOptions,
|
101 |
+
highlightAtomLists=None,
|
102 |
+
highlightBondLists=None,
|
103 |
+
)
|
104 |
+
|
105 |
+
return molecule_image, score_df, new_path
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
with gr.Blocks() as demo:
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column(scale=1):
|
112 |
+
gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
|
113 |
+
with gr.Row():
|
114 |
+
gr.Markdown("[![arXiv](https://img.shields.io/badge/arXiv-2302.07868-b31b1b.svg)](https://arxiv.org/abs/2302.07868)")
|
115 |
+
gr.Markdown("[![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/DrugGEN)")
|
116 |
+
|
117 |
+
with gr.Accordion("Expand to display information about models", open=False):
|
118 |
+
gr.Markdown("""
|
119 |
+
### Model Variations
|
120 |
+
- **DrugGEN** is the default model. The input of the generator is the real molecules (ChEMBL) dataset (to ease the learning process) and the discriminator compares the generated molecules with the real inhibitors of the given target protein.
|
121 |
+
- **NoTarget** is the non-target-specific version of DrugGEN. This model only focuses on learning the chemical properties from the ChEMBL training dataset.
|
122 |
+
""")
|
123 |
+
model_name = gr.Radio(
|
124 |
+
choices=("DrugGEN", "NoTarget"),
|
125 |
+
value="DrugGEN",
|
126 |
+
label="Select a model to make inference",
|
127 |
+
info=" DrugGEN model design molecules to target the AKT1 protein"
|
128 |
+
)
|
129 |
+
|
130 |
+
num_molecules = gr.Number(
|
131 |
+
label="Number of molecules to generate",
|
132 |
+
precision=0, # integer input
|
133 |
+
minimum=1,
|
134 |
+
value=1000,
|
135 |
+
maximum=10_000,
|
136 |
+
)
|
137 |
+
seed_num = gr.Number(
|
138 |
+
label="RNG seed value (can be used for reproducibility):",
|
139 |
+
precision=0, # integer input
|
140 |
+
minimum=0,
|
141 |
+
value=42,
|
142 |
+
)
|
143 |
+
|
144 |
+
submit_button = gr.Button(
|
145 |
+
value="Start Generating"
|
146 |
+
)
|
147 |
+
|
148 |
+
with gr.Column(scale=2):
|
149 |
+
scores_df = gr.Dataframe(
|
150 |
+
label="Scores",
|
151 |
+
headers=["Runtime (seconds)", "Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (AKT)", "MaxLen", "MeanAtomType", "SNN (ChEMBL)", "SNN (AKT)"],
|
152 |
+
)
|
153 |
+
file_download = gr.File(
|
154 |
+
label="Click to download generated molecules",
|
155 |
+
)
|
156 |
+
image_output = gr.Image(
|
157 |
+
label="Structures of randomly selected 12 de novo molecules from the inference set:"
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
submit_button.click(function, inputs=[model_name, num_molecules, seed_num], outputs=[image_output, scores_df, file_download], api_name="inference")
|
162 |
+
|
163 |
+
demo.queue(concurrency_count=1)
|
164 |
+
demo.launch()
|
inference.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import pickle
|
4 |
+
import random
|
5 |
+
from tqdm import tqdm
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch_geometric.loader import DataLoader
|
10 |
+
import torch.utils.data
|
11 |
+
from rdkit import RDLogger
|
12 |
+
torch.set_num_threads(5)
|
13 |
+
RDLogger.DisableLog('rdApp.*')
|
14 |
+
|
15 |
+
from utils import *
|
16 |
+
from models import Generator
|
17 |
+
from new_dataloader import DruggenDataset
|
18 |
+
from loss import generator_loss
|
19 |
+
from training_data import load_molecules
|
20 |
+
|
21 |
+
|
22 |
+
class Inference(object):
|
23 |
+
"""Inference class for DrugGEN."""
|
24 |
+
|
25 |
+
def __init__(self, config):
|
26 |
+
if config.set_seed:
|
27 |
+
np.random.seed(config.seed)
|
28 |
+
random.seed(config.seed)
|
29 |
+
torch.manual_seed(config.seed)
|
30 |
+
torch.cuda.manual_seed_all(config.seed)
|
31 |
+
|
32 |
+
torch.backends.cudnn.deterministic = True
|
33 |
+
torch.backends.cudnn.benchmark = False
|
34 |
+
|
35 |
+
os.environ["PYTHONHASHSEED"] = str(config.seed)
|
36 |
+
|
37 |
+
print(f'Using seed {config.seed}')
|
38 |
+
|
39 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
40 |
+
|
41 |
+
# Initialize configurations
|
42 |
+
self.submodel = config.submodel
|
43 |
+
|
44 |
+
self.inference_model = config.inference_model
|
45 |
+
self.sample_num = config.sample_num
|
46 |
+
|
47 |
+
# Data loader.
|
48 |
+
self.inf_raw_file = config.inf_raw_file # SMILES containing text file for first dataset.
|
49 |
+
# Write the full path to file.
|
50 |
+
self.inf_dataset_file = config.inf_dataset_file # Dataset file name for the first GAN.
|
51 |
+
# Contains large number of molecules.
|
52 |
+
self.inf_batch_size = config.inf_batch_size
|
53 |
+
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
54 |
+
self.dataset_name = self.inf_dataset_file.split(".")[0]
|
55 |
+
self.max_atom = config.max_atom # Model is based on one-shot generation.
|
56 |
+
# Max atom number for molecules must be specified.
|
57 |
+
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
|
58 |
+
# Additional node features can be added. Please check new_dataloarder.py Line 102.
|
59 |
+
|
60 |
+
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
61 |
+
self.inf_dataset_file,
|
62 |
+
self.inf_raw_file,
|
63 |
+
self.max_atom,
|
64 |
+
self.features) # Dataset for the first GAN. Custom dataset class from PyG parent class.
|
65 |
+
# Can create any molecular graph dataset given smiles string.
|
66 |
+
# Nonisomeric SMILES are suggested but not necessary.
|
67 |
+
# Uses sparse matrix representation for graphs,
|
68 |
+
# For computational and speed efficiency.
|
69 |
+
|
70 |
+
self.inf_loader = DataLoader(self.inf_dataset,
|
71 |
+
shuffle=True,
|
72 |
+
batch_size=self.inf_batch_size,
|
73 |
+
drop_last=True) # PyG dataloader for the first GAN.
|
74 |
+
|
75 |
+
|
76 |
+
# Atom and bond type dimensions for the construction of the model.
|
77 |
+
self.atom_decoders = self.decoder_load("atom") # Atom type decoders for first GAN.
|
78 |
+
# eg. 0:0, 1:6 (C), 2:7 (N), 3:8 (O), 4:9 (F)
|
79 |
+
self.bond_decoders = self.decoder_load("bond") # Bond type decoders for first GAN.
|
80 |
+
# eg. 0: (no-bond), 1: (single), 2: (double), 3: (triple), 4: (aromatic)
|
81 |
+
self.m_dim = len(self.atom_decoders) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension.
|
82 |
+
self.b_dim = len(self.bond_decoders) # Bond type dimension.
|
83 |
+
self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
84 |
+
|
85 |
+
# Transformer and Convolution configurations.
|
86 |
+
self.act = config.act
|
87 |
+
self.dim = config.dim
|
88 |
+
self.depth = config.depth
|
89 |
+
self.heads = config.heads
|
90 |
+
self.mlp_ratio = config.mlp_ratio
|
91 |
+
self.dropout = config.dropout
|
92 |
+
|
93 |
+
self.build_model()
|
94 |
+
|
95 |
+
|
96 |
+
def build_model(self):
|
97 |
+
"""Create generators and discriminators."""
|
98 |
+
self.G = Generator(self.act,
|
99 |
+
self.vertexes,
|
100 |
+
self.b_dim,
|
101 |
+
self.m_dim,
|
102 |
+
self.dropout,
|
103 |
+
dim=self.dim,
|
104 |
+
depth=self.depth,
|
105 |
+
heads=self.heads,
|
106 |
+
mlp_ratio=self.mlp_ratio,
|
107 |
+
submodel = self.submodel)
|
108 |
+
|
109 |
+
self.print_network(self.G, 'G')
|
110 |
+
|
111 |
+
self.G.to(self.device)
|
112 |
+
|
113 |
+
|
114 |
+
def decoder_load(self, dictionary_name):
|
115 |
+
''' Loading the atom and bond decoders'''
|
116 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
117 |
+
return pickle.load(f)
|
118 |
+
|
119 |
+
|
120 |
+
def print_network(self, model, name):
|
121 |
+
"""Print out the network information."""
|
122 |
+
num_params = 0
|
123 |
+
for p in model.parameters():
|
124 |
+
num_params += p.numel()
|
125 |
+
print(model)
|
126 |
+
print(name)
|
127 |
+
print("The number of parameters: {}".format(num_params))
|
128 |
+
|
129 |
+
|
130 |
+
def restore_model(self, submodel, model_directory):
|
131 |
+
"""Restore the trained generator and discriminator."""
|
132 |
+
print('Loading the model...')
|
133 |
+
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
|
134 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
135 |
+
|
136 |
+
|
137 |
+
def inference(self):
|
138 |
+
# Load the trained generator.
|
139 |
+
self.restore_model(self.submodel, self.inference_model)
|
140 |
+
|
141 |
+
# smiles data for metrics calculation.
|
142 |
+
chembl_smiles = [line for line in open("DrugGEN/data/chembl_train.smi", 'r').read().splitlines()]
|
143 |
+
chembl_test = [line for line in open("DrugGEN/data/chembl_test.smi", 'r').read().splitlines()]
|
144 |
+
drug_smiles = [line for line in open("DrugGEN/data/akt_inhibitors.smi", 'r').read().splitlines()]
|
145 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
146 |
+
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
147 |
+
|
148 |
+
|
149 |
+
# Make directories if not exist.
|
150 |
+
if not os.path.exists("DrugGEN/experiments/inference/{}".format(self.submodel)):
|
151 |
+
os.makedirs("DrugGEN/experiments/inference/{}".format(self.submodel))
|
152 |
+
|
153 |
+
|
154 |
+
self.G.eval()
|
155 |
+
|
156 |
+
start_time = time.time()
|
157 |
+
metric_calc_dr = []
|
158 |
+
uniqueness_calc = []
|
159 |
+
real_smiles_snn = []
|
160 |
+
nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
|
161 |
+
|
162 |
+
val_counter = 0
|
163 |
+
none_counter = 0
|
164 |
+
# Inference mode
|
165 |
+
with torch.inference_mode():
|
166 |
+
pbar = tqdm(range(self.sample_num))
|
167 |
+
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
168 |
+
for i, data in enumerate(self.inf_loader):
|
169 |
+
|
170 |
+
val_counter += 1
|
171 |
+
# Preprocess dataset
|
172 |
+
_, a_tensor, x_tensor = load_molecules(
|
173 |
+
data=data,
|
174 |
+
batch_size=self.inf_batch_size,
|
175 |
+
device=self.device,
|
176 |
+
b_dim=self.b_dim,
|
177 |
+
m_dim=self.m_dim,
|
178 |
+
)
|
179 |
+
|
180 |
+
_, _, node_sample, edge_sample = self.G(a_tensor, x_tensor)
|
181 |
+
|
182 |
+
g_edges_hat_sample = torch.max(edge_sample, -1)[1]
|
183 |
+
g_nodes_hat_sample = torch.max(node_sample, -1)[1]
|
184 |
+
|
185 |
+
fake_mol_g = [self.inf_dataset.matrices2mol_drugs(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
|
186 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
187 |
+
|
188 |
+
a_tensor_sample = torch.max(a_tensor, -1)[1]
|
189 |
+
x_tensor_sample = torch.max(x_tensor, -1)[1]
|
190 |
+
real_mols = [self.inf_dataset.matrices2mol_drugs(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
|
191 |
+
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
|
192 |
+
|
193 |
+
inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g]
|
194 |
+
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
195 |
+
|
196 |
+
for molecules in inference_drugs:
|
197 |
+
if molecules is None:
|
198 |
+
none_counter += 1
|
199 |
+
|
200 |
+
with open("DrugGEN/experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
201 |
+
for molecules in inference_drugs:
|
202 |
+
if molecules is not None:
|
203 |
+
molecules = molecules.replace("*", "C")
|
204 |
+
f.write(molecules)
|
205 |
+
f.write("\n")
|
206 |
+
uniqueness_calc.append(molecules)
|
207 |
+
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1,45,1)), 0)
|
208 |
+
pbar.update(1)
|
209 |
+
metric_calc_dr.append(molecules)
|
210 |
+
|
211 |
+
|
212 |
+
generation_number = len([x for x in metric_calc_dr if x is not None])
|
213 |
+
if generation_number == self.sample_num or none_counter == self.sample_num:
|
214 |
+
break
|
215 |
+
real_smiles_snn.append(real_mols[0])
|
216 |
+
|
217 |
+
et = time.time() - start_time
|
218 |
+
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
|
219 |
+
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
|
220 |
+
print("Inference mode is lasted for {:.2f} seconds".format(et))
|
221 |
+
|
222 |
+
print("Metrics calculation started using MOSES.")
|
223 |
+
# post-process * to Carbon atom in valid molecules
|
224 |
+
|
225 |
+
return{
|
226 |
+
"Validity": fraction_valid(metric_calc_dr),
|
227 |
+
"Uniqueness": fraction_unique(uniqueness_calc),
|
228 |
+
"Novelty (Train)": novelty(metric_calc_dr, chembl_smiles),
|
229 |
+
"Novelty (Inference)": novelty(metric_calc_dr, chembl_test),
|
230 |
+
"Novelty (AKT)": novelty(metric_calc_dr, drug_smiles),
|
231 |
+
"MaxLen": Metrics.max_component(uniqueness_calc, self.vertexes),
|
232 |
+
"MeanAtomType": Metrics.mean_atom_type(nodes_sample),
|
233 |
+
"SNN (ChEMBL)": average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)),
|
234 |
+
"SNN (AKT)": average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)),
|
235 |
+
}
|
236 |
+
|
237 |
+
|
238 |
+
if __name__=="__main__":
|
239 |
+
parser = argparse.ArgumentParser()
|
240 |
+
|
241 |
+
# Inference configuration.
|
242 |
+
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
|
243 |
+
parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
|
244 |
+
parser.add_argument('--sample_num', type=int, default=10000, help='inference samples')
|
245 |
+
|
246 |
+
# Data configuration.
|
247 |
+
parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
|
248 |
+
parser.add_argument('--inf_raw_file', type=str, default='DrugGEN/data/chembl_test.smi')
|
249 |
+
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
|
250 |
+
parser.add_argument('--mol_data_dir', type=str, default='DrugGEN/data')
|
251 |
+
parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
|
252 |
+
|
253 |
+
# Model configuration.
|
254 |
+
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
|
255 |
+
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
|
256 |
+
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
|
257 |
+
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
|
258 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
|
259 |
+
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
|
260 |
+
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
|
261 |
+
|
262 |
+
# Seed configuration.
|
263 |
+
parser.add_argument('--set_seed', type=bool, default=False, help='set seed for reproducibility')
|
264 |
+
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
|
265 |
+
|
266 |
+
config = parser.parse_args()
|
267 |
+
inference = Inference(config)
|
268 |
+
inference.inference()
|
layers.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class MLP(nn.Module):
|
7 |
+
def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
if not hid_feat:
|
11 |
+
hid_feat = in_feat
|
12 |
+
if not out_feat:
|
13 |
+
out_feat = in_feat
|
14 |
+
|
15 |
+
self.fc1 = nn.Linear(in_feat, hid_feat)
|
16 |
+
self.act = torch.nn.ReLU()
|
17 |
+
self.fc2 = nn.Linear(hid_feat,out_feat)
|
18 |
+
self.droprateout = nn.Dropout(dropout)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.fc1(x)
|
22 |
+
x = self.act(x)
|
23 |
+
x = self.fc2(x)
|
24 |
+
return self.droprateout(x)
|
25 |
+
|
26 |
+
class Attention_new(nn.Module):
|
27 |
+
def __init__(self, dim, heads, attention_dropout=0.):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
assert dim % heads == 0
|
31 |
+
|
32 |
+
self.heads = heads
|
33 |
+
self.scale = 1./dim**0.5
|
34 |
+
self.q = nn.Linear(dim, dim)
|
35 |
+
self.k = nn.Linear(dim, dim)
|
36 |
+
self.v = nn.Linear(dim, dim)
|
37 |
+
self.e = nn.Linear(dim, dim)
|
38 |
+
self.d_k = dim // heads
|
39 |
+
self.heads = heads
|
40 |
+
self.out_e = nn.Linear(dim,dim)
|
41 |
+
self.out_n = nn.Linear(dim, dim)
|
42 |
+
|
43 |
+
def forward(self, node, edge):
|
44 |
+
b, n, c = node.shape
|
45 |
+
|
46 |
+
q_embed = self.q(node).view(-1, n, self.heads, c//self.heads)
|
47 |
+
k_embed = self.k(node).view(-1, n, self.heads, c//self.heads)
|
48 |
+
v_embed = self.v(node).view(-1, n, self.heads, c//self.heads)
|
49 |
+
e_embed = self.e(edge).view(-1, n, n, self.heads, c//self.heads)
|
50 |
+
|
51 |
+
q_embed = q_embed.unsqueeze(2)
|
52 |
+
k_embed = k_embed.unsqueeze(1)
|
53 |
+
|
54 |
+
attn = q_embed * k_embed
|
55 |
+
attn = attn/ math.sqrt(self.d_k)
|
56 |
+
attn = attn * (e_embed + 1) * e_embed
|
57 |
+
|
58 |
+
edge = self.out_e(attn.flatten(3))
|
59 |
+
|
60 |
+
attn = F.softmax(attn, dim=2)
|
61 |
+
|
62 |
+
v_embed = v_embed.unsqueeze(1)
|
63 |
+
v_embed = attn * v_embed
|
64 |
+
v_embed = v_embed.sum(dim=2).flatten(2)
|
65 |
+
|
66 |
+
node = self.out_n(v_embed)
|
67 |
+
return node, edge
|
68 |
+
|
69 |
+
|
70 |
+
class Encoder_Block(nn.Module):
|
71 |
+
def __init__(self, dim, heads, act, mlp_ratio=4, drop_rate=0.):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.ln1 = nn.LayerNorm(dim)
|
75 |
+
self.attn = Attention_new(dim, heads, drop_rate)
|
76 |
+
self.ln3 = nn.LayerNorm(dim)
|
77 |
+
self.ln4 = nn.LayerNorm(dim)
|
78 |
+
self.mlp = MLP(dim, dim*mlp_ratio, dim, dropout=drop_rate)
|
79 |
+
self.mlp2 = MLP(dim, dim*mlp_ratio, dim, dropout=drop_rate)
|
80 |
+
self.ln5 = nn.LayerNorm(dim)
|
81 |
+
self.ln6 = nn.LayerNorm(dim)
|
82 |
+
|
83 |
+
def forward(self, x, y):
|
84 |
+
x1 = self.ln1(x)
|
85 |
+
x2,y1 = self.attn(x1,y)
|
86 |
+
x2 = x1 + x2
|
87 |
+
y2 = y1 + y
|
88 |
+
x2 = self.ln3(x2)
|
89 |
+
y2 = self.ln4(y2)
|
90 |
+
x = self.ln5(x2 + self.mlp(x2))
|
91 |
+
y = self.ln6(y2 + self.mlp2(y2))
|
92 |
+
return x, y
|
93 |
+
|
94 |
+
|
95 |
+
class TransformerEncoder(nn.Module):
|
96 |
+
def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
self.Encoder_Blocks = nn.ModuleList([
|
100 |
+
Encoder_Block(dim, heads, act, mlp_ratio, drop_rate)
|
101 |
+
for i in range(depth)])
|
102 |
+
|
103 |
+
def forward(self, x, y):
|
104 |
+
for Encoder_Block in self.Encoder_Blocks:
|
105 |
+
x, y = Encoder_Block(x,y)
|
106 |
+
return x, y
|
loss.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def discriminator_loss(generator, discriminator, mol_graph, batch_size, device, grad_pen, lambda_gp, z_edge, z_node):
|
5 |
+
# Compute loss with real molecules.
|
6 |
+
logits_real_disc = discriminator(mol_graph)
|
7 |
+
prediction_real = - torch.mean(logits_real_disc)
|
8 |
+
|
9 |
+
# Compute loss with fake molecules.
|
10 |
+
node, edge, node_sample, edge_sample = generator(z_edge, z_node)
|
11 |
+
graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
|
12 |
+
logits_fake_disc = discriminator(graph.detach())
|
13 |
+
prediction_fake = torch.mean(logits_fake_disc)
|
14 |
+
|
15 |
+
# Compute gradient loss.
|
16 |
+
eps = torch.rand(mol_graph.size(0),1).to(device)
|
17 |
+
x_int0 = (eps * mol_graph + (1. - eps) * graph).requires_grad_(True)
|
18 |
+
grad0 = discriminator(x_int0)
|
19 |
+
d_loss_gp = grad_pen(grad0, x_int0)
|
20 |
+
|
21 |
+
# Calculate total loss
|
22 |
+
d_loss = prediction_fake + prediction_real + d_loss_gp * lambda_gp
|
23 |
+
return node, edge, d_loss
|
24 |
+
|
25 |
+
|
26 |
+
def generator_loss(generator, discriminator, adj, annot, batch_size):
|
27 |
+
# Compute loss with fake molecules.
|
28 |
+
node, edge, node_sample, edge_sample = generator(adj, annot)
|
29 |
+
|
30 |
+
graph = torch.cat((node_sample.view(batch_size, -1), edge_sample.view(batch_size, -1)), dim=-1)
|
31 |
+
|
32 |
+
logits_fake_disc = discriminator(graph)
|
33 |
+
prediction_fake = - torch.mean(logits_fake_disc)
|
34 |
+
g_loss = prediction_fake
|
35 |
+
|
36 |
+
return g_loss, node, edge, node_sample, edge_sample
|
models.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from layers import TransformerEncoder
|
4 |
+
|
5 |
+
class Generator(nn.Module):
|
6 |
+
"""Generator network."""
|
7 |
+
|
8 |
+
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio, submodel):
|
9 |
+
super(Generator, self).__init__()
|
10 |
+
self.submodel = submodel
|
11 |
+
self.vertexes = vertexes
|
12 |
+
self.edges = edges
|
13 |
+
self.nodes = nodes
|
14 |
+
self.depth = depth
|
15 |
+
self.dim = dim
|
16 |
+
self.heads = heads
|
17 |
+
self.mlp_ratio = mlp_ratio
|
18 |
+
self.dropout = dropout
|
19 |
+
|
20 |
+
if act == "relu":
|
21 |
+
act = nn.ReLU()
|
22 |
+
elif act == "leaky":
|
23 |
+
act = nn.LeakyReLU()
|
24 |
+
elif act == "sigmoid":
|
25 |
+
act = nn.Sigmoid()
|
26 |
+
elif act == "tanh":
|
27 |
+
act = nn.Tanh()
|
28 |
+
|
29 |
+
self.features = vertexes * vertexes * edges + vertexes * nodes
|
30 |
+
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
|
31 |
+
self.pos_enc_dim = 5
|
32 |
+
|
33 |
+
self.node_layers = nn.Sequential(nn.Linear(nodes, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
|
34 |
+
self.edge_layers = nn.Sequential(nn.Linear(edges, 64), act, nn.Linear(64,dim), act, nn.Dropout(self.dropout))
|
35 |
+
self.TransformerEncoder = TransformerEncoder(dim=self.dim, depth=self.depth, heads=self.heads, act = act,
|
36 |
+
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout)
|
37 |
+
|
38 |
+
self.readout_e = nn.Linear(self.dim, edges)
|
39 |
+
self.readout_n = nn.Linear(self.dim, nodes)
|
40 |
+
self.softmax = nn.Softmax(dim = -1)
|
41 |
+
|
42 |
+
def _generate_square_subsequent_mask(self, sz):
|
43 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
44 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
45 |
+
return mask
|
46 |
+
|
47 |
+
def laplacian_positional_enc(self, adj):
|
48 |
+
A = adj
|
49 |
+
D = torch.diag(torch.count_nonzero(A, dim=-1))
|
50 |
+
L = torch.eye(A.shape[0], device=A.device) - D * A * D
|
51 |
+
|
52 |
+
EigVal, EigVec = torch.linalg.eig(L)
|
53 |
+
idx = torch.argsort(torch.real(EigVal))
|
54 |
+
EigVal, EigVec = EigVal[idx], torch.real(EigVec[:,idx])
|
55 |
+
pos_enc = EigVec[:,1:self.pos_enc_dim + 1]
|
56 |
+
return pos_enc
|
57 |
+
|
58 |
+
def forward(self, z_e, z_n):
|
59 |
+
b, n, c = z_n.shape
|
60 |
+
_, _, _ , d = z_e.shape
|
61 |
+
|
62 |
+
node = self.node_layers(z_n)
|
63 |
+
edge = self.edge_layers(z_e)
|
64 |
+
edge = (edge + edge.permute(0, 2, 1, 3)) / 2
|
65 |
+
|
66 |
+
node, edge = self.TransformerEncoder(node,edge)
|
67 |
+
|
68 |
+
node_sample = self.readout_n(node)
|
69 |
+
edge_sample = self.readout_e(edge)
|
70 |
+
return node, edge, node_sample, edge_sample
|
71 |
+
|
72 |
+
|
73 |
+
class simple_disc(nn.Module):
|
74 |
+
def __init__(self, act, m_dim, vertexes, b_dim):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
if act == "relu":
|
78 |
+
act = nn.ReLU()
|
79 |
+
elif act == "leaky":
|
80 |
+
act = nn.LeakyReLU()
|
81 |
+
elif act == "sigmoid":
|
82 |
+
act = nn.Sigmoid()
|
83 |
+
elif act == "tanh":
|
84 |
+
act = nn.Tanh()
|
85 |
+
|
86 |
+
features = vertexes * m_dim + vertexes * vertexes * b_dim
|
87 |
+
self.predictor = nn.Sequential(nn.Linear(features,256), act, nn.Linear(256,128), act, nn.Linear(128,64), act,
|
88 |
+
nn.Linear(64,32), act, nn.Linear(32,16), act,
|
89 |
+
nn.Linear(16,1))
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
prediction = self.predictor(x)
|
93 |
+
return prediction
|
new_dataloader.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from rdkit import Chem
|
5 |
+
from torch_geometric.data import (Data, InMemoryDataset)
|
6 |
+
import os.path as osp
|
7 |
+
from tqdm import tqdm
|
8 |
+
import re
|
9 |
+
from rdkit import RDLogger
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
RDLogger.DisableLog('rdApp.*')
|
13 |
+
class DruggenDataset(InMemoryDataset):
|
14 |
+
|
15 |
+
def __init__(self, root, dataset_file, raw_files, max_atom, features, transform=None, pre_transform=None, pre_filter=None):
|
16 |
+
self.dataset_name = dataset_file.split(".")[0]
|
17 |
+
self.dataset_file = dataset_file
|
18 |
+
self.raw_files = raw_files
|
19 |
+
self.max_atom = max_atom
|
20 |
+
self.features = features
|
21 |
+
super().__init__(root, transform, pre_transform, pre_filter)
|
22 |
+
path = osp.join(self.processed_dir, dataset_file)
|
23 |
+
self.data, self.slices = torch.load(path)
|
24 |
+
self.root = root
|
25 |
+
|
26 |
+
|
27 |
+
@property
|
28 |
+
def processed_dir(self):
|
29 |
+
|
30 |
+
return self.root
|
31 |
+
|
32 |
+
@property
|
33 |
+
def raw_file_names(self):
|
34 |
+
return self.raw_files
|
35 |
+
|
36 |
+
@property
|
37 |
+
def processed_file_names(self):
|
38 |
+
return self.dataset_file
|
39 |
+
|
40 |
+
def _generate_encoders_decoders(self, data):
|
41 |
+
|
42 |
+
self.data = data
|
43 |
+
print('Creating atoms and bonds encoder and decoder..')
|
44 |
+
|
45 |
+
atom_labels = set()
|
46 |
+
bond_labels = set()
|
47 |
+
max_length = 0
|
48 |
+
smiles_list = []
|
49 |
+
for smiles in tqdm(data):
|
50 |
+
mol = Chem.MolFromSmiles(smiles)
|
51 |
+
molecule_size = mol.GetNumAtoms()
|
52 |
+
if molecule_size > self.max_atom:
|
53 |
+
continue
|
54 |
+
smiles_list.append(smiles)
|
55 |
+
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
|
56 |
+
max_length = max(max_length, molecule_size)
|
57 |
+
bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
|
58 |
+
|
59 |
+
atom_labels.update([0]) # add PAD symbol (for unknown atoms)
|
60 |
+
atom_labels = sorted(atom_labels) # turn set into list and sort it
|
61 |
+
|
62 |
+
bond_labels = sorted(bond_labels)
|
63 |
+
bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
|
64 |
+
|
65 |
+
# atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
|
66 |
+
self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
|
67 |
+
self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
|
68 |
+
self.atom_num_types = len(atom_labels)
|
69 |
+
print('Created atoms encoder and decoder with {} atom types and 1 PAD symbol!'.format(
|
70 |
+
self.atom_num_types - 1))
|
71 |
+
print("atom_labels", atom_labels)
|
72 |
+
# print('Creating bonds encoder and decoder..')
|
73 |
+
# bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
|
74 |
+
# for mol in self.data
|
75 |
+
# for bond in mol.GetBonds())))
|
76 |
+
# bond_labels = [
|
77 |
+
# Chem.rdchem.BondType.ZERO,
|
78 |
+
# Chem.rdchem.BondType.SINGLE,
|
79 |
+
# Chem.rdchem.BondType.DOUBLE,
|
80 |
+
# Chem.rdchem.BondType.TRIPLE,
|
81 |
+
# Chem.rdchem.BondType.AROMATIC,
|
82 |
+
# ]
|
83 |
+
|
84 |
+
print("bond labels", bond_labels)
|
85 |
+
self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
|
86 |
+
self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
|
87 |
+
self.bond_num_types = len(bond_labels)
|
88 |
+
print('Created bonds encoder and decoder with {} bond types and 1 PAD symbol!'.format(
|
89 |
+
self.bond_num_types - 1))
|
90 |
+
#dataset_names = str(self.dataset_name)
|
91 |
+
with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
|
92 |
+
pickle.dump(self.atom_encoder_m,atom_encoders)
|
93 |
+
|
94 |
+
|
95 |
+
with open("DrugGEN/data/decoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_decoders:
|
96 |
+
pickle.dump(self.atom_decoder_m,atom_decoders)
|
97 |
+
|
98 |
+
|
99 |
+
with open("DrugGEN/data/encoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_encoders:
|
100 |
+
pickle.dump(self.bond_encoder_m,bond_encoders)
|
101 |
+
|
102 |
+
|
103 |
+
with open("DrugGEN/data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders:
|
104 |
+
pickle.dump(self.bond_decoder_m,bond_decoders)
|
105 |
+
|
106 |
+
return max_length, smiles_list # data is filtered now
|
107 |
+
|
108 |
+
def _genA(self, mol, connected=True, max_length=None):
|
109 |
+
|
110 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
111 |
+
|
112 |
+
A = np.zeros(shape=(max_length, max_length))
|
113 |
+
|
114 |
+
begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()]
|
115 |
+
bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
|
116 |
+
|
117 |
+
A[begin, end] = bond_type
|
118 |
+
A[end, begin] = bond_type
|
119 |
+
|
120 |
+
degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
|
121 |
+
|
122 |
+
return A if connected and (degree > 0).all() else None
|
123 |
+
|
124 |
+
def _genX(self, mol, max_length=None):
|
125 |
+
|
126 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
127 |
+
|
128 |
+
return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
|
129 |
+
max_length - mol.GetNumAtoms()))
|
130 |
+
|
131 |
+
def _genF(self, mol, max_length=None):
|
132 |
+
|
133 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
134 |
+
|
135 |
+
features = np.array([[*[a.GetDegree() == i for i in range(5)],
|
136 |
+
*[a.GetExplicitValence() == i for i in range(9)],
|
137 |
+
*[int(a.GetHybridization()) == i for i in range(1, 7)],
|
138 |
+
*[a.GetImplicitValence() == i for i in range(9)],
|
139 |
+
a.GetIsAromatic(),
|
140 |
+
a.GetNoImplicit(),
|
141 |
+
*[a.GetNumExplicitHs() == i for i in range(5)],
|
142 |
+
*[a.GetNumImplicitHs() == i for i in range(5)],
|
143 |
+
*[a.GetNumRadicalElectrons() == i for i in range(5)],
|
144 |
+
a.IsInRing(),
|
145 |
+
*[a.IsInRingSize(i) for i in range(2, 9)]] for a in mol.GetAtoms()], dtype=np.int32)
|
146 |
+
|
147 |
+
return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
|
148 |
+
|
149 |
+
def decoder_load(self, dictionary_name, file):
|
150 |
+
with open("DrugGEN/data/decoders/" + dictionary_name + "_" + file + '.pkl', 'rb') as f:
|
151 |
+
return pickle.load(f)
|
152 |
+
|
153 |
+
def drugs_decoder_load(self, dictionary_name):
|
154 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
|
155 |
+
return pickle.load(f)
|
156 |
+
|
157 |
+
def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
|
158 |
+
mol = Chem.RWMol()
|
159 |
+
RDLogger.DisableLog('rdApp.*')
|
160 |
+
atom_decoders = self.decoder_load("atom", file_name)
|
161 |
+
bond_decoders = self.decoder_load("bond", file_name)
|
162 |
+
|
163 |
+
for node_label in node_labels:
|
164 |
+
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
|
165 |
+
|
166 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
167 |
+
if start > end:
|
168 |
+
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
|
169 |
+
#mol = self.correct_mol(mol)
|
170 |
+
if strict:
|
171 |
+
try:
|
172 |
+
|
173 |
+
Chem.SanitizeMol(mol)
|
174 |
+
except:
|
175 |
+
mol = None
|
176 |
+
|
177 |
+
return mol
|
178 |
+
|
179 |
+
def drug_decoder_load(self, dictionary_name, file):
|
180 |
+
|
181 |
+
''' Loading the atom and bond decoders '''
|
182 |
+
|
183 |
+
with open("DrugGEN/data/decoders/" + dictionary_name +"_" + file +'.pkl', 'rb') as f:
|
184 |
+
|
185 |
+
return pickle.load(f)
|
186 |
+
def matrices2mol_drugs(self, node_labels, edge_labels, strict=True, file_name=None):
|
187 |
+
mol = Chem.RWMol()
|
188 |
+
RDLogger.DisableLog('rdApp.*')
|
189 |
+
atom_decoders = self.drug_decoder_load("atom", file_name)
|
190 |
+
bond_decoders = self.drug_decoder_load("bond", file_name)
|
191 |
+
|
192 |
+
for node_label in node_labels:
|
193 |
+
|
194 |
+
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
|
195 |
+
|
196 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
197 |
+
if start > end:
|
198 |
+
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
|
199 |
+
#mol = self.correct_mol(mol)
|
200 |
+
if strict:
|
201 |
+
try:
|
202 |
+
Chem.SanitizeMol(mol)
|
203 |
+
except:
|
204 |
+
mol = None
|
205 |
+
|
206 |
+
return mol
|
207 |
+
def check_valency(self,mol):
|
208 |
+
"""
|
209 |
+
Checks that no atoms in the mol have exceeded their possible
|
210 |
+
valency
|
211 |
+
:return: True if no valency issues, False otherwise
|
212 |
+
"""
|
213 |
+
try:
|
214 |
+
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
|
215 |
+
return True, None
|
216 |
+
except ValueError as e:
|
217 |
+
e = str(e)
|
218 |
+
p = e.find('#')
|
219 |
+
e_sub = e[p:]
|
220 |
+
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
|
221 |
+
return False, atomid_valence
|
222 |
+
|
223 |
+
|
224 |
+
def correct_mol(self,x):
|
225 |
+
xsm = Chem.MolToSmiles(x, isomericSmiles=True)
|
226 |
+
mol = x
|
227 |
+
while True:
|
228 |
+
flag, atomid_valence = self.check_valency(mol)
|
229 |
+
if flag:
|
230 |
+
break
|
231 |
+
else:
|
232 |
+
assert len (atomid_valence) == 2
|
233 |
+
idx = atomid_valence[0]
|
234 |
+
v = atomid_valence[1]
|
235 |
+
queue = []
|
236 |
+
for b in mol.GetAtomWithIdx(idx).GetBonds():
|
237 |
+
queue.append(
|
238 |
+
(b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
|
239 |
+
)
|
240 |
+
queue.sort(key=lambda tup: tup[1], reverse=True)
|
241 |
+
if len(queue) > 0:
|
242 |
+
start = queue[0][2]
|
243 |
+
end = queue[0][3]
|
244 |
+
t = queue[0][1] - 1
|
245 |
+
mol.RemoveBond(start, end)
|
246 |
+
|
247 |
+
#if t >= 1:
|
248 |
+
|
249 |
+
#mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
|
250 |
+
# if '.' in Chem.MolToSmiles(mol, isomericSmiles=True):
|
251 |
+
# mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
|
252 |
+
# print(tt)
|
253 |
+
# print(Chem.MolToSmiles(mol, isomericSmiles=True))
|
254 |
+
|
255 |
+
return mol
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
def label2onehot(self, labels, dim):
|
260 |
+
|
261 |
+
"""Convert label indices to one-hot vectors."""
|
262 |
+
|
263 |
+
out = torch.zeros(list(labels.size())+[dim])
|
264 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
265 |
+
|
266 |
+
return out.float()
|
267 |
+
|
268 |
+
def process(self, size= None):
|
269 |
+
smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
|
270 |
+
max_length, smiles_list = self._generate_encoders_decoders(smiles_list)
|
271 |
+
|
272 |
+
data_list = []
|
273 |
+
|
274 |
+
self.m_dim = len(self.atom_decoder_m)
|
275 |
+
for smiles in tqdm(smiles_list, desc='Processing chembl dataset', total=len(smiles_list)):
|
276 |
+
mol = Chem.MolFromSmiles(smiles)
|
277 |
+
A = self._genA(mol, connected=True, max_length=max_length)
|
278 |
+
if A is not None:
|
279 |
+
|
280 |
+
|
281 |
+
x = torch.from_numpy(self._genX(mol, max_length=max_length)).to(torch.long).view(1, -1)
|
282 |
+
|
283 |
+
x = self.label2onehot(x,self.m_dim).squeeze()
|
284 |
+
if self.features:
|
285 |
+
f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
|
286 |
+
x = torch.concat((x,f), dim=-1)
|
287 |
+
|
288 |
+
adjacency = torch.from_numpy(A)
|
289 |
+
|
290 |
+
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
|
291 |
+
edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
|
292 |
+
|
293 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
|
294 |
+
|
295 |
+
if self.pre_filter is not None and not self.pre_filter(data):
|
296 |
+
continue
|
297 |
+
|
298 |
+
if self.pre_transform is not None:
|
299 |
+
data = self.pre_transform(data)
|
300 |
+
|
301 |
+
data_list.append(data)
|
302 |
+
|
303 |
+
|
304 |
+
torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
if __name__ == '__main__':
|
310 |
+
data = DruggenDataset("DrugGEN/data")
|
311 |
+
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libcairo2-dev
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
rdkit-pypi
|
3 |
+
tqdm
|
4 |
+
numpy
|
5 |
+
seaborn
|
6 |
+
matplotlib
|
7 |
+
pandas
|
8 |
+
torch_geometric
|
9 |
+
# demo related installs
|
10 |
+
streamlit
|
11 |
+
ipython
|
12 |
+
streamlit-ext
|
training_data.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch_geometric.utils as geoutils
|
3 |
+
from utils import label2onehot
|
4 |
+
|
5 |
+
def generate_z_values(batch_size=32, z_dim=32, vertexes=32, b_dim=32, m_dim=32, device=None):
|
6 |
+
z = torch.normal(mean=0, std=1, size=(batch_size, z_dim), device=device) # (batch,max_len)
|
7 |
+
z_edge = torch.normal(mean=0, std=1, size=(batch_size, vertexes, vertexes, b_dim), device=device) # (batch,max_len,max_len)
|
8 |
+
z_node = torch.normal(mean=0, std=1, size=(batch_size, vertexes, m_dim), device=device) # (batch,max_len)
|
9 |
+
|
10 |
+
z = z.float().requires_grad_(True)
|
11 |
+
z_edge = z_edge.float().requires_grad_(True) # Edge noise.(batch,max_len,max_len)
|
12 |
+
z_node = z_node.float().requires_grad_(True) # Node noise.(batch,max_len)
|
13 |
+
return z, z_edge, z_node
|
14 |
+
|
15 |
+
|
16 |
+
def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
|
17 |
+
data = data.to(device)
|
18 |
+
a = geoutils.to_dense_adj(
|
19 |
+
edge_index = data.edge_index,
|
20 |
+
batch=data.batch,
|
21 |
+
edge_attr=data.edge_attr,
|
22 |
+
max_num_nodes=int(data.batch.shape[0]/batch_size)
|
23 |
+
)
|
24 |
+
x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
|
25 |
+
a_tensor = label2onehot(a, b_dim, device)
|
26 |
+
|
27 |
+
a_tensor_vec = a_tensor.reshape(batch_size,-1)
|
28 |
+
x_tensor_vec = x_tensor.reshape(batch_size,-1)
|
29 |
+
real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
|
30 |
+
|
31 |
+
return real_graphs, a_tensor, x_tensor
|
utils.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from statistics import mean
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
import datetime
|
6 |
+
from rdkit import DataStructs
|
7 |
+
from rdkit import Chem
|
8 |
+
from rdkit import RDLogger
|
9 |
+
from rdkit.Chem import AllChem
|
10 |
+
from rdkit.Chem import Draw
|
11 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
12 |
+
import numpy as np
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
from matplotlib.lines import Line2D
|
15 |
+
import torch
|
16 |
+
import wandb
|
17 |
+
RDLogger.DisableLog('rdApp.*')
|
18 |
+
import warnings
|
19 |
+
from multiprocessing import Pool
|
20 |
+
class Metrics(object):
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def valid(x):
|
24 |
+
return x is not None and Chem.MolToSmiles(x) != ''
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def tanimoto_sim_1v2(data1, data2):
|
28 |
+
min_len = data1.size if data1.size > data2.size else data2
|
29 |
+
sims = []
|
30 |
+
for i in range(min_len):
|
31 |
+
sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
|
32 |
+
sims.append(sim)
|
33 |
+
mean_sim = mean(sim)
|
34 |
+
return mean_sim
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def mol_length(x):
|
38 |
+
if x is not None:
|
39 |
+
return len([char for char in max(x.split(sep =".")).upper() if char.isalpha()])
|
40 |
+
else:
|
41 |
+
return 0
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def max_component(data, max_len):
|
45 |
+
|
46 |
+
return ((np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean())
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def mean_atom_type(data):
|
50 |
+
atom_types_used = []
|
51 |
+
for i in data:
|
52 |
+
|
53 |
+
atom_types_used.append(len(i.unique().tolist()))
|
54 |
+
av_type = np.mean(atom_types_used) - 1
|
55 |
+
|
56 |
+
return av_type
|
57 |
+
|
58 |
+
|
59 |
+
def sim_reward(mol_gen, fps_r):
|
60 |
+
|
61 |
+
gen_scaf = []
|
62 |
+
|
63 |
+
for x in mol_gen:
|
64 |
+
if x is not None:
|
65 |
+
try:
|
66 |
+
|
67 |
+
gen_scaf.append(MurckoScaffold.GetScaffoldForMol(x))
|
68 |
+
except:
|
69 |
+
pass
|
70 |
+
|
71 |
+
if len(gen_scaf) == 0:
|
72 |
+
|
73 |
+
rew = 1
|
74 |
+
else:
|
75 |
+
fps = [Chem.RDKFingerprint(x) for x in gen_scaf]
|
76 |
+
|
77 |
+
|
78 |
+
fps = np.array(fps)
|
79 |
+
fps_r = np.array(fps_r)
|
80 |
+
|
81 |
+
rew = average_agg_tanimoto(fps_r,fps)
|
82 |
+
if math.isnan(rew):
|
83 |
+
rew = 1
|
84 |
+
|
85 |
+
return rew ## change this to penalty
|
86 |
+
|
87 |
+
##########################################
|
88 |
+
##########################################
|
89 |
+
##########################################
|
90 |
+
|
91 |
+
def mols2grid_image(mols,path):
|
92 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
93 |
+
|
94 |
+
for i in range(len(mols)):
|
95 |
+
if Metrics.valid(mols[i]):
|
96 |
+
AllChem.Compute2DCoords(mols[i])
|
97 |
+
Draw.MolToFile(mols[i], os.path.join(path,"{}.png".format(i+1)), size=(1200,1200))
|
98 |
+
#wandb.save(os.path.join(path,"{}.png".format(i+1)))
|
99 |
+
else:
|
100 |
+
continue
|
101 |
+
|
102 |
+
def save_smiles_matrices(mols,edges_hard, nodes_hard, path, data_source = None):
|
103 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
104 |
+
|
105 |
+
for i in range(len(mols)):
|
106 |
+
if Metrics.valid(mols[i]):
|
107 |
+
save_path = os.path.join(path,"{}.txt".format(i+1))
|
108 |
+
with open(save_path, "a") as f:
|
109 |
+
np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n",fmt='%1.2f')
|
110 |
+
f.write("\n")
|
111 |
+
np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:",fmt='%1.2f')
|
112 |
+
f.write("\n")
|
113 |
+
#f.write(m0)
|
114 |
+
f.write("\n")
|
115 |
+
print(Chem.MolToSmiles(mols[i]), file=open(save_path,"a"))
|
116 |
+
#wandb.save(save_path)
|
117 |
+
else:
|
118 |
+
continue
|
119 |
+
|
120 |
+
|
121 |
+
##########################################
|
122 |
+
##########################################
|
123 |
+
##########################################
|
124 |
+
|
125 |
+
|
126 |
+
def dense_to_sparse_with_attr(adj):
|
127 |
+
assert adj.dim() >= 2 and adj.dim() <= 3
|
128 |
+
assert adj.size(-1) == adj.size(-2)
|
129 |
+
|
130 |
+
index = adj.nonzero(as_tuple=True)
|
131 |
+
edge_attr = adj[index]
|
132 |
+
|
133 |
+
if len(index) == 3:
|
134 |
+
batch = index[0] * adj.size(-1)
|
135 |
+
index = (batch + index[1], batch + index[2])
|
136 |
+
#index = torch.stack(index, dim=0)
|
137 |
+
return index, edge_attr
|
138 |
+
|
139 |
+
|
140 |
+
def label2onehot(labels, dim, device):
|
141 |
+
"""Convert label indices to one-hot vectors."""
|
142 |
+
out = torch.zeros(list(labels.size())+[dim]).to(device)
|
143 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
144 |
+
|
145 |
+
return out.float()
|
146 |
+
|
147 |
+
|
148 |
+
def mol_sample(sample_directory, edges, nodes, idx, i,matrices2mol, dataset_name):
|
149 |
+
sample_path = os.path.join(sample_directory,"{}_{}-epoch_iteration".format(idx+1, i+1))
|
150 |
+
g_edges_hat_sample = torch.max(edges, -1)[1]
|
151 |
+
g_nodes_hat_sample = torch.max(nodes , -1)[1]
|
152 |
+
mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
|
153 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
154 |
+
|
155 |
+
if not os.path.exists(sample_path):
|
156 |
+
os.makedirs(sample_path)
|
157 |
+
|
158 |
+
mols2grid_image(mol,sample_path)
|
159 |
+
save_smiles_matrices(mol,g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
|
160 |
+
|
161 |
+
if len(os.listdir(sample_path)) == 0:
|
162 |
+
os.rmdir(sample_path)
|
163 |
+
|
164 |
+
print("Valid molecules are saved.")
|
165 |
+
print("Valid matrices and smiles are saved")
|
166 |
+
|
167 |
+
|
168 |
+
def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node,
|
169 |
+
matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
|
170 |
+
|
171 |
+
g_edges_hat_sample = torch.max(edge, -1)[1]
|
172 |
+
g_nodes_hat_sample = torch.max(node , -1)[1]
|
173 |
+
|
174 |
+
a_tensor_sample = torch.max(real_adj, -1)[1].float()
|
175 |
+
x_tensor_sample = torch.max(real_annot, -1)[1].float()
|
176 |
+
|
177 |
+
mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
|
178 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
179 |
+
|
180 |
+
real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=dataset_name)
|
181 |
+
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
|
182 |
+
|
183 |
+
atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
|
184 |
+
real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
|
185 |
+
gen_smiles = []
|
186 |
+
uniq_smiles = []
|
187 |
+
for line in mols:
|
188 |
+
if line is not None:
|
189 |
+
gen_smiles.append(Chem.MolToSmiles(line))
|
190 |
+
uniq_smiles.append(Chem.MolToSmiles(line))
|
191 |
+
elif line is None:
|
192 |
+
gen_smiles.append(None)
|
193 |
+
|
194 |
+
gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
|
195 |
+
uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
|
196 |
+
|
197 |
+
sample_save_dir = os.path.join(save_path, "samples.txt")
|
198 |
+
with open(sample_save_dir, "a") as f:
|
199 |
+
for idxs in range(len(gen_smiles_saves)):
|
200 |
+
if gen_smiles_saves[idxs] is not None:
|
201 |
+
f.write(gen_smiles_saves[idxs])
|
202 |
+
f.write("\n")
|
203 |
+
|
204 |
+
k = len(set(uniq_smiles_saves) - {None})
|
205 |
+
et = time.time() - start_time
|
206 |
+
et = str(datetime.timedelta(seconds=et))[:-7]
|
207 |
+
log = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i+1)
|
208 |
+
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
|
209 |
+
chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
|
210 |
+
|
211 |
+
# Log update
|
212 |
+
#m0 = get_all_metrics(gen = gen_smiles, train = train_smiles, batch_size=batch_size, k = valid_mol_num, device=self.device)
|
213 |
+
valid = fraction_valid(gen_smiles_saves)
|
214 |
+
unique = fraction_unique(uniq_smiles_saves, k, check_validity=False)
|
215 |
+
novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
|
216 |
+
novel_akt = novelty(gen_smiles_saves, drug_smiles)
|
217 |
+
if (len(uniq_smiles_saves) == 0):
|
218 |
+
snn_chembl = 0
|
219 |
+
snn_akt = 0
|
220 |
+
maxlen = 0
|
221 |
+
else:
|
222 |
+
snn_chembl = average_agg_tanimoto(np.array(chembl_vecs),np.array(gen_vecs))
|
223 |
+
snn_akt = average_agg_tanimoto(np.array(drug_vecs),np.array(gen_vecs))
|
224 |
+
maxlen = Metrics.max_component(uniq_smiles_saves, 45)
|
225 |
+
|
226 |
+
loss.update({'Validity': valid})
|
227 |
+
loss.update({'Uniqueness': unique})
|
228 |
+
loss.update({'Novelty': novel_starting_mol})
|
229 |
+
loss.update({'Novelty_akt': novel_akt})
|
230 |
+
loss.update({'SNN_chembl': snn_chembl})
|
231 |
+
loss.update({'SNN_akt': snn_akt})
|
232 |
+
loss.update({'MaxLen': maxlen})
|
233 |
+
loss.update({'Atom_types': atom_types_average})
|
234 |
+
|
235 |
+
wandb.log({"Validity": valid, "Uniqueness": unique, "Novelty": novel_starting_mol,
|
236 |
+
"Novelty_akt": novel_akt, "SNN_chembl": snn_chembl, "SNN_akt": snn_akt,
|
237 |
+
"MaxLen": maxlen, "Atom_types": atom_types_average})
|
238 |
+
|
239 |
+
for tag, value in loss.items():
|
240 |
+
log += ", {}: {:.4f}".format(tag, value)
|
241 |
+
with open(log_path, "a") as f:
|
242 |
+
f.write(log)
|
243 |
+
f.write("\n")
|
244 |
+
print(log)
|
245 |
+
print("\n")
|
246 |
+
|
247 |
+
|
248 |
+
def plot_grad_flow(named_parameters, model, itera, epoch,grad_flow_directory):
|
249 |
+
# Based on https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10
|
250 |
+
'''Plots the gradients flowing through different layers in the net during training.
|
251 |
+
Can be used for checking for possible gradient vanishing / exploding problems.
|
252 |
+
|
253 |
+
Usage: Plug this function in Trainer class after loss.backwards() as
|
254 |
+
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
|
255 |
+
ave_grads = []
|
256 |
+
max_grads= []
|
257 |
+
layers = []
|
258 |
+
for n, p in named_parameters:
|
259 |
+
if(p.requires_grad) and ("bias" not in n):
|
260 |
+
#print(p.grad,n)
|
261 |
+
layers.append(n)
|
262 |
+
ave_grads.append(p.grad.abs().mean().cpu())
|
263 |
+
max_grads.append(p.grad.abs().max().cpu())
|
264 |
+
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
|
265 |
+
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
|
266 |
+
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
|
267 |
+
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
|
268 |
+
plt.xlim(left=0, right=len(ave_grads))
|
269 |
+
plt.ylim(bottom = -0.001, top=1) # zoom in on the lower gradient regions
|
270 |
+
plt.xlabel("Layers")
|
271 |
+
plt.ylabel("average gradient")
|
272 |
+
plt.title("Gradient flow")
|
273 |
+
plt.grid(True)
|
274 |
+
plt.legend([Line2D([0], [0], color="c", lw=4),
|
275 |
+
Line2D([0], [0], color="b", lw=4),
|
276 |
+
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
277 |
+
pltsavedir = grad_flow_directory
|
278 |
+
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
279 |
+
|
280 |
+
|
281 |
+
def get_mol(smiles_or_mol):
|
282 |
+
'''
|
283 |
+
Loads SMILES/molecule into RDKit's object
|
284 |
+
'''
|
285 |
+
if isinstance(smiles_or_mol, str):
|
286 |
+
if len(smiles_or_mol) == 0:
|
287 |
+
return None
|
288 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
289 |
+
if mol is None:
|
290 |
+
return None
|
291 |
+
try:
|
292 |
+
Chem.SanitizeMol(mol)
|
293 |
+
except ValueError:
|
294 |
+
return None
|
295 |
+
return mol
|
296 |
+
return smiles_or_mol
|
297 |
+
|
298 |
+
|
299 |
+
def mapper(n_jobs):
|
300 |
+
'''
|
301 |
+
Returns function for map call.
|
302 |
+
If n_jobs == 1, will use standard map
|
303 |
+
If n_jobs > 1, will use multiprocessing pool
|
304 |
+
If n_jobs is a pool object, will return its map function
|
305 |
+
'''
|
306 |
+
if n_jobs == 1:
|
307 |
+
def _mapper(*args, **kwargs):
|
308 |
+
return list(map(*args, **kwargs))
|
309 |
+
|
310 |
+
return _mapper
|
311 |
+
if isinstance(n_jobs, int):
|
312 |
+
pool = Pool(n_jobs)
|
313 |
+
|
314 |
+
def _mapper(*args, **kwargs):
|
315 |
+
try:
|
316 |
+
result = pool.map(*args, **kwargs)
|
317 |
+
finally:
|
318 |
+
pool.terminate()
|
319 |
+
return result
|
320 |
+
|
321 |
+
return _mapper
|
322 |
+
return n_jobs.map
|
323 |
+
|
324 |
+
|
325 |
+
def remove_invalid(gen, canonize=True, n_jobs=1):
|
326 |
+
"""
|
327 |
+
Removes invalid molecules from the dataset
|
328 |
+
"""
|
329 |
+
if not canonize:
|
330 |
+
mols = mapper(n_jobs)(get_mol, gen)
|
331 |
+
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
332 |
+
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
333 |
+
x is not None]
|
334 |
+
|
335 |
+
|
336 |
+
def fraction_valid(gen, n_jobs=1):
|
337 |
+
"""
|
338 |
+
Computes a number of valid molecules
|
339 |
+
Parameters:
|
340 |
+
gen: list of SMILES
|
341 |
+
n_jobs: number of threads for calculation
|
342 |
+
"""
|
343 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
344 |
+
return 1 - gen.count(None) / len(gen)
|
345 |
+
def canonic_smiles(smiles_or_mol):
|
346 |
+
mol = get_mol(smiles_or_mol)
|
347 |
+
if mol is None:
|
348 |
+
return None
|
349 |
+
return Chem.MolToSmiles(mol)
|
350 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=False):
|
351 |
+
"""
|
352 |
+
Computes a number of unique molecules
|
353 |
+
Parameters:
|
354 |
+
gen: list of SMILES
|
355 |
+
k: compute unique@k
|
356 |
+
n_jobs: number of threads for calculation
|
357 |
+
check_validity: raises ValueError if invalid molecules are present
|
358 |
+
"""
|
359 |
+
if k is not None:
|
360 |
+
if len(gen) < k:
|
361 |
+
warnings.warn(
|
362 |
+
"Can't compute unique@{}.".format(k) +
|
363 |
+
"gen contains only {} molecules".format(len(gen))
|
364 |
+
)
|
365 |
+
gen = gen[:k]
|
366 |
+
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
367 |
+
if None in canonic and check_validity:
|
368 |
+
#canonic = [i for i in canonic if i is not None]
|
369 |
+
raise ValueError("Invalid molecule passed to unique@k")
|
370 |
+
return 0 if len(gen) == 0 else len(canonic) / len(gen)
|
371 |
+
|
372 |
+
def novelty(gen, train, n_jobs=1):
|
373 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
374 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
375 |
+
train_set = set(train)
|
376 |
+
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
377 |
+
|
378 |
+
|
379 |
+
|
380 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
381 |
+
batch_size=5000, agg='max',
|
382 |
+
device='cpu', p=1):
|
383 |
+
"""
|
384 |
+
For each molecule in gen_vecs finds closest molecule in stock_vecs.
|
385 |
+
Returns average tanimoto score for between these molecules
|
386 |
+
|
387 |
+
Parameters:
|
388 |
+
stock_vecs: numpy array <n_vectors x dim>
|
389 |
+
gen_vecs: numpy array <n_vectors' x dim>
|
390 |
+
agg: max or mean
|
391 |
+
p: power for averaging: (mean x^p)^(1/p)
|
392 |
+
"""
|
393 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
394 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
395 |
+
total = np.zeros(len(gen_vecs))
|
396 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
397 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
398 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
399 |
+
|
400 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
401 |
+
y_gen = y_gen.transpose(0, 1)
|
402 |
+
tp = torch.mm(x_stock, y_gen)
|
403 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) +
|
404 |
+
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
405 |
+
jac[np.isnan(jac)] = 1
|
406 |
+
if p != 1:
|
407 |
+
jac = jac**p
|
408 |
+
if agg == 'max':
|
409 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
410 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
411 |
+
elif agg == 'mean':
|
412 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
413 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
414 |
+
if agg == 'mean':
|
415 |
+
agg_tanimoto /= total
|
416 |
+
if p != 1:
|
417 |
+
agg_tanimoto = (agg_tanimoto)**(1/p)
|
418 |
+
return np.mean(agg_tanimoto)
|
419 |
+
|
420 |
+
def str2bool(v):
|
421 |
+
return v.lower() in ('true')
|