ZhaohanM commited on
Commit
0312a01
1 Parent(s): 91a8b2f

Initial commit

Browse files
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true}}
tokenizer/vocab.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<unk>":0,"<s>":1,"</s>":2,"<pad>":3,"<mask>":4,"\n":5,"#":6,"+":7,"-":8,".":9,"/":10,"0":11,"1":12,"2":13,"3":14,"4":15,"5":16,"6":17,"7":18,"8":19,"9":20,"=":21,"@":22,"A":23,"B":24,"C":25,"F":26,"H":27,"I":28,"K":29,"L":30,"M":31,"N":32,"O":33,"P":34,"R":35,"S":36,"T":37,"Z":38,"\\":39,"a":40,"b":41,"c":42,"e":43,"g":44,"h":45,"i":46,"l":47,"n":48,"r":49,"s":50,"Br":51,"an":52,"ch":53,"Bran":54,"Branch":55,"Branch1":56,"=C":57,"Ri":58,"ng":59,"Ring":60,"Ring1":61,"=Branch1":62,"Branch2":63,"=O":64,"Ring2":65,"H1":66,"C@":67,"=N":68,"#Branch1":69,"C@@":70,"=Branch2":71,"C@H1":72,"C@@H1":73,"#Branch2":74,"#C":75,"Cl":76,"/C":77,"NH1":78,"=Ring1":79,"+1":80,"-1":81,"O-1":82,"N+1":83,"\\C":84,"#N":85,"/N":86,"=Ring2":87,"=S":88,"=N+1":89,"\\N":90,"Na":91,"Na+1":92,"/O":93,"\\O":94,"Br-1":95,"Branch3":96,"\\S":97,"S+1":98,"Cl-1":99,"I-1":100,"/C@@H1":101,"Si":102,"/C@H1":103,"/S":104,"=N-1":105,"Se":106,"=P":107,"N-1":108,"Ring3":109,"2H":110,"P+1":111,"K+1":112,"\\C@@H1":113,"\\C@H1":114,"/N+1":115,"@@":116,"C-1":117,"#N+1":118,"B-1":119,"+3":120,"Cl+3":121,"\\NH1":122,"Li":123,"Li+1":124,"PH1":125,"18":126,"18F":127,"@+1":128,"3H":129,"P@@":130,"H0":131,"OH0":132,"12":133,"P@":134,"+2":135,"@@+1":136,"S-1":137,"/Br":138,"-/":139,"\\Cl":140,"-/Ring2":141,"\\O-1":142,"11":143,"5I":144,"125I":145,"11C":146,"H3":147,"\\N+1":148,"-\\":149,"/C@@":150,"S@+1":151,"As":152,"/Cl":153,"11CH3":154,"=Se":155,"S@@+1":156,"N@+1":157,"14":158,"-\\Ring2":159,"14C":160,"\\F":161,"/C@":162,"Te":163,"H2":164,"H1-1":165,"=O+1":166,"N@@+1":167,"C+1":168,"=S+1":169,"Zn":170,"/P":171,"a+2":172,"/I":173,"OH1-1":174,"Ca+2":175,"\\Br":176,"Mg":177,"Zn+2":178,"Al":179,"/F":180,"Mg+2":181,"123":182,"123I":183,"13":184,"I+1":185,"/O-1":186,"-\\Ring1":187,"BH2":188,"BH2-1":189,"\\I":190,"/NH1":191,"O+1":192,"131":193,"131I":194,"=14C":195,"/S+1":196,"=Ring3":197,"\\C@@":198,"H2+1":199,"\\C@":200,"Ag":201,"=As":202,"=Se+1":203,"NH2+1":204,"SeH1":205,"-/Ring1":206,"=Te":207,"Al+3":208,"NaH1":209,"=Te+1":210,"NH1+1":211,"Ag+1":212,"H1+1":213,"NH1-1":214,"\\P":215,"14CH2":216,"13C":217,"14CH1":218,"=11C":219,"S@@":220,"=P@@":221,"SiH2":222,"H3-1":223,"14CH3":224,"BH3-1":225,"S@":226,"=14CH1":227,"=PH1":228,"=P@":229,"=NH1+1":230,"\\S+1":231,"124":232,"CH1-1":233,"Sr":234,"=Si":235,"124I":236,"Sr+2":237,"#C-1":238,"/C-1":239,"N@":240,"/N-1":241,"13CH1":242,"/B":243,"19":244,"Ba+2":245,"H4":246,"SH1+1":247,"Se+1":248,"19F":249,"/125I":250,"P@+1":251,"Rb":252,"Cl+1":253,"SiH4":254,"Rb+1":255,"=Branch3":256,"N@@":257,"As+1":258,"/Si":259,"BH1-1":260,"SH1":261,"/123I":262,"32":263,"=Mg":264,"H+1":265,"\\B":266,"SiH1":267,"P@@+1":268,"-2":269,"15":270,"17":271,"35":272,"=13CH1":273,"Cs":274,"=NH2+1":275,"=SH1":276,"MgH2":277,"32P":278,"17F":279,"35S":280,"Cs+1":281,"#11C":282,"/131I":283,"Bi":284,"\\125I":285,"=S@@":286,"\\S-1":287,"6Br":288,"7I":289,"76Br":290,"=B":291,"eH1":292,"\\N-1":293,"18O":294,"127I":295,"11CH2":296,"14C@@H1":297,"TeH2":298,"15NH1":299,"Bi+3":300,"/P+1":301,"/13C":302,"/13CH1":303,"0B":304,"10B":305,"=Al":306,"=18O":307,"BH0":308,"F-1":309,"NH3":310,"S-2":311,"Br+2":312,"Cl+2":313,"\\Si":314,"/S-1":315,"=PH2":316,"14C@H1":317,"NH3+1":318,"#14C":319,"#O+1":320,"-3":321,"22":322,"4H":323,"5Se":324,"5Sr+2":325,"75Se":326,"85Sr+2":327,"=B-1":328,"=13C":329,"@-1":330,"Be":331,"B@@":332,"B@-1":333,"Ca":334,"CH1":335,"I+3":336,"KH1":337,"OH1+1":338,"Ra+2":339,"SH1-1":340,"\\PH1":341,"\\123I":342,"=Ca":343,"\\CH1-1":344,"=S@":345,"\\SeH1":346,"/SeH1":347,"Se-1":348,"LiH1":349,"18F-1":350,"125IH1":351,"11CH1":352,"TeH1":353,"Zn+1":354,"Zn-2":355,"Al-3":356,"13CH3":357,"15N":358,"Be+2":359,"B@@-1":360,"#P":361,"#S":362,"-4":363,"/PH1":364,"/P@@":365,"/As":366,"/14C":367,"/14CH1":368,"2K+1":369,"2Rb+1":370,"3Se":371,"3Ra+2":372,"45":373,"47":374,"42K+1":375,"5I-1":376,"73Se":377,"89":378,"82Rb+1":379,"=32":380,"=32P":381,"CH0":382,"CH2":383,"I+2":384,"NH0":385,"NH4":386,"OH1":387,"PH2+1":388,"SH0":389,"SH2":390,"\\3H":391,"\\11CH3":392,"\\C-1":393,"\\Se":394,"Si@":395,"Si-1":396,"SiH1-1":397,"SiH3-1":398,"/Se":399,"Se-2":400,"\\NH1-1":401,"18FH1":402,"125I-1":403,"11C@@H1":404,"11C-1":405,"AsH1":406,"As-1":407,"14C@@":408,"Te-1":409,"Mg+1":410,"123I-1":411,"123Te":412,"123IH1":413,"135I":414,"131I-1":415,"Ag-4":416,"124I-1":417,"76BrH1":418,"18OH1":419,"22Na+1":420,"223Ra+2":421,"CaH2":422,"45Ca+2":423,"47Ca+2":424,"89Sr+2":425,"=32PH1":426,"NH4+1":427}
tokenizer/vocab.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <unk>
2
+ <s>
3
+ </s>
4
+ <pad>
5
+ <mask>
6
+
7
+
8
+ #
9
+ +
10
+ -
11
+ .
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ C
28
+ F
29
+ H
30
+ I
31
+ K
32
+ L
33
+ M
34
+ N
35
+ O
36
+ P
37
+ R
38
+ S
39
+ T
40
+ Z
41
+ \
42
+ a
43
+ b
44
+ c
45
+ e
46
+ g
47
+ h
48
+ i
49
+ l
50
+ n
51
+ r
52
+ s
53
+ Br
54
+ an
55
+ ch
56
+ Bran
57
+ Branch
58
+ Branch1
59
+ =C
60
+ Ri
61
+ ng
62
+ Ring
63
+ Ring1
64
+ =Branch1
65
+ Branch2
66
+ =O
67
+ Ring2
68
+ H1
69
+ C@
70
+ =N
71
+ #Branch1
72
+ C@@
73
+ =Branch2
74
+ C@H1
75
+ C@@H1
76
+ #Branch2
77
+ #C
78
+ Cl
79
+ /C
80
+ NH1
81
+ =Ring1
82
+ +1
83
+ -1
84
+ O-1
85
+ N+1
86
+ \C
87
+ #N
88
+ /N
89
+ =Ring2
90
+ =S
91
+ =N+1
92
+ \N
93
+ Na
94
+ Na+1
95
+ /O
96
+ \O
97
+ Br-1
98
+ Branch3
99
+ \S
100
+ S+1
101
+ Cl-1
102
+ I-1
103
+ /C@@H1
104
+ Si
105
+ /C@H1
106
+ /S
107
+ =N-1
108
+ Se
109
+ =P
110
+ N-1
111
+ Ring3
112
+ 2H
113
+ P+1
114
+ K+1
115
+ \C@@H1
116
+ \C@H1
117
+ /N+1
118
+ @@
119
+ C-1
120
+ #N+1
121
+ B-1
122
+ +3
123
+ Cl+3
124
+ \NH1
125
+ Li
126
+ Li+1
127
+ PH1
128
+ 18
129
+ 18F
130
+ @+1
131
+ 3H
132
+ P@@
133
+ H0
134
+ OH0
135
+ 12
136
+ P@
137
+ +2
138
+ @@+1
139
+ S-1
140
+ /Br
141
+ -/
142
+ \Cl
143
+ -/Ring2
144
+ \O-1
145
+ 11
146
+ 5I
147
+ 125I
148
+ 11C
149
+ H3
150
+ \N+1
151
+ -\
152
+ /C@@
153
+ S@+1
154
+ As
155
+ /Cl
156
+ 11CH3
157
+ =Se
158
+ S@@+1
159
+ N@+1
160
+ 14
161
+ -\Ring2
162
+ 14C
163
+ \F
164
+ /C@
165
+ Te
166
+ H2
167
+ H1-1
168
+ =O+1
169
+ N@@+1
170
+ C+1
171
+ =S+1
172
+ Zn
173
+ /P
174
+ a+2
175
+ /I
176
+ OH1-1
177
+ Ca+2
178
+ \Br
179
+ Mg
180
+ Zn+2
181
+ Al
182
+ /F
183
+ Mg+2
184
+ 123
185
+ 123I
186
+ 13
187
+ I+1
188
+ /O-1
189
+ -\Ring1
190
+ BH2
191
+ BH2-1
192
+ \I
193
+ /NH1
194
+ O+1
195
+ 131
196
+ 131I
197
+ =14C
198
+ /S+1
199
+ =Ring3
200
+ \C@@
201
+ H2+1
202
+ \C@
203
+ Ag
204
+ =As
205
+ =Se+1
206
+ NH2+1
207
+ SeH1
208
+ -/Ring1
209
+ =Te
210
+ Al+3
211
+ NaH1
212
+ =Te+1
213
+ NH1+1
214
+ Ag+1
215
+ H1+1
216
+ NH1-1
217
+ \P
218
+ 14CH2
219
+ 13C
220
+ 14CH1
221
+ =11C
222
+ S@@
223
+ =P@@
224
+ SiH2
225
+ H3-1
226
+ 14CH3
227
+ BH3-1
228
+ S@
229
+ =14CH1
230
+ =PH1
231
+ =P@
232
+ =NH1+1
233
+ \S+1
234
+ 124
235
+ CH1-1
236
+ Sr
237
+ =Si
238
+ 124I
239
+ Sr+2
240
+ #C-1
241
+ /C-1
242
+ N@
243
+ /N-1
244
+ 13CH1
245
+ /B
246
+ 19
247
+ Ba+2
248
+ H4
249
+ SH1+1
250
+ Se+1
251
+ 19F
252
+ /125I
253
+ P@+1
254
+ Rb
255
+ Cl+1
256
+ SiH4
257
+ Rb+1
258
+ =Branch3
259
+ N@@
260
+ As+1
261
+ /Si
262
+ BH1-1
263
+ SH1
264
+ /123I
265
+ 32
266
+ =Mg
267
+ H+1
268
+ \B
269
+ SiH1
270
+ P@@+1
271
+ -2
272
+ 15
273
+ 17
274
+ 35
275
+ =13CH1
276
+ Cs
277
+ =NH2+1
278
+ =SH1
279
+ MgH2
280
+ 32P
281
+ 17F
282
+ 35S
283
+ Cs+1
284
+ #11C
285
+ /131I
286
+ Bi
287
+ \125I
288
+ =S@@
289
+ \S-1
290
+ 6Br
291
+ 7I
292
+ 76Br
293
+ =B
294
+ eH1
295
+ \N-1
296
+ 18O
297
+ 127I
298
+ 11CH2
299
+ 14C@@H1
300
+ TeH2
301
+ 15NH1
302
+ Bi+3
303
+ /P+1
304
+ /13C
305
+ /13CH1
306
+ 0B
307
+ 10B
308
+ =Al
309
+ =18O
310
+ BH0
311
+ F-1
312
+ NH3
313
+ S-2
314
+ Br+2
315
+ Cl+2
316
+ \Si
317
+ /S-1
318
+ =PH2
319
+ 14C@H1
320
+ NH3+1
321
+ #14C
322
+ #O+1
323
+ -3
324
+ 22
325
+ 4H
326
+ 5Se
327
+ 5Sr+2
328
+ 75Se
329
+ 85Sr+2
330
+ =B-1
331
+ =13C
332
+ @-1
333
+ Be
334
+ B@@
335
+ B@-1
336
+ Ca
337
+ CH1
338
+ I+3
339
+ KH1
340
+ OH1+1
341
+ Ra+2
342
+ SH1-1
343
+ \PH1
344
+ \123I
345
+ =Ca
346
+ \CH1-1
347
+ =S@
348
+ \SeH1
349
+ /SeH1
350
+ Se-1
351
+ LiH1
352
+ 18F-1
353
+ 125IH1
354
+ 11CH1
355
+ TeH1
356
+ Zn+1
357
+ Zn-2
358
+ Al-3
359
+ 13CH3
360
+ 15N
361
+ Be+2
362
+ B@@-1
363
+ #P
364
+ #S
365
+ -4
366
+ /PH1
367
+ /P@@
368
+ /As
369
+ /14C
370
+ /14CH1
371
+ 2K+1
372
+ 2Rb+1
373
+ 3Se
374
+ 3Ra+2
375
+ 45
376
+ 47
377
+ 42K+1
378
+ 5I-1
379
+ 73Se
380
+ 89
381
+ 82Rb+1
382
+ =32
383
+ =32P
384
+ CH0
385
+ CH2
386
+ I+2
387
+ NH0
388
+ NH4
389
+ OH1
390
+ PH2+1
391
+ SH0
392
+ SH2
393
+ \3H
394
+ \11CH3
395
+ \C-1
396
+ \Se
397
+ Si@
398
+ Si-1
399
+ SiH1-1
400
+ SiH3-1
401
+ /Se
402
+ Se-2
403
+ \NH1-1
404
+ 18FH1
405
+ 125I-1
406
+ 11C@@H1
407
+ 11C-1
408
+ AsH1
409
+ As-1
410
+ 14C@@
411
+ Te-1
412
+ Mg+1
413
+ 123I-1
414
+ 123Te
415
+ 123IH1
416
+ 135I
417
+ 131I-1
418
+ Ag-4
419
+ 124I-1
420
+ 76BrH1
421
+ 18OH1
422
+ 22Na+1
423
+ 223Ra+2
424
+ CaH2
425
+ 45Ca+2
426
+ 47Ca+2
427
+ 89Sr+2
428
+ =32PH1
429
+ NH4+1
utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ class DrugTokenizer:
8
+ def __init__(self, vocab_path="tokenizer/vocab.json", special_tokens_path="tokenizer/special_tokens_map.json"):
9
+ self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
10
+ self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
11
+ self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
12
+ self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
13
+ self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
14
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
15
+
16
+ def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
17
+ with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
18
+ vocab = json.load(vocab_file)
19
+ with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
20
+ special_tokens_raw = json.load(special_tokens_file)
21
+
22
+ special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
23
+ return vocab, special_tokens
24
+
25
+ def encode(self, sequence):
26
+ tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
27
+ input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
28
+ attention_mask = [1] * len(input_ids)
29
+ return {
30
+ 'input_ids': input_ids,
31
+ 'attention_mask': attention_mask
32
+ }
33
+
34
+ def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
35
+ input_ids_list = []
36
+ attention_mask_list = []
37
+
38
+ for sequence in sequences:
39
+ encoded = self.encode(sequence)
40
+ input_ids = encoded['input_ids']
41
+ attention_mask = encoded['attention_mask']
42
+
43
+ if len(input_ids) > max_length:
44
+ input_ids = input_ids[:max_length]
45
+ attention_mask = attention_mask[:max_length]
46
+ elif len(input_ids) < max_length:
47
+ pad_length = max_length - len(input_ids)
48
+ input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
49
+ attention_mask = attention_mask + [0] * pad_length
50
+
51
+ input_ids_list.append(input_ids)
52
+ attention_mask_list.append(attention_mask)
53
+
54
+ return {
55
+ 'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
56
+ 'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
57
+ }
58
+
59
+ def decode(self, input_ids, skip_special_tokens=False):
60
+ tokens = []
61
+ for id in input_ids:
62
+ if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
63
+ continue
64
+ tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
65
+ sequence = ''.join([f'[{token}]' for token in tokens])
66
+ return sequence
utils/__pycache__/drug_tokenizer.cpython-38.pyc ADDED
Binary file (3.25 kB). View file
 
utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
utils/drug_tokenizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ class DrugTokenizer:
8
+ def __init__(self, vocab_path="tokenizer/vocab.json", special_tokens_path="tokenizer/special_tokens_map.json"):
9
+ self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
10
+ self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
11
+ self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
12
+ self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
13
+ self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
14
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
15
+
16
+ def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
17
+ with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
18
+ vocab = json.load(vocab_file)
19
+ with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
20
+ special_tokens_raw = json.load(special_tokens_file)
21
+
22
+ special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
23
+ return vocab, special_tokens
24
+
25
+ def encode(self, sequence):
26
+ tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
27
+ input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
28
+ attention_mask = [1] * len(input_ids)
29
+ return {
30
+ 'input_ids': input_ids,
31
+ 'attention_mask': attention_mask
32
+ }
33
+
34
+ def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
35
+ input_ids_list = []
36
+ attention_mask_list = []
37
+
38
+ for sequence in sequences:
39
+ encoded = self.encode(sequence)
40
+ input_ids = encoded['input_ids']
41
+ attention_mask = encoded['attention_mask']
42
+
43
+ if len(input_ids) > max_length:
44
+ input_ids = input_ids[:max_length]
45
+ attention_mask = attention_mask[:max_length]
46
+ elif len(input_ids) < max_length:
47
+ pad_length = max_length - len(input_ids)
48
+ input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
49
+ attention_mask = attention_mask + [0] * pad_length
50
+
51
+ input_ids_list.append(input_ids)
52
+ attention_mask_list.append(attention_mask)
53
+
54
+ return {
55
+ 'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
56
+ 'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
57
+ }
58
+
59
+ def decode(self, input_ids, skip_special_tokens=False):
60
+ tokens = []
61
+ for id in input_ids:
62
+ if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
63
+ continue
64
+ tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
65
+ sequence = ''.join([f'[{token}]' for token in tokens])
66
+ return sequence
utils/metric_learning_models_att_maps.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ sys.path.append("../")
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.cuda.amp import autocast
11
+ from torch.nn import Module
12
+ from tqdm import tqdm
13
+ from torch.nn.utils.weight_norm import weight_norm
14
+ from torch.utils.data import Dataset
15
+
16
+ LOGGER = logging.getLogger(__name__)
17
+
18
+ class FusionDTI(nn.Module):
19
+ def __init__(self, prot_out_dim, disease_out_dim, args):
20
+ super(FusionDTI, self).__init__()
21
+ self.fusion = args.fusion
22
+ self.drug_reg = nn.Linear(disease_out_dim, 512)
23
+ self.prot_reg = nn.Linear(prot_out_dim, 512)
24
+
25
+ if self.fusion == "CAN":
26
+ self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args)
27
+ self.mlp_classifier = MlPdecoder_CAN(input_dim=1024)
28
+ elif self.fusion == "BAN":
29
+ self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None)
30
+ self.mlp_classifier = MlPdecoder_CAN(input_dim=256)
31
+ elif self.fusion == "Nan":
32
+ self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214)
33
+
34
+ def forward(self, prot_embed, drug_embed, prot_mask, drug_mask):
35
+ # print("drug_embed", drug_embed.shape)
36
+ if self.fusion == "Nan":
37
+ prot_embed = prot_embed.mean(1) # query : [batch_size, hidden]
38
+ drug_embed = drug_embed.mean(1) # query : [batch_size, hidden]
39
+ joint_embed = torch.cat([prot_embed, drug_embed], dim=1)
40
+ score = self.mlp_classifier_nan(joint_embed)
41
+ else:
42
+ prot_embed = self.prot_reg(prot_embed)
43
+ drug_embed = self.drug_reg(drug_embed)
44
+
45
+ if self.fusion == "CAN":
46
+ joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask)
47
+ elif self.fusion == "BAN":
48
+ joint_embed, att = self.ban_layer(prot_embed, drug_embed)
49
+
50
+ score = self.mlp_classifier(joint_embed)
51
+
52
+ return score, att
53
+
54
+ class Pre_encoded(nn.Module):
55
+ def __init__(
56
+ self, prot_encoder, drug_encoder, args
57
+ ):
58
+ """Constructor for the model.
59
+
60
+ Args:
61
+ prot_encoder (_type_): Protein sturcture-aware sequence encoder.
62
+ drug_encoder (_type_): Drug SFLFIES encoder.
63
+ args (_type_): _description_
64
+ """
65
+ super(Pre_encoded, self).__init__()
66
+ self.prot_encoder = prot_encoder
67
+ self.drug_encoder = drug_encoder
68
+
69
+ def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask):
70
+ # Process inputs through encoders
71
+ prot_embed = self.prot_encoder(
72
+ input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
73
+ ).logits
74
+ # prot_embed = self.prot_reg(prot_embed)
75
+
76
+ drug_embed = self.drug_encoder(
77
+ input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True
78
+ ).last_hidden_state # .last_hidden_state
79
+
80
+ # print("drug_embed", drug_embed.shape)
81
+
82
+ return prot_embed, drug_embed
83
+
84
+
85
+ class CAN_Layer(nn.Module):
86
+ def __init__(self, hidden_dim, num_heads, args):
87
+ super(CAN_Layer, self).__init__()
88
+ self.agg_mode = args.agg_mode
89
+ self.group_size = args.group_size # Control Fusion Scale
90
+ self.hidden_dim = hidden_dim
91
+ self.num_heads = num_heads
92
+ self.head_size = hidden_dim // num_heads
93
+
94
+ self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
95
+ self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
96
+ self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
97
+
98
+ self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
99
+ self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
100
+ self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
101
+
102
+ def alpha_logits(self, logits, mask_row, mask_col, inf=1e6):
103
+ N, L1, L2, H = logits.shape
104
+ mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
105
+ mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
106
+ mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
107
+
108
+ logits = torch.where(mask_pair, logits, logits - inf)
109
+ alpha = torch.softmax(logits, dim=2)
110
+ mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
111
+ alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
112
+ return alpha
113
+
114
+ def apply_heads(self, x, n_heads, n_ch):
115
+ s = list(x.size())[:-1] + [n_heads, n_ch]
116
+ return x.view(*s)
117
+
118
+ def group_embeddings(self, x, mask, group_size):
119
+ N, L, D = x.shape
120
+ groups = L // group_size
121
+ x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
122
+ mask_grouped = mask.view(N, groups, group_size).any(dim=2)
123
+ return x_grouped, mask_grouped
124
+
125
+ def forward(self, protein, drug, mask_prot, mask_drug):
126
+ # Group embeddings before applying multi-head attention
127
+ protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size)
128
+ drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size)
129
+
130
+ # print("protein_grouped:", protein_grouped.shape)
131
+ # print("mask_prot_grouped:", mask_prot_grouped.shape)
132
+
133
+ # Compute queries, keys, values for both protein and drug after grouping
134
+ query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size)
135
+ key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size)
136
+ value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size)
137
+
138
+ query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size)
139
+ key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size)
140
+ value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size)
141
+
142
+ # Compute attention scores
143
+ logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
144
+ logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug)
145
+ logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot)
146
+ logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug)
147
+ # print("logits_pp:", logits_pp.shape)
148
+
149
+ alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
150
+ alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped)
151
+ alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped)
152
+ alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped)
153
+
154
+ prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) +
155
+ torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2
156
+ drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) +
157
+ torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2
158
+
159
+ # print("prot_embedding:", prot_embedding.shape)
160
+
161
+ # Continue as usual with the aggregation mode
162
+ if self.agg_mode == "cls":
163
+ prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
164
+ drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden]
165
+ elif self.agg_mode == "mean_all_tok":
166
+ prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
167
+ drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden]
168
+ elif self.agg_mode == "mean":
169
+ prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
170
+ drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1)
171
+ else:
172
+ raise NotImplementedError()
173
+
174
+ # print("prot_embed:", prot_embed.shape)
175
+
176
+ query_embed = torch.cat([prot_embed, drug_embed], dim=1)
177
+
178
+
179
+ att = torch.zeros(1, 1, 1024, 1024)
180
+ att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein
181
+ att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug
182
+ att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein
183
+ att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug
184
+
185
+ # print("query_embed:", query_embed.shape)
186
+ return query_embed, att
187
+
188
+ class MlPdecoder_CAN(nn.Module):
189
+ def __init__(self, input_dim):
190
+ super(MlPdecoder_CAN, self).__init__()
191
+ self.fc1 = nn.Linear(input_dim, input_dim)
192
+ self.bn1 = nn.BatchNorm1d(input_dim)
193
+ self.fc2 = nn.Linear(input_dim, input_dim // 2)
194
+ self.bn2 = nn.BatchNorm1d(input_dim // 2)
195
+ self.fc3 = nn.Linear(input_dim // 2, input_dim // 4)
196
+ self.bn3 = nn.BatchNorm1d(input_dim // 4)
197
+ self.output = nn.Linear(input_dim // 4, 1)
198
+
199
+ def forward(self, x):
200
+ x = self.bn1(torch.relu(self.fc1(x)))
201
+ x = self.bn2(torch.relu(self.fc2(x)))
202
+ x = self.bn3(torch.relu(self.fc3(x)))
203
+ x = torch.sigmoid(self.output(x))
204
+ return x
205
+
206
+ class MLPdecoder_BAN(nn.Module):
207
+ def __init__(self, in_dim, hidden_dim, out_dim, binary=1):
208
+ super(MLPdecoder_BAN, self).__init__()
209
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
210
+ self.bn1 = nn.BatchNorm1d(hidden_dim)
211
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
212
+ self.bn2 = nn.BatchNorm1d(hidden_dim)
213
+ self.fc3 = nn.Linear(hidden_dim, out_dim)
214
+ self.bn3 = nn.BatchNorm1d(out_dim)
215
+ self.fc4 = nn.Linear(out_dim, binary)
216
+
217
+ def forward(self, x):
218
+ x = self.bn1(F.relu(self.fc1(x)))
219
+ x = self.bn2(F.relu(self.fc2(x)))
220
+ x = self.bn3(F.relu(self.fc3(x)))
221
+ # x = self.fc4(x)
222
+ x = torch.sigmoid(self.fc4(x))
223
+ return x
224
+
225
+ class BANLayer(nn.Module):
226
+ """ Bilinear attention network
227
+ Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py
228
+ """
229
+ def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
230
+ super(BANLayer, self).__init__()
231
+
232
+ self.c = 32
233
+ self.k = k
234
+ self.v_dim = v_dim
235
+ self.q_dim = q_dim
236
+ self.h_dim = h_dim
237
+ self.h_out = h_out
238
+
239
+ self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
240
+ self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
241
+ # self.dropout = nn.Dropout(dropout[1])
242
+ if 1 < k:
243
+ self.p_net = nn.AvgPool1d(self.k, stride=self.k)
244
+
245
+ if h_out <= self.c:
246
+ self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
247
+ self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
248
+ else:
249
+ self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
250
+
251
+ self.bn = nn.BatchNorm1d(h_dim)
252
+
253
+ def attention_pooling(self, v, q, att_map):
254
+ fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
255
+ if 1 < self.k:
256
+ fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
257
+ fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
258
+ return fusion_logits
259
+
260
+ def forward(self, v, q, softmax=False):
261
+ v_num = v.size(1)
262
+ q_num = q.size(1)
263
+ # print("v_num", v_num)
264
+ # print("v_num ", v_num)
265
+ if self.h_out <= self.c:
266
+ v_ = self.v_net(v)
267
+ q_ = self.q_net(q)
268
+ # print("v_", v_.shape)
269
+ # print("q_ ", q_.shape)
270
+ att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
271
+ # print("Attention map_1",att_maps.shape)
272
+ else:
273
+ v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
274
+ q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
275
+ d_ = torch.matmul(v_, q_) # b x h_dim x v x q
276
+ att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
277
+ att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
278
+ # print("Attention map_2",att_maps.shape)
279
+ if softmax:
280
+ p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
281
+ att_maps = p.view(-1, self.h_out, v_num, q_num)
282
+ # print("Attention map_softmax", att_maps.shape)
283
+ logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
284
+ for i in range(1, self.h_out):
285
+ logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
286
+ logits += logits_i
287
+ logits = self.bn(logits)
288
+ return logits, att_maps
289
+
290
+
291
+ class FCNet(nn.Module):
292
+ """Simple class for non-linear fully connect network
293
+ Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
294
+ """
295
+
296
+ def __init__(self, dims, act='ReLU', dropout=0):
297
+ super(FCNet, self).__init__()
298
+
299
+ layers = []
300
+ for i in range(len(dims) - 2):
301
+ in_dim = dims[i]
302
+ out_dim = dims[i + 1]
303
+ if 0 < dropout:
304
+ layers.append(nn.Dropout(dropout))
305
+ layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
306
+ if '' != act:
307
+ layers.append(getattr(nn, act)())
308
+ if 0 < dropout:
309
+ layers.append(nn.Dropout(dropout))
310
+ layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
311
+ if '' != act:
312
+ layers.append(getattr(nn, act)())
313
+
314
+ self.main = nn.Sequential(*layers)
315
+
316
+ def forward(self, x):
317
+ return self.main(x)
318
+
319
+
320
+ class BatchFileDataset_Case(Dataset):
321
+ def __init__(self, file_list):
322
+ self.file_list = file_list
323
+
324
+ def __len__(self):
325
+ return len(self.file_list)
326
+
327
+ def __getitem__(self, idx):
328
+ batch_file = self.file_list[idx]
329
+ data = torch.load(batch_file)
330
+ return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']