Spaces:
Running
Running
Update inference.py
Browse files- inference.py +12 -12
inference.py
CHANGED
@@ -114,7 +114,7 @@ class Inference(object):
|
|
114 |
|
115 |
def decoder_load(self, dictionary_name):
|
116 |
''' Loading the atom and bond decoders'''
|
117 |
-
with open("
|
118 |
return pickle.load(f)
|
119 |
|
120 |
|
@@ -140,16 +140,16 @@ class Inference(object):
|
|
140 |
self.restore_model(self.submodel, self.inference_model)
|
141 |
|
142 |
# smiles data for metrics calculation.
|
143 |
-
chembl_smiles = [line for line in open("
|
144 |
-
chembl_test = [line for line in open("
|
145 |
-
drug_smiles = [line for line in open("
|
146 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
147 |
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
148 |
|
149 |
|
150 |
# Make directories if not exist.
|
151 |
-
if not os.path.exists("
|
152 |
-
os.makedirs("
|
153 |
if self.correct:
|
154 |
correct = smi_correct(self.submodel, "DrugGEN_/experiments/inference/{}".format(self.submodel))
|
155 |
search_res = pd.DataFrame(columns=["submodel", "validity",
|
@@ -166,7 +166,7 @@ class Inference(object):
|
|
166 |
uniqueness_calc = []
|
167 |
real_smiles_snn = []
|
168 |
nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
|
169 |
-
f = open("
|
170 |
f.write("SMILES")
|
171 |
f.write("\n")
|
172 |
val_counter = 0
|
@@ -226,16 +226,16 @@ class Inference(object):
|
|
226 |
f.close()
|
227 |
print("Inference completed, starting metrics calculation.")
|
228 |
if self.correct:
|
229 |
-
corrected = correct.correct("
|
230 |
gen_smi = corrected["SMILES"].tolist()
|
231 |
|
232 |
else:
|
233 |
-
gen_smi = pd.read_csv("
|
234 |
|
235 |
|
236 |
et = time.time() - start_time
|
237 |
|
238 |
-
with open("
|
239 |
for i in gen_smi:
|
240 |
f.write(i)
|
241 |
f.write("\n")
|
@@ -265,9 +265,9 @@ if __name__=="__main__":
|
|
265 |
|
266 |
# Data configuration.
|
267 |
parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
|
268 |
-
parser.add_argument('--inf_raw_file', type=str, default='
|
269 |
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
|
270 |
-
parser.add_argument('--mol_data_dir', type=str, default='
|
271 |
parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
|
272 |
|
273 |
# Model configuration.
|
|
|
114 |
|
115 |
def decoder_load(self, dictionary_name):
|
116 |
''' Loading the atom and bond decoders'''
|
117 |
+
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
118 |
return pickle.load(f)
|
119 |
|
120 |
|
|
|
140 |
self.restore_model(self.submodel, self.inference_model)
|
141 |
|
142 |
# smiles data for metrics calculation.
|
143 |
+
chembl_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
144 |
+
chembl_test = [line for line in open("data/chembl_test.smi", 'r').read().splitlines()]
|
145 |
+
drug_smiles = [line for line in open("data/akt_inhibitors.smi", 'r').read().splitlines()]
|
146 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
147 |
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
148 |
|
149 |
|
150 |
# Make directories if not exist.
|
151 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
152 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
153 |
if self.correct:
|
154 |
correct = smi_correct(self.submodel, "DrugGEN_/experiments/inference/{}".format(self.submodel))
|
155 |
search_res = pd.DataFrame(columns=["submodel", "validity",
|
|
|
166 |
uniqueness_calc = []
|
167 |
real_smiles_snn = []
|
168 |
nodes_sample = torch.Tensor(size=[1,45,1]).to(self.device)
|
169 |
+
f = open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w")
|
170 |
f.write("SMILES")
|
171 |
f.write("\n")
|
172 |
val_counter = 0
|
|
|
226 |
f.close()
|
227 |
print("Inference completed, starting metrics calculation.")
|
228 |
if self.correct:
|
229 |
+
corrected = correct.correct("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
|
230 |
gen_smi = corrected["SMILES"].tolist()
|
231 |
|
232 |
else:
|
233 |
+
gen_smi = pd.read_csv("experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist()
|
234 |
|
235 |
|
236 |
et = time.time() - start_time
|
237 |
|
238 |
+
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w") as f:
|
239 |
for i in gen_smi:
|
240 |
f.write(i)
|
241 |
f.write("\n")
|
|
|
265 |
|
266 |
# Data configuration.
|
267 |
parser.add_argument('--inf_dataset_file', type=str, default='chembl45_test.pt')
|
268 |
+
parser.add_argument('--inf_raw_file', type=str, default='data/chembl_test.smi')
|
269 |
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
|
270 |
+
parser.add_argument('--mol_data_dir', type=str, default='data')
|
271 |
parser.add_argument('--features', type=str2bool, default=False, help='features dimension for nodes')
|
272 |
|
273 |
# Model configuration.
|