azminetoushikwasi
commited on
Commit
•
4c3c1d1
1
Parent(s):
aaaa78b
Upload 10 files
Browse files- .gitattributes +2 -35
- Fig/model.png +0 -0
- LICENSE +21 -0
- Model/_DATA/all_chem_df.csv +0 -0
- Model/data/__init__.py +1 -0
- Model/data/dataset.py +277 -0
- Model/methods/MLP.py +31 -0
- Model/methods/__init__.py +1 -0
- Model/train-ngram.py +186 -0
- README.md +48 -3
.gitattributes
CHANGED
@@ -1,35 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Fig/model.png
ADDED
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Azmine Toushik Wasi
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Model/_DATA/all_chem_df.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Model/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .dataset import get_loaders_sequence, get_loaders_n_gram
|
Model/data/dataset.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from collections import Counter
|
5 |
+
from itertools import product
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
from torch.nn.utils.rnn import pad_sequence
|
11 |
+
from torch.utils.data import DataLoader, Dataset
|
12 |
+
|
13 |
+
|
14 |
+
def read_csv(
|
15 |
+
csv_file,
|
16 |
+
x_col="smiles",
|
17 |
+
y_col="tags",
|
18 |
+
):
|
19 |
+
df = pd.read_csv(csv_file)
|
20 |
+
|
21 |
+
all_y = set()
|
22 |
+
all_x = set()
|
23 |
+
|
24 |
+
# drop multi columns
|
25 |
+
df = df[~df[y_col].str.contains(" ")]
|
26 |
+
|
27 |
+
x = df[x_col]
|
28 |
+
y = df[y_col]
|
29 |
+
|
30 |
+
# find all y
|
31 |
+
for item_y in y:
|
32 |
+
all_y.update(item_y.split(" "))
|
33 |
+
|
34 |
+
# make y mapping
|
35 |
+
mapping_y = {val: index for index, val in enumerate(sorted(list(all_y)))}
|
36 |
+
|
37 |
+
# find all x
|
38 |
+
for item_x in x:
|
39 |
+
all_x.update(set(item_x))
|
40 |
+
|
41 |
+
# make x mapping
|
42 |
+
mapping_x = {val: index + 1 for index, val in enumerate(sorted(list(all_x)))}
|
43 |
+
mapping_x["<pad>"] = 0
|
44 |
+
|
45 |
+
# encode y
|
46 |
+
ys = [mapping_y[i] for i in y]
|
47 |
+
ys = np.array(ys)
|
48 |
+
|
49 |
+
# encode x
|
50 |
+
xs = []
|
51 |
+
for item_x in x:
|
52 |
+
encoded_item = [mapping_x[c] for c in item_x]
|
53 |
+
xs.append(encoded_item)
|
54 |
+
xs = [np.array(item) for item in xs]
|
55 |
+
|
56 |
+
to_return = {
|
57 |
+
"x": {"raw": x.values, "data": xs},
|
58 |
+
"y": {"data": ys},
|
59 |
+
"mapping": {"x": mapping_x, "y": mapping_y},
|
60 |
+
}
|
61 |
+
return to_return
|
62 |
+
|
63 |
+
|
64 |
+
def split_data(data, ratio_dev=0.1, ratio_test=0.1, seed=None):
|
65 |
+
# random number generator
|
66 |
+
rng = np.random.default_rng(seed=seed)
|
67 |
+
|
68 |
+
# dataset sizes
|
69 |
+
size_total = len(data["y"]["data"])
|
70 |
+
ratios = {"dev": ratio_dev, "test": ratio_test}
|
71 |
+
sizes = {}
|
72 |
+
for split, ratio in ratios.items():
|
73 |
+
sizes[split] = int(ratio * size_total)
|
74 |
+
sizes["train"] = size_total - sum(sizes.values())
|
75 |
+
|
76 |
+
# split
|
77 |
+
index = np.arange(size_total)
|
78 |
+
rng.shuffle(index)
|
79 |
+
|
80 |
+
indices = {}
|
81 |
+
start = 0
|
82 |
+
for split, size in sizes.items():
|
83 |
+
indices[split] = index[start : start + size]
|
84 |
+
start += size
|
85 |
+
|
86 |
+
splits = {}
|
87 |
+
for split, index in indices.items():
|
88 |
+
x_data = data["x"]
|
89 |
+
x_data = {k: [v[i] for i in index] for k, v in x_data.items()}
|
90 |
+
|
91 |
+
y_data = data["y"]
|
92 |
+
y_data = {k: v[index] for k, v in y_data.items()}
|
93 |
+
|
94 |
+
splits[split] = {"x": x_data, "y": y_data}
|
95 |
+
|
96 |
+
return splits
|
97 |
+
|
98 |
+
|
99 |
+
def make_n_gram_mapping(mapping, n):
|
100 |
+
values = mapping.keys()
|
101 |
+
combos = product(values, repeat=n)
|
102 |
+
mapping = {"".join(v): i for i, v in enumerate(sorted(combos))}
|
103 |
+
return mapping
|
104 |
+
|
105 |
+
|
106 |
+
def count_n_grams(text, n):
|
107 |
+
len_gram = len(text) + 1 - n
|
108 |
+
n_grams = [text[i : i + n] for i in range(len_gram)]
|
109 |
+
return Counter(n_grams)
|
110 |
+
|
111 |
+
|
112 |
+
def get_topk_n_grams(data, n, topk=1000):
|
113 |
+
counters = [count_n_grams(text, n) for text in data]
|
114 |
+
counter = Counter()
|
115 |
+
for c in counters:
|
116 |
+
counter += c
|
117 |
+
results = [w for w, _ in counter.most_common(topk)]
|
118 |
+
return results
|
119 |
+
|
120 |
+
|
121 |
+
def sequence_collate(batch):
|
122 |
+
x, y = zip(*batch)
|
123 |
+
x = [torch.LongTensor(item) for item in x]
|
124 |
+
lens = torch.LongTensor([len(i) for i in x])
|
125 |
+
x_padded = pad_sequence(x, batch_first=True, padding_value=0)
|
126 |
+
y = torch.LongTensor(np.array(y))
|
127 |
+
_, perm_idx = lens.sort(0, descending=True)
|
128 |
+
return x_padded[perm_idx], y[perm_idx], lens[perm_idx]
|
129 |
+
|
130 |
+
|
131 |
+
class NgramDataset(Dataset):
|
132 |
+
"""
|
133 |
+
Encoder based on n grams
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, x, y, top_grams=None, n=1, topk=1000):
|
137 |
+
data_x = x["raw"]
|
138 |
+
data_y = y["data"]
|
139 |
+
if top_grams is None:
|
140 |
+
top_grams = get_topk_n_grams(data_x, n, topk=topk)
|
141 |
+
|
142 |
+
all_grams = []
|
143 |
+
for item_x in data_x:
|
144 |
+
unk = 0 # other tokens
|
145 |
+
grams = count_n_grams(item_x, n)
|
146 |
+
item = [grams[g] for g in top_grams]
|
147 |
+
unk = [v for k, v in grams.items() if k not in top_grams] # unk
|
148 |
+
unk = sum(unk)
|
149 |
+
item.append(unk)
|
150 |
+
all_grams.append(item)
|
151 |
+
|
152 |
+
self.top_grams = top_grams
|
153 |
+
self.x = np.array(all_grams, dtype="float32")
|
154 |
+
self.x_raw = data_x
|
155 |
+
self.y = np.array(data_y, dtype="long")
|
156 |
+
|
157 |
+
def __getitem__(self, index):
|
158 |
+
item_x = self.x[index]
|
159 |
+
item_y = self.y[index]
|
160 |
+
|
161 |
+
return item_x, item_y
|
162 |
+
|
163 |
+
def __len__(self):
|
164 |
+
return len(self.x)
|
165 |
+
|
166 |
+
|
167 |
+
class SequenceDataset(Dataset):
|
168 |
+
"""
|
169 |
+
Encode each character in sequence.
|
170 |
+
0: padding
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, x, y, mapping_x, mapping_y, n=1):
|
174 |
+
data_x = x["data"]
|
175 |
+
data_y = y["data"]
|
176 |
+
|
177 |
+
self.x = data_x
|
178 |
+
|
179 |
+
self.x_raw = x["raw"]
|
180 |
+
self.y = np.array(data_y, dtype="int64")
|
181 |
+
|
182 |
+
self.mapping_x = mapping_x
|
183 |
+
self.mapping_x_inverse = {v: k for k, v in self.mapping_x.items()}
|
184 |
+
self.mapping_y = mapping_y
|
185 |
+
self.mapping_y_inverse = {v: k for k, v in self.mapping_y.items()}
|
186 |
+
|
187 |
+
def __getitem__(self, index):
|
188 |
+
item_x = np.array(self.x[index], dtype="int64")
|
189 |
+
item_y = self.y[index]
|
190 |
+
|
191 |
+
return item_x, item_y
|
192 |
+
|
193 |
+
def __len__(self):
|
194 |
+
return len(self.x)
|
195 |
+
|
196 |
+
|
197 |
+
def get_loaders_n_gram(
|
198 |
+
csv_file, n=1, topk=20, ratio_dev=0.1, ratio_test=0.1, batch_size=32, seed=None
|
199 |
+
):
|
200 |
+
data = read_csv(csv_file)
|
201 |
+
mapping_x = data["mapping"]["x"]
|
202 |
+
mapping_y = data["mapping"]["y"]
|
203 |
+
splits = split_data(
|
204 |
+
data,
|
205 |
+
ratio_dev=ratio_dev,
|
206 |
+
ratio_test=ratio_test,
|
207 |
+
seed=seed,
|
208 |
+
)
|
209 |
+
|
210 |
+
# make train sets
|
211 |
+
split_train = splits.pop("train")
|
212 |
+
dataset_train = NgramDataset(split_train["x"], split_train["y"], n=n, topk=topk)
|
213 |
+
top_grams = dataset_train.top_grams
|
214 |
+
|
215 |
+
datasets = {
|
216 |
+
k: NgramDataset(v["x"], v["y"], n=n, top_grams=top_grams)
|
217 |
+
for k, v in splits.items()
|
218 |
+
}
|
219 |
+
datasets["train"] = dataset_train
|
220 |
+
# batch size * 2 for train
|
221 |
+
batch_sizes = {
|
222 |
+
k: batch_size if k == "train" else batch_size * 2 for k in datasets.keys()
|
223 |
+
}
|
224 |
+
# shuffle only the train set
|
225 |
+
shuffle = {k: True if k == "train" else False for k in datasets.keys()}
|
226 |
+
# make loaders
|
227 |
+
loaders = {
|
228 |
+
k: DataLoader(v, batch_size=batch_sizes[k], shuffle=shuffle[k])
|
229 |
+
for k, v in datasets.items()
|
230 |
+
}
|
231 |
+
# find sizes
|
232 |
+
size_x = len(top_grams) + 1
|
233 |
+
size_y = len(mapping_y)
|
234 |
+
return {"loaders": loaders, "sizes": {"x": size_x, "y": size_y}}
|
235 |
+
|
236 |
+
|
237 |
+
def get_loaders_sequence(
|
238 |
+
csv_file,
|
239 |
+
ratio_dev=0.1,
|
240 |
+
ratio_test=0.1,
|
241 |
+
batch_size=32,
|
242 |
+
seed=None,
|
243 |
+
):
|
244 |
+
data = read_csv(csv_file)
|
245 |
+
mapping_x = data["mapping"]["x"]
|
246 |
+
mapping_y = data["mapping"]["y"]
|
247 |
+
splits = split_data(
|
248 |
+
data,
|
249 |
+
ratio_dev=ratio_dev,
|
250 |
+
ratio_test=ratio_test,
|
251 |
+
seed=seed,
|
252 |
+
)
|
253 |
+
|
254 |
+
datasets = {
|
255 |
+
k: SequenceDataset(v["x"], v["y"], mapping_x, mapping_y)
|
256 |
+
for k, v in splits.items()
|
257 |
+
}
|
258 |
+
# batch size * 2 for train
|
259 |
+
batch_sizes = {
|
260 |
+
k: batch_size if k == "train" else batch_size * 2 for k in datasets.keys()
|
261 |
+
}
|
262 |
+
# shuffle only the train set
|
263 |
+
shuffle = {k: True if k == "train" else False for k in datasets.keys()}
|
264 |
+
# make loaders
|
265 |
+
loaders = {
|
266 |
+
k: DataLoader(
|
267 |
+
v,
|
268 |
+
batch_size=batch_sizes[k],
|
269 |
+
shuffle=shuffle[k],
|
270 |
+
collate_fn=sequence_collate,
|
271 |
+
)
|
272 |
+
for k, v in datasets.items()
|
273 |
+
}
|
274 |
+
# find sizes
|
275 |
+
size_x = len(mapping_x)
|
276 |
+
size_y = len(mapping_y)
|
277 |
+
return {"loaders": loaders, "sizes": {"x": size_x, "y": size_y}}
|
Model/methods/MLP.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
class MLP(nn.Module):
|
7 |
+
"""
|
8 |
+
Multi layer perceptron.
|
9 |
+
"""
|
10 |
+
def __init__(self, size_in, size_out, size_hidden=None, dropout=0.0):
|
11 |
+
super().__init__()
|
12 |
+
if size_hidden is None:
|
13 |
+
size_hidden = []
|
14 |
+
sizes = [size_in] + size_hidden + [size_out]
|
15 |
+
|
16 |
+
net = []
|
17 |
+
for i in range(len(sizes) - 2):
|
18 |
+
net.append(nn.Linear(sizes[i], sizes[i+1]))
|
19 |
+
net.append(nn.ReLU())
|
20 |
+
net.append(nn.Dropout(dropout))
|
21 |
+
|
22 |
+
net.append(nn.Linear(sizes[-2], sizes[-1]))
|
23 |
+
net = nn.Sequential(*net)
|
24 |
+
self.net = net
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
"""
|
28 |
+
Forward method.
|
29 |
+
"""
|
30 |
+
x = self.net(x)
|
31 |
+
return x
|
Model/methods/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .MLP import MLP
|
Model/train-ngram.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import copy
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.optim import Adam
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from data import get_loaders_n_gram
|
12 |
+
from methods import MLP
|
13 |
+
|
14 |
+
|
15 |
+
def train(loader_train, loader_dev, model, device, optimizer, n_epochs):
|
16 |
+
acc_best = 0
|
17 |
+
model_best = None
|
18 |
+
criterion = nn.CrossEntropyLoss()
|
19 |
+
|
20 |
+
bar_epochs = tqdm(range(n_epochs), leave=False)
|
21 |
+
for epoch in bar_epochs:
|
22 |
+
# train
|
23 |
+
bar_epoch = tqdm(loader_train, disable=True, leave=False)
|
24 |
+
model.train()
|
25 |
+
for x, y in bar_epoch:
|
26 |
+
x = x.to(device)
|
27 |
+
y = y.to(device)
|
28 |
+
y_out = model(x)
|
29 |
+
loss = criterion(y_out, y.type(torch.LongTensor))
|
30 |
+
loss.backward()
|
31 |
+
optimizer.step()
|
32 |
+
loss_iter = loss.item()
|
33 |
+
bar_epoch.set_postfix({"loss": loss_iter})
|
34 |
+
bar_epoch.close()
|
35 |
+
|
36 |
+
bar_dev = tqdm(loader_dev, disable=True, leave=False)
|
37 |
+
model.eval()
|
38 |
+
|
39 |
+
# val
|
40 |
+
ys_pred, ys_true = [], []
|
41 |
+
with torch.no_grad():
|
42 |
+
for x, y in bar_dev:
|
43 |
+
x = x.to(device)
|
44 |
+
y = y.to(device)
|
45 |
+
y_out = model(x)
|
46 |
+
y_pred = torch.argmax(y_out, axis=1)
|
47 |
+
ys_pred.append(y_pred.cpu())
|
48 |
+
ys_true.append(y.cpu())
|
49 |
+
bar_dev.close()
|
50 |
+
ys_pred = torch.cat(ys_pred)
|
51 |
+
ys_true = torch.cat(ys_true)
|
52 |
+
acc = (ys_pred == ys_true).float().mean()
|
53 |
+
acc = acc.item() * 100
|
54 |
+
if acc > acc_best:
|
55 |
+
acc_best = acc
|
56 |
+
model_best = copy.deepcopy(model)
|
57 |
+
bar_epochs.set_postfix({"acc_best": acc_best})
|
58 |
+
|
59 |
+
return model_best
|
60 |
+
|
61 |
+
|
62 |
+
def test(loader_test, model, device):
|
63 |
+
model.eval()
|
64 |
+
ys_pred, ys_true = [], []
|
65 |
+
bar_test = tqdm(loader_test, leave=False)
|
66 |
+
with torch.no_grad():
|
67 |
+
for x, y in bar_test:
|
68 |
+
x = x.to(device)
|
69 |
+
y = y.to(device)
|
70 |
+
y_pred = model(x)
|
71 |
+
y_pred = torch.argmax(y_pred, axis=1)
|
72 |
+
ys_pred.append(y_pred.cpu())
|
73 |
+
ys_true.append(y.cpu())
|
74 |
+
|
75 |
+
bar_test.close()
|
76 |
+
|
77 |
+
ys_pred = torch.cat(ys_pred)
|
78 |
+
ys_true = torch.cat(ys_true)
|
79 |
+
|
80 |
+
return ys_pred, ys_true
|
81 |
+
|
82 |
+
|
83 |
+
def run(
|
84 |
+
csv_file,
|
85 |
+
seed,
|
86 |
+
n=5,
|
87 |
+
topk=1000,
|
88 |
+
ratio_dev=0.1,
|
89 |
+
ratio_test=0.1,
|
90 |
+
batch_size=32,
|
91 |
+
size_hidden=None,
|
92 |
+
dropout=0.1,
|
93 |
+
n_epochs=50,
|
94 |
+
lr=3e-4,
|
95 |
+
weight_decay=0,
|
96 |
+
):
|
97 |
+
# data settings
|
98 |
+
ratio_dev = ratio_dev
|
99 |
+
ratio_test = ratio_test
|
100 |
+
batch_size = batch_size
|
101 |
+
n = n
|
102 |
+
data = get_loaders_n_gram(
|
103 |
+
csv_file,
|
104 |
+
n=n,
|
105 |
+
topk=topk,
|
106 |
+
ratio_dev=ratio_dev,
|
107 |
+
ratio_test=ratio_test,
|
108 |
+
seed=seed,
|
109 |
+
batch_size=batch_size,
|
110 |
+
)
|
111 |
+
size_x = data["sizes"]["x"]
|
112 |
+
size_y = data["sizes"]["y"]
|
113 |
+
loader_train = data["loaders"]["train"]
|
114 |
+
loader_dev = data["loaders"]["dev"]
|
115 |
+
loader_test = data["loaders"]["test"]
|
116 |
+
# device
|
117 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
118 |
+
# model settings
|
119 |
+
if size_hidden is None:
|
120 |
+
size_hidden = [size_x // 2, size_x // 4]
|
121 |
+
size_hidden = [size_x] + size_hidden
|
122 |
+
dropout = dropout
|
123 |
+
model = MLP(
|
124 |
+
size_in=size_x,
|
125 |
+
size_out=size_y,
|
126 |
+
size_hidden=size_hidden,
|
127 |
+
dropout=dropout,
|
128 |
+
)
|
129 |
+
model = model.to(device)
|
130 |
+
|
131 |
+
# training settings
|
132 |
+
n_epochs = n_epochs
|
133 |
+
lr = lr
|
134 |
+
weight_decay = weight_decay
|
135 |
+
optimizer = Adam(
|
136 |
+
model.parameters(),
|
137 |
+
lr=lr,
|
138 |
+
weight_decay=weight_decay,
|
139 |
+
)
|
140 |
+
|
141 |
+
# train
|
142 |
+
model_best = train(loader_train, loader_dev, model, device, optimizer, n_epochs)
|
143 |
+
return test(loader_test, model_best, device)
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
# data dir
|
148 |
+
csv_file = "./_DATA/all_chem_df.csv"
|
149 |
+
# number of trials
|
150 |
+
n_trials = 5
|
151 |
+
seeds = list(range(n_trials))
|
152 |
+
# data settings
|
153 |
+
topk = 1000
|
154 |
+
ratio_dev = 0.1
|
155 |
+
ratio_test = 0.2
|
156 |
+
batch_size = 32
|
157 |
+
# model settings
|
158 |
+
n = 5
|
159 |
+
dropout = 0.1
|
160 |
+
size_hidden = [512, 256, 128, 32]
|
161 |
+
# training settings
|
162 |
+
n_epochs = 200
|
163 |
+
lr = 3e-5
|
164 |
+
weight_decay = 0
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
for seed in seeds:
|
169 |
+
y_pred, y_true = run(
|
170 |
+
csv_file,
|
171 |
+
seed,
|
172 |
+
n,
|
173 |
+
topk,
|
174 |
+
ratio_dev,
|
175 |
+
ratio_test,
|
176 |
+
batch_size,
|
177 |
+
size_hidden,
|
178 |
+
dropout,
|
179 |
+
n_epochs,
|
180 |
+
lr,
|
181 |
+
)
|
182 |
+
log_file = f"./scores/MLP/{seed}-seed--{n}-gram--topk-{topk}--lr-{lr}.csv"
|
183 |
+
with open(log_file, "a") as f:
|
184 |
+
f.write("pred,true\n")
|
185 |
+
for p, t in zip(y_pred, y_true):
|
186 |
+
f.write(f"{p},{t}\n")
|
README.md
CHANGED
@@ -1,3 +1,48 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ***When SMILES have Language*: Drug Classification using Text Classification Methods on Drug SMILES Strings**
|
2 |
+
- **Authors:** Azmine Toushik Wasi, Šerbetar Karlo, Raima Islam, Taki Hasan Rafi, Dong-Kyu Chae
|
3 |
+
- Accepted (***invited to present***) to the **The Second Tiny Papers Track at ICLR 2024**!
|
4 |
+
- Read full paper in [arXiv](https://arxiv.org/abs/2403.12984).
|
5 |
+
---
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<img src="Fig/model.png" width="1000"/>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
**Abstract**: Complex chemical structures, like drugs, are usually defined by SMILES strings as a sequence of molecules and bonds. These SMILES strings are used in different complex machine learning-based drug-related research and representation works. Escaping from complex representation, in this work, we pose a single question: What if we treat drug SMILES as conventional sentences and engage in text classification for drug classification? Our experiments affirm the possibility with very competitive scores. The study explores the notion of viewing each atom and bond as sentence components, employing basic NLP methods to categorize drug types, proving that complex problems can also be solved with simpler perspectives.
|
12 |
+
|
13 |
+
---
|
14 |
+
|
15 |
+
# Setup and run
|
16 |
+
- Data is available at `./Model/_DATA_`
|
17 |
+
- Dataloader is available at `./Model/data`
|
18 |
+
- To run the training script, place the dataset from DrugBank, go to `./Model/` folder and run: `python train-ngram.py`
|
19 |
+
- To change parameters, you can check and edit `145-165` no lines of `./Model/train-ngram.py`
|
20 |
+
|
21 |
+
# Experimental Results
|
22 |
+
|
23 |
+
| Model | Accuracy | Precision | Recall | F1 (Weighted) | F1 (Macro) | ROC-AUC |
|
24 |
+
|----------------|----------|-----------|--------|----------------|-------------|---------|
|
25 |
+
| 1-gram+MLP | 0.622 | 0.610 | 0.622 | 0.604 | 0.406 | 0.760 |
|
26 |
+
| 2-gram+MLP | 0.669 | 0.700 | 0.669 | 0.672 | 0.445 | 0.810 |
|
27 |
+
| 3-gram+MLP | **0.737**| **0.764** | **0.737**| **0.744** | 0.553 | **0.848**|
|
28 |
+
| 4-gram+MLP | 0.726 | 0.758 | 0.726 | 0.731 | 0.524 | 0.841 |
|
29 |
+
| 5-gram+MLP | 0.728 | 0.740 | 0.728 | 0.730 | **0.563** | 0.838 |
|
30 |
+
| AtomPair+MLP | 0.799 | 0.804 | 0.800 | 0.799 | 0.702 | 0.876 |
|
31 |
+
| MACCS+MLP | 0.797 | 0.801 | 0.797 | 0.796 | 0.702 | 0.873 |
|
32 |
+
| Morgan+MLP | **0.800**| **0.804** | **0.800**| **0.799** | **0.703** | **0.876**|
|
33 |
+
|
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
# Citation
|
39 |
+
```
|
40 |
+
@inproceedings{wasi2024drug_nlp,,
|
41 |
+
author = {Azmine Toushik Wasi and Šerbetar Karlo and Raima Islam and Taki Hasan Rafi and Dong-Kyu Chae},
|
42 |
+
title = {When SMILES have Language: Drug Classification using Text Classification Methods on Drug SMILES Strings},
|
43 |
+
booktitle = {The Second Tiny Papers Track at {ICLR} 2024, Tiny Papers @ {ICLR} 2024, Vienna Austria, May 11, 2024},
|
44 |
+
publisher = {OpenReview.net},
|
45 |
+
year = {2023},
|
46 |
+
url = {https://openreview.net/forum?id=VUYCyH8fCw}
|
47 |
+
}
|
48 |
+
```
|