Spaces:
Runtime error
Runtime error
uragankatrrin
commited on
Commit
•
2956799
1
Parent(s):
6dd21b5
Upload 12 files
Browse files- mhnreact/.gitkeep +0 -0
- mhnreact/__init__.py +1 -0
- mhnreact/data.py +338 -0
- mhnreact/inference.py +13 -0
- mhnreact/inspect.py +95 -0
- mhnreact/model.py +660 -0
- mhnreact/molutils.py +772 -0
- mhnreact/plotutils.py +158 -0
- mhnreact/retroeval.py +240 -0
- mhnreact/train.py +804 -0
- mhnreact/utils.py +126 -0
- mhnreact/view.py +60 -0
mhnreact/.gitkeep
ADDED
File without changes
|
mhnreact/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.0.1"
|
mhnreact/data.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
File contains functions that help prepare and download USPTO-related datasets
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import gzip
|
13 |
+
import pickle
|
14 |
+
import requests
|
15 |
+
import subprocess
|
16 |
+
import pandas as pd
|
17 |
+
import numpy as np
|
18 |
+
from scipy import sparse
|
19 |
+
import json
|
20 |
+
|
21 |
+
def download_temprel_repo(save_path='data/temprel-fortunato', chunk_size=128):
|
22 |
+
"downloads the template-relevance master branch"
|
23 |
+
url = "https://gitlab.com/mefortunato/template-relevance/-/archive/master/template-relevance-master.zip"
|
24 |
+
r = requests.get(url, stream=True)
|
25 |
+
with open(save_path, 'wb') as fd:
|
26 |
+
for chunk in r.iter_content(chunk_size=chunk_size):
|
27 |
+
fd.write(chunk)
|
28 |
+
|
29 |
+
def unzip(path):
|
30 |
+
"unzips a file given a path"
|
31 |
+
import zipfile
|
32 |
+
with zipfile.ZipFile(path, 'r') as zip_ref:
|
33 |
+
zip_ref.extractall(path.replace('.zip',''))
|
34 |
+
|
35 |
+
|
36 |
+
def download_file(url, output_path=None):
|
37 |
+
"""
|
38 |
+
# code from fortunato
|
39 |
+
# could also import from temprel.data.download import get_uspto_50k but slightly altered ;)
|
40 |
+
|
41 |
+
"""
|
42 |
+
if not output_path:
|
43 |
+
output_path = url.split('/')[-1]
|
44 |
+
with requests.get(url, stream=True) as r:
|
45 |
+
r.raise_for_status()
|
46 |
+
with open(output_path, 'wb') as f:
|
47 |
+
for chunk in r.iter_content(chunk_size=8192):
|
48 |
+
if chunk:
|
49 |
+
f.write(chunk)
|
50 |
+
|
51 |
+
def get_uspto_480k():
|
52 |
+
if not os.path.exists('data'):
|
53 |
+
os.mkdir('data')
|
54 |
+
if not os.path.exists('data/raw'):
|
55 |
+
os.mkdir('data/raw')
|
56 |
+
os.chdir('data/raw')
|
57 |
+
download_file(
|
58 |
+
'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/train.txt.tar.gz',
|
59 |
+
'train.txt.tar.gz'
|
60 |
+
)
|
61 |
+
subprocess.run(['tar', 'zxf', 'train.txt.tar.gz'])
|
62 |
+
download_file(
|
63 |
+
'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/valid.txt.tar.gz',
|
64 |
+
'valid.txt.tar.gz'
|
65 |
+
)
|
66 |
+
subprocess.run(['tar', 'zxf', 'valid.txt.tar.gz'])
|
67 |
+
download_file(
|
68 |
+
'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/test.txt.tar.gz',
|
69 |
+
'test.txt.tar.gz'
|
70 |
+
)
|
71 |
+
subprocess.run(['tar', 'zxf', 'test.txt.tar.gz'])
|
72 |
+
|
73 |
+
with open('train.txt') as f:
|
74 |
+
train = [
|
75 |
+
{
|
76 |
+
'reaction_smiles': line.strip(),
|
77 |
+
'split': 'train'
|
78 |
+
}
|
79 |
+
for line in f.readlines()
|
80 |
+
]
|
81 |
+
with open('valid.txt') as f:
|
82 |
+
valid = [
|
83 |
+
{
|
84 |
+
'reaction_smiles': line.strip(),
|
85 |
+
'split': 'valid'
|
86 |
+
}
|
87 |
+
for line in f.readlines()
|
88 |
+
]
|
89 |
+
with open('test.txt') as f:
|
90 |
+
test = [
|
91 |
+
{
|
92 |
+
'reaction_smiles': line.strip(),
|
93 |
+
'split': 'test'
|
94 |
+
}
|
95 |
+
for line in f.readlines()
|
96 |
+
]
|
97 |
+
|
98 |
+
df = pd.concat([
|
99 |
+
pd.DataFrame(train),
|
100 |
+
pd.DataFrame(valid),
|
101 |
+
pd.DataFrame(test)
|
102 |
+
]).reset_index()
|
103 |
+
df.to_json('uspto_lg_reactions.json.gz', compression='gzip')
|
104 |
+
os.chdir('..')
|
105 |
+
os.chdir('..')
|
106 |
+
return df
|
107 |
+
|
108 |
+
def get_uspto_50k():
|
109 |
+
'''
|
110 |
+
get SI from:
|
111 |
+
Nadine Schneider; Daniel M. Lowe; Roger A. Sayle; Gregory A. Landrum. J. Chem. Inf. Model.201555139-53
|
112 |
+
'''
|
113 |
+
if not os.path.exists('data'):
|
114 |
+
os.mkdir('data')
|
115 |
+
if not os.path.exists('data/raw'):
|
116 |
+
os.mkdir('data/raw')
|
117 |
+
os.chdir('data/raw')
|
118 |
+
subprocess.run(['wget', 'https://pubs.acs.org/doi/suppl/10.1021/ci5006614/suppl_file/ci5006614_si_002.zip'])
|
119 |
+
subprocess.run(['unzip', '-o', 'ci5006614_si_002.zip'])
|
120 |
+
data = []
|
121 |
+
with gzip.open('ChemReactionClassification/data/training_test_set_patent_data.pkl.gz') as f:
|
122 |
+
while True:
|
123 |
+
try:
|
124 |
+
data.append(pickle.load(f))
|
125 |
+
except EOFError:
|
126 |
+
break
|
127 |
+
reaction_smiles = [d[0] for d in data]
|
128 |
+
reaction_reference = [d[1] for d in data]
|
129 |
+
reaction_class = [d[2] for d in data]
|
130 |
+
df = pd.DataFrame()
|
131 |
+
df['reaction_smiles'] = reaction_smiles
|
132 |
+
df['reaction_reference'] = reaction_reference
|
133 |
+
df['reaction_class'] = reaction_class
|
134 |
+
df.to_json('uspto_sm_reactions.json.gz', compression='gzip')
|
135 |
+
os.chdir('..')
|
136 |
+
os.chdir('..')
|
137 |
+
return df
|
138 |
+
|
139 |
+
def get_uspto_golden():
|
140 |
+
""" get uspto golden and convert it to smiles dataframe from
|
141 |
+
Lin, Arkadii; Dyubankova, Natalia; Madzhidov, Timur; Nugmanov, Ramil;
|
142 |
+
Rakhimbekova, Assima; Ibragimova, Zarina; Akhmetshin, Tagir; Gimadiev,
|
143 |
+
Timur; Suleymanov, Rail; Verhoeven, Jonas; Wegner, Jörg Kurt;
|
144 |
+
Ceulemans, Hugo; Varnek, Alexandre (2020):
|
145 |
+
Atom-to-Atom Mapping: A Benchmarking Study of Popular Mapping Algorithms and Consensus Strategies.
|
146 |
+
ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.13012679.v1
|
147 |
+
"""
|
148 |
+
if os.path.exists('data/raw/uspto_golden.json.gz'):
|
149 |
+
print('loading precomputed')
|
150 |
+
return pd.read_json('data/raw/uspto_golden.json.gz', compression='gzip')
|
151 |
+
if not os.path.exists('data'):
|
152 |
+
os.mkdir('data')
|
153 |
+
if not os.path.exists('data/raw'):
|
154 |
+
os.mkdir('data/raw')
|
155 |
+
os.chdir('data/raw')
|
156 |
+
subprocess.run(['wget', 'https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning/raw/master/data/golden_dataset.zip'])
|
157 |
+
subprocess.run(['unzip', '-o', 'golden_dataset.zip']) #return golden_dataset.rdf
|
158 |
+
|
159 |
+
from CGRtools.files import RDFRead
|
160 |
+
import CGRtools
|
161 |
+
from rdkit.Chem import AllChem
|
162 |
+
def cgr2rxnsmiles(cgr_rx):
|
163 |
+
smiles_rx = '.'.join([AllChem.MolToSmiles(CGRtools.to_rdkit_molecule(m)) for m in cgr_rx.reactants])
|
164 |
+
smiles_rx += '>>'+'.'.join([AllChem.MolToSmiles(CGRtools.to_rdkit_molecule(m)) for m in cgr_rx.products])
|
165 |
+
return smiles_rx
|
166 |
+
|
167 |
+
data = {}
|
168 |
+
input_file = 'golden_dataset.rdf'
|
169 |
+
do_basic_standardization=True
|
170 |
+
print('reading and converting the rdf-file')
|
171 |
+
with RDFRead(input_file) as f:
|
172 |
+
while True:
|
173 |
+
try:
|
174 |
+
r = next(f)
|
175 |
+
key = r.meta['Reaction_ID']
|
176 |
+
if do_basic_standardization:
|
177 |
+
r.thiele()
|
178 |
+
r.standardize()
|
179 |
+
data[key] = cgr2rxnsmiles(r)
|
180 |
+
except StopIteration:
|
181 |
+
break
|
182 |
+
|
183 |
+
print('saving as a dataframe to data/uspto_golden.json.gz')
|
184 |
+
df = pd.DataFrame([data],index=['reaction_smiles']).T
|
185 |
+
df['reaction_reference'] = df.index
|
186 |
+
df.index = range(len(df)) #reindex
|
187 |
+
df.to_json('uspto_golden.json.gz', compression='gzip')
|
188 |
+
|
189 |
+
os.chdir('..')
|
190 |
+
os.chdir('..')
|
191 |
+
return df
|
192 |
+
|
193 |
+
def load_USPTO_fortu(path='data/processed', which='uspto_sm_', is_appl_matrix=False):
|
194 |
+
"""
|
195 |
+
loads the fortunato preprocessed data as
|
196 |
+
dict X containing X['train'], X['valid'], and X['test']
|
197 |
+
as well as the labels containing the corresponding splits
|
198 |
+
returns X, y
|
199 |
+
"""
|
200 |
+
|
201 |
+
X = {}
|
202 |
+
y = {}
|
203 |
+
|
204 |
+
for split in ['train','valid', 'test']:
|
205 |
+
tmp = np.load(f'{path}/{which}{split}.input.smiles.npy', allow_pickle=True)
|
206 |
+
X[split] = []
|
207 |
+
for ii in range(len(tmp)):
|
208 |
+
X[split].append( tmp[ii].split('.'))
|
209 |
+
|
210 |
+
if is_appl_matrix:
|
211 |
+
y[split] = sparse.load_npz(f'{path}/{which}{split}.appl_matrix.npz')
|
212 |
+
else:
|
213 |
+
y[split] = np.load(f'{path}/{which}{split}.labels.classes.npy', allow_pickle=True)
|
214 |
+
print(split, y[split].shape[0], 'samples (', y[split].max() if not is_appl_matrix else y[split].shape[1],'max label)')
|
215 |
+
return X, y
|
216 |
+
|
217 |
+
#TODO one should load in this file pd.read_json('uspto_R_retro.templates.uspto_R_.json.gz')
|
218 |
+
# this only holds the templates.. the other holds everything
|
219 |
+
def load_templates_sm(path = 'data/processed/uspto_sm_templates.df.json.gz', get_complete_df=False):
|
220 |
+
"returns a dict mapping from class index to mapped reaction_smarts from the templates_df"
|
221 |
+
df = pd.read_json(path)
|
222 |
+
if get_complete_df: return df
|
223 |
+
template_dict = {}
|
224 |
+
for row in range(len(df)):
|
225 |
+
template_dict[df.iloc[row]['index']] = df.iloc[row].reaction_smarts
|
226 |
+
return template_dict
|
227 |
+
|
228 |
+
def load_templates_lg(path = 'data/processed/uspto_lg_templates.df.json.gz', get_complete_df=False):
|
229 |
+
return load_templates_sm(path=path, get_complete_df=get_complete_df)
|
230 |
+
|
231 |
+
def load_USPTO_sm():
|
232 |
+
"loads the default dataset"
|
233 |
+
return load_USPTO_fortu(which='uspto_sm_')
|
234 |
+
|
235 |
+
def load_USPTO_lg():
|
236 |
+
"loads the default dataset"
|
237 |
+
return load_USPTO_fortu(which='uspto_lg_')
|
238 |
+
|
239 |
+
def load_USPTO_sm_pretraining():
|
240 |
+
"loads the default application matrix label and dataset"
|
241 |
+
return load_USPTO_fortu(which='uspto_sm_', is_appl_matrix=True)
|
242 |
+
def load_USPTO_lg_pretraining():
|
243 |
+
"loads the default application matrix label and dataset"
|
244 |
+
return load_USPTO_fortu(which='uspto_lg_', is_appl_matrix=True)
|
245 |
+
|
246 |
+
def load_USPTO_df_sm():
|
247 |
+
"loads the USPTO small Sm dataset dataframe"
|
248 |
+
return pd.read_json('data/raw/uspto_sm_reactions.json.gz')
|
249 |
+
|
250 |
+
def load_USPTO_df_lg():
|
251 |
+
"loads the USPTO large Lg dataset dataframe"
|
252 |
+
return pd.read_json('data/raw/uspto_sm_reactions.json.gz')
|
253 |
+
|
254 |
+
def load_USPTO_golden():
|
255 |
+
"loads the golden USPTO dataset"
|
256 |
+
return load_USPTO_fortu(which=f'uspto_golden_', is_appl_matrix=False)
|
257 |
+
|
258 |
+
def load_USPTO(which = 'sm', is_appl_matrix=False):
|
259 |
+
return load_USPTO_fortu(which=f'uspto_{which}_', is_appl_matrix=is_appl_matrix)
|
260 |
+
|
261 |
+
def load_templates(which = 'sm',fdir='data/processed', get_complete_df=False):
|
262 |
+
return load_templates_sm(path=f'{fdir}/uspto_{which}_templates.df.json.gz', get_complete_df=get_complete_df)
|
263 |
+
|
264 |
+
def load_data(dataset, path):
|
265 |
+
splits = ['train', 'valid', 'test']
|
266 |
+
split2smiles = {}
|
267 |
+
split2label = {}
|
268 |
+
split2reactants = {}
|
269 |
+
split2appl = {}
|
270 |
+
split2prod_idx_reactants = {}
|
271 |
+
|
272 |
+
for split in splits:
|
273 |
+
label_fn = os.path.join(path, f'{dataset}_{split}.labels.classes.npy')
|
274 |
+
split2label[split] = np.load(label_fn, allow_pickle=True)
|
275 |
+
|
276 |
+
smiles_fn = os.path.join(path, f'{dataset}_{split}.input.smiles.npy')
|
277 |
+
split2smiles[split] = np.load(smiles_fn, allow_pickle=True)
|
278 |
+
|
279 |
+
reactants_fn = os.path.join(path, f'uspto_R_{split}.reactants.canonical.npy')
|
280 |
+
split2reactants[split] = np.load(reactants_fn, allow_pickle=True)
|
281 |
+
|
282 |
+
|
283 |
+
split2appl[split] = np.load(os.path.join(path, f'{dataset}_{split}.applicability.npy'))
|
284 |
+
|
285 |
+
pir_fn = os.path.join(path, f'{dataset}_{split}.prod.idx.reactants.p')
|
286 |
+
if os.path.isfile(pir_fn):
|
287 |
+
with open(pir_fn, 'rb') as f:
|
288 |
+
split2prod_idx_reactants[split] = pickle.load(f)
|
289 |
+
|
290 |
+
|
291 |
+
if len(split2prod_idx_reactants) == 0:
|
292 |
+
split2prod_idx_reactants = None
|
293 |
+
|
294 |
+
with open(os.path.join(path, f'{dataset}_templates.json'), 'r') as f:
|
295 |
+
label2template = json.load(f)
|
296 |
+
label2template = {int(k): v for k,v in label2template.items()}
|
297 |
+
|
298 |
+
return split2smiles, split2label, split2reactants, split2appl, split2prod_idx_reactants, label2template
|
299 |
+
|
300 |
+
|
301 |
+
def load_dataset_from_csv(csv_path='', split_col='split', input_col='prod_smiles', ssretroeval=False, reactants_col='reactants_can', ret_df=False, **kwargs):
|
302 |
+
"""loads the dataset from a CSV file containing a split-column, and input-column which can be defined,
|
303 |
+
as well as a 'reaction_smarts' column containing the extracted template, a 'label' column (the index of the template)
|
304 |
+
:returns
|
305 |
+
|
306 |
+
"""
|
307 |
+
print('loading X, y from csv')
|
308 |
+
df = pd.read_csv(csv_path)
|
309 |
+
X = {}
|
310 |
+
y = {}
|
311 |
+
|
312 |
+
for spli in set(df[split_col]):
|
313 |
+
#X[spli] = list(df[df[split_col]==spli]['prod_smiles'].apply(lambda k: [k]))
|
314 |
+
X[spli] = list(df[df[split_col]==spli][input_col].apply(lambda k: [k]))
|
315 |
+
y[spli] = (df[df[split_col]==spli]['label']).values
|
316 |
+
print(spli, len(X[spli]), 'samples')
|
317 |
+
|
318 |
+
# template to dict
|
319 |
+
tmp = df[['reaction_smarts','label']].drop_duplicates(subset=['reaction_smarts','label']).sort_values('label')
|
320 |
+
tmp.index= tmp.label
|
321 |
+
template_list = tmp['reaction_smarts'].to_dict()
|
322 |
+
print(len(template_list),'templates')
|
323 |
+
|
324 |
+
if ssretroeval:
|
325 |
+
# setup for ttest
|
326 |
+
test_reactants_can = list(df[df[split_col]=='test'][reactants_col])
|
327 |
+
|
328 |
+
only_in_test = set(y['test']) - set(y['train']).union(set(y['valid']))
|
329 |
+
print('obfuscating', len(only_in_test), 'templates because they are only in test')
|
330 |
+
for ii in only_in_test:
|
331 |
+
template_list[ii] = 'CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCCC>>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCC' #obfuscate them
|
332 |
+
if ret_df:
|
333 |
+
return X, y, template_list, test_reactants_can, df
|
334 |
+
return X, y, template_list, test_reactants_can
|
335 |
+
|
336 |
+
if ret_df:
|
337 |
+
return X, y, template_list, None, df
|
338 |
+
return X, y, template_list, None
|
mhnreact/inference.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
File contains functions that help prepare and download USPTO-related datasets
|
9 |
+
"""
|
10 |
+
|
11 |
+
# Cell
|
12 |
+
from .model import ModelConfig, MHN
|
13 |
+
import torch
|
mhnreact/inspect.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
File contains functions that
|
9 |
+
"""
|
10 |
+
|
11 |
+
from . import model
|
12 |
+
import torch
|
13 |
+
import os
|
14 |
+
|
15 |
+
MODEL_PATH = 'data/model/'
|
16 |
+
|
17 |
+
def smarts2svg(smarts, useSmiles=True, highlightByReactant=True, save_to=''):
|
18 |
+
"""
|
19 |
+
draws smiles of smarts to an SVG and displays it in the Notebook,
|
20 |
+
or optinally can be saved to a file `save_to`
|
21 |
+
adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
|
22 |
+
"""
|
23 |
+
# adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
|
24 |
+
from rdkit import RDConfig
|
25 |
+
from rdkit import Chem
|
26 |
+
from rdkit.Chem import Draw, AllChem
|
27 |
+
from rdkit.Chem.Draw import rdMolDraw2D
|
28 |
+
from rdkit import Geometry
|
29 |
+
import matplotlib.pyplot as plt
|
30 |
+
import matplotlib.cm as cm
|
31 |
+
import matplotlib
|
32 |
+
from IPython.display import SVG, display
|
33 |
+
|
34 |
+
rxn = AllChem.ReactionFromSmarts(smarts,useSmiles=useSmiles)
|
35 |
+
d = Draw.MolDraw2DSVG(900, 100)
|
36 |
+
|
37 |
+
# rxn = AllChem.ReactionFromSmarts('[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][NH2:6]>CC(O)C.[Pt]>[CH3:1][C:2](=[O:3])[NH:6][CH3:5].[OH2:4]',useSmiles=True)
|
38 |
+
colors=[(0.3, 0.7, 0.9),(0.9, 0.7, 0.9),(0.6,0.9,0.3),(0.9,0.9,0.1)]
|
39 |
+
try:
|
40 |
+
d.DrawReaction(rxn,highlightByReactant=highlightByReactant)
|
41 |
+
d.FinishDrawing()
|
42 |
+
|
43 |
+
txt = d.GetDrawingText()
|
44 |
+
# self.assertTrue(txt.find("<svg") != -1)
|
45 |
+
# self.assertTrue(txt.find("</svg>") != -1)
|
46 |
+
|
47 |
+
svg = d.GetDrawingText()
|
48 |
+
svg2 = svg.replace('svg:','')
|
49 |
+
svg3 = SVG(svg2)
|
50 |
+
display(svg3)
|
51 |
+
|
52 |
+
if save_to!='':
|
53 |
+
with open(save_to, 'w') as f_handle:
|
54 |
+
f_handle.write(svg3.data)
|
55 |
+
except:
|
56 |
+
print('Error drawing')
|
57 |
+
|
58 |
+
return svg2
|
59 |
+
|
60 |
+
def list_models(model_path=MODEL_PATH):
|
61 |
+
"""returns a list of loadable models"""
|
62 |
+
return dict(enumerate(list(filter(lambda k: str(k)[-3:]=='.pt', os.listdir(model_path)))))
|
63 |
+
|
64 |
+
def load_clf(model_fn='', model_path=MODEL_PATH, device='cpu', model_type='mhn'):
|
65 |
+
""" returns the model with loaded weights given a filename"""
|
66 |
+
import json
|
67 |
+
config_fn = '_'.join(model_fn.split('_')[-2:]).split('.pt')[0]
|
68 |
+
conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )
|
69 |
+
train_conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )
|
70 |
+
|
71 |
+
# specify the config the saved model had
|
72 |
+
conf = model.ModelConfig(**conf_dict)
|
73 |
+
conf.device = device
|
74 |
+
print(conf.__dict__)
|
75 |
+
|
76 |
+
if model_type == 'staticQK':
|
77 |
+
clf = model.StaticQK(conf)
|
78 |
+
elif model_type == 'mhn':
|
79 |
+
clf = model.MHN(conf)
|
80 |
+
elif model_type == 'segler':
|
81 |
+
clf = model.SeglerBaseline(conf)
|
82 |
+
elif model_type == 'fortunato':
|
83 |
+
clf = model.SeglerBaseline(conf)
|
84 |
+
else:
|
85 |
+
raise NotImplementedError('model_type',model_type,'not found')
|
86 |
+
|
87 |
+
# load the model
|
88 |
+
PATH = model_path+model_fn
|
89 |
+
params = torch.load(PATH, map_location=torch.device('cpu')) #!!!
|
90 |
+
clf.load_state_dict(params, strict=False)
|
91 |
+
if 'templates+noise' in params.keys():
|
92 |
+
print('loading templates+noise')
|
93 |
+
clf.templates = params['templates+noise']
|
94 |
+
#clf.templates.to(clf.config.device)
|
95 |
+
return clf
|
mhnreact/model.py
ADDED
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
Model related functionality
|
9 |
+
"""
|
10 |
+
from .utils import top_k_accuracy
|
11 |
+
from .plotutils import plot_loss, plot_topk, plot_nte
|
12 |
+
from .molutils import convert_smiles_to_fp
|
13 |
+
import os
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from collections import defaultdict
|
18 |
+
from scipy import sparse
|
19 |
+
import logging
|
20 |
+
from tqdm import tqdm
|
21 |
+
import wandb
|
22 |
+
|
23 |
+
log = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
class ChemRXNDataset(torch.utils.data.Dataset):
|
26 |
+
"Torch Dataset for ChemRXN containing Xs: the input as np array, target: the target molecules (or nothing), and ys: the label"
|
27 |
+
def __init__(self, Xs, target, ys, is_smiles=False, fp_size=2048, fingerprint_type='morgan'):
|
28 |
+
self.is_smiles=is_smiles
|
29 |
+
if is_smiles:
|
30 |
+
self.Xs = Xs
|
31 |
+
self.target = target
|
32 |
+
self.fp_size = fp_size
|
33 |
+
self.fingerprint_type = fingerprint_type
|
34 |
+
else:
|
35 |
+
self.Xs = Xs.astype(np.float32)
|
36 |
+
self.target = target.astype(np.float32)
|
37 |
+
self.ys = ys
|
38 |
+
self.ys_is_sparse = isinstance(self.ys, sparse.csr.csr_matrix)
|
39 |
+
|
40 |
+
def __getitem__(self, k):
|
41 |
+
mol_fp = self.Xs[k]
|
42 |
+
if self.is_smiles:
|
43 |
+
mol_fp = convert_smiles_to_fp(mol_fp, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32)
|
44 |
+
|
45 |
+
target = None if self.target is None else self.target[k]
|
46 |
+
if self.is_smiles and self.target:
|
47 |
+
target = convert_smiles_to_fp(target, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32)
|
48 |
+
|
49 |
+
label = self.ys[k]
|
50 |
+
if isinstance(self.ys, sparse.csr.csr_matrix):
|
51 |
+
label = label.toarray()[0]
|
52 |
+
|
53 |
+
return (mol_fp, target, label)
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.Xs)
|
57 |
+
|
58 |
+
class ModelConfig(object):
|
59 |
+
def __init__(self, **kwargs):
|
60 |
+
self.fingerprint_type = kwargs.pop("fingerprint_type", 'morgan')
|
61 |
+
self.template_fp_type = kwargs.pop("template_fp_type", 'rdk')
|
62 |
+
self.num_templates = kwargs.pop("num_templates", 401)
|
63 |
+
self.fp_size = kwargs.pop("fp_size", 2048)
|
64 |
+
self.fp_radius = kwargs.pop("fp_radius", 4)
|
65 |
+
|
66 |
+
self.device = kwargs.pop("device", 'cuda' if torch.cuda.is_available() else 'cpu')
|
67 |
+
self.batch_size = kwargs.pop("batch_size", 32)
|
68 |
+
self.pooling_operation_state_embedding = kwargs.pop('pooling_operation_state_embedding', 'mean')
|
69 |
+
self.pooling_operation_head = kwargs.pop('pooling_operation_head', 'max')
|
70 |
+
|
71 |
+
self.dropout = kwargs.pop('dropout', 0.0)
|
72 |
+
|
73 |
+
self.lr = kwargs.pop('lr', 1e-4)
|
74 |
+
self.optimizer = kwargs.pop("optimizer", "Adam")
|
75 |
+
|
76 |
+
self.activation_function = kwargs.pop('activation_function', 'ReLU')
|
77 |
+
self.verbose = kwargs.pop("verbose", False) # debugging or printing additional warnings / information set tot True
|
78 |
+
|
79 |
+
self.hopf_input_size = kwargs.pop('hopf_input_size', 2048)
|
80 |
+
self.hopf_output_size = kwargs.pop("hopf_output_size", 768)
|
81 |
+
self.hopf_num_heads = kwargs.pop("hopf_num_heads", 1)
|
82 |
+
self.hopf_asso_dim = kwargs.pop("hopf_asso_dim", 768)
|
83 |
+
self.hopf_association_activation = kwargs.pop("hopf_association_activation", None)
|
84 |
+
self.hopf_beta = kwargs.pop("hopf_beta",0.125) # 1/(self.hopf_asso_dim**(1/2) sqrt(d_k)
|
85 |
+
self.norm_input = kwargs.pop("norm_input",False)
|
86 |
+
self.norm_asso = kwargs.pop("norm_asso", False)
|
87 |
+
|
88 |
+
# additional experimental hyperparams
|
89 |
+
if 'hopf_n_layers' in kwargs.keys():
|
90 |
+
self.hopf_n_layers = kwargs.pop('hopf_n_layers', 0)
|
91 |
+
if 'mol_encoder_layers' in kwargs.keys():
|
92 |
+
self.mol_encoder_layers = kwargs.pop('mol_encoder_layers', 1)
|
93 |
+
if 'temp_encoder_layers' in kwargs.keys():
|
94 |
+
self.temp_encoder_layers = kwargs.pop('temp_encoder_layers', 1)
|
95 |
+
if 'encoder_af' in kwargs.keys():
|
96 |
+
self.encoder_af = kwargs.pop('encoder_af', 'ReLU')
|
97 |
+
|
98 |
+
# additional kwargs
|
99 |
+
for key, value in kwargs.items():
|
100 |
+
try:
|
101 |
+
setattr(self, key, value)
|
102 |
+
except AttributeError as err:
|
103 |
+
log.error(f"Can't set {key} with value {value} for {self}")
|
104 |
+
raise err
|
105 |
+
|
106 |
+
|
107 |
+
class Encoder(nn.Module):
|
108 |
+
"""Simple FFNN"""
|
109 |
+
def __init__(self, input_size: int = 2048, output_size: int = 1024,
|
110 |
+
num_layers: int = 1, dropout: float = 0.3, af_name: str ='None',
|
111 |
+
norm_in: bool = False, norm_out: bool = False):
|
112 |
+
super().__init__()
|
113 |
+
self.ws = []
|
114 |
+
self.setup_af(af_name)
|
115 |
+
self.norm_in = (lambda k: k) if not norm_in else torch.nn.LayerNorm(input_size, elementwise_affine=False)
|
116 |
+
self.norm_out = (lambda k: k) if not norm_out else torch.nn.LayerNorm(output_size, elementwise_affine=False)
|
117 |
+
self.setup_ff(input_size, output_size, num_layers)
|
118 |
+
self.dropout = nn.Dropout(p=dropout)
|
119 |
+
|
120 |
+
def forward(self, x: torch.Tensor):
|
121 |
+
x = self.norm_in(x)
|
122 |
+
for i, w in enumerate(self.ws):
|
123 |
+
if i==(len(self.ws)-1):
|
124 |
+
x = self.dropout(w(x)) # all except last haf ff_af
|
125 |
+
else:
|
126 |
+
x = self.dropout(self.af(w(x)))
|
127 |
+
x = self.norm_out(x)
|
128 |
+
return x
|
129 |
+
|
130 |
+
def setup_ff(self, input_size:int, output_size:int, num_layers=1):
|
131 |
+
"""setup feed-forward NN with n-layers"""
|
132 |
+
for n in range(0, num_layers):
|
133 |
+
w = nn.Linear(input_size if n==0 else output_size, output_size)
|
134 |
+
torch.nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init
|
135 |
+
setattr(self, f'W_{n}', w) # consider doing a step-wise reduction
|
136 |
+
self.ws.append(getattr(self, f'W_{n}'))
|
137 |
+
|
138 |
+
def setup_af(self, af_name : str):
|
139 |
+
"""set activation function"""
|
140 |
+
if af_name is None or (af_name == 'None'):
|
141 |
+
self.af = lambda k: k
|
142 |
+
else:
|
143 |
+
try:
|
144 |
+
self.af = getattr(nn, af_name)()
|
145 |
+
except AttributeError as err:
|
146 |
+
log.error(f"Can't find activation-function {af_name} in torch.nn")
|
147 |
+
raise err
|
148 |
+
|
149 |
+
|
150 |
+
class MoleculeEncoder(Encoder):
|
151 |
+
"""
|
152 |
+
Class for Molecule encoder: can be any class mapping Smiles to a Vector (preferable differentiable ;)
|
153 |
+
"""
|
154 |
+
def __init__(self, config):
|
155 |
+
self.config = config
|
156 |
+
|
157 |
+
class FPMolEncoder(Encoder):
|
158 |
+
"""
|
159 |
+
Fingerprint Based Molecular encoder
|
160 |
+
"""
|
161 |
+
def __init__(self, config):
|
162 |
+
super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads,
|
163 |
+
output_size = config.hopf_asso_dim*config.hopf_num_heads,
|
164 |
+
num_layers = config.mol_encoder_layers,
|
165 |
+
dropout = config.dropout,
|
166 |
+
af_name = config.encoder_af,
|
167 |
+
norm_in = config.norm_input,
|
168 |
+
norm_out = config.norm_asso,
|
169 |
+
)
|
170 |
+
# number of layers = self.config.mol_encoder_layers
|
171 |
+
# layer-dimension = self.config.hopf_asso_dim
|
172 |
+
# activation-function = self.config.af
|
173 |
+
|
174 |
+
self.config = config
|
175 |
+
|
176 |
+
def forward_smiles(self, list_of_smiles: list):
|
177 |
+
fp_tensor = self.convert_smiles_to_tensor(list_of_smiles)
|
178 |
+
return self.forward(fp_tensor)
|
179 |
+
|
180 |
+
def convert_smiles_to_tensor(self, list_of_smiles):
|
181 |
+
fps = convert_smiles_to_fp(list_of_smiles, fp_size=self.config.fp_size,
|
182 |
+
which=self.config.fingerprint_type, radius=self.config.fp_radius)
|
183 |
+
fps_tensor = torch.from_numpy(fps.astype(np.float)).to(dtype=torch.float).to(self.config.device)
|
184 |
+
return fps_tensor
|
185 |
+
|
186 |
+
class TemplateEncoder(Encoder):
|
187 |
+
"""
|
188 |
+
Class for Template encoder: can be any class mapping a Smarts-Reaction to a Vector (preferable differentiable ;)
|
189 |
+
"""
|
190 |
+
def __init__(self, config):
|
191 |
+
super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads,
|
192 |
+
output_size = config.hopf_asso_dim*config.hopf_num_heads,
|
193 |
+
num_layers = config.temp_encoder_layers,
|
194 |
+
dropout = config.dropout,
|
195 |
+
af_name = config.encoder_af,
|
196 |
+
norm_in = config.norm_input,
|
197 |
+
norm_out = config.norm_asso,
|
198 |
+
)
|
199 |
+
self.config = config
|
200 |
+
#number of layers
|
201 |
+
#template fingerprint type
|
202 |
+
#random template threshold
|
203 |
+
#reactant pooling
|
204 |
+
if config.temp_encoder_layers==0:
|
205 |
+
print('No Key-Projection = Static Key/Templates')
|
206 |
+
assert self.config.hopf_asso_dim==self.config.fp_size
|
207 |
+
self.wks = []
|
208 |
+
|
209 |
+
|
210 |
+
class MHN(nn.Module):
|
211 |
+
"""
|
212 |
+
MHN - modern Hopfield Network -- for Template relevance prediction
|
213 |
+
"""
|
214 |
+
def __init__(self, config=None, layer2weight=0.05, use_template_encoder=True):
|
215 |
+
super().__init__()
|
216 |
+
if config:
|
217 |
+
self.config = config
|
218 |
+
else:
|
219 |
+
self.config = ModelConfig()
|
220 |
+
self.beta = self.config.hopf_beta
|
221 |
+
# hopf_num_heads
|
222 |
+
self.mol_encoder = FPMolEncoder(self.config)
|
223 |
+
if use_template_encoder:
|
224 |
+
self.template_encoder = TemplateEncoder(self.config)
|
225 |
+
|
226 |
+
self.W_v = None
|
227 |
+
self.layer2weight = layer2weight
|
228 |
+
|
229 |
+
# more MHN layers -- added recursively
|
230 |
+
if hasattr(self.config, 'hopf_n_layers'):
|
231 |
+
di = self.config.__dict__
|
232 |
+
di['hopf_n_layers'] -= 1
|
233 |
+
if di['hopf_n_layers']>0:
|
234 |
+
conf_wo_hopf_nlayers = ModelConfig(**di)
|
235 |
+
self.layer = MHN(conf_wo_hopf_nlayers)
|
236 |
+
if di['hopf_n_layers']!=0:
|
237 |
+
self.W_v = nn.Linear(self.config.hopf_asso_dim, self.config.hopf_input_size)
|
238 |
+
torch.nn.init.kaiming_normal_(self.W_v.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init
|
239 |
+
|
240 |
+
self.softmax = torch.nn.Softmax(dim=1)
|
241 |
+
|
242 |
+
self.lossfunction = nn.CrossEntropyLoss(reduction='none')#, weight=class_weights)
|
243 |
+
self.pretrain_lossfunction = nn.BCEWithLogitsLoss(reduction='none')#, weight=class_weights)
|
244 |
+
|
245 |
+
self.lr = self.config.lr
|
246 |
+
|
247 |
+
if self.config.hopf_association_activation is None or (self.config.hopf_association_activation.lower()=='none'):
|
248 |
+
self.af = lambda k: k
|
249 |
+
else:
|
250 |
+
self.af = getattr(nn, self.config.hopf_association_activation)()
|
251 |
+
|
252 |
+
self.pooling_operation_head = getattr(torch, self.config.pooling_operation_head)
|
253 |
+
|
254 |
+
self.X = None # templates projected to Hopfield Layer
|
255 |
+
|
256 |
+
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr)
|
257 |
+
self.steps = 0
|
258 |
+
self.hist = defaultdict(list)
|
259 |
+
self.to(self.config.device)
|
260 |
+
|
261 |
+
def set_templates(self, template_list, which='rdk', fp_size=None, radius=2, learnable=False, njobs=1, only_templates_in_batch=False):
|
262 |
+
self.template_list = template_list.copy()
|
263 |
+
if fp_size is None:
|
264 |
+
fp_size = self.config.fp_size
|
265 |
+
if len(template_list)>=100000:
|
266 |
+
import math
|
267 |
+
print('batch-wise template_calculation')
|
268 |
+
bs = 30000
|
269 |
+
final_temp_emb = torch.zeros((len(template_list), fp_size)).float().to(self.config.device)
|
270 |
+
for b in range(math.ceil(len(template_list)//bs)+1):
|
271 |
+
self.template_list = template_list[bs*b:min(bs*(b+1), len(template_list))]
|
272 |
+
templ_emb = self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch)
|
273 |
+
final_temp_emb[bs*b:min(bs*(b+1), len(template_list))] = torch.from_numpy(templ_emb)
|
274 |
+
self.templates = final_temp_emb
|
275 |
+
else:
|
276 |
+
self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch)
|
277 |
+
|
278 |
+
self.set_templates_recursively()
|
279 |
+
|
280 |
+
def set_templates_recursively(self):
|
281 |
+
if 'hopf_n_layers' in self.config.__dict__.keys():
|
282 |
+
if self.config.hopf_n_layers >0:
|
283 |
+
self.layer.templates = self.templates
|
284 |
+
self.layer.set_templates_recursively()
|
285 |
+
|
286 |
+
def update_template_embedding(self,fp_size=2048, radius=4, which='rdk', learnable=False, njobs=1, only_templates_in_batch=False):
|
287 |
+
print('updating template-embedding; (just computing the template-fingerprint and using that)')
|
288 |
+
bs = self.config.batch_size
|
289 |
+
|
290 |
+
split_template_list = [str(t).split('>')[0].split('.') for t in self.template_list]
|
291 |
+
templates_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs)
|
292 |
+
|
293 |
+
split_template_list = [str(t).split('>')[-1].split('.') for t in self.template_list]
|
294 |
+
reactants_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs)
|
295 |
+
|
296 |
+
template_representation = templates_np-(reactants_np*0.5)
|
297 |
+
if learnable:
|
298 |
+
self.templates = torch.nn.Parameter(torch.from_numpy(template_representation).float(), requires_grad=True).to(self.config.device)
|
299 |
+
self.register_parameter(name='templates', param=self.templates)
|
300 |
+
else:
|
301 |
+
if only_templates_in_batch:
|
302 |
+
self.templates_np = template_representation
|
303 |
+
else:
|
304 |
+
self.templates = torch.from_numpy(template_representation).float().to(self.config.device)
|
305 |
+
|
306 |
+
return template_representation
|
307 |
+
|
308 |
+
|
309 |
+
def np_fp_to_tensor(self, np_fp):
|
310 |
+
return torch.from_numpy(np_fp.astype(np.float64)).to(self.config.device).float()
|
311 |
+
|
312 |
+
def masked_loss_fun(self, loss_fun, h_out, ys_batch):
|
313 |
+
if loss_fun == self.BCEWithLogitsLoss:
|
314 |
+
mask = (ys_batch != -1).float()
|
315 |
+
ys_batch = ys_batch.float()
|
316 |
+
else:
|
317 |
+
mask = (ys_batch.long() != -1).long()
|
318 |
+
mask_sum = int(mask.sum().cpu().numpy())
|
319 |
+
if mask_sum == 0:
|
320 |
+
return 0
|
321 |
+
|
322 |
+
ys_batch = ys_batch * mask
|
323 |
+
|
324 |
+
loss = (loss_fun(h_out, ys_batch * mask) * mask.float()).sum() / mask_sum # only mean from non -1
|
325 |
+
return loss
|
326 |
+
|
327 |
+
def compute_losses(self, out, ys_batch, head_loss_weight=None):
|
328 |
+
|
329 |
+
if len(ys_batch.shape)==2:
|
330 |
+
if ys_batch.shape[1]==self.config.num_templates: # it is in pretraining_mode
|
331 |
+
loss = self.pretrain_lossfunction(out, ys_batch.float()).mean()
|
332 |
+
else:
|
333 |
+
# legacy from policyNN
|
334 |
+
loss = self.lossfunction(out, ys_batch[:, 2]).mean() # WARNING: HEAD4 Reaction Template is ys[:,2]
|
335 |
+
else:
|
336 |
+
loss = self.lossfunction(out, ys_batch).mean()
|
337 |
+
return loss
|
338 |
+
|
339 |
+
def forward_smiles(self, list_of_smiles, templates=None):
|
340 |
+
state_tensor = self.mol_encoder.convert_smiles_to_tensor(list_of_smiles)
|
341 |
+
return self.forward(state_tensor, templates=templates)
|
342 |
+
|
343 |
+
def forward(self, m, templates=None):
|
344 |
+
"""
|
345 |
+
m: molecule in the form batch x fingerprint
|
346 |
+
templates: None or newly given templates if not instanciated
|
347 |
+
returns logits ranking the templates for each molecule
|
348 |
+
"""
|
349 |
+
#states_emb = self.fcfe(state_fp)
|
350 |
+
bs = m.shape[0] #batch_size
|
351 |
+
#templates = self.temp_emb(torch.arange(0,2000).long())
|
352 |
+
if (templates is None) and (self.X is None) and (self.templates is None):
|
353 |
+
raise Exception('Either pass in templates, or init templates by runnting clf.set_templates')
|
354 |
+
n_temp = len(templates) if templates is not None else len(self.templates)
|
355 |
+
if self.training or (templates is None) or (self.X is not None):
|
356 |
+
templates = templates if templates is not None else self.templates
|
357 |
+
X = self.template_encoder(templates)
|
358 |
+
else:
|
359 |
+
X = self.X # precomputed from last forward run
|
360 |
+
|
361 |
+
Xi = self.mol_encoder(m)
|
362 |
+
|
363 |
+
Xi = Xi.view(bs, self.config.hopf_num_heads, self.config.hopf_asso_dim) # [bs, H, A]
|
364 |
+
X = X.view(1, n_temp, self.config.hopf_asso_dim, self.config.hopf_num_heads) #[1, T, A, H]
|
365 |
+
|
366 |
+
XXi = torch.tensordot(Xi, X, dims=[(2,1), (2,0)]) # AxA -> [bs, T, H]
|
367 |
+
|
368 |
+
# pooling over heads
|
369 |
+
if self.config.hopf_num_heads<=1:
|
370 |
+
#QKt_pooled = QKt
|
371 |
+
XXi = XXi[:,:,0] #torch.squeeze(QKt, dim=2)
|
372 |
+
else:
|
373 |
+
XXi = self.pooling_operation_head(XXi, dim=2) # default is max pooling over H [bs, T]
|
374 |
+
if (self.config.pooling_operation_head =='max') or (self.config.pooling_operation_head =='min'):
|
375 |
+
XXi = XXi[0] #max and min also return the indices =S
|
376 |
+
|
377 |
+
out = self.beta*XXi # [bs, T, H] # softmax over dim=1 #pooling_operation_head
|
378 |
+
|
379 |
+
self.xinew = self.softmax(out)@X.view(n_temp, self.config.hopf_asso_dim) # [bs,T]@[T,emb] -> [bs,emb]
|
380 |
+
|
381 |
+
if self.W_v:
|
382 |
+
# call layers recursive
|
383 |
+
hopfout = self.W_v(self.xinew) # [bs,emb]@[emb,hopf_inp] --> [bs, hopf_inp]
|
384 |
+
# TODO check if using x_pooled or if not going through mol_encoder again
|
385 |
+
hopfout = hopfout + m # skip-connection
|
386 |
+
# give it to the next layer
|
387 |
+
out2 = self.layer.forward(hopfout) #templates=self.W_v(self.K)
|
388 |
+
out = out*(1-self.layer2weight)+out2*self.layer2weight
|
389 |
+
|
390 |
+
return out
|
391 |
+
|
392 |
+
def train_from_np(self, Xs, targets, ys, is_smiles=False, epochs=2, lr=0.001, bs=32,
|
393 |
+
permute_batches=False, shuffle=True, optimizer=None,
|
394 |
+
use_dataloader=True, verbose=False,
|
395 |
+
wandb=None, scheduler=None, only_templates_in_batch=False):
|
396 |
+
"""
|
397 |
+
Xs in the form sample x states
|
398 |
+
targets
|
399 |
+
ys in the form sample x [y_h1, y_h2, y_h3, y_h4]
|
400 |
+
"""
|
401 |
+
self.train()
|
402 |
+
if optimizer is None:
|
403 |
+
try:
|
404 |
+
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr if lr is None else lr)
|
405 |
+
except AttributeError as err:
|
406 |
+
log.error(f"Can't find optimizer {config.optimizer} in torch.optim")
|
407 |
+
raise err
|
408 |
+
optimizer = self.optimizer
|
409 |
+
|
410 |
+
dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles,
|
411 |
+
fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type)
|
412 |
+
|
413 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None,
|
414 |
+
batch_sampler=None, num_workers=0, collate_fn=None,
|
415 |
+
pin_memory=False, drop_last=False, timeout=0,
|
416 |
+
worker_init_fn=None)
|
417 |
+
|
418 |
+
for epoch in range(epochs): # loop over the dataset multiple times
|
419 |
+
running_loss = 0.0
|
420 |
+
running_loss_dict = defaultdict(int)
|
421 |
+
batch_order = range(0, len(Xs), bs)
|
422 |
+
if permute_batches:
|
423 |
+
batch_order = np.random.permutation(batch_order)
|
424 |
+
|
425 |
+
for step, s in tqdm(enumerate(dataloader),mininterval=2):
|
426 |
+
batch = [b.to(self.config.device, non_blocking=True) for b in s]
|
427 |
+
Xs_batch, target_batch, ys_batch = batch
|
428 |
+
|
429 |
+
# zero the parameter gradients
|
430 |
+
optimizer.zero_grad()
|
431 |
+
|
432 |
+
# forward + backward + optimize
|
433 |
+
out = self.forward(Xs_batch)
|
434 |
+
total_loss = self.compute_losses(out, ys_batch)
|
435 |
+
|
436 |
+
loss_dict = {'CE_loss': total_loss}
|
437 |
+
|
438 |
+
total_loss.backward()
|
439 |
+
|
440 |
+
optimizer.step()
|
441 |
+
if scheduler:
|
442 |
+
scheduler.step()
|
443 |
+
self.steps += 1
|
444 |
+
|
445 |
+
# print statistics
|
446 |
+
for k in loss_dict:
|
447 |
+
running_loss_dict[k] += loss_dict[k].item()
|
448 |
+
try:
|
449 |
+
running_loss += total_loss.item()
|
450 |
+
except:
|
451 |
+
running_loss += 0
|
452 |
+
|
453 |
+
rs = min(100,len(Xs)//bs) # reporting/logging steps
|
454 |
+
if step % rs == (rs-1): # print every 2000 mini-batches
|
455 |
+
if verbose: print('[%d, %5d] loss: %.3f' %
|
456 |
+
(epoch + 1, step + 1, running_loss / rs))
|
457 |
+
self.hist['step'].append(self.steps)
|
458 |
+
self.hist['loss'].append(running_loss/rs)
|
459 |
+
self.hist['trianing_running_loss'].append(running_loss/rs)
|
460 |
+
|
461 |
+
[self.hist[k].append(running_loss_dict[k]/rs) for k in running_loss_dict]
|
462 |
+
|
463 |
+
if wandb:
|
464 |
+
wandb.log({'trianing_running_loss': running_loss / rs})
|
465 |
+
|
466 |
+
running_loss = 0.0
|
467 |
+
running_loss_dict = defaultdict(int)
|
468 |
+
|
469 |
+
if verbose: print('Finished Training')
|
470 |
+
return optimizer
|
471 |
+
|
472 |
+
def evaluate(self, Xs, targets, ys, split='test', is_smiles=False, bs = 32, shuffle=False, wandb=None, only_loss=False):
|
473 |
+
self.eval()
|
474 |
+
y_preds = np.zeros( (ys.shape[0], self.config.num_templates), dtype=np.float16)
|
475 |
+
|
476 |
+
loss_metrics = defaultdict(int)
|
477 |
+
new_hist = defaultdict(float)
|
478 |
+
with torch.no_grad():
|
479 |
+
dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles,
|
480 |
+
fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type)
|
481 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None,
|
482 |
+
batch_sampler=None, num_workers=0, collate_fn=None,
|
483 |
+
pin_memory=False, drop_last=False, timeout=0,
|
484 |
+
worker_init_fn=None)
|
485 |
+
|
486 |
+
#for step, s in eoutputs = self.forward(batch[0], batchnumerate(range(0, len(Xs), bs)):
|
487 |
+
for step, batch in enumerate(dataloader):#
|
488 |
+
batch = [b.to(self.config.device, non_blocking=True) for b in batch]
|
489 |
+
ys_batch = batch[2]
|
490 |
+
|
491 |
+
if hasattr(self, 'templates_np'):
|
492 |
+
outputs = []
|
493 |
+
for ii in range(10):
|
494 |
+
tlen = len(self.templates_np)
|
495 |
+
i_tlen = tlen//10
|
496 |
+
templates = torch.from_numpy(self.templates_np[(i_tlen*ii):min(i_tlen*(ii+1), tlen)]).float().to(self.config.device)
|
497 |
+
outputs.append( self.forward(batch[0], templates = templates ) )
|
498 |
+
outputs = torch.cat(outputs, dim=0)
|
499 |
+
|
500 |
+
else:
|
501 |
+
outputs = self.forward(batch[0])
|
502 |
+
|
503 |
+
loss = self.compute_losses(outputs, ys_batch, None)
|
504 |
+
|
505 |
+
# not quite right because in every batch there might be different number of valid samples
|
506 |
+
weight = 1/len(batch[0])#len(Xs[s:min(s + bs, len(Xs))]) / len(Xs)
|
507 |
+
|
508 |
+
loss_metrics['loss'] += (loss.item())
|
509 |
+
|
510 |
+
if len(ys.shape)>1:
|
511 |
+
outputs = self.softmax(outputs) if not (ys.shape[1]==self.config.num_templates) else torch.sigmoid(outputs)
|
512 |
+
else:
|
513 |
+
outputs = self.softmax(outputs)
|
514 |
+
|
515 |
+
outputs_np = [None if o is None else o.to('cpu').numpy().astype(np.float16) for o in outputs]
|
516 |
+
|
517 |
+
if not only_loss:
|
518 |
+
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
|
519 |
+
topkacc, mrocc = top_k_accuracy(ys_batch, outputs, k=ks, ret_arocc=True, ret_mrocc=False)
|
520 |
+
# mrocc -- median rank of correct choice
|
521 |
+
for k, tkacc in zip(ks, topkacc):
|
522 |
+
#iterative average update
|
523 |
+
new_hist[f't{k}_acc_{split}'] += (tkacc-new_hist[f't{k}_acc_{split}']) / (step+1)
|
524 |
+
# todo weight by batch-size
|
525 |
+
new_hist[f'meanrank_{split}'] = mrocc
|
526 |
+
|
527 |
+
y_preds[step*bs : min((step+1)*bs,len(y_preds))] = outputs_np
|
528 |
+
|
529 |
+
|
530 |
+
new_hist[f'steps_{split}'] = (self.steps)
|
531 |
+
new_hist[f'loss_{split}'] = (loss_metrics['loss'] / (step+1))
|
532 |
+
|
533 |
+
for k in new_hist:
|
534 |
+
self.hist[k].append(new_hist[k])
|
535 |
+
|
536 |
+
if wandb:
|
537 |
+
wandb.log(new_hist)
|
538 |
+
|
539 |
+
|
540 |
+
self.hist[f'loss_{split}'].append(loss_metrics[f'loss'] / (step+1))
|
541 |
+
|
542 |
+
return y_preds
|
543 |
+
|
544 |
+
def save_hist(self, prefix='', postfix=''):
|
545 |
+
HIST_PATH = 'data/hist/'
|
546 |
+
if not os.path.exists(HIST_PATH):
|
547 |
+
os.mkdir(HIST_PATH)
|
548 |
+
fn_hist = HIST_PATH+prefix+postfix+'.csv'
|
549 |
+
with open(fn_hist, 'w') as fh:
|
550 |
+
print(dict(self.hist), file=fh)
|
551 |
+
return fn_hist
|
552 |
+
|
553 |
+
def save_model(self, prefix='', postfix='', name_as_conf=False):
|
554 |
+
MODEL_PATH = 'data/model/'
|
555 |
+
if not os.path.exists(MODEL_PATH):
|
556 |
+
os.mkdir(MODEL_PATH)
|
557 |
+
if name_as_conf:
|
558 |
+
confi_str = str(self.config.__dict__.values()).replace("'","").replace(': ','_').replace(', ',';')
|
559 |
+
else:
|
560 |
+
confi_str = ''
|
561 |
+
model_name = prefix+confi_str+postfix+'.pt'
|
562 |
+
torch.save(self.state_dict(), MODEL_PATH+model_name)
|
563 |
+
return MODEL_PATH+model_name
|
564 |
+
|
565 |
+
def plot_loss(self):
|
566 |
+
plot_loss(self.hist)
|
567 |
+
|
568 |
+
def plot_topk(self, sets=['train', 'valid', 'test'], with_last = 2):
|
569 |
+
plot_topk(self.hist, sets=sets, with_last = with_last)
|
570 |
+
|
571 |
+
def plot_nte(self, last_cpt=1, dataset='Sm', include_bar=True):
|
572 |
+
plot_nte(self.hist, dataset=dataset, last_cpt=last_cpt, include_bar=include_bar)
|
573 |
+
|
574 |
+
|
575 |
+
class SeglerBaseline(MHN):
|
576 |
+
"""FFNN - only the Molecule Encoder + an output projection"""
|
577 |
+
def __init__(self, config=None):
|
578 |
+
config.template_fp_type = 'none'
|
579 |
+
config.temp_encoder_layers = 0
|
580 |
+
super().__init__(config, use_template_encoder=False)
|
581 |
+
self.W_out = torch.nn.Linear(config.hopf_asso_dim, config.num_templates)
|
582 |
+
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr)
|
583 |
+
self.steps = 0
|
584 |
+
self.hist = defaultdict(list)
|
585 |
+
self.to(self.config.device)
|
586 |
+
|
587 |
+
def forward(self, m, templates=None):
|
588 |
+
"""
|
589 |
+
m: molecule in the form batch x fingerprint
|
590 |
+
templates: won't be used in this case
|
591 |
+
returns logits ranking the templates for each molecule
|
592 |
+
"""
|
593 |
+
bs = m.shape[0] #batch_size
|
594 |
+
Xi = self.mol_encoder(m)
|
595 |
+
Xi = self.mol_encoder.af(Xi) # is not applied in encoder for last layer
|
596 |
+
out = self.W_out(Xi) # [bs, T] # softmax over dim=1
|
597 |
+
return out
|
598 |
+
|
599 |
+
class StaticQK(MHN):
|
600 |
+
""" Static QK baseline - beware to have the same fingerprint for mol_encoder as for the template_encoder (fp2048 r4 rdk by default)"""
|
601 |
+
def __init__(self, config=None):
|
602 |
+
if config:
|
603 |
+
self.config = config
|
604 |
+
else:
|
605 |
+
self.config = ModelConfig()
|
606 |
+
super().__init__(config)
|
607 |
+
|
608 |
+
self.fp_size = 2048
|
609 |
+
self.fingerprint_type = 'rdk'
|
610 |
+
self.beta = 1
|
611 |
+
|
612 |
+
def update_template_embedding(self, which='rdk', fp_size=2048, radius=4, learnable=False):
|
613 |
+
bs = self.config.batch_size
|
614 |
+
split_template_list = [t.split('>>')[0].split('.') for t in self.template_list]
|
615 |
+
self.templates = torch.from_numpy(convert_smiles_to_fp(split_template_list,
|
616 |
+
is_smarts=True, fp_size=fp_size,
|
617 |
+
radius=radius, which=which).max(1)).float().to(self.config.device)
|
618 |
+
|
619 |
+
|
620 |
+
def forward(self, m, templates=None):
|
621 |
+
"""
|
622 |
+
|
623 |
+
"""
|
624 |
+
#states_emb = self.fcfe(state_fp)
|
625 |
+
bs = m.shape[0] #batch_size
|
626 |
+
|
627 |
+
Xi = m #[bs, emb]
|
628 |
+
X = self.templates #[T, emb])
|
629 |
+
|
630 |
+
XXi = [email protected] # [bs, T]
|
631 |
+
|
632 |
+
# normalize
|
633 |
+
t_sum = templates.sum(1) #[T]
|
634 |
+
t_sum = t_sum.view(1,-1).expand(bs, -1) #[bs, T]
|
635 |
+
XXi = XXi / t_sum
|
636 |
+
|
637 |
+
# not neccecaire because it is not trained
|
638 |
+
out = self.beta*XXi # [bs, T] # softmax over dim=1
|
639 |
+
return out
|
640 |
+
|
641 |
+
class Retrosim(StaticQK):
|
642 |
+
""" Retrosim-like baseline only for template relevance prediction """
|
643 |
+
def fit_with_train(self, X_fp_train, y_train):
|
644 |
+
self.templates = torch.from_numpy(X_fp_train).float().to(self.config.device)
|
645 |
+
# train_samples, num_templates
|
646 |
+
self.sample2acttemplate = torch.nn.functional.one_hot(torch.from_numpy(y_train), self.config.num_templates).float()
|
647 |
+
tmpnorm = self.sample2acttemplate.sum(0)
|
648 |
+
tmpnorm[tmpnorm==0] = 1
|
649 |
+
self.sample2acttemplate = (self.sample2acttemplate / tmpnorm).to(self.config.device) # results in an average after dot product
|
650 |
+
|
651 |
+
def forward(self, m, templates=None):
|
652 |
+
"""
|
653 |
+
"""
|
654 |
+
out = super().forward(m, templates=templates)
|
655 |
+
# bs, train_samples
|
656 |
+
|
657 |
+
# map out to actual templates
|
658 |
+
out = out @ self.sample2acttemplate
|
659 |
+
|
660 |
+
return out
|
mhnreact/molutils.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl, Philipp Renz
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
Molutils contains functions that aid in handling molecules or templates
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import re
|
13 |
+
import warnings
|
14 |
+
from itertools import product, permutations
|
15 |
+
|
16 |
+
from multiprocessing import Pool
|
17 |
+
from tqdm.contrib.concurrent import process_map
|
18 |
+
from tqdm.notebook import tqdm
|
19 |
+
import swifter
|
20 |
+
|
21 |
+
import rdkit.RDLogger as rkl
|
22 |
+
from rdkit import Chem
|
23 |
+
from rdkit.Chem import AllChem
|
24 |
+
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprint
|
25 |
+
from rdkit.Chem.rdmolops import FastFindRings
|
26 |
+
from rdkit.Chem.rdMHFPFingerprint import MHFPEncoder
|
27 |
+
|
28 |
+
from scipy import sparse
|
29 |
+
from sklearn.feature_extraction import DictVectorizer
|
30 |
+
|
31 |
+
import warnings
|
32 |
+
import rdkit.RDLogger as rkl
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
log = logging.getLogger(__name__)
|
36 |
+
logger = rkl.logger()
|
37 |
+
|
38 |
+
def remove_attom_mapping(smiles):
|
39 |
+
""" removes a number after a ':' """
|
40 |
+
return re.sub(r':\d+', '', str(smiles))
|
41 |
+
|
42 |
+
|
43 |
+
def canonicalize_smi(smi, is_smarts=False, remove_atom_mapping=True):
|
44 |
+
r"""
|
45 |
+
Canonicalize SMARTS from https://github.com/rxn4chemistry/rxnfp/blob/master/rxnfp/tokenization.py#L249
|
46 |
+
"""
|
47 |
+
mol = Chem.MolFromSmarts(smi)
|
48 |
+
if not mol:
|
49 |
+
raise ValueError("Molecule not canonicalizable")
|
50 |
+
if remove_atom_mapping:
|
51 |
+
for atom in mol.GetAtoms():
|
52 |
+
if atom.HasProp("molAtomMapNumber"):
|
53 |
+
atom.ClearProp("molAtomMapNumber")
|
54 |
+
return Chem.MolToSmiles(mol)
|
55 |
+
|
56 |
+
|
57 |
+
def canonicalize_template(smarts):
|
58 |
+
smarts = str(smarts)
|
59 |
+
# remove attom-mapping
|
60 |
+
#smarts = remove_attom_mapping(smarts)
|
61 |
+
|
62 |
+
# order the list of smiles + canonicalize it
|
63 |
+
results = []
|
64 |
+
for part in smarts.split('>>'):
|
65 |
+
a = part.split('.')
|
66 |
+
a = [canonicalize_smi(x, is_smarts=True, remove_atom_mapping=True) for x in a]
|
67 |
+
#a = [remove_attom_mapping(x) for x in a]
|
68 |
+
a.sort()
|
69 |
+
results.append( '.'.join(a) )
|
70 |
+
return '>>'.join(results)
|
71 |
+
|
72 |
+
def ebv2np(ebv):
|
73 |
+
"""Explicit bit vector returned by rdkit to numpy array. """
|
74 |
+
return np.frombuffer(bytes(ebv.ToBitString(), 'utf-8'), 'u1') - ord('0')
|
75 |
+
|
76 |
+
def smiles2morgan(smiles, radius=2):
|
77 |
+
""" computes ecfp from smiles """
|
78 |
+
return GetMorganFingerprint(smiles, radius)
|
79 |
+
|
80 |
+
|
81 |
+
def getFingerprint(smiles, fp_size=4096, radius=2, is_smarts=False, which='morgan', sanitize=True):
|
82 |
+
"""maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc"""
|
83 |
+
if isinstance(smiles, list):
|
84 |
+
return np.array([getFingerprint(smi, fp_size, radius, is_smarts, which) for smi in smiles]).max(0) # max pooling if it's list of lists
|
85 |
+
|
86 |
+
if is_smarts:
|
87 |
+
mol = Chem.MolFromSmarts(str(smiles), mergeHs=False)
|
88 |
+
#mol.UpdatePropertyCache() #Correcting valence info
|
89 |
+
#FastFindRings(mol) #Providing ring info
|
90 |
+
else:
|
91 |
+
mol = Chem.MolFromSmiles(str(smiles), sanitize=False)
|
92 |
+
|
93 |
+
if mol is None:
|
94 |
+
msg = f"{smiles} couldn't be converted to a fingerprint using 0's instead"
|
95 |
+
logger.warning(msg)
|
96 |
+
#warnings.warn(msg)
|
97 |
+
return np.zeros(fp_size).astype(np.bool)
|
98 |
+
|
99 |
+
if sanitize:
|
100 |
+
faild_op = Chem.SanitizeMol(mol, catchErrors=True)
|
101 |
+
FastFindRings(mol) #Providing ring info
|
102 |
+
|
103 |
+
mol.UpdatePropertyCache(strict=False) #Correcting valence info # important operation
|
104 |
+
|
105 |
+
def mol2np(mol, which, fp_size):
|
106 |
+
is_dict = False
|
107 |
+
if which=='morgan':
|
108 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=fp_size, useFeatures=False, useChirality=True)
|
109 |
+
elif which=='rdk':
|
110 |
+
fp = Chem.RDKFingerprint(mol, fpSize=fp_size, maxPath=6)
|
111 |
+
elif which=='rdkc':
|
112 |
+
# https://greglandrum.github.io/rdkit-blog/similarity/reference/2021/05/26/similarity-threshold-observations1.html
|
113 |
+
# -- maxPath 6 found to be better for retrieval in databases
|
114 |
+
fp = AllChem.UnfoldedRDKFingerprintCountBased(mol, maxPath=6).GetNonzeroElements()
|
115 |
+
is_dict = True
|
116 |
+
elif which=='morganc':
|
117 |
+
fp = AllChem.GetMorganFingerprint(mol, radius, useChirality=True, useBondTypes=True, useFeatures=True, useCounts=True).GetNonzeroElements()
|
118 |
+
is_dict = True
|
119 |
+
elif which=='topologicaltorsion':
|
120 |
+
fp = AllChem.GetTopologicalTorsionFingerprint(mol).GetNonzeroElements()
|
121 |
+
is_dict = True
|
122 |
+
elif which=='maccs':
|
123 |
+
fp = AllChem.GetMACCSKeysFingerprint(mol)
|
124 |
+
elif which=='erg':
|
125 |
+
v = AllChem.GetErGFingerprint(mol)
|
126 |
+
fp = {idx:v[idx] for idx in np.nonzero(v)[0]}
|
127 |
+
is_dict = True
|
128 |
+
elif which=='atompair':
|
129 |
+
fp = AllChem.GetAtomPairFingerprint(mol).GetNonzeroElements()
|
130 |
+
is_dict = True
|
131 |
+
elif which=='pattern':
|
132 |
+
fp = Chem.PatternFingerprint(mol, fpSize=fp_size)
|
133 |
+
elif which=='ecfp4':
|
134 |
+
# roughly equivalent to ECFP4
|
135 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=fp_size, useFeatures=False, useChirality=True)
|
136 |
+
elif which=='layered':
|
137 |
+
fp = AllChem.LayeredFingerprint(mol, fpSize=fp_size, maxPath=7)
|
138 |
+
elif which=='mhfp':
|
139 |
+
#TODO check if one can avoid instantiating the MHFP encoder
|
140 |
+
fp = MHFPEncoder().EncodeMol(mol, radius=radius, rings=True, isomeric=False, kekulize=False, min_radius=1)
|
141 |
+
fp = {f:1 for f in fp}
|
142 |
+
is_dict = True
|
143 |
+
elif not (type(which)==str):
|
144 |
+
fp = which(mol)
|
145 |
+
|
146 |
+
if is_dict:
|
147 |
+
nd = np.zeros(fp_size)
|
148 |
+
for k in fp:
|
149 |
+
nk = k%fp_size #remainder
|
150 |
+
#print(nk, k, fp_size)
|
151 |
+
#3160 36322170 3730
|
152 |
+
#print(nd[nk], fp[k])
|
153 |
+
if nd[nk]!=0:
|
154 |
+
#print('c',end='')
|
155 |
+
nd[nk] = nd[nk]+fp[k] #pooling colisions
|
156 |
+
nd[nk] = fp[k]
|
157 |
+
|
158 |
+
return nd #np.log(1+nd) # discussion with segler
|
159 |
+
|
160 |
+
return ebv2np(fp)
|
161 |
+
|
162 |
+
""" + for folding * for concat """
|
163 |
+
cc_symb = '*'
|
164 |
+
if ('+' in which) or (cc_symb in which):
|
165 |
+
concat = False
|
166 |
+
split_sym = '+'
|
167 |
+
if cc_symb in which:
|
168 |
+
concat=True
|
169 |
+
split_sym = '*'
|
170 |
+
|
171 |
+
np_fp = np.zeros(fp_size)
|
172 |
+
|
173 |
+
remaining_fps = (which.count(split_sym)+1)
|
174 |
+
fp_length_remain = fp_size
|
175 |
+
|
176 |
+
for fp_type in which.split(split_sym):
|
177 |
+
if concat:
|
178 |
+
fpp = mol2np(mol, fp_type, fp_length_remain//remaining_fps)
|
179 |
+
np_fp[(fp_size-fp_length_remain):(fp_size-fp_length_remain+len(fpp))] += fpp
|
180 |
+
fp_length_remain -= len(fpp)
|
181 |
+
remaining_fps -=1
|
182 |
+
else:
|
183 |
+
try:
|
184 |
+
fpp = mol2np(mol, fp_type, fp_size)
|
185 |
+
np_fp[:len(fpp)] += fpp
|
186 |
+
except:
|
187 |
+
pass
|
188 |
+
#print(fp_type,end='')
|
189 |
+
|
190 |
+
return np.log(1 + np_fp)
|
191 |
+
else:
|
192 |
+
return mol2np(mol, which, fp_size)
|
193 |
+
|
194 |
+
|
195 |
+
def _getFingerprint(inp):
|
196 |
+
return getFingerprint(inp[0], inp[1], inp[2], inp[3], inp[4])
|
197 |
+
|
198 |
+
|
199 |
+
def disable_rdkit_logging():
|
200 |
+
"""
|
201 |
+
Disables RDKit whiny logging.
|
202 |
+
"""
|
203 |
+
import rdkit.rdBase as rkrb
|
204 |
+
import rdkit.RDLogger as rkl
|
205 |
+
logger.setLevel(rkl.ERROR)
|
206 |
+
rkrb.DisableLog('rdApp.error')
|
207 |
+
|
208 |
+
|
209 |
+
def convert_smiles_to_fp(list_of_smiles, fp_size=2048, is_smarts=False, which='morgan', radius=2, njobs=1, verbose=False):
|
210 |
+
"""
|
211 |
+
list of smiles can be list of lists, than the resulting array will pe badded to the max list len
|
212 |
+
which: morgan, rdk, ecfp4, or object
|
213 |
+
NOTE: morgan or ecfp4 throws error for is_smarts
|
214 |
+
"""
|
215 |
+
|
216 |
+
inp = [(smi, fp_size, radius, is_smarts, which) for smi in list_of_smiles]
|
217 |
+
#print(inp)
|
218 |
+
if verbose: print(f'starting pool with {njobs} workers')
|
219 |
+
if njobs>1:
|
220 |
+
#with Pool(njobs) as pool:
|
221 |
+
# fps = pool.map(_getFingerprint, inp)
|
222 |
+
fps = process_map(_getFingerprint, inp, max_workers=njobs, chunksize=1, mininterval=0)
|
223 |
+
else:
|
224 |
+
fps = [getFingerprint(smi, fp_size=fp_size, radius=radius, is_smarts=is_smarts, which=which) for smi in list_of_smiles]
|
225 |
+
return np.array(fps)
|
226 |
+
|
227 |
+
|
228 |
+
def convert_smartes_to_fp(list_of_smarts, fp_size=2048):
|
229 |
+
if isinstance(list_of_smarts, np.ndarray):
|
230 |
+
list_of_smarts = list_of_smarts.tolist()
|
231 |
+
if isinstance(list_of_smarts, list):
|
232 |
+
if isinstance(list_of_smarts[0], list):
|
233 |
+
pad = len(max(list_of_smarts, key=len))
|
234 |
+
fps = [[getTemplateFingerprint(smarts, fp_size=fp_size) for smarts in sample]
|
235 |
+
+ [np.zeros(fp_size, dtype=np.bool)] * (pad - len(sample)) # zero padding
|
236 |
+
for sample in list_of_smarts]
|
237 |
+
else:
|
238 |
+
fps = [[getTemplateFingerprint(smarts, fp_size=fp_size) for smarts in list_of_smarts]]
|
239 |
+
return np.asarray(fps)
|
240 |
+
|
241 |
+
|
242 |
+
def get_reactants_from_smarts(smarts):
|
243 |
+
"""
|
244 |
+
from a (forward-)reaction given as a smart, only returns the reactants (not e.g. solvents or reagents)
|
245 |
+
returns list of smiles or empty list
|
246 |
+
"""
|
247 |
+
from rdkit.Chem import RDConfig
|
248 |
+
import sys
|
249 |
+
sys.path.append(RDConfig.RDContribDir)
|
250 |
+
from RxnRoleAssignment import identifyReactants
|
251 |
+
try:
|
252 |
+
rdk_reaction = AllChem.ReactionFromSmarts(smarts)
|
253 |
+
rx_idx = identifyReactants.identifyReactants(rdk_reaction)[0][0]
|
254 |
+
except ValueError:
|
255 |
+
return []
|
256 |
+
# TODO what if a product is recognized as a reactanat.. is that possible??
|
257 |
+
return [Chem.MolToSmiles(rdk_reaction.GetReactants()[i]) for i in rx_idx]
|
258 |
+
|
259 |
+
|
260 |
+
def smarts2rdkfp(smart, fp_size=2048):
|
261 |
+
mol = Chem.MolFromSmarts(str(smart))
|
262 |
+
if mol is None: return np.zeros(fp_size).astype(np.bool)
|
263 |
+
return AllChem.RDKFingerprint(mol)
|
264 |
+
# fp = np.asarray(fp).astype(np.bool) # takes ages =/
|
265 |
+
|
266 |
+
|
267 |
+
def smiles2rdkfp(smiles, fp_size=2048):
|
268 |
+
mol = Chem.MolFromSmiles(str(smiles))
|
269 |
+
if mol is None: return np.zeros(fp_size).astype(np.bool)
|
270 |
+
return AllChem.RDKFingerprint(mol)
|
271 |
+
|
272 |
+
|
273 |
+
def mol2morganfp(mol, radius=2, fp_size=2048):
|
274 |
+
try:
|
275 |
+
Chem.SanitizeMol(mol) # due to error --> see https://sourceforge.net/p/rdkit/mailman/message/34828604/
|
276 |
+
except:
|
277 |
+
pass
|
278 |
+
# print(mol)
|
279 |
+
# return np.zeros(fp_size).astype(np.bool)
|
280 |
+
# TODO
|
281 |
+
return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=fp_size)
|
282 |
+
|
283 |
+
|
284 |
+
def smarts2morganfp(smart, fp_size=2048, radius=2):
|
285 |
+
mol = Chem.MolFromSmarts(str(smart))
|
286 |
+
if mol is None: return np.zeros(fp_size).astype(np.bool)
|
287 |
+
return mol2morganfp(mol)
|
288 |
+
|
289 |
+
|
290 |
+
def smiles2morganfp(smiles, fp_size=2048, radius=2):
|
291 |
+
mol = Chem.MolFromSmiles(str(smiles))
|
292 |
+
if mol is None: return np.zeros(fp_size).astype(np.bool)
|
293 |
+
return mol2morganfp(mol)
|
294 |
+
|
295 |
+
|
296 |
+
def smarts2fp(smart, which='morgan', fp_size=2048, radius=2):
|
297 |
+
if which == 'rdk':
|
298 |
+
return smarts2rdkfp(smart, fp_size=fp_size)
|
299 |
+
else:
|
300 |
+
return smarts2morganfp(smart, fp_size=fp_size, radius=radius)
|
301 |
+
|
302 |
+
|
303 |
+
def smiles2fp(smiles, which='morgan', fp_size=2048, radius=2):
|
304 |
+
if which == 'rdk':
|
305 |
+
return smiles2rdkfp(smiles, fp_size=fp_size)
|
306 |
+
else:
|
307 |
+
return smiles2morganfp(smiles, fp_size=fp_size, radius=radius)
|
308 |
+
|
309 |
+
|
310 |
+
class FP_featurizer():
|
311 |
+
"FP_featurizer: Fingerprint featurizer"
|
312 |
+
def __init__(self,
|
313 |
+
fp_types = ['MACCS','Morgan2CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK','ECFP6'],
|
314 |
+
max_features = 4096, counts=True, log_scale=True, folding=None, collision_pooling='max'):
|
315 |
+
|
316 |
+
self.v = DictVectorizer(sparse=True, dtype=np.uint16)
|
317 |
+
self.max_features = max_features
|
318 |
+
self.idx_col = None
|
319 |
+
self.counts = counts
|
320 |
+
self.fp_types = [fp_types] if isinstance(fp_types, str) else fp_types
|
321 |
+
|
322 |
+
self.log_scale = log_scale # from discussion with segler
|
323 |
+
|
324 |
+
self.folding = None
|
325 |
+
self.colision_pooling = collision_pooling
|
326 |
+
|
327 |
+
def compute_fp_list(self, smiles_list, is_smarts=False):
|
328 |
+
fp_list = []
|
329 |
+
for smiles in smiles_list:
|
330 |
+
try:
|
331 |
+
if isinstance(smiles, list):
|
332 |
+
smiles = smiles[0]
|
333 |
+
if is_smarts:
|
334 |
+
mol = Chem.MolFromSmarts(smiles)
|
335 |
+
else:
|
336 |
+
mol = Chem.MolFromSmiles(smiles) #TODO small hack only applicable here!!!
|
337 |
+
fp_dict = {}
|
338 |
+
for fp_type in self.fp_types:
|
339 |
+
fp_dict.update( fingerprintTypes[fp_type](mol) ) #returns a dict
|
340 |
+
fp_list.append(fp_dict)
|
341 |
+
except:
|
342 |
+
fp_list.append({})
|
343 |
+
return fp_list
|
344 |
+
|
345 |
+
def fit(self, x_train, is_smarts=False):
|
346 |
+
fp_list = self.compute_fp_list(x_train, is_smarts=is_smarts)
|
347 |
+
Xraw = self.v.fit_transform(fp_list)
|
348 |
+
# compute variance of a csr_matrix E[x**2] - E[x]**2
|
349 |
+
axis = 0
|
350 |
+
Xraw_sqrd = Xraw.copy()
|
351 |
+
Xraw_sqrd.data **= 2
|
352 |
+
var_col = Xraw_sqrd.mean(axis) - np.square(Xraw.mean(axis))
|
353 |
+
#idx_col = (-np.array((Xraw>0).var(axis=0)).argpartition(self.max_features))
|
354 |
+
#idx_col = np.array((Xraw>0).sum(axis=0)>=self.min_fragm_occur).flatten()
|
355 |
+
self.idx_col = (-np.array(var_col)).flatten().argpartition(min(self.max_features, Xraw.shape[1]-1))[:min(self.max_features, Xraw.shape[1])]
|
356 |
+
print(f'from {var_col.shape[1]} to {len(self.idx_col)}')
|
357 |
+
return self.scale(Xraw[:,self.idx_col].toarray())
|
358 |
+
|
359 |
+
def transform(self, x_test, is_smarts=False):
|
360 |
+
fp_list = self.compute_fp_list(x_test, is_smarts=is_smarts)
|
361 |
+
X_raw = self.v.transform(fp_list)
|
362 |
+
return self.scale(X_raw[:,self.idx_col].toarray())
|
363 |
+
|
364 |
+
def scale(self, X):
|
365 |
+
if self.log_scale:
|
366 |
+
return np.log(1 + X)
|
367 |
+
return X
|
368 |
+
|
369 |
+
def save(self, path='data/fpfeat.pkl'):
|
370 |
+
import pickle
|
371 |
+
with open(path, 'wb') as output:
|
372 |
+
pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
|
373 |
+
|
374 |
+
def load(self, path='data/fpfeat.pkl'):
|
375 |
+
import pickle
|
376 |
+
with open(path, 'rb') as input:
|
377 |
+
self = pickle.load(input)
|
378 |
+
|
379 |
+
|
380 |
+
def getTemplateFingerprintOnBits(smarts, fp_size=2048):
|
381 |
+
rxn = AllChem.ReactionFromSmarts(str(smarts))
|
382 |
+
#construct a structural fingerprint for a ChemicalReaction by concatenating the reactant fingerprint and the product fingerprint
|
383 |
+
return (AllChem.CreateStructuralFingerprintForReaction(rxn)).GetOnBits()
|
384 |
+
|
385 |
+
|
386 |
+
def calc_template_fingerprint_group_mapping(template_list, fp_size, save_path=''):
|
387 |
+
"""
|
388 |
+
calculate the mapping from old idx to new idx for the templates
|
389 |
+
returns a set with a numpy array with the mapping and the indices to take
|
390 |
+
"""
|
391 |
+
|
392 |
+
templ_df = pd.DataFrame()
|
393 |
+
templ_df['smarts'] = template_list
|
394 |
+
templ_df['templ_emb'] = templ_df['smarts'].swifter.apply(lambda smarts: str(list(getTemplateFingerprintOnBits(smarts, fp_size))))
|
395 |
+
templ_df['idx_orig'] = [ii for ii in range(len(templ_df))]
|
396 |
+
|
397 |
+
grouped_templ = templ_df.groupby('templ_emb').apply(lambda x: x.index.tolist())
|
398 |
+
|
399 |
+
grouped_templ = templ_df.groupby('templ_emb')
|
400 |
+
grouped_templ = grouped_templ.min().sort_values('idx_orig')
|
401 |
+
grouped_templ['new_idx'] = range(len(grouped_templ))
|
402 |
+
|
403 |
+
new_templ_df = templ_df.join(grouped_templ, on='templ_emb',how='right', lsuffix='_l', rsuffix='_r').sort_values('idx_orig_l')
|
404 |
+
|
405 |
+
map_orig2new = new_templ_df['new_idx'].values
|
406 |
+
take_those_indices_from_orig = grouped_templ.idx_orig.values
|
407 |
+
if save_path!='':
|
408 |
+
suffix_maporig2new = '_maporig2new.npy'
|
409 |
+
suffix_takethose = '_tfp_take_idxs.npy'
|
410 |
+
np.save(f'{save_path}{suffix_maporig2new}', map_orig2new,allow_pickle=False)
|
411 |
+
np.save(f'{save_path}{suffix_takethose}', take_those_indices_from_orig,allow_pickle=False)
|
412 |
+
return (map_orig2new, take_those_indices_from_orig)
|
413 |
+
|
414 |
+
|
415 |
+
class ECFC_featurizer():
|
416 |
+
def __init__(self, radius=6, min_fragm_occur=50, useChirality=True, useFeatures=False):
|
417 |
+
self.v = DictVectorizer(sparse=True, dtype=np.uint16)
|
418 |
+
self.min_fragm_occur=min_fragm_occur
|
419 |
+
self.idx_col = None
|
420 |
+
self.radius=radius
|
421 |
+
self.useChirality = useChirality
|
422 |
+
self.useFeatures = useFeatures
|
423 |
+
|
424 |
+
def compute_fp_list(self, smiles_list):
|
425 |
+
fp_list = []
|
426 |
+
for smiles in smiles_list:
|
427 |
+
try:
|
428 |
+
if isinstance(smiles, list):
|
429 |
+
smiles = smiles[0]
|
430 |
+
mol = Chem.MolFromSmiles(smiles) #TODO small hack only applicable here!!!
|
431 |
+
fp_list.append( AllChem.GetMorganFingerprint(mol, self.radius, useChirality=self.useChirality,
|
432 |
+
useFeatures=self.useFeatures).GetNonzeroElements() ) #returns a dict
|
433 |
+
except:
|
434 |
+
fp_list.append({})
|
435 |
+
return fp_list
|
436 |
+
|
437 |
+
def fit(self, x_train):
|
438 |
+
fp_list = self.compute_fp_list(x_train)
|
439 |
+
Xraw = self.v.fit_transform(fp_list)
|
440 |
+
idx_col = np.array((Xraw>0).sum(axis=0)>=self.min_fragm_occur).flatten()
|
441 |
+
self.idx_col = idx_col
|
442 |
+
return Xraw[:,self.idx_col].toarray()
|
443 |
+
|
444 |
+
def transform(self, x_test):
|
445 |
+
fp_list = self.compute_fp_list(x_test)
|
446 |
+
X_raw = self.v.transform(fp_list)
|
447 |
+
return X_raw[:,self.idx_col].toarray()
|
448 |
+
|
449 |
+
|
450 |
+
def ecfp2dict(mol, radius=3):
|
451 |
+
#SECFP (SMILES Extended Connectifity Fingerprint)
|
452 |
+
# from mhfp.encoder import MHFPEncoder
|
453 |
+
from mhfp.encoder import MHFPEncoder
|
454 |
+
v = MHFPEncoder.secfp_from_mol(mol, length=4068, radius=radius, rings=True, kekulize=True, min_radius=1)
|
455 |
+
return {f'ECFP{radius*2}_'+str(idx):1 for idx in np.nonzero(v)[0]}
|
456 |
+
|
457 |
+
|
458 |
+
def erg2dict(mol):
|
459 |
+
v = AllChem.GetErGFingerprint(mol)
|
460 |
+
return {'erg'+str(idx):v[idx] for idx in np.nonzero(v)[0]}
|
461 |
+
|
462 |
+
|
463 |
+
def morgan2dict(mol, radius=2, useChirality=True, useBondTypes=True, useFeatures=True, useConts=True):
|
464 |
+
mdic = AllChem.GetMorganFingerprint(mol, radius=radius, useChirality=useChirality, useBondTypes=True,
|
465 |
+
useFeatures=True, useCounts=True).GetNonzeroElements()
|
466 |
+
return {f'm{radius}{useChirality}{useBondTypes}{useFeatures}'+str(kk):mdic[kk]for kk in mdic}
|
467 |
+
|
468 |
+
|
469 |
+
def atompair2dict(mol):
|
470 |
+
mdic = AllChem.GetAtomPairFingerprint(mol).GetNonzeroElements()
|
471 |
+
return {f'ap'+str(kk):mdic[kk]for kk in mdic}
|
472 |
+
|
473 |
+
|
474 |
+
def tt2dict(mol):
|
475 |
+
mdic = AllChem.GetTopologicalTorsionFingerprint(mol).GetNonzeroElements()
|
476 |
+
return {f'tt'+str(kk):mdic[kk]for kk in mdic}
|
477 |
+
|
478 |
+
|
479 |
+
def rdk2dict(mol):
|
480 |
+
mdic = AllChem.UnfoldedRDKFingerprintCountBased(mol).GetNonzeroElements()
|
481 |
+
return {f'rdk'+str(kk):mdic[kk]for kk in mdic}
|
482 |
+
|
483 |
+
|
484 |
+
def pattern2dict(mol):
|
485 |
+
mdic = AllChem.PatternFingerprint(mol, fpSize=16384).GetOnBits()
|
486 |
+
return {'pt'+str(kk):1 for kk in mdic}
|
487 |
+
|
488 |
+
|
489 |
+
fingerprintTypes = {
|
490 |
+
'MACCS' : lambda k: {'MCCS'+str(ob):1 for ob in AllChem.GetMACCSKeysFingerprint(k).GetOnBits()},
|
491 |
+
'Morgan2CBF' : lambda mol: morgan2dict(mol, 2, True, True, True, True),
|
492 |
+
'Morgan4CBF' : lambda mol: morgan2dict(mol, 4, True, True, True, True),
|
493 |
+
'Morgan6CBF' : lambda mol: morgan2dict(mol, 6, True, True, True, True),
|
494 |
+
'ErG' : erg2dict,
|
495 |
+
'AtomPair' : atompair2dict,
|
496 |
+
'TopologicalTorsion' : tt2dict,
|
497 |
+
#'RDK' : lambda k: {'MCCS'+str(ob):1 for ob in AllChem.RDKFingerprint(k).GetOnBits()},
|
498 |
+
'RDK' : rdk2dict,
|
499 |
+
'ECFP6' : lambda mol: ecfp2dict(mol, radius=3),
|
500 |
+
'Pattern': pattern2dict,
|
501 |
+
}
|
502 |
+
|
503 |
+
|
504 |
+
def smarts2appl(product_smarts, template_product_smarts, fpsize=2048, v=False, use_tqdm=False, njobs=1, nsplits=1):
|
505 |
+
"""This takes in a list of product smiles (misnamed in code) and a list of product sides
|
506 |
+
of templates and calculates which templates are applicable to which product.
|
507 |
+
This is basically a substructure search. Maybe there are faster versions but I wrote this one.
|
508 |
+
|
509 |
+
Args:
|
510 |
+
product_smarts: List of smiles of molecules to check.
|
511 |
+
template_product_smarts: List of substructures to check
|
512 |
+
fpsize: fingerprint size to use in screening
|
513 |
+
v: if v then information will be printed
|
514 |
+
use_tdqm: if True then a progressbar will be displayed but slows down the computation.
|
515 |
+
njobs: how many parallel jobs to run in parallel.
|
516 |
+
nsplits: how many splits should be made along the product_smarts list. Useful to avoid memory
|
517 |
+
explosion.
|
518 |
+
Returns: list of tuples (i,j) that indicates the product i has substructure j.
|
519 |
+
"""
|
520 |
+
if v: print("Calculating template molecules")
|
521 |
+
template_mols = [Chem.MolFromSmarts(s) for s in template_product_smarts]
|
522 |
+
if v: print("Calculating template fingerprints")
|
523 |
+
template_ebvs = [Chem.PatternFingerprint(m, fpSize=fpsize) for m in template_mols]
|
524 |
+
if v: print(f'Building template ints: [{len(template_mols)}, {fpsize}]')
|
525 |
+
template_ints = [int(e.ToBitString(), base=2) for e in template_ebvs]
|
526 |
+
del template_ebvs
|
527 |
+
|
528 |
+
if njobs == 1 and nsplits == 1:
|
529 |
+
return _smarts2appl(product_smarts, template_product_smarts, template_ints, fpsize, v, use_tqdm)
|
530 |
+
elif nsplits == 1:
|
531 |
+
nsplits = njobs
|
532 |
+
|
533 |
+
|
534 |
+
# split products into batches
|
535 |
+
product_splits = np.array_split(np.array(product_smarts), nsplits)
|
536 |
+
ioffsets = [0] + list(np.cumsum([p.shape[0] for p in product_splits[:-1]]))
|
537 |
+
inps = [(ps, template_product_smarts, template_ints, fpsize, v, use_tqdm, ioff, 0) for ps, ioff in zip(product_splits, ioffsets)]
|
538 |
+
|
539 |
+
if v: print("Creating workers")
|
540 |
+
#results = process_map(__smarts2appl, inps, max_workers=njobs, chunksize=1)
|
541 |
+
with Pool(njobs) as pool:
|
542 |
+
results = pool.starmap(_smarts2appl, inps)
|
543 |
+
imatch = np.concatenate([r[0] for r in results])
|
544 |
+
jmatch = np.concatenate([r[1] for r in results])
|
545 |
+
return imatch, jmatch
|
546 |
+
|
547 |
+
|
548 |
+
def __smarts2appl(inp):
|
549 |
+
return _smarts2appl(*inp)
|
550 |
+
|
551 |
+
|
552 |
+
def _smarts2appl(product_smarts, template_product_smarts, template_ints, fpsize=2048, v=False, use_tqdm=True, ioffset=0, joffset=0):
|
553 |
+
"""See smarts2appl for a description"""
|
554 |
+
|
555 |
+
if v: print("Calculating product molecules")
|
556 |
+
product_mols = [Chem.MolFromSmiles(s) for s in product_smarts]
|
557 |
+
if v: print("Calculating product fingerprints")
|
558 |
+
product_ebvs = [Chem.PatternFingerprint(m, fpSize=fpsize) for m in product_mols]
|
559 |
+
if v: print(f'Building product ints: [{len(product_mols)}, {fpsize}]')
|
560 |
+
# This loads each fingerprint into a python integer on which we can use bitwise operations.
|
561 |
+
product_ints = [int(e.ToBitString(), base=2) for e in product_ebvs]
|
562 |
+
del product_ebvs
|
563 |
+
|
564 |
+
# product_mols = {i: m for i,m in enumerate(product_mols)}
|
565 |
+
|
566 |
+
|
567 |
+
if v: print('Checking symbolically')
|
568 |
+
# buffer for template molecules. This are handed over as smarts as they are slow to pickle
|
569 |
+
template_mols = {}
|
570 |
+
|
571 |
+
# create iterator and add progressbar if use_tqdm is True
|
572 |
+
iterator = product(enumerate(product_ints), enumerate(template_ints))
|
573 |
+
if use_tqdm:
|
574 |
+
nelem = len(product_ints) * len(template_ints)
|
575 |
+
iterator = tqdm(iterator, total=nelem, miniters=1_000_000)
|
576 |
+
|
577 |
+
imatch = []
|
578 |
+
jmatch = []
|
579 |
+
for (i, p_int), (j, t_int) in iterator:
|
580 |
+
if (p_int & t_int) == t_int: # fingerprint based screen
|
581 |
+
p = product_mols[i]
|
582 |
+
t = template_mols.get(j, False)
|
583 |
+
if not t:
|
584 |
+
t = Chem.MolFromSmarts(template_product_smarts[j])
|
585 |
+
template_mols[j] = t
|
586 |
+
if p.HasSubstructMatch(t):
|
587 |
+
imatch.append(i)
|
588 |
+
jmatch.append(j)
|
589 |
+
if v: print("Finished loop")
|
590 |
+
return np.array(imatch)+ioffset, np.array(jmatch)+joffset
|
591 |
+
|
592 |
+
|
593 |
+
def extract_from_reaction(reaction, radius=1, verbose=False):
|
594 |
+
"""adapted from rdchiral package"""
|
595 |
+
from rdchiral.template_extractor import mols_from_smiles_list, replace_deuterated, get_fragments_for_changed_atoms, expand_changed_atom_tags, canonicalize_transform, get_changed_atoms
|
596 |
+
reactants = mols_from_smiles_list(replace_deuterated(reaction['reactants']).split('.'))
|
597 |
+
products = mols_from_smiles_list(replace_deuterated(reaction['products']).split('.'))
|
598 |
+
|
599 |
+
# if rdkit cant understand molecule, return
|
600 |
+
if None in reactants: return {'reaction_id': reaction['_id']}
|
601 |
+
if None in products: return {'reaction_id': reaction['_id']}
|
602 |
+
|
603 |
+
# try to sanitize molecules
|
604 |
+
try:
|
605 |
+
#for i in range(len(reactants)):
|
606 |
+
# reactants[i] = AllChem.RemoveHs(reactants[i]) # *might* not be safe
|
607 |
+
#for i in range(len(products)):
|
608 |
+
# products[i] = AllChem.RemoveHs(products[i]) # *might* not be safe
|
609 |
+
|
610 |
+
#[Chem.SanitizeMol(mol) for mol in reactants + products] # redundant w/ RemoveHs
|
611 |
+
for mol in reactants + products:
|
612 |
+
Chem.SanitizeMol(mol, catchErrors=True)
|
613 |
+
FastFindRings(mol) #Providing ring info
|
614 |
+
mol.UpdatePropertyCache(strict=False) #Correcting valence info # important operation
|
615 |
+
|
616 |
+
#changed
|
617 |
+
#[Chem.SanitizeMol(mol, catchErrors=True) for mol in reactants + products] # redundant w/ RemoveHs
|
618 |
+
|
619 |
+
#[mol.UpdatePropertyCache() for mol in reactants + products]
|
620 |
+
except Exception as e:
|
621 |
+
# can't sanitize -> skip
|
622 |
+
print(e)
|
623 |
+
print('Could not load SMILES or sanitize')
|
624 |
+
print('ID: {}'.format(reaction['_id']))
|
625 |
+
return {'reaction_id': reaction['_id']}
|
626 |
+
|
627 |
+
are_unmapped_product_atoms = False
|
628 |
+
extra_reactant_fragment = ''
|
629 |
+
for product in products:
|
630 |
+
prod_atoms = product.GetAtoms()
|
631 |
+
if sum([a.HasProp('molAtomMapNumber') for a in prod_atoms]) < len(prod_atoms):
|
632 |
+
if verbose: print('Not all product atoms have atom mapping')
|
633 |
+
if verbose: print('ID: {}'.format(reaction['_id']))
|
634 |
+
are_unmapped_product_atoms = True
|
635 |
+
|
636 |
+
if are_unmapped_product_atoms: # add fragment to template
|
637 |
+
for product in products:
|
638 |
+
prod_atoms = product.GetAtoms()
|
639 |
+
# Get unmapped atoms
|
640 |
+
unmapped_ids = [
|
641 |
+
a.GetIdx() for a in prod_atoms if not a.HasProp('molAtomMapNumber')
|
642 |
+
]
|
643 |
+
if len(unmapped_ids) > MAXIMUM_NUMBER_UNMAPPED_PRODUCT_ATOMS:
|
644 |
+
# Skip this example - too many unmapped product atoms!
|
645 |
+
return
|
646 |
+
# Define new atom symbols for fragment with atom maps, generalizing fully
|
647 |
+
atom_symbols = ['[{}]'.format(a.GetSymbol()) for a in prod_atoms]
|
648 |
+
# And bond symbols...
|
649 |
+
bond_symbols = ['~' for b in product.GetBonds()]
|
650 |
+
if unmapped_ids:
|
651 |
+
extra_reactant_fragment += AllChem.MolFragmentToSmiles(
|
652 |
+
product, unmapped_ids,
|
653 |
+
allHsExplicit = False, isomericSmiles = USE_STEREOCHEMISTRY,
|
654 |
+
atomSymbols = atom_symbols, bondSymbols = bond_symbols
|
655 |
+
) + '.'
|
656 |
+
if extra_reactant_fragment:
|
657 |
+
extra_reactant_fragment = extra_reactant_fragment[:-1]
|
658 |
+
if verbose: print(' extra reactant fragment: {}'.format(extra_reactant_fragment))
|
659 |
+
|
660 |
+
# Consolidate repeated fragments (stoichometry)
|
661 |
+
extra_reactant_fragment = '.'.join(sorted(list(set(extra_reactant_fragment.split('.')))))
|
662 |
+
|
663 |
+
|
664 |
+
if None in reactants + products:
|
665 |
+
print('Could not parse all molecules in reaction, skipping')
|
666 |
+
print('ID: {}'.format(reaction['_id']))
|
667 |
+
return {'reaction_id': reaction['_id']}
|
668 |
+
|
669 |
+
# Calculate changed atoms
|
670 |
+
changed_atoms, changed_atom_tags, err = get_changed_atoms(reactants, products)
|
671 |
+
if err:
|
672 |
+
if verbose:
|
673 |
+
print('Could not get changed atoms')
|
674 |
+
print('ID: {}'.format(reaction['_id']))
|
675 |
+
return
|
676 |
+
if not changed_atom_tags:
|
677 |
+
if verbose:
|
678 |
+
print('No atoms changed?')
|
679 |
+
print('ID: {}'.format(reaction['_id']))
|
680 |
+
# print('Reaction SMILES: {}'.format(example_doc['RXN_SMILES']))
|
681 |
+
return {'reaction_id': reaction['_id']}
|
682 |
+
|
683 |
+
try:
|
684 |
+
# Get fragments for reactants
|
685 |
+
reactant_fragments, intra_only, dimer_only = get_fragments_for_changed_atoms(reactants, changed_atom_tags,
|
686 |
+
radius = radius, expansion = [], category = 'reactants')
|
687 |
+
# Get fragments for products
|
688 |
+
# (WITHOUT matching groups but WITH the addition of reactant fragments)
|
689 |
+
product_fragments, _, _ = get_fragments_for_changed_atoms(products, changed_atom_tags,
|
690 |
+
radius = radius-1, expansion = expand_changed_atom_tags(changed_atom_tags, reactant_fragments),
|
691 |
+
category = 'products')
|
692 |
+
except ValueError as e:
|
693 |
+
if verbose:
|
694 |
+
print(e)
|
695 |
+
print(reaction['_id'])
|
696 |
+
return {'reaction_id': reaction['_id']}
|
697 |
+
|
698 |
+
# Put together and canonicalize (as best as possible)
|
699 |
+
rxn_string = '{}>>{}'.format(reactant_fragments, product_fragments)
|
700 |
+
rxn_canonical = canonicalize_transform(rxn_string)
|
701 |
+
# Change from inter-molecular to intra-molecular
|
702 |
+
rxn_canonical_split = rxn_canonical.split('>>')
|
703 |
+
rxn_canonical = rxn_canonical_split[0][1:-1].replace(').(', '.') + \
|
704 |
+
'>>' + rxn_canonical_split[1][1:-1].replace(').(', '.')
|
705 |
+
|
706 |
+
reactants_string = rxn_canonical.split('>>')[0]
|
707 |
+
products_string = rxn_canonical.split('>>')[1]
|
708 |
+
|
709 |
+
retro_canonical = products_string + '>>' + reactants_string
|
710 |
+
|
711 |
+
# Load into RDKit
|
712 |
+
rxn = AllChem.ReactionFromSmarts(retro_canonical)
|
713 |
+
# edited
|
714 |
+
#if rxn.Validate()[1] != 0:
|
715 |
+
# print('Could not validate reaction successfully')
|
716 |
+
# print('ID: {}'.format(reaction['_id']))
|
717 |
+
# print('retro_canonical: {}'.format(retro_canonical))
|
718 |
+
# if VERBOSE: raw_input('Pausing...')
|
719 |
+
# return {'reaction_id': reaction['_id']}
|
720 |
+
n_warning, n_errors = rxn.Validate()
|
721 |
+
if n_errors:
|
722 |
+
# resolves some errors
|
723 |
+
rxn = AllChem.ReactionFromSmarts(AllChem.ReactionToSmiles(rxn))
|
724 |
+
n_warning, n_errors = rxn.Validate()
|
725 |
+
|
726 |
+
template = {
|
727 |
+
'products': products_string,
|
728 |
+
'reactants': reactants_string,
|
729 |
+
'reaction_smarts': retro_canonical,
|
730 |
+
'intra_only': intra_only,
|
731 |
+
'dimer_only': dimer_only,
|
732 |
+
'reaction_id': reaction['_id'],
|
733 |
+
'necessary_reagent': extra_reactant_fragment,
|
734 |
+
'num_errors': n_errors,
|
735 |
+
'num_warnings': n_warning,
|
736 |
+
}
|
737 |
+
|
738 |
+
return template
|
739 |
+
|
740 |
+
|
741 |
+
def extract_template(rxn_smi, radius=1):
|
742 |
+
if isinstance(rxn_smi, str):
|
743 |
+
reaction = {
|
744 |
+
'reactants': rxn_smi.split('>')[0],
|
745 |
+
'products': rxn_smi.split('>')[-1],
|
746 |
+
'id': rxn_smi,
|
747 |
+
'_id': rxn_smi
|
748 |
+
}
|
749 |
+
else:
|
750 |
+
reaction = rxn_smi
|
751 |
+
try:
|
752 |
+
res = extract_from_reaction(reaction, radius=radius)
|
753 |
+
return res['reaction_smarts'] # returns a retro-template
|
754 |
+
except:
|
755 |
+
msg = f'failed to extract template from "{rxn_smi}"'
|
756 |
+
log.warning(msg)
|
757 |
+
return None
|
758 |
+
|
759 |
+
|
760 |
+
def getTemplateFingerprint(smarts, fp_size=4096):
|
761 |
+
""" CreateStructuralFingerprintForReaction """
|
762 |
+
if isinstance(smarts, (list,)):
|
763 |
+
return np.vstack([getTemplateFingerprint(sm) for sm in smarts])
|
764 |
+
|
765 |
+
rxn = AllChem.ReactionFromSmarts(str(smarts))
|
766 |
+
if rxn is None:
|
767 |
+
msg = f"{smarts} couldn't be converted to a fingerprint using 0's instead"
|
768 |
+
log.warning(msg)
|
769 |
+
#warnings.warn(msg)
|
770 |
+
return np.zeros(fp_size).astype(np.bool)
|
771 |
+
|
772 |
+
return np.array(list(AllChem.CreateStructuralFingerprintForReaction(rxn, )), dtype=np.bool)
|
mhnreact/plotutils.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
Plot utils
|
9 |
+
"""
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
from matplotlib import pyplot as plt
|
15 |
+
|
16 |
+
plt.style.use('default')
|
17 |
+
|
18 |
+
|
19 |
+
def normal_approx_interval(p_hat, n, z=1.96):
|
20 |
+
""" approximating the distribution of error about a binomially-distributed observation, {\hat {p)), with a normal distribution
|
21 |
+
z = 1.96 --> alpha =0.05
|
22 |
+
z = 1 --> std
|
23 |
+
https://www.wikiwand.com/en/Binomial_proportion_confidence_interval"""
|
24 |
+
return z*((p_hat*(1-p_hat))/n)**(1/2)
|
25 |
+
|
26 |
+
|
27 |
+
our_colors = {
|
28 |
+
"lightblue": ( 0/255, 132/255, 187/255),
|
29 |
+
"red": (217/255, 92/255, 76/255),
|
30 |
+
"blue": ( 0/255, 132/255, 187/255),
|
31 |
+
"green": ( 91/255, 167/255, 85/255),
|
32 |
+
"yellow": (241/255, 188/255, 63/255),
|
33 |
+
"cyan": ( 79/255, 176/255, 191/255),
|
34 |
+
"grey": (125/255, 130/255, 140/255),
|
35 |
+
"lightgreen":(191/255, 206/255, 82/255),
|
36 |
+
"violett": (174/255, 97/255, 157/255),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def plot_std(p_hats, n_samples,z=1.96, color=our_colors['red'], alpha=0.2, xs=None):
|
41 |
+
p_hats = np.array(p_hats)
|
42 |
+
stds = np.array([normal_approx_interval(p_hats[ii], n_samples[ii], z=z) for ii in range(len(p_hats))])
|
43 |
+
xs = range(len(p_hats)) if xs is None else xs
|
44 |
+
plt.fill_between(xs, p_hats-(stds), p_hats+stds, color=color, alpha=alpha)
|
45 |
+
#plt.errorbar(range(13), asdf, [normal_approx_interval(asdf[ii], n_samples[ii], z=z) for ii in range(len(asdf))],
|
46 |
+
# c=our_colors['red'], linestyle='None', marker='.', ecolor=our_colors['red'])
|
47 |
+
|
48 |
+
|
49 |
+
def plot_loss(hist):
|
50 |
+
plt.plot(hist['step'], hist['loss'] )
|
51 |
+
plt.plot(hist['steps_valid'], np.array(hist['loss_valid']))
|
52 |
+
plt.legend(['train','validation'])
|
53 |
+
plt.xlabel('update-step')
|
54 |
+
plt.ylabel('loss (categorical-crossentropy-loss)')
|
55 |
+
|
56 |
+
|
57 |
+
def plot_topk(hist, sets=['train', 'valid', 'test'], with_last = 2):
|
58 |
+
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
|
59 |
+
baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400}
|
60 |
+
plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--')
|
61 |
+
for i in range(1,with_last):
|
62 |
+
for s in sets:
|
63 |
+
plt.plot(ks, [hist[f't{k}_acc_{s}'][-i] for k in ks],'.--', alpha=1/i)
|
64 |
+
plt.xlabel('top-k')
|
65 |
+
plt.ylabel('Accuracy')
|
66 |
+
plt.legend(sets)
|
67 |
+
plt.title('Hopfield-NN')
|
68 |
+
plt.ylim([-0.02,1])
|
69 |
+
|
70 |
+
|
71 |
+
def plot_nte(hist, dataset='Sm', last_cpt=1, include_bar=True, model_legend='MHN (ours)',
|
72 |
+
draw_std=True, z=1.96, n_samples=None, group_by_template_fp=False, schwaller_hist=None, fortunato_hist=None): #1.96 for 95%CI
|
73 |
+
markers = ['.']*4#['1','2','3','4']#['8','P','p','*']
|
74 |
+
lw = 2
|
75 |
+
ms = 8
|
76 |
+
k = 100
|
77 |
+
ntes = range(13)
|
78 |
+
if dataset=='Sm':
|
79 |
+
basel_values = [0. , 0.38424785, 0.66807858, 0.7916149 , 0.9051132 ,
|
80 |
+
0.92531258, 0.87295875, 0.94865587, 0.91830721, 0.95993717,
|
81 |
+
0.97215858, 0.9896713 , 0.99917817] #old basel_values = [0.0, 0.3882, 0.674, 0.7925, 0.9023, 0.9272, 0.874, 0.947, 0.9185, 0.959, 0.9717, 0.9927, 1.0]
|
82 |
+
pretr_values = [0.08439423, 0.70743412, 0.85555528, 0.95200267, 0.96513376,
|
83 |
+
0.96976397, 0.98373613, 0.99960286, 0.98683919, 0.96684724,
|
84 |
+
0.95907246, 0.9839079 , 0.98683919]# old [0.094, 0.711, 0.8584, 0.952, 0.9683, 0.9717, 0.988, 1.0, 1.0, 0.984, 0.9717, 1.0, 1.0]
|
85 |
+
staticQK = [0.2096, 0.1992, 0.2291, 0.1787, 0.2301, 0.1753, 0.2142, 0.2693, 0.2651, 0.1786, 0.2834, 0.5366, 0.6636]
|
86 |
+
if group_by_template_fp:
|
87 |
+
staticQK = [0.2651, 0.2617, 0.261 , 0.2181, 0.2622, 0.2393, 0.2157, 0.2184, 0.2 , 0.225 , 0.2039, 0.4568, 0.5293]
|
88 |
+
if dataset=='Lg':
|
89 |
+
pretr_values = [0.03410448, 0.65397054, 0.7254572 , 0.78969294, 0.81329924,
|
90 |
+
0.8651173 , 0.86775655, 0.8593128 , 0.88184124, 0.87764794,
|
91 |
+
0.89734215, 0.93328846, 0.99531597]
|
92 |
+
basel_values = [0. , 0.62478044, 0.68784314, 0.75089511, 0.77044644,
|
93 |
+
0.81229423, 0.82968149, 0.82965544, 0.83778338, 0.83049176,
|
94 |
+
0.8662873 , 0.92308414, 1.00042408]
|
95 |
+
#staticQK = [0.03638, 0.0339 , 0.03732, 0.03506, 0.03717, 0.0331 , 0.03003, 0.03613, 0.0304 , 0.02109, 0.0297 , 0.02632, 0.02217] # on 90k templates
|
96 |
+
staticQK = [0.006416,0.00686, 0.00616, 0.00825, 0.005085,0.006718,0.01041, 0.0015335,0.006668,0.004673,0.001706,0.02551,0.04074]
|
97 |
+
if dataset=='Golden':
|
98 |
+
staticQK = [0]*13
|
99 |
+
pretr_values = [0]*13
|
100 |
+
basel_values = [0]*13
|
101 |
+
|
102 |
+
if schwaller_hist:
|
103 |
+
midx = np.argmin(schwaller_hist['loss_valid'])
|
104 |
+
basel_values = ([schwaller_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
|
105 |
+
if fortunato_hist:
|
106 |
+
midx = np.argmin(fortunato_hist['loss_valid'])
|
107 |
+
pretr_values = ([fortunato_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
|
108 |
+
|
109 |
+
#hand_val = [0.0 , 0.4, 0.68, 0.79, 0.89, 0.91, 0.86, 0.9,0.88, 0.9, 0.93]
|
110 |
+
|
111 |
+
|
112 |
+
if include_bar:
|
113 |
+
if dataset=='Sm':
|
114 |
+
if n_samples is None:
|
115 |
+
n_samples = [610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]
|
116 |
+
if group_by_template_fp:
|
117 |
+
n_samples = [460, 993, 433, 243, 183, 117, 102, 87, 110, 80, 103, 3048, 2203]
|
118 |
+
if dataset=='Lg':
|
119 |
+
if n_samples is None:
|
120 |
+
n_samples = [18861, 32226, 4220, 2546, 1573, 1191, 865, 652, 1350, 642, 586, 11638, 4958] #new
|
121 |
+
if group_by_template_fp:
|
122 |
+
n_samples = [13923, 17709, 7637, 4322, 2936, 2137, 1586, 1260, 1272, 1044, 829, 21695, 10559]
|
123 |
+
#[5169, 15904, 2814, 1853, 1238, 966, 766, 609, 1316, 664, 640, 30699, 21471]
|
124 |
+
#[13424,17246, 7681, 4332, 2844,2129,1698,1269, 1336,1067, 833, 22491, 11202] #grouped fp
|
125 |
+
plt.bar(range(11+2), np.array(n_samples)/sum(n_samples[:-1]), alpha=0.4, color=our_colors['grey'])
|
126 |
+
|
127 |
+
xti = [*[str(i) for i in range(11)], '>10', '>49']
|
128 |
+
asdf = []
|
129 |
+
for nte in xti:
|
130 |
+
try:
|
131 |
+
asdf.append( hist[f't{k}_acc_nte_{nte}'][-last_cpt])
|
132 |
+
except:
|
133 |
+
asdf.append(None)
|
134 |
+
|
135 |
+
plt.plot(range(13), asdf,f'{markers[3]}--', markersize=ms,c=our_colors['red'], linewidth=lw,alpha=1)
|
136 |
+
plt.plot(ntes, pretr_values,f'{markers[1]}--', c=our_colors['green'],
|
137 |
+
linewidth=lw, alpha=1,markersize=ms) #old [0.08, 0.7, 0.85, 0.9, 0.91, 0.95, 0.98, 0.97,0.98, 1, 1]
|
138 |
+
plt.plot(ntes, basel_values,f'{markers[0]}--',linewidth=lw,
|
139 |
+
c=our_colors['blue'], markersize=ms,alpha=1)
|
140 |
+
plt.plot(range(len(staticQK)), staticQK, f'{markers[2]}--',markersize=ms,c=our_colors['yellow'],linewidth=lw, alpha=1)
|
141 |
+
|
142 |
+
plt.title(f'USPTO-{dataset}')
|
143 |
+
plt.xlabel('number of training examples')
|
144 |
+
plt.ylabel('top-100 test-accuracy')
|
145 |
+
plt.legend([model_legend, 'Fortunato et al.','FNN baseline',"FPM baseline", #static${\\xi X}: \\dfrac{|{\\xi} \\cap {X}|}{|{X}|}$
|
146 |
+
'test sample proportion'])
|
147 |
+
|
148 |
+
if draw_std:
|
149 |
+
alpha=0.2
|
150 |
+
plot_std(asdf, n_samples, z=z, color=our_colors['red'], alpha=alpha)
|
151 |
+
plot_std(pretr_values, n_samples, z=z, color=our_colors['green'], alpha=alpha)
|
152 |
+
plot_std(basel_values, n_samples, z=z, color=our_colors['blue'], alpha=alpha)
|
153 |
+
plot_std(staticQK, n_samples, z=z, color=our_colors['yellow'], alpha=alpha)
|
154 |
+
|
155 |
+
|
156 |
+
plt.xticks(range(13),xti);
|
157 |
+
plt.yticks(np.arange(0,1.05,0.1))
|
158 |
+
plt.grid('on', alpha=0.3)
|
mhnreact/retroeval.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl, Philipp Renz
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
Evaluation functions for single-step-retrosynthesis
|
9 |
+
"""
|
10 |
+
import sys
|
11 |
+
|
12 |
+
import rdchiral
|
13 |
+
from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
|
14 |
+
import hashlib
|
15 |
+
from rdkit import Chem
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import numpy as np
|
19 |
+
import pandas as pd
|
20 |
+
from collections import defaultdict
|
21 |
+
from copy import deepcopy
|
22 |
+
from glob import glob
|
23 |
+
import os
|
24 |
+
import pickle
|
25 |
+
|
26 |
+
from multiprocessing import Pool
|
27 |
+
import hashlib
|
28 |
+
import pickle
|
29 |
+
import logging
|
30 |
+
|
31 |
+
#import timeout_decorator
|
32 |
+
|
33 |
+
|
34 |
+
def _cont_hash(fn):
|
35 |
+
with open(fn, 'rb') as f:
|
36 |
+
return hashlib.md5(f.read()).hexdigest()
|
37 |
+
|
38 |
+
def load_templates_only(path, cache_dir='/tmp'):
|
39 |
+
arg_hash_base = 'load_templates_only' + path
|
40 |
+
arg_hash = hashlib.md5(arg_hash_base.encode()).hexdigest()
|
41 |
+
matches = glob(os.path.join(cache_dir, arg_hash+'*'))
|
42 |
+
|
43 |
+
if len(matches) > 1:
|
44 |
+
raise RuntimeError('Too many matches')
|
45 |
+
elif len(matches) == 1:
|
46 |
+
fn = matches[0]
|
47 |
+
content_hash = _cont_hash(path)
|
48 |
+
content_hash_file = os.path.basename(fn).split('_')[1].split('.')[0]
|
49 |
+
if content_hash_file == content_hash:
|
50 |
+
with open(fn, 'rb') as f:
|
51 |
+
return pickle.load(f)
|
52 |
+
|
53 |
+
df = pd.read_json(path)
|
54 |
+
template_dict = {}
|
55 |
+
for row in range(len(df)):
|
56 |
+
template_dict[df.iloc[row]['index']] = df.iloc[row].reaction_smarts
|
57 |
+
|
58 |
+
# cache the file
|
59 |
+
content_hash = _cont_hash(path)
|
60 |
+
fn = os.path.join(cache_dir, f"{arg_hash}_{content_hash}.p")
|
61 |
+
with open(fn, 'wb') as f:
|
62 |
+
pickle.dump(template_dict, f)
|
63 |
+
|
64 |
+
def load_templates_v2(path, get_complete_df=False):
|
65 |
+
if get_complete_df:
|
66 |
+
df = pd.read_json(path)
|
67 |
+
return df
|
68 |
+
|
69 |
+
return load_templates_only(path)
|
70 |
+
|
71 |
+
def canonicalize_reactants(smiles, can_steps=2):
|
72 |
+
if can_steps==0:
|
73 |
+
return smiles
|
74 |
+
|
75 |
+
mol = Chem.MolFromSmiles(smiles)
|
76 |
+
for a in mol.GetAtoms():
|
77 |
+
a.ClearProp('molAtomMapNumber')
|
78 |
+
|
79 |
+
smiles = Chem.MolToSmiles(mol, True)
|
80 |
+
if can_steps==1:
|
81 |
+
return smiles
|
82 |
+
|
83 |
+
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), True)
|
84 |
+
if can_steps==2:
|
85 |
+
return smiles
|
86 |
+
|
87 |
+
raise ValueError("Invalid can_steps")
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def load_test_set(fn):
|
92 |
+
df = pd.read_csv(fn, index_col=0)
|
93 |
+
test = df[df.dataset=='test']
|
94 |
+
|
95 |
+
test_product_smarts = list(test.prod_smiles) # we make predictions for these
|
96 |
+
for s in test_product_smarts:
|
97 |
+
assert len(s.split('.')) == 1
|
98 |
+
assert '>' not in s
|
99 |
+
|
100 |
+
test_reactants = [] # we want to predict these
|
101 |
+
for rs in list(test.rxn_smiles):
|
102 |
+
rs = rs.split('>>')
|
103 |
+
assert len(rs) == 2
|
104 |
+
reactants_ori, products = rs
|
105 |
+
reactants = reactants_ori.split('.')
|
106 |
+
products = products.split('.')
|
107 |
+
assert len(reactants) >= 1
|
108 |
+
assert len(products) == 1
|
109 |
+
|
110 |
+
test_reactants.append(reactants_ori)
|
111 |
+
|
112 |
+
return test_product_smarts, test_reactants
|
113 |
+
|
114 |
+
|
115 |
+
#@timeout_decorator.timeout(1, use_signals=False)
|
116 |
+
def time_out_rdchiralRun(temp, prod_rct, combine_enantiomers=False):
|
117 |
+
rxn = rdchiralReaction(temp)
|
118 |
+
return rdchiralRun(rxn, prod_rct, combine_enantiomers=combine_enantiomers)
|
119 |
+
|
120 |
+
def _run_templates_rdchiral(prod_appl):
|
121 |
+
prod, applicable_templates = prod_appl
|
122 |
+
prod_rct = rdchiralReactants(prod) # preprocess reactants with rdchiral
|
123 |
+
|
124 |
+
results = {}
|
125 |
+
for idx, temp in applicable_templates:
|
126 |
+
temp = str(temp)
|
127 |
+
try:
|
128 |
+
results[(idx, temp)] = time_out_rdchiralRun(temp, prod_rct, combine_enantiomers=False)
|
129 |
+
except:
|
130 |
+
pass
|
131 |
+
|
132 |
+
return results
|
133 |
+
|
134 |
+
def _run_templates_rdchiral_original(prod_appl):
|
135 |
+
prod, applicable_templates = prod_appl
|
136 |
+
prod_rct = rdchiralReactants(prod) # preprocess reactants with rdchiral
|
137 |
+
|
138 |
+
results = {}
|
139 |
+
rxn_cache = {}
|
140 |
+
for idx, temp in applicable_templates:
|
141 |
+
temp = str(temp)
|
142 |
+
if temp in rxn_cache:
|
143 |
+
rxn = rxn_cache[(temp)]
|
144 |
+
else:
|
145 |
+
try:
|
146 |
+
rxn = rdchiralReaction(temp)
|
147 |
+
rxn_cache[temp] = rxn
|
148 |
+
except:
|
149 |
+
rxn_cache[temp] = None
|
150 |
+
msg = temp+' error converting to rdchiralReaction'
|
151 |
+
logging.debug(msg)
|
152 |
+
try:
|
153 |
+
res = rdchiralRun(rxn, prod_rct, combine_enantiomers=False)
|
154 |
+
results[(idx, temp)] = res
|
155 |
+
except:
|
156 |
+
pass
|
157 |
+
|
158 |
+
return results
|
159 |
+
|
160 |
+
def run_templates(test_product_smarts, templates, appl, njobs=32, cache_dir='/tmp'):
|
161 |
+
appl_dict = defaultdict(list)
|
162 |
+
for i,j in zip(*appl):
|
163 |
+
appl_dict[i].append(j)
|
164 |
+
|
165 |
+
prod_appl_list = []
|
166 |
+
for prod_idx, prod in enumerate(test_product_smarts):
|
167 |
+
applicable_templates = [(idx, templates[idx]) for idx in appl_dict[prod_idx]]
|
168 |
+
prod_appl_list.append((prod, applicable_templates))
|
169 |
+
|
170 |
+
arg_hash = hashlib.md5(pickle.dumps(prod_appl_list)).hexdigest()
|
171 |
+
cache_file = os.path.join(cache_dir, arg_hash+'.p')
|
172 |
+
|
173 |
+
if os.path.isfile(cache_file):
|
174 |
+
with open(cache_file, 'rb') as f:
|
175 |
+
print('loading results from file',f)
|
176 |
+
all_results = pickle.load(f)
|
177 |
+
|
178 |
+
#find /tmp -type f \( ! -user root \) -atime +3 -delete
|
179 |
+
# to delete the tmp files that havent been accessed 3 days
|
180 |
+
|
181 |
+
else:
|
182 |
+
#with Pool(njobs) as pool:
|
183 |
+
# all_results = pool.map(_run_templates_rdchiral, prod_appl_list)
|
184 |
+
|
185 |
+
from tqdm.contrib.concurrent import process_map
|
186 |
+
all_results = process_map(_run_templates_rdchiral, prod_appl_list, max_workers=njobs, chunksize=1, mininterval=2)
|
187 |
+
|
188 |
+
#with open(cache_file, 'wb') as f:
|
189 |
+
# print('saving applicable_templates to cache', cache_file)
|
190 |
+
# pickle.dump(all_results, f)
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
prod_idx_reactants = []
|
195 |
+
prod_temp_reactants = []
|
196 |
+
|
197 |
+
for prod, idx_temp_reactants in zip(test_product_smarts, all_results):
|
198 |
+
prod_idx_reactants.append({idx_temp[0]: r for idx_temp, r in idx_temp_reactants.items()})
|
199 |
+
prod_temp_reactants.append({idx_temp[1]: r for idx_temp, r in idx_temp_reactants.items()})
|
200 |
+
|
201 |
+
return prod_idx_reactants, prod_temp_reactants
|
202 |
+
|
203 |
+
def sort_by_template(template_scores, prod_idx_reactants):
|
204 |
+
sorted_results = []
|
205 |
+
for i, predictions in enumerate(prod_idx_reactants):
|
206 |
+
score_row = template_scores[i]
|
207 |
+
appl_idxs = np.array(list(predictions.keys()))
|
208 |
+
if len(appl_idxs) == 0:
|
209 |
+
sorted_results.append([])
|
210 |
+
continue
|
211 |
+
scores = score_row[appl_idxs]
|
212 |
+
sorted_idxs = appl_idxs[np.argsort(scores)][::-1]
|
213 |
+
sorted_reactants = [predictions[idx] for idx in sorted_idxs]
|
214 |
+
sorted_results.append(sorted_reactants)
|
215 |
+
return sorted_results
|
216 |
+
|
217 |
+
def no_dup_same_order(l):
|
218 |
+
return list({r: 0 for r in l}.keys())
|
219 |
+
|
220 |
+
def flatten_per_product(sorted_results, remove_duplicates=True):
|
221 |
+
flat_results = [sum((r for r in row), []) for row in sorted_results]
|
222 |
+
if remove_duplicates:
|
223 |
+
flat_results = [no_dup_same_order(row) for row in flat_results]
|
224 |
+
return flat_results
|
225 |
+
|
226 |
+
|
227 |
+
def topkaccuracy(test_reactants, predicted_reactants, ks=[1], ret_ranks=False):
|
228 |
+
ks = [k if k is not None else 1e10 for k in ks]
|
229 |
+
ranks = []
|
230 |
+
for true, pred in zip(test_reactants, predicted_reactants):
|
231 |
+
try:
|
232 |
+
rank = pred.index(true) + 1
|
233 |
+
except ValueError:
|
234 |
+
rank = 1e15
|
235 |
+
ranks.append(rank)
|
236 |
+
ranks = np.array(ranks)
|
237 |
+
if ret_ranks:
|
238 |
+
return ranks
|
239 |
+
|
240 |
+
return [np.mean([ranks <= k]) for k in ks]
|
mhnreact/train.py
ADDED
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
Training
|
9 |
+
"""
|
10 |
+
|
11 |
+
from .utils import str2bool, lgamma, multinom_gk, top_k_accuracy
|
12 |
+
from .data import load_templates, load_dataset_from_csv, load_USPTO
|
13 |
+
from .model import ModelConfig, MHN, StaticQK, SeglerBaseline, Retrosim
|
14 |
+
from .molutils import convert_smiles_to_fp, FP_featurizer, smarts2appl, getTemplateFingerprint, disable_rdkit_logging
|
15 |
+
from collections import defaultdict
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
import numpy as np
|
19 |
+
import pandas as pd
|
20 |
+
import datetime
|
21 |
+
import sys
|
22 |
+
from time import time
|
23 |
+
import matplotlib.pyplot as plt
|
24 |
+
import torch
|
25 |
+
import multiprocessing
|
26 |
+
import warnings
|
27 |
+
from joblib import Memory
|
28 |
+
|
29 |
+
cachedir = 'data/cache/'
|
30 |
+
memory = Memory(cachedir, verbose=0, bytes_limit=80e9)
|
31 |
+
|
32 |
+
def parse_args():
|
33 |
+
parser = argparse.ArgumentParser(description="Train MHNreact.",
|
34 |
+
epilog="--", prog="Train")
|
35 |
+
parser.add_argument('-f', type=str)
|
36 |
+
parser.add_argument('--model_type', type=str, default='mhn',
|
37 |
+
help="Model-type: choose from 'segler', 'fortunato', 'mhn' or 'staticQK', default:'mhn'")
|
38 |
+
parser.add_argument("--exp_name", type=str, default='', help="experiment name, (added as postfix to the file-names)")
|
39 |
+
parser.add_argument("-d", "--dataset_type", type=str, default='sm',
|
40 |
+
help="Input Dataset 'sm' for Scheider-USPTO-50k 'lg' for USPTO large or 'golden' or use keyword '--csv_path to specify an input file', default: 'sm'")
|
41 |
+
parser.add_argument("--csv_path", default=None, type=str, help="path to preprocessed trainings file + split columns, default: None")
|
42 |
+
parser.add_argument("--split_col", default='split', type=str, help="split column of csv, default: 'split'")
|
43 |
+
parser.add_argument("--input_col", default='prod_smiles', type=str, help="input column of csv, default: 'pro_smiles'")
|
44 |
+
parser.add_argument("--reactants_col", default='reactants_can', type=str, help="reactant colum of csv, default: 'reactants_can'")
|
45 |
+
|
46 |
+
parser.add_argument("--fp_type", type=str, default='morganc',
|
47 |
+
help="Fingerprint type for the input only!: default: 'morgan', other options: 'rdk', 'ECFP', 'ECFC', 'MxFP', 'Morgan2CBF' or a combination of fingerprints with '+'' for max-pooling and '&' for concatination e.g. maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp, default: 'morganc'")
|
48 |
+
parser.add_argument("--template_fp_type", type=str, default='rdk',
|
49 |
+
help="Fingerprint type for the template fingerprint, default: 'rdk'")
|
50 |
+
parser.add_argument("--device", type=str, default='best',
|
51 |
+
help="Device to run the model on, preferably 'cuda:0', default: 'best' (takes the gpu with most RAM)")
|
52 |
+
parser.add_argument("--fp_size", type=int, default=4096,
|
53 |
+
help="fingerprint-size used for templates as well as for inputs, default: 4096")
|
54 |
+
parser.add_argument("--fp_radius", type=int, default=2, help="fingerprint-radius (if applicable to the fingerprint-type), default: 2")
|
55 |
+
parser.add_argument("--epochs", type=int, default=10, help='number of epochs, default: 10')
|
56 |
+
|
57 |
+
parser.add_argument("--pretrain_epochs", type=int, default=0,
|
58 |
+
help="applicability-matrix pretraining epochs if applicable (e.g. fortunato model_type), default: 0")
|
59 |
+
parser.add_argument("--save_model", type=str2bool, default=False, help="save the model, default: False")
|
60 |
+
|
61 |
+
parser.add_argument("--dropout", type=float, default=0.2, help="dropout rate for encoders, default: 0.2")
|
62 |
+
parser.add_argument("--lr", type=float, default=5e-4, help="learning-rate, dfeault: 5e-4")
|
63 |
+
parser.add_argument("--hopf_beta", type=float, default=0.05, help="hopfield beta parameter, default: 0.125")
|
64 |
+
parser.add_argument("--hopf_asso_dim", type=int, default=512, help="association dimension, default: 512")
|
65 |
+
parser.add_argument("--hopf_num_heads", type=int, default=1, help="hopfield number of heads, default: 1")
|
66 |
+
parser.add_argument("--hopf_association_activation", type=str, default='None',
|
67 |
+
help="hopfield association activation function recommended:'Tanh' or 'None', other: 'ReLU', 'SeLU', 'GeLU', or 'None' for more, see torch.nn, default: 'None'")
|
68 |
+
|
69 |
+
parser.add_argument("--norm_input", default=True, type=str2bool,
|
70 |
+
help="input-normalization, default: True")
|
71 |
+
parser.add_argument("--norm_asso", default=True, type=str2bool,
|
72 |
+
help="association-normalization, default: True")
|
73 |
+
|
74 |
+
# additional experimental hyperparams
|
75 |
+
parser.add_argument("--hopf_n_layers", default=1, type=int, help="Number of hopfield-layers, default: 1")
|
76 |
+
parser.add_argument("--mol_encoder_layers", default=1, type=int, help="Number of molecule-encoder layers, default: 1")
|
77 |
+
parser.add_argument("--temp_encoder_layers", default=1, type=int, help="Number of template-encoder layers, default: 1")
|
78 |
+
parser.add_argument("--encoder_af", default='ReLU', type=str,
|
79 |
+
help="Encoder-NN intermediate activation function (before association_activation function), default: 'ReLU'")
|
80 |
+
parser.add_argument("--hopf_pooling_operation_head", default='mean', type=str, help="Pooling operation over heads default=max, (max, min, mean, ...), default: 'mean'")
|
81 |
+
|
82 |
+
parser.add_argument("--splitting_scheme", default=None, type=str, help="Splitting_scheme for non-csv-input, default: None, other options: 'class-freq', 'random'")
|
83 |
+
|
84 |
+
parser.add_argument("--concat_rand_template_thresh", default=-1, type=int, help="Concatinates a random vector to the tempalte-fingerprint at all templates with num_training samples > this threshold; -1 (default) means deactivated")
|
85 |
+
parser.add_argument("--repl_quotient", default=10, type=float, help="Only if --concat_rand_template_thresh >= 0 - Quotient of how much should be replaced by random in template-embedding, (default: 10)")
|
86 |
+
parser.add_argument("--verbose", default=False, type=str2bool, help="If verbose, will print out more stuff, default: False")
|
87 |
+
parser.add_argument("--batch_size", default=128, type=int, help="Training batch-size, default: 128")
|
88 |
+
parser.add_argument("--eval_every_n_epochs", default=1, type=int, help="Evaluate every _ epochs (Evaluation is costly for USPTO-Lg), default: 1")
|
89 |
+
parser.add_argument("--save_preds", default=False, type=str2bool, help="Save predictions for test split at the end of training, default: False")
|
90 |
+
parser.add_argument("--wandb", default=False, type=str2bool, help="Save to wandb; login required, default: False")
|
91 |
+
parser.add_argument("--seed", default=None, type=int, help="Seed your run to make it reproducible, defualt: None")
|
92 |
+
|
93 |
+
parser.add_argument("--template_fp_type2", default=None, type=str, help="experimental template_fp_type for layer 2, default: None")
|
94 |
+
parser.add_argument("--layer2weight",default=0.2, type=float, help="hopf-layer2 weight of p, default: 0.2")
|
95 |
+
|
96 |
+
parser.add_argument("--reactant_pooling", default='max', type=str, help="reactant pooling operation over template-fingerprint, default: 'max', options: 'min','mean','lgamma'")
|
97 |
+
|
98 |
+
|
99 |
+
parser.add_argument("--ssretroeval", default=False, type=str2bool, help="single-step retro-synthesis eval, default: False")
|
100 |
+
parser.add_argument("--addval2train", default=False, type=str2bool, help="adds the validation set to the training set, default: False")
|
101 |
+
parser.add_argument("--njobs",default=-1, type=int, help="Number of jobs, default: -1 -> uses all available")
|
102 |
+
|
103 |
+
parser.add_argument("--eval_only_loss", default=False, type=str2bool, help="if only loss should be evaluated (if top-k acc may be time consuming), default: False")
|
104 |
+
parser.add_argument("--only_templates_in_batch", default=False, type=str2bool, help="while training only forwards templates that are in the batch, default: False")
|
105 |
+
|
106 |
+
parser.add_argument("--plot_res", default=False, type=str2bool, help="Plotting results for USPTO-sm/lg, default: False")
|
107 |
+
args = parser.parse_args()
|
108 |
+
|
109 |
+
if args.njobs ==-1:
|
110 |
+
args.njobs = int(multiprocessing.cpu_count())
|
111 |
+
|
112 |
+
if args.device=='best':
|
113 |
+
from .utils import get_best_gpu
|
114 |
+
try:
|
115 |
+
args.device = get_best_gpu()
|
116 |
+
except:
|
117 |
+
print('couldnt get the best gpu, using cpu instead')
|
118 |
+
args.device = 'cpu'
|
119 |
+
|
120 |
+
# some save checks on model type
|
121 |
+
if (args.model_type == 'segler') & (args.pretrain_epochs>=1):
|
122 |
+
print('changing model type to fortunato because of pretraining_epochs>0')
|
123 |
+
args.model_type = 'fortunato'
|
124 |
+
if ((args.model_type == 'staticQK') or (args.model_type == 'retrosim')) & (args.epochs>1):
|
125 |
+
print('changing epochs to 1 (StaticQK is not lernable ;)')
|
126 |
+
args.epochs=1
|
127 |
+
if args.template_fp_type != args.fp_type:
|
128 |
+
print('fp_type must be the same as template_fp_type --> setting template_fp_type to fp_type')
|
129 |
+
args.template_fp_type = args.fp_type
|
130 |
+
if args.save_model & (args.fp_type=='MxFP'):
|
131 |
+
warnings.warn('Currently MxFP is not recommended for saving the model paprameter (fragment dict for others would need to be saved or compued again, currently not implemented)')
|
132 |
+
|
133 |
+
return args
|
134 |
+
|
135 |
+
@memory.cache(ignore=['njobs'])
|
136 |
+
def featurize_smiles(X, fp_type='morgan', fp_size=4096, fp_radius=2, njobs=1, verbose=False):
|
137 |
+
X_fp = {}
|
138 |
+
|
139 |
+
if fp_type in ['MxFP','MACCS','Morgan2CBF','Morgan4CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK']:
|
140 |
+
print('computing', fp_type)
|
141 |
+
if fp_type == 'MxFP':
|
142 |
+
fp_types = ['MACCS','Morgan2CBF','Morgan4CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK']
|
143 |
+
else:
|
144 |
+
fp_types = [fp_type]
|
145 |
+
|
146 |
+
remaining = int(fp_size)
|
147 |
+
for fp_type in fp_types:
|
148 |
+
print(fp_type,end=' ')
|
149 |
+
feat = FP_featurizer(fp_types=fp_type,
|
150 |
+
max_features= (fp_size//len(fp_types)) if (fp_type != fp_types[-1]) else remaining )
|
151 |
+
X_fp[f'train_{fp_type}'] = feat.fit(X['train'])
|
152 |
+
X_fp[f'valid_{fp_type}'] = feat.transform(X['valid'])
|
153 |
+
X_fp[f'test_{fp_type}'] = feat.transform(X['test'])
|
154 |
+
|
155 |
+
remaining -= X_fp[f'train_{fp_type}'].shape[1]
|
156 |
+
#X_fp['train'].shape, X_fp['test'].shape
|
157 |
+
X_fp['train'] = np.hstack([ X_fp[f'train_{fp_type}'] for fp_type in fp_types])
|
158 |
+
X_fp['valid'] = np.hstack([ X_fp[f'valid_{fp_type}'] for fp_type in fp_types])
|
159 |
+
X_fp['test'] = np.hstack([ X_fp[f'test_{fp_type}'] for fp_type in fp_types])
|
160 |
+
|
161 |
+
else: #fp_type in ['rdk','morgan','ecfp4','pattern','morganc','rdkc']:
|
162 |
+
if verbose: print('computing', fp_type, 'folded')
|
163 |
+
for split in X.keys():
|
164 |
+
X_fp[split] = convert_smiles_to_fp(X[split], fp_size=fp_size, which=fp_type, radius=fp_radius, njobs=njobs, verbose=verbose)
|
165 |
+
|
166 |
+
return X_fp
|
167 |
+
|
168 |
+
|
169 |
+
def compute_template_fp(fp_len=2048, reactant_pooling='max', do_log=True):
|
170 |
+
"""Pre-Compute the template-fingerprint"""
|
171 |
+
# combine them to one fingerprint
|
172 |
+
comb_template_fp = np.zeros((max(template_list.keys())+1,fp_len if reactant_pooling!='concat' else fp_len*6))
|
173 |
+
for i in template_list:
|
174 |
+
tpl = template_list[i]
|
175 |
+
try:
|
176 |
+
pr, rea = str(tpl).split('>>')
|
177 |
+
idxx = temp_part_to_fp[pr]
|
178 |
+
prod_fp = templates_fp['fp'][idxx]
|
179 |
+
except:
|
180 |
+
print('err', pr, end='\r')
|
181 |
+
prod_fp = np.zeros(fp_len)
|
182 |
+
|
183 |
+
rea_fp = templates_fp['fp'][[temp_part_to_fp[r] for r in str(rea).split('.')]] # max-pooling
|
184 |
+
|
185 |
+
if reactant_pooling=='only_product':
|
186 |
+
rea_fp = np.zeros(fp_len)
|
187 |
+
if reactant_pooling=='max':
|
188 |
+
rea_fp = np.log(1 + rea_fp.max(0))
|
189 |
+
elif reactant_pooling=='mean':
|
190 |
+
rea_fp = np.log(1 + rea_fp.mean(0))
|
191 |
+
elif reactant_pooling=='sum':
|
192 |
+
rea_fp = np.log(1 + rea_fp.mean(0))
|
193 |
+
elif reactant_pooling=='lgamma':
|
194 |
+
rea_fp = multinom_gk(rea_fp, axis=0)
|
195 |
+
elif reactant_pooling=='concat':
|
196 |
+
rs = str(rea).split('.')
|
197 |
+
rs.sort()
|
198 |
+
for ii, r in enumerate(rs):
|
199 |
+
idx = temp_part_to_fp[r]
|
200 |
+
rea_fp = templates_fp['fp'][idx]
|
201 |
+
comb_template_fp[i, (fp_len*(ii+1)):(fp_len*(ii+2))] = np.log(1 + rea_fp)
|
202 |
+
|
203 |
+
comb_template_fp[i,:prod_fp.shape[0]] = np.log(1 + prod_fp) #- rea_fp*0.5
|
204 |
+
if reactant_pooling!='concat':
|
205 |
+
#comb_template_fp[i] = multinom_gk(np.stack([np.log(1+prod_fp), rea_fp]))
|
206 |
+
#comb_template_fp[i,fp_len:] = rea_fp
|
207 |
+
comb_template_fp[i,:rea_fp.shape[0]] = comb_template_fp[i, :rea_fp.shape[0]] - rea_fp*0.5
|
208 |
+
|
209 |
+
return comb_template_fp
|
210 |
+
|
211 |
+
|
212 |
+
def set_up_model(args, template_list=None):
|
213 |
+
hpn_config = ModelConfig(num_templates = int(max(template_list.keys()))+1,
|
214 |
+
#len(template_list.values()), #env.num_templates, #
|
215 |
+
dropout=args.dropout,
|
216 |
+
fingerprint_type=args.fp_type,
|
217 |
+
template_fp_type = args.template_fp_type,
|
218 |
+
fp_size = args.fp_size,
|
219 |
+
fp_radius= args.fp_radius,
|
220 |
+
device=args.device,
|
221 |
+
lr=args.lr,
|
222 |
+
hopf_beta=args.hopf_beta, #1/(128**0.5),#1/(2048**0.5),
|
223 |
+
hopf_input_size=args.fp_size,
|
224 |
+
hopf_output_size=None,
|
225 |
+
hopf_num_heads=args.hopf_num_heads,
|
226 |
+
hopf_asso_dim=args.hopf_asso_dim,
|
227 |
+
|
228 |
+
hopf_association_activation = args.hopf_association_activation, #or ReLU, Tanh works better, SELU, GELU
|
229 |
+
norm_input = args.norm_input,
|
230 |
+
norm_asso = args.norm_asso,
|
231 |
+
|
232 |
+
hopf_n_layers= args.hopf_n_layers,
|
233 |
+
mol_encoder_layers=args.mol_encoder_layers,
|
234 |
+
temp_encoder_layers=args.temp_encoder_layers,
|
235 |
+
encoder_af=args.encoder_af,
|
236 |
+
|
237 |
+
hopf_pooling_operation_head = args.hopf_pooling_operation_head,
|
238 |
+
batch_size=args.batch_size,
|
239 |
+
)
|
240 |
+
print(hpn_config.__dict__)
|
241 |
+
|
242 |
+
if args.model_type=='segler': # baseline
|
243 |
+
clf = SeglerBaseline(hpn_config)
|
244 |
+
elif args.model_type=='mhn':
|
245 |
+
clf = MHN(hpn_config, layer2weight=args.layer2weight)
|
246 |
+
elif args.model_type=='fortunato': # pretraining with applicability-matrix
|
247 |
+
clf = SeglerBaseline(hpn_config)
|
248 |
+
elif args.model_type=='staticQK': # staticQK
|
249 |
+
clf = StaticQK(hpn_config)
|
250 |
+
elif args.model_type=='retrosim': # staticQK
|
251 |
+
clf = Retrosim(hpn_config)
|
252 |
+
else:
|
253 |
+
raise NotImplementedError
|
254 |
+
|
255 |
+
return clf, hpn_config
|
256 |
+
|
257 |
+
def set_up_template_encoder(args, clf, label_to_n_train_samples=None, template_list=None):
|
258 |
+
|
259 |
+
if isinstance(clf, SeglerBaseline):
|
260 |
+
clf.templates = []
|
261 |
+
elif args.model_type=='staticQK':
|
262 |
+
clf.template_list = list(template_list.values())
|
263 |
+
clf.update_template_embedding(which=args.template_fp_type, fp_size=args.fp_size, radius=args.fp_radius, njobs=args.njobs)
|
264 |
+
elif args.model_type=='retrosim':
|
265 |
+
#clf.template_list = list(X['train'].values())
|
266 |
+
clf.fit_with_train(X_fp['train'], y['train'])
|
267 |
+
else:
|
268 |
+
import hashlib
|
269 |
+
PATH = './data/cache/'
|
270 |
+
if not os.path.exists(PATH):
|
271 |
+
os.mkdir(PATH)
|
272 |
+
fn_templ_emb = f'{PATH}templ_emb_{args.fp_size}_{args.template_fp_type}{args.fp_radius}_{len(template_list)}_{int(hashlib.sha512((str(template_list)).encode()).hexdigest(), 16)}.npy'
|
273 |
+
if (os.path.exists(fn_templ_emb)): # load the template embedding
|
274 |
+
print(f'loading tfp from file {fn_templ_emb}')
|
275 |
+
templ_emb = np.load(fn_templ_emb)
|
276 |
+
# !!! beware of different fingerprint types
|
277 |
+
clf.template_list = list(template_list.values())
|
278 |
+
|
279 |
+
if args.only_templates_in_batch:
|
280 |
+
clf.templates_np = templ_emb
|
281 |
+
clf.templates = None
|
282 |
+
else:
|
283 |
+
clf.templates = torch.from_numpy(templ_emb).float().to(clf.config.device)
|
284 |
+
else:
|
285 |
+
if args.template_fp_type=='MxFP':
|
286 |
+
clf.template_list = list(template_list.values())
|
287 |
+
clf.templates = torch.from_numpy(comb_template_fp).float().to(clf.config.device)
|
288 |
+
clf.set_templates_recursively()
|
289 |
+
elif args.template_fp_type=='Tfidf':
|
290 |
+
clf.template_list = list(template_list.values())
|
291 |
+
clf.templates = torch.from_numpy(tfidf_template_fp).float().to(clf.config.device)
|
292 |
+
clf.set_templates_recursively()
|
293 |
+
elif args.template_fp_type=='random':
|
294 |
+
clf.template_list = list(template_list.values())
|
295 |
+
clf.templates = torch.from_numpy(np.random.rand(len(template_list),args.fp_size)).float().to(clf.config.device)
|
296 |
+
clf.set_templates_recursively()
|
297 |
+
else:
|
298 |
+
clf.set_templates(list(template_list.values()), which=args.template_fp_type, fp_size=args.fp_size,
|
299 |
+
radius=args.fp_radius, learnable=False, njobs=args.njobs, only_templates_in_batch=args.only_templates_in_batch)
|
300 |
+
#if len(template_list)<100000:
|
301 |
+
np.save(fn_templ_emb, clf.templates_np if args.only_templates_in_batch else clf.templates.detach().cpu().numpy().astype(np.float16))
|
302 |
+
|
303 |
+
# concatinate the current fingerprint with a random fingerprint if the threshold is above
|
304 |
+
if (args.concat_rand_template_thresh != -1) & (args.repl_quotient>0):
|
305 |
+
REPLACE_FACTOR = int(args.repl_quotient) # default was 8
|
306 |
+
|
307 |
+
# fold the original fingerprint
|
308 |
+
pre_comp_templates = clf.templates_np if args.only_templates_in_batch else clf.templates.detach().cpu().numpy()
|
309 |
+
|
310 |
+
# mask of labels with mor than 49 training samples
|
311 |
+
l_mask = np.array([label_to_n_train_samples[k]>=args.concat_rand_template_thresh for k in template_list])
|
312 |
+
print(f'Num of templates with added rand-vect of size {pre_comp_templates.shape[1]//REPLACE_FACTOR} due to >=thresh ({args.concat_rand_template_thresh}):',l_mask.sum())
|
313 |
+
|
314 |
+
# remove the bits with the lowest variance
|
315 |
+
v = pre_comp_templates.var(0)
|
316 |
+
idx_lowest_var_half = v.argsort()[:(pre_comp_templates.shape[1]//REPLACE_FACTOR)]
|
317 |
+
|
318 |
+
# the new zero-init-vectors
|
319 |
+
pre = np.zeros([pre_comp_templates.shape[0], pre_comp_templates.shape[1]//REPLACE_FACTOR]).astype(np.float)
|
320 |
+
print(pre.shape, l_mask.shape, l_mask.sum()) #(616, 1700) (11790,) 519
|
321 |
+
print(pre_comp_templates.shape, len(template_list)) #(616, 17000) 616
|
322 |
+
# only the ones with >thresh will receive a random vect
|
323 |
+
pre[l_mask] = np.random.rand(l_mask.sum(), pre.shape[1])
|
324 |
+
|
325 |
+
pre_comp_templates[:,idx_lowest_var_half] = pre
|
326 |
+
|
327 |
+
#clf.templates = torch.from_numpy(pre_comp_templates).float().to(clf.config.device)
|
328 |
+
if pre_comp_templates.shape[0]<100000:
|
329 |
+
print('adding template_matrix to params')
|
330 |
+
param = torch.nn.Parameter(torch.from_numpy(pre_comp_templates).float(), requires_grad=False)
|
331 |
+
clf.register_parameter(name='templates+noise', param=param)
|
332 |
+
clf.templates = param.to(clf.config.device)
|
333 |
+
clf.set_templates_recursively()
|
334 |
+
else: #otherwise might cause memory issues
|
335 |
+
print('more than 100k templates')
|
336 |
+
if args.only_templates_in_batch:
|
337 |
+
clf.templates = None
|
338 |
+
clf.templates_np = pre_comp_templates
|
339 |
+
else:
|
340 |
+
clf.templates = torch.from_numpy(pre_comp_templates).float()
|
341 |
+
clf.set_templates_recursively()
|
342 |
+
|
343 |
+
|
344 |
+
# set's this for the first layer!!
|
345 |
+
if args.template_fp_type2=='MxFP':
|
346 |
+
print('first_layer template_fingerprint is set to MxFP')
|
347 |
+
clf.templates = torch.from_numpy(comb_template_fp).float().to(clf.config.device)
|
348 |
+
elif args.template_fp_type2=='Tfidf':
|
349 |
+
print('first_layer template_fingerprint is set to Tfidf')
|
350 |
+
clf.templates = torch.from_numpy(tfidf_template_fp).float().to(clf.config.device)
|
351 |
+
elif args.template_fp_type2=='random':
|
352 |
+
print('first_layer template_fingerprint is set to random')
|
353 |
+
clf.templates = torch.from_numpy(np.random.rand(len(template_list),args.fp_size)).float().to(clf.config.device)
|
354 |
+
elif args.template_fp_type2=='stfp':
|
355 |
+
print('first_layer template_fingerprint is set to stfp ! only works with 4096 fp_size')
|
356 |
+
tfp = getTemplateFingerprint(list(template_list.values()))
|
357 |
+
clf.templates = torch.from_numpy(tfp).float().to(clf.config.device)
|
358 |
+
|
359 |
+
return clf
|
360 |
+
|
361 |
+
|
362 |
+
if __name__ == '__main__':
|
363 |
+
|
364 |
+
args = parse_args()
|
365 |
+
|
366 |
+
run_id = str(time()).split('.')[0]
|
367 |
+
fn_postfix = str(args.exp_name) + '_' + run_id
|
368 |
+
|
369 |
+
if args.wandb:
|
370 |
+
import wandb
|
371 |
+
wandb.init(project='mhn-react', entity='phseidl', name=args.dataset_type+'_'+args.model_type+'_'+fn_postfix, config=args.__dict__)
|
372 |
+
else:
|
373 |
+
wandb=None
|
374 |
+
|
375 |
+
if not args.verbose:
|
376 |
+
disable_rdkit_logging()
|
377 |
+
|
378 |
+
if args.seed is not None:
|
379 |
+
from .utils import seed_everything
|
380 |
+
seed_everything(args.seed)
|
381 |
+
print('seeded with',args.seed)
|
382 |
+
|
383 |
+
# load csv or data
|
384 |
+
if args.csv_path is None:
|
385 |
+
X, y = load_USPTO(which=args.dataset_type)
|
386 |
+
template_list = load_templates(which=args.dataset_type)
|
387 |
+
else:
|
388 |
+
X, y, template_list, test_reactants_can = load_dataset_from_csv(**vars(args))
|
389 |
+
|
390 |
+
if args.addval2train:
|
391 |
+
print('adding val to train')
|
392 |
+
X['train'] = [*X['train'],*X['valid']]
|
393 |
+
y['train'] = np.concatenate([y['train'],y['valid']])
|
394 |
+
|
395 |
+
splits = ['train', 'valid', 'test']
|
396 |
+
|
397 |
+
#TODO split up in seperate class
|
398 |
+
if args.splitting_scheme == 'class-freq':
|
399 |
+
X_all = np.concatenate([X[split] for split in splits], axis=0)
|
400 |
+
y_all = np.concatenate([y[split] for split in splits])
|
401 |
+
|
402 |
+
# sort class by frequency / assumes class-index is ordered (wich is mildely violated)
|
403 |
+
res = y_all.argsort()
|
404 |
+
|
405 |
+
# use same split proportions
|
406 |
+
cum_split_lens = np.cumsum([len(y[split]) for split in splits]) #cumulative split length
|
407 |
+
|
408 |
+
X['train'] = X_all[res[0:cum_split_lens[0]]]
|
409 |
+
y['train'] = y_all[res[0:cum_split_lens[0]]]
|
410 |
+
|
411 |
+
X['valid'] = X_all[res[cum_split_lens[0]:cum_split_lens[1]]]
|
412 |
+
y['valid'] = y_all[res[cum_split_lens[0]:cum_split_lens[1]]]
|
413 |
+
|
414 |
+
X['test'] = X_all[res[cum_split_lens[1]:]]
|
415 |
+
y['test'] = y_all[res[cum_split_lens[1]:]]
|
416 |
+
for split in splits:
|
417 |
+
print(split, y[split].shape[0], 'samples (', y[split].max(),'max label)')
|
418 |
+
|
419 |
+
if args.splitting_scheme == 'remove_once_in_train_and_not_in_test':
|
420 |
+
print('remove_once_in_train')
|
421 |
+
from collections import Counter
|
422 |
+
cc = Counter()
|
423 |
+
cc.update(y['train'])
|
424 |
+
classes_set_only_once_in_train = set(np.array(list(cc.keys()))[ (np.array(list(cc.values())))==1])
|
425 |
+
not_in_test = set(y['train']).union(y['valid']) - (set(y['test']))
|
426 |
+
classes_set_only_once_in_train = (classes_set_only_once_in_train.intersection(not_in_test))
|
427 |
+
remove_those_mask = np.array([yii in classes_set_only_once_in_train for yii in y['train']])
|
428 |
+
X['train'] = np.array(X['train'])[~remove_those_mask]
|
429 |
+
y['train'] = np.array(y['train'])[~remove_those_mask]
|
430 |
+
print(remove_those_mask.mean(),'%', remove_those_mask.sum(), 'samples removed')
|
431 |
+
|
432 |
+
if args.splitting_scheme == 'random':
|
433 |
+
print('random-splitting-scheme:8-1-1')
|
434 |
+
if args.ssretroeval:
|
435 |
+
print('ssretroeval not available')
|
436 |
+
raise NotImplementedError
|
437 |
+
import numpy as np
|
438 |
+
from sklearn.model_selection import train_test_split
|
439 |
+
|
440 |
+
def _unpack(lod):
|
441 |
+
r = []
|
442 |
+
for k,v in lod.items():
|
443 |
+
[r.append(i) for i in v]
|
444 |
+
return r
|
445 |
+
|
446 |
+
X_all = _unpack(X)
|
447 |
+
y_all = np.array( _unpack(y) )
|
448 |
+
|
449 |
+
X['train'], X['test'], y['train'], y['test'] = train_test_split(X_all, y_all, test_size=0.2, random_state=70135)
|
450 |
+
X['test'], X['valid'], y['test'], y['valid'] = train_test_split(X['test'], y['test'], test_size=0.5, random_state=70135)
|
451 |
+
|
452 |
+
zero_shot = set(y['test']).difference( set(y['train']).union(set(y['valid'])) )
|
453 |
+
zero_shot_mask = np.array([yi in zero_shot for yi in y['test']])
|
454 |
+
print(sum(zero_shot_mask))
|
455 |
+
#y['test'][zero_shot_mask] = list(zero_shot)[0] #not right but quick
|
456 |
+
|
457 |
+
|
458 |
+
if args.model_type=='staticQK' or args.model_type=='retrosim':
|
459 |
+
print('staticQK model: caution: use pattern, or rdk -fingerprint-embedding')
|
460 |
+
|
461 |
+
fp_size = args.fp_size
|
462 |
+
radius = args.fp_radius #quite important ;)
|
463 |
+
fp_embedding = args.fp_type
|
464 |
+
|
465 |
+
X_fp = featurize_smiles(X, fp_type=args.fp_type, fp_size=args.fp_size, fp_radius=args.fp_radius, njobs=args.njobs)
|
466 |
+
|
467 |
+
if args.template_fp_type=='MxFP' or (args.template_fp_type2=='MxFP'):
|
468 |
+
temp_part_to_fp = {}
|
469 |
+
for i in template_list:
|
470 |
+
tpl = template_list[i]
|
471 |
+
for part in str(tpl).split('>>'):
|
472 |
+
for p in str(part).split('.'):
|
473 |
+
temp_part_to_fp[p]=None
|
474 |
+
for i, k in enumerate(temp_part_to_fp):
|
475 |
+
temp_part_to_fp[k] = i
|
476 |
+
|
477 |
+
fp_types = ['Morgan2CBF','Morgan4CBF', 'Morgan6CBF','AtomPair','TopologicalTorsion', 'Pattern', 'RDK']
|
478 |
+
#MACCS ErG don't work --> errors with explicit / inplicit valence
|
479 |
+
templates_fp = {}
|
480 |
+
remaining = args.fp_size
|
481 |
+
for fp_type in fp_types:
|
482 |
+
#print(fp_type, end='\t')
|
483 |
+
# if it's that last use up the remaining fps
|
484 |
+
te_feat = FP_featurizer(fp_types=fp_type,
|
485 |
+
max_features=(args.fp_size//len(fp_types)) if (fp_type != fp_types[-1]) else remaining,
|
486 |
+
log_scale=False
|
487 |
+
)
|
488 |
+
templates_fp[fp_type] = te_feat.fit(list(temp_part_to_fp.keys())[:], is_smarts=True)
|
489 |
+
#print(np.unique(templates_fp[fp_type]), end='\r')
|
490 |
+
remaining -= templates_fp[fp_type].shape[1]
|
491 |
+
templates_fp['fp'] = np.hstack([ templates_fp[f'{fp_type}'] for fp_type in fp_types])
|
492 |
+
|
493 |
+
|
494 |
+
if args.template_fp_type=='MxFP' or (args.template_fp_type2=='MxFP'):
|
495 |
+
comb_template_fp = compute_template_fp(fp_len= args.fp_size, reactant_pooling=args.reactant_pooling)
|
496 |
+
|
497 |
+
|
498 |
+
|
499 |
+
if args.template_fp_type=='Tfidf' or (args.template_fp_type2 == 'Tfidf'):
|
500 |
+
print('using tfidf template-fingerprint')
|
501 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
502 |
+
corpus = (list(template_list.values()))
|
503 |
+
vectorizer = TfidfVectorizer(analyzer='char', ngram_range=(1,12), max_features=args.fp_size)
|
504 |
+
tfidf_template_fp = vectorizer.fit_transform(corpus).toarray()
|
505 |
+
tfidf_template_fp.shape
|
506 |
+
|
507 |
+
|
508 |
+
acutal_fp_size = X_fp['train'].shape[1]
|
509 |
+
if acutal_fp_size != args.fp_size:
|
510 |
+
args.fp_size = int(X_fp['train'].shape[1])
|
511 |
+
print('Warning: fp-size has changed to', acutal_fp_size)
|
512 |
+
|
513 |
+
|
514 |
+
label_to_n_train_samples = {}
|
515 |
+
n_train_samples_to_label = defaultdict(list)
|
516 |
+
n_templates = max(template_list.keys())+1 #max(max(y['train']), max(y['test']), max(y['valid']))
|
517 |
+
for i in range(n_templates):
|
518 |
+
n_train_samples = (y['train']==i).sum()
|
519 |
+
label_to_n_train_samples[i] = n_train_samples
|
520 |
+
n_train_samples_to_label[n_train_samples].append(i)
|
521 |
+
|
522 |
+
|
523 |
+
up_to = 11
|
524 |
+
n_samples = []
|
525 |
+
masks = []
|
526 |
+
ntes = range(up_to)
|
527 |
+
mask_dict = {}
|
528 |
+
|
529 |
+
for nte in ntes: # Number of training examples
|
530 |
+
split = f'nte_{nte}'
|
531 |
+
#print(split)
|
532 |
+
mask = np.zeros(y['test'].shape)
|
533 |
+
|
534 |
+
if isinstance(nte, int):
|
535 |
+
for label_with_nte in n_train_samples_to_label[nte]:
|
536 |
+
mask += (y['test'] == label_with_nte)
|
537 |
+
|
538 |
+
mask = mask>=1
|
539 |
+
masks.append(mask)
|
540 |
+
mask_dict[str(nte)] = mask
|
541 |
+
n_samples.append(mask.sum())
|
542 |
+
|
543 |
+
# for greater than 10 # >10
|
544 |
+
n_samples.append((np.array(masks).max(0)==0).sum())
|
545 |
+
mask_dict['>10'] = (np.array(masks).max(0)==0)
|
546 |
+
|
547 |
+
sum(n_samples), mask.shape
|
548 |
+
|
549 |
+
ntes = range(50) #to 49
|
550 |
+
for nte in ntes: # Number of training examples
|
551 |
+
split = f'nte_{nte}'
|
552 |
+
#print(split)
|
553 |
+
mask = np.zeros(y['test'].shape)
|
554 |
+
for label_with_nte in n_train_samples_to_label[nte]:
|
555 |
+
mask += (y['test'] == label_with_nte)
|
556 |
+
mask = mask>=1
|
557 |
+
masks.append(mask)
|
558 |
+
# for greater than 10 # >49
|
559 |
+
n_samples.append((np.array(masks).max(0)==0).sum())
|
560 |
+
mask_dict['>49'] = np.array(masks).max(0)==0
|
561 |
+
|
562 |
+
print(n_samples)
|
563 |
+
|
564 |
+
clf, hpn_config = set_up_model(args, template_list=template_list)
|
565 |
+
clf = set_up_template_encoder(args, clf, label_to_n_train_samples=label_to_n_train_samples, template_list=template_list)
|
566 |
+
|
567 |
+
if args.verbose:
|
568 |
+
print(clf.config.__dict__)
|
569 |
+
print(clf)
|
570 |
+
|
571 |
+
wda = torch.optim.AdamW(clf.parameters(), lr=args.lr, weight_decay=1e-2)
|
572 |
+
|
573 |
+
if args.wandb:
|
574 |
+
wandb.watch(clf)
|
575 |
+
|
576 |
+
|
577 |
+
# pretraining with applicablity matrix, if applicable
|
578 |
+
if args.model_type == 'fortunato' or args.pretrain_epochs>1:
|
579 |
+
print('pretraining on applicability-matrix -- loading the matrix')
|
580 |
+
_, y_appl = load_USPTO(args.dataset_type, is_appl_matrix=True)
|
581 |
+
if args.splitting_scheme == 'remove_once_in_train_and_not_in_test':
|
582 |
+
y_appl['train'] = y_appl['train'][~remove_those_mask]
|
583 |
+
|
584 |
+
# check random if the applicability is true for y
|
585 |
+
splt = 'train'
|
586 |
+
for i in range(500):
|
587 |
+
i = np.random.randint(len(y[splt]))
|
588 |
+
#assert ( y_appl[splt][i].indices == y[splt][i] ).sum()==1
|
589 |
+
|
590 |
+
print('pre-training (BCE-loss)')
|
591 |
+
for epoch in range(args.pretrain_epochs):
|
592 |
+
clf.train_from_np(X_fp['train'], X_fp['train'], y_appl['train'], use_dataloader=True, is_smiles=False,
|
593 |
+
epochs=1, wandb=wandb, verbose=args.verbose, bs=args.batch_size,
|
594 |
+
permute_batches=True, shuffle=True, optimizer=wda,
|
595 |
+
only_templates_in_batch=args.only_templates_in_batch)
|
596 |
+
y_pred = clf.evaluate(X_fp['valid'], X_fp['valid'], y_appl['valid'],
|
597 |
+
split='pretrain_valid', is_smiles=False, only_loss=True,
|
598 |
+
bs=args.batch_size,wandb=wandb)
|
599 |
+
appl_acc = ((y_appl['valid'].toarray()) == (y_pred>0.5)).mean()
|
600 |
+
print(f'{epoch:2.0f} -- train_loss: {clf.hist["loss"][-1]:1.3f}, loss_valid: {clf.hist["loss_pretrain_valid"][-1]:1.3f}, train_acc: {appl_acc:1.5f}')
|
601 |
+
|
602 |
+
fn_hist = None
|
603 |
+
y_preds = None
|
604 |
+
|
605 |
+
for epoch in range(round(args.epochs / args.eval_every_n_epochs)):
|
606 |
+
if not isinstance(clf, StaticQK):
|
607 |
+
now = time()
|
608 |
+
clf.train_from_np(X_fp['train'], X_fp['train'], y['train'], use_dataloader=True, is_smiles=False,
|
609 |
+
epochs=args.eval_every_n_epochs, wandb=wandb, verbose=args.verbose, bs=args.batch_size,
|
610 |
+
permute_batches=True, shuffle=True, optimizer=wda, only_templates_in_batch=args.only_templates_in_batch)
|
611 |
+
if args.verbose: print(f'training took {(time()-now)/60:3.1f} min for {args.eval_every_n_epochs} epochs')
|
612 |
+
for split in ['valid', 'test']:
|
613 |
+
print(split, 'evaluating', end='\r')
|
614 |
+
now = time()
|
615 |
+
#only_loss = ((epoch%5)==4) if args.dataset_type=='lg' else True
|
616 |
+
y_preds = clf.evaluate(X_fp[split], X_fp[split], y[split], is_smiles=False, split=split, bs=args.batch_size, only_loss=args.eval_only_loss, wandb=wandb);
|
617 |
+
|
618 |
+
if args.verbose: print(f'eval {split} took',(time()-now)/60,'min')
|
619 |
+
if not isinstance(clf, StaticQK):
|
620 |
+
try:
|
621 |
+
print(f'{epoch:2.0f} -- train_loss: {clf.hist["loss"][-1]:1.3f}, loss_valid: {clf.hist["loss_valid"][-1]:1.3f}, val_t1acc: {clf.hist["t1_acc_valid"][-1]:1.3f}, val_t100acc: {clf.hist["t100_acc_valid"][-1]:1.3f}')
|
622 |
+
except:
|
623 |
+
pass
|
624 |
+
|
625 |
+
now = time()
|
626 |
+
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
|
627 |
+
for nte in mask_dict: # Number of training examples
|
628 |
+
split = f'nte_{nte}'
|
629 |
+
#print(split)
|
630 |
+
mask = mask_dict[nte]
|
631 |
+
|
632 |
+
topkacc = top_k_accuracy(np.array(y['test'])[mask], y_preds[mask, :], k=ks, ret_arocc=False)
|
633 |
+
|
634 |
+
new_hist = {}
|
635 |
+
for k, tkacc in zip(ks, topkacc):
|
636 |
+
new_hist[f't{k}_acc_{split}'] = tkacc
|
637 |
+
#new_hist[(f'arocc_{split}')] = (arocc)
|
638 |
+
new_hist[f'steps_{split}'] = (clf.steps)
|
639 |
+
|
640 |
+
for k in new_hist:
|
641 |
+
clf.hist[k].append(new_hist[k])
|
642 |
+
|
643 |
+
if args.verbose: print(f'eval nte-test took',(time()-now)/60,'min')
|
644 |
+
|
645 |
+
fn_hist = clf.save_hist(prefix=f'USTPO_{args.dataset_type}_{args.model_type}_', postfix=fn_postfix)
|
646 |
+
|
647 |
+
if args.save_preds:
|
648 |
+
PATH = './data/preds/'
|
649 |
+
if not os.path.exists(PATH):
|
650 |
+
os.mkdir(PATH)
|
651 |
+
pred_fn = f'{PATH}USPTO_{args.dataset_type}_test_{args.model_type}_{fn_postfix}.npy'
|
652 |
+
print('saving predictions to',pred_fn)
|
653 |
+
np.save(pred_fn,y_preds)
|
654 |
+
args.save_preds = pred_fn
|
655 |
+
|
656 |
+
|
657 |
+
if args.save_model:
|
658 |
+
model_save_path = clf.save_model(prefix=f'USPTO_{args.dataset_type}_{args.model_type}_valloss{clf.hist.get("loss_valid",[-1])[-1]:1.3f}_',name_as_conf=False, postfix=fn_postfix)
|
659 |
+
|
660 |
+
# Serialize data into file:
|
661 |
+
import json
|
662 |
+
json.dump( args.__dict__, open( f"data/model/{fn_postfix}_args.json", 'w' ) )
|
663 |
+
json.dump( hpn_config.__dict__,
|
664 |
+
open( f"data/model/{fn_postfix}_config.json", 'w' ) )
|
665 |
+
|
666 |
+
print('model saved to', model_save_path)
|
667 |
+
|
668 |
+
print(min(clf.hist.get('loss_valid',[-1])))
|
669 |
+
|
670 |
+
if args.plot_res:
|
671 |
+
from plotutils import plot_topk, plot_nte
|
672 |
+
|
673 |
+
plt.figure()
|
674 |
+
clf.plot_loss()
|
675 |
+
plt.draw()
|
676 |
+
|
677 |
+
plt.figure()
|
678 |
+
plot_topk(clf.hist, sets=['valid'])
|
679 |
+
if args.dataset_type=='sm':
|
680 |
+
baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400}
|
681 |
+
plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--')
|
682 |
+
plt.draw()
|
683 |
+
plt.figure()
|
684 |
+
|
685 |
+
best_cpt = np.array(clf.hist['loss_valid'])[::-1].argmin()+1
|
686 |
+
print(best_cpt)
|
687 |
+
try:
|
688 |
+
best_cpt = np.array(clf.hist['t10_acc_valid'])[::-1].argmax()+1
|
689 |
+
print(best_cpt)
|
690 |
+
except:
|
691 |
+
print('err with t10_acc_valid')
|
692 |
+
plot_nte(clf.hist, dataset=args.dataset_type.capitalize(), last_cpt=best_cpt, include_bar=True, model_legend=args.exp_name,
|
693 |
+
n_samples=n_samples, z=1.96)
|
694 |
+
if os.path.exists('data/figs/'):
|
695 |
+
try:
|
696 |
+
os.mkdir(f'data/figs/{args.exp_name}/')
|
697 |
+
except:
|
698 |
+
pass
|
699 |
+
plt.savefig(f'data/figs/{args.exp_name}/training_examples_vs_top100_acc_{args.dataset_type}_{hash(str(args))}.pdf')
|
700 |
+
plt.draw()
|
701 |
+
fn_hist = clf.save_hist(prefix=f'USTPO_{args.dataset_type}_{args.model_type}_', postfix=fn_postfix)
|
702 |
+
|
703 |
+
|
704 |
+
if args.ssretroeval:
|
705 |
+
print('testing on the real test set ;)')
|
706 |
+
from .data import load_templates
|
707 |
+
from .retroeval import run_templates, topkaccuracy
|
708 |
+
from .utils import sort_by_template_and_flatten
|
709 |
+
|
710 |
+
|
711 |
+
a = list(template_list.keys())
|
712 |
+
#assert list(range(len(a))) == a
|
713 |
+
templates = list(template_list.values())
|
714 |
+
#templates = [*templates, *expert_templates]
|
715 |
+
template_product_smarts = [str(s).split('>')[0] for s in templates]
|
716 |
+
|
717 |
+
#execute all template
|
718 |
+
print('execute all templates')
|
719 |
+
test_product_smarts = [xi[0] for xi in X['test']] #added later
|
720 |
+
smarts2appl = memory.cache(smarts2appl, ignore=['njobs','nsplits', 'use_tqdm'])
|
721 |
+
appl = smarts2appl(test_product_smarts, template_product_smarts, njobs=args.njobs)
|
722 |
+
n_pairs = len(test_product_smarts) * len(template_product_smarts)
|
723 |
+
n_appl = len(appl[0])
|
724 |
+
print(n_pairs, n_appl, n_appl/n_pairs)
|
725 |
+
|
726 |
+
#forward
|
727 |
+
split = 'test'
|
728 |
+
print('len(X_fp[test]):',len(X_fp[split]))
|
729 |
+
y[split] = np.zeros(len(X[split])).astype(np.int)
|
730 |
+
clf.eval()
|
731 |
+
if y_preds is None:
|
732 |
+
y_preds = clf.evaluate(X_fp[split], X_fp[split], y[split], is_smiles=False,
|
733 |
+
split='ttest', bs=args.batch_size, only_loss=True, wandb=None);
|
734 |
+
|
735 |
+
template_scores = y_preds #this should allready be test
|
736 |
+
|
737 |
+
####
|
738 |
+
if y_preds.shape[1]>100000:
|
739 |
+
kth = 200
|
740 |
+
print(f'only evaluating top {kth} applicable predicted templates')
|
741 |
+
# only take top kth and multiply by applicability matrix
|
742 |
+
appl_mtrx = np.zeros_like(y_preds, dtype=bool)
|
743 |
+
appl_mtrx[appl[0], appl[1]] = 1
|
744 |
+
|
745 |
+
appl_and_topkth = ([], [])
|
746 |
+
for row in range(len(y_preds)):
|
747 |
+
argpreds = (np.argpartition(-(y_preds[row]*appl_mtrx[row]), kth, axis=0)[:kth])
|
748 |
+
# if there are less than kth applicable
|
749 |
+
mask = appl_mtrx[row][argpreds]
|
750 |
+
argpreds = argpreds[mask]
|
751 |
+
#if len(argpreds)!=kth:
|
752 |
+
# print('changed to ', len(argpreds))
|
753 |
+
|
754 |
+
appl_and_topkth[0].extend([row for _ in range(len(argpreds))])
|
755 |
+
appl_and_topkth[1].extend(list(argpreds))
|
756 |
+
|
757 |
+
appl = appl_and_topkth
|
758 |
+
####
|
759 |
+
|
760 |
+
print('running the templates')
|
761 |
+
run_templates = run_templates #memory.cache( ) ... allready cached to tmp
|
762 |
+
prod_idx_reactants, prod_temp_reactants = run_templates(test_product_smarts, templates, appl, njobs=args.njobs)
|
763 |
+
#sorted_results = sort_by_template(template_scores, prod_idx_reactants)
|
764 |
+
#flat_results = flatten_per_product(sorted_results, remove_duplicates=True)
|
765 |
+
#now aglomerates over same outcome
|
766 |
+
flat_results = sort_by_template_and_flatten(y_preds, prod_idx_reactants, agglo_fun=sum)
|
767 |
+
accs = topkaccuracy(test_reactants_can, flat_results, [*list(range(1,101)), 100000])
|
768 |
+
|
769 |
+
mtrcs2 = {f't{k}acc_ttest':accs[k-1] for k in [1,2,3,5,10,20,50,100,101]}
|
770 |
+
if wandb:
|
771 |
+
wandb.log(mtrcs2)
|
772 |
+
print('Single-step retrosynthesis-evaluation, results on ttest:')
|
773 |
+
#print([k[:-6]+'|' for k in mtrcs2.keys()])
|
774 |
+
[print(k[:-6],end='\t') for k in mtrcs2.keys()]
|
775 |
+
print()
|
776 |
+
for k,v in mtrcs2.items():
|
777 |
+
print(f'{v*100:2.2f}',end='\t')
|
778 |
+
|
779 |
+
|
780 |
+
# save the history of this experiment
|
781 |
+
EXP_DIR = 'data/experiments/'
|
782 |
+
|
783 |
+
df = pd.DataFrame([args.__dict__])
|
784 |
+
df['min_loss_valid'] = min(clf.hist.get('loss_valid', [-1]))
|
785 |
+
df['min_loss_train'] = 0 if ((args.model_type=='staticQK') or (args.model_type=='retrosim')) else min(clf.hist.get('loss',[-1]))
|
786 |
+
try:
|
787 |
+
df['max_t1_acc_valid'] = max(clf.hist.get('t1_acc_valid', [0]))
|
788 |
+
df['max_t100_acc_valid'] = max(clf.hist.get('t100_acc_valid', [0]))
|
789 |
+
except:
|
790 |
+
pass
|
791 |
+
df['hist'] = [clf.hist]
|
792 |
+
df['n_samples'] = [n_samples]
|
793 |
+
|
794 |
+
df['fn_hist'] = fn_hist if fn_hist else None
|
795 |
+
df['fn_model'] = '' if not args.save_model else model_save_path
|
796 |
+
df['date'] = str(datetime.datetime.fromtimestamp(time()))
|
797 |
+
df['cmd'] = ' '.join(sys.argv[:])
|
798 |
+
|
799 |
+
|
800 |
+
if not os.path.exists(EXP_DIR):
|
801 |
+
os.mkdir(EXP_DIR)
|
802 |
+
|
803 |
+
df.to_csv(f'{EXP_DIR}{run_id}.tsv', sep='\t')
|
804 |
+
df
|
mhnreact/utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
General utility functions
|
9 |
+
"""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
from collections import defaultdict
|
13 |
+
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
import math
|
16 |
+
import torch
|
17 |
+
|
18 |
+
# used and fastest version
|
19 |
+
def top_k_accuracy(y_true, y_pred, k=5, ret_arocc=False, ret_mrocc=False, verbose=False, count_equal_as_correct=False, eps_noise=0):
|
20 |
+
""" partly from http://stephantul.github.io/python/pytorch/2020/09/18/fast_topk/
|
21 |
+
count_equal counts equal values as beein a correct choice e.g. all preds = 0 --> T1acc = 1
|
22 |
+
ret_mrocc ... also return median rank of correct choice
|
23 |
+
eps_noise ... if >0 ads noise*eps to y_pred .. recommended e.g. 1e-10
|
24 |
+
"""
|
25 |
+
if eps_noise>0:
|
26 |
+
if torch.is_tensor(y_pred):
|
27 |
+
y_pred = y_pred + torch.rand(y_pred.shape)*eps_noise
|
28 |
+
else:
|
29 |
+
y_pred = y_pred + np.random.rand(*y_pred.shape)*eps_noise
|
30 |
+
|
31 |
+
if count_equal_as_correct:
|
32 |
+
greater = (y_pred > y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger
|
33 |
+
else:
|
34 |
+
greater = (y_pred >= y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger or equal
|
35 |
+
if torch.is_tensor(y_pred):
|
36 |
+
greater = greater.long()
|
37 |
+
if isinstance(k, int): k = [k] # pack it into a list
|
38 |
+
tkaccs = []
|
39 |
+
for ki in k:
|
40 |
+
if count_equal_as_correct:
|
41 |
+
tkacc = (greater<=(ki-1))
|
42 |
+
else:
|
43 |
+
tkacc = (greater<=(ki))
|
44 |
+
if torch.is_tensor(y_pred):
|
45 |
+
tkacc = tkacc.float().mean().detach().cpu().numpy()
|
46 |
+
else:
|
47 |
+
tkacc = tkacc.mean()
|
48 |
+
tkaccs.append(tkacc)
|
49 |
+
if verbose: print('Top', ki, 'acc:\t', str(tkacc)[:6])
|
50 |
+
|
51 |
+
if ret_arocc:
|
52 |
+
arocc = greater.float().mean()+1
|
53 |
+
if torch.is_tensor(arocc):
|
54 |
+
arocc = arocc.detach().cpu().numpy()
|
55 |
+
return (tkaccs[0], arocc) if len(tkaccs) == 1 else (tkaccs, arocc)
|
56 |
+
if ret_mrocc:
|
57 |
+
mrocc = greater.median()+1
|
58 |
+
if torch.is_tensor(mrocc):
|
59 |
+
mrocc = mrocc.float().detach().cpu().numpy()
|
60 |
+
return (tkaccs[0], mrocc) if len(tkaccs) == 1 else (tkaccs, mrocc)
|
61 |
+
|
62 |
+
|
63 |
+
return tkaccs[0] if len(tkaccs) == 1 else tkaccs
|
64 |
+
|
65 |
+
|
66 |
+
def seed_everything(seed=70135):
|
67 |
+
""" does what it says ;) - from https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335"""
|
68 |
+
import numpy as np
|
69 |
+
import random
|
70 |
+
import os
|
71 |
+
import torch
|
72 |
+
|
73 |
+
random.seed(seed)
|
74 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
75 |
+
np.random.seed(seed)
|
76 |
+
torch.manual_seed(seed)
|
77 |
+
torch.cuda.manual_seed(seed)
|
78 |
+
torch.backends.cudnn.deterministic = True
|
79 |
+
|
80 |
+
def get_best_gpu():
|
81 |
+
'''Get the gpu with most RAM on the machine. From P. Neves'''
|
82 |
+
import torch
|
83 |
+
if torch.cuda.is_available():
|
84 |
+
gpus_ram = []
|
85 |
+
for ind in range(torch.cuda.device_count()):
|
86 |
+
gpus_ram.append(torch.cuda.get_device_properties(ind).total_memory/1e9)
|
87 |
+
return f"cuda:{gpus_ram.index(max(gpus_ram))}"
|
88 |
+
else:
|
89 |
+
raise ValueError("No gpus were detected in this machine.")
|
90 |
+
|
91 |
+
|
92 |
+
def sort_by_template_and_flatten(template_scores, prod_idx_reactants, agglo_fun=sum):
|
93 |
+
flat_results = []
|
94 |
+
for ii in range(len(template_scores)):
|
95 |
+
idx_prod_reactants = defaultdict(list)
|
96 |
+
for k,v in prod_idx_reactants[ii].items():
|
97 |
+
for iv in v:
|
98 |
+
idx_prod_reactants[iv].append(template_scores[ii,k])
|
99 |
+
d2 = {k: agglo_fun(v) for k, v in idx_prod_reactants.items()}
|
100 |
+
if len(d2)==0:
|
101 |
+
flat_results.append([])
|
102 |
+
else:
|
103 |
+
flat_results.append(pd.DataFrame.from_dict(d2, orient='index').sort_values(0, ascending=False).index.values.tolist())
|
104 |
+
return flat_results
|
105 |
+
|
106 |
+
|
107 |
+
def str2bool(v):
|
108 |
+
"""adapted from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse"""
|
109 |
+
if isinstance(v, bool):
|
110 |
+
return v
|
111 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1', '',' '):
|
112 |
+
return True
|
113 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
114 |
+
return False
|
115 |
+
else:
|
116 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
117 |
+
|
118 |
+
|
119 |
+
@np.vectorize
|
120 |
+
def lgamma(x):
|
121 |
+
return math.lgamma(x)
|
122 |
+
|
123 |
+
def multinom_gk(array, axis=0):
|
124 |
+
"""Multinomial lgamma pooling over a given axis"""
|
125 |
+
res = lgamma(np.sum(array,axis=axis)+2) - np.sum(lgamma(array+1),axis=axis)
|
126 |
+
return res
|
mhnreact/view.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Author: Philipp Seidl
|
4 |
+
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
|
5 |
+
Johannes Kepler University Linz
|
6 |
+
Contact: [email protected]
|
7 |
+
|
8 |
+
Loading log-files from training
|
9 |
+
"""
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
import os
|
13 |
+
import datetime
|
14 |
+
import pandas as pd
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
def load_experiments(EXP_DIR = Path('data/experiments/')):
|
20 |
+
dfs = []
|
21 |
+
for fn in os.listdir(EXP_DIR):
|
22 |
+
print(fn, end='\r')
|
23 |
+
if fn.split('.')[-1]=='tsv':
|
24 |
+
df = pd.read_csv(EXP_DIR/fn, sep='\t', index_col=0)
|
25 |
+
try:
|
26 |
+
with open(df['fn_hist'][0]) as f:
|
27 |
+
hist = eval(f.readlines()[0] )
|
28 |
+
df['hist'] = [hist]
|
29 |
+
df['fn'] = fn
|
30 |
+
except:
|
31 |
+
print('err')
|
32 |
+
#print(df['fn_hist'])
|
33 |
+
dfs.append( df )
|
34 |
+
df = pd.concat(dfs,ignore_index=True)
|
35 |
+
return df
|
36 |
+
|
37 |
+
def get_x(k, kw, operation='max', index=None):
|
38 |
+
operation = getattr(np,operation)
|
39 |
+
try:
|
40 |
+
if index is not None:
|
41 |
+
return k[kw][index]
|
42 |
+
|
43 |
+
return operation(k[kw])
|
44 |
+
except:
|
45 |
+
return 0
|
46 |
+
|
47 |
+
def get_min_val_loss_idx(k):
|
48 |
+
return get_x(k, 'loss_valid', 'argmin') #changed from argmax to argmin!!
|
49 |
+
|
50 |
+
def get_tauc(hist):
|
51 |
+
idx = get_min_val_loss_idx(hist)
|
52 |
+
# takes max TODO take idx
|
53 |
+
return np.mean([get_x(hist, f't100_acc_nte_{nt}') for nt in [*range(11),'>10']])
|
54 |
+
|
55 |
+
def get_stats_from_hist(df):
|
56 |
+
df['0shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_0'))
|
57 |
+
df['1shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_1'))
|
58 |
+
df['>49shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_>49'))
|
59 |
+
df['min_loss_valid'] = df['hist'].apply(lambda k: get_x(k, 'loss_valid', 'min'))
|
60 |
+
return df
|