Spaces:
Sleeping
Sleeping
ZhaohanM
commited on
Commit
•
0312a01
1
Parent(s):
91a8b2f
Initial commit
Browse files- tokenizer/special_tokens_map.json +1 -0
- tokenizer/vocab.json +1 -0
- tokenizer/vocab.txt +429 -0
- utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py +66 -0
- utils/__pycache__/drug_tokenizer.cpython-38.pyc +0 -0
- utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc +0 -0
- utils/drug_tokenizer.py +66 -0
- utils/metric_learning_models_att_maps.py +330 -0
tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true}}
|
tokenizer/vocab.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"<unk>":0,"<s>":1,"</s>":2,"<pad>":3,"<mask>":4,"\n":5,"#":6,"+":7,"-":8,".":9,"/":10,"0":11,"1":12,"2":13,"3":14,"4":15,"5":16,"6":17,"7":18,"8":19,"9":20,"=":21,"@":22,"A":23,"B":24,"C":25,"F":26,"H":27,"I":28,"K":29,"L":30,"M":31,"N":32,"O":33,"P":34,"R":35,"S":36,"T":37,"Z":38,"\\":39,"a":40,"b":41,"c":42,"e":43,"g":44,"h":45,"i":46,"l":47,"n":48,"r":49,"s":50,"Br":51,"an":52,"ch":53,"Bran":54,"Branch":55,"Branch1":56,"=C":57,"Ri":58,"ng":59,"Ring":60,"Ring1":61,"=Branch1":62,"Branch2":63,"=O":64,"Ring2":65,"H1":66,"C@":67,"=N":68,"#Branch1":69,"C@@":70,"=Branch2":71,"C@H1":72,"C@@H1":73,"#Branch2":74,"#C":75,"Cl":76,"/C":77,"NH1":78,"=Ring1":79,"+1":80,"-1":81,"O-1":82,"N+1":83,"\\C":84,"#N":85,"/N":86,"=Ring2":87,"=S":88,"=N+1":89,"\\N":90,"Na":91,"Na+1":92,"/O":93,"\\O":94,"Br-1":95,"Branch3":96,"\\S":97,"S+1":98,"Cl-1":99,"I-1":100,"/C@@H1":101,"Si":102,"/C@H1":103,"/S":104,"=N-1":105,"Se":106,"=P":107,"N-1":108,"Ring3":109,"2H":110,"P+1":111,"K+1":112,"\\C@@H1":113,"\\C@H1":114,"/N+1":115,"@@":116,"C-1":117,"#N+1":118,"B-1":119,"+3":120,"Cl+3":121,"\\NH1":122,"Li":123,"Li+1":124,"PH1":125,"18":126,"18F":127,"@+1":128,"3H":129,"P@@":130,"H0":131,"OH0":132,"12":133,"P@":134,"+2":135,"@@+1":136,"S-1":137,"/Br":138,"-/":139,"\\Cl":140,"-/Ring2":141,"\\O-1":142,"11":143,"5I":144,"125I":145,"11C":146,"H3":147,"\\N+1":148,"-\\":149,"/C@@":150,"S@+1":151,"As":152,"/Cl":153,"11CH3":154,"=Se":155,"S@@+1":156,"N@+1":157,"14":158,"-\\Ring2":159,"14C":160,"\\F":161,"/C@":162,"Te":163,"H2":164,"H1-1":165,"=O+1":166,"N@@+1":167,"C+1":168,"=S+1":169,"Zn":170,"/P":171,"a+2":172,"/I":173,"OH1-1":174,"Ca+2":175,"\\Br":176,"Mg":177,"Zn+2":178,"Al":179,"/F":180,"Mg+2":181,"123":182,"123I":183,"13":184,"I+1":185,"/O-1":186,"-\\Ring1":187,"BH2":188,"BH2-1":189,"\\I":190,"/NH1":191,"O+1":192,"131":193,"131I":194,"=14C":195,"/S+1":196,"=Ring3":197,"\\C@@":198,"H2+1":199,"\\C@":200,"Ag":201,"=As":202,"=Se+1":203,"NH2+1":204,"SeH1":205,"-/Ring1":206,"=Te":207,"Al+3":208,"NaH1":209,"=Te+1":210,"NH1+1":211,"Ag+1":212,"H1+1":213,"NH1-1":214,"\\P":215,"14CH2":216,"13C":217,"14CH1":218,"=11C":219,"S@@":220,"=P@@":221,"SiH2":222,"H3-1":223,"14CH3":224,"BH3-1":225,"S@":226,"=14CH1":227,"=PH1":228,"=P@":229,"=NH1+1":230,"\\S+1":231,"124":232,"CH1-1":233,"Sr":234,"=Si":235,"124I":236,"Sr+2":237,"#C-1":238,"/C-1":239,"N@":240,"/N-1":241,"13CH1":242,"/B":243,"19":244,"Ba+2":245,"H4":246,"SH1+1":247,"Se+1":248,"19F":249,"/125I":250,"P@+1":251,"Rb":252,"Cl+1":253,"SiH4":254,"Rb+1":255,"=Branch3":256,"N@@":257,"As+1":258,"/Si":259,"BH1-1":260,"SH1":261,"/123I":262,"32":263,"=Mg":264,"H+1":265,"\\B":266,"SiH1":267,"P@@+1":268,"-2":269,"15":270,"17":271,"35":272,"=13CH1":273,"Cs":274,"=NH2+1":275,"=SH1":276,"MgH2":277,"32P":278,"17F":279,"35S":280,"Cs+1":281,"#11C":282,"/131I":283,"Bi":284,"\\125I":285,"=S@@":286,"\\S-1":287,"6Br":288,"7I":289,"76Br":290,"=B":291,"eH1":292,"\\N-1":293,"18O":294,"127I":295,"11CH2":296,"14C@@H1":297,"TeH2":298,"15NH1":299,"Bi+3":300,"/P+1":301,"/13C":302,"/13CH1":303,"0B":304,"10B":305,"=Al":306,"=18O":307,"BH0":308,"F-1":309,"NH3":310,"S-2":311,"Br+2":312,"Cl+2":313,"\\Si":314,"/S-1":315,"=PH2":316,"14C@H1":317,"NH3+1":318,"#14C":319,"#O+1":320,"-3":321,"22":322,"4H":323,"5Se":324,"5Sr+2":325,"75Se":326,"85Sr+2":327,"=B-1":328,"=13C":329,"@-1":330,"Be":331,"B@@":332,"B@-1":333,"Ca":334,"CH1":335,"I+3":336,"KH1":337,"OH1+1":338,"Ra+2":339,"SH1-1":340,"\\PH1":341,"\\123I":342,"=Ca":343,"\\CH1-1":344,"=S@":345,"\\SeH1":346,"/SeH1":347,"Se-1":348,"LiH1":349,"18F-1":350,"125IH1":351,"11CH1":352,"TeH1":353,"Zn+1":354,"Zn-2":355,"Al-3":356,"13CH3":357,"15N":358,"Be+2":359,"B@@-1":360,"#P":361,"#S":362,"-4":363,"/PH1":364,"/P@@":365,"/As":366,"/14C":367,"/14CH1":368,"2K+1":369,"2Rb+1":370,"3Se":371,"3Ra+2":372,"45":373,"47":374,"42K+1":375,"5I-1":376,"73Se":377,"89":378,"82Rb+1":379,"=32":380,"=32P":381,"CH0":382,"CH2":383,"I+2":384,"NH0":385,"NH4":386,"OH1":387,"PH2+1":388,"SH0":389,"SH2":390,"\\3H":391,"\\11CH3":392,"\\C-1":393,"\\Se":394,"Si@":395,"Si-1":396,"SiH1-1":397,"SiH3-1":398,"/Se":399,"Se-2":400,"\\NH1-1":401,"18FH1":402,"125I-1":403,"11C@@H1":404,"11C-1":405,"AsH1":406,"As-1":407,"14C@@":408,"Te-1":409,"Mg+1":410,"123I-1":411,"123Te":412,"123IH1":413,"135I":414,"131I-1":415,"Ag-4":416,"124I-1":417,"76BrH1":418,"18OH1":419,"22Na+1":420,"223Ra+2":421,"CaH2":422,"45Ca+2":423,"47Ca+2":424,"89Sr+2":425,"=32PH1":426,"NH4+1":427}
|
tokenizer/vocab.txt
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<unk>
|
2 |
+
<s>
|
3 |
+
</s>
|
4 |
+
<pad>
|
5 |
+
<mask>
|
6 |
+
|
7 |
+
|
8 |
+
#
|
9 |
+
+
|
10 |
+
-
|
11 |
+
.
|
12 |
+
/
|
13 |
+
0
|
14 |
+
1
|
15 |
+
2
|
16 |
+
3
|
17 |
+
4
|
18 |
+
5
|
19 |
+
6
|
20 |
+
7
|
21 |
+
8
|
22 |
+
9
|
23 |
+
=
|
24 |
+
@
|
25 |
+
A
|
26 |
+
B
|
27 |
+
C
|
28 |
+
F
|
29 |
+
H
|
30 |
+
I
|
31 |
+
K
|
32 |
+
L
|
33 |
+
M
|
34 |
+
N
|
35 |
+
O
|
36 |
+
P
|
37 |
+
R
|
38 |
+
S
|
39 |
+
T
|
40 |
+
Z
|
41 |
+
\
|
42 |
+
a
|
43 |
+
b
|
44 |
+
c
|
45 |
+
e
|
46 |
+
g
|
47 |
+
h
|
48 |
+
i
|
49 |
+
l
|
50 |
+
n
|
51 |
+
r
|
52 |
+
s
|
53 |
+
Br
|
54 |
+
an
|
55 |
+
ch
|
56 |
+
Bran
|
57 |
+
Branch
|
58 |
+
Branch1
|
59 |
+
=C
|
60 |
+
Ri
|
61 |
+
ng
|
62 |
+
Ring
|
63 |
+
Ring1
|
64 |
+
=Branch1
|
65 |
+
Branch2
|
66 |
+
=O
|
67 |
+
Ring2
|
68 |
+
H1
|
69 |
+
C@
|
70 |
+
=N
|
71 |
+
#Branch1
|
72 |
+
C@@
|
73 |
+
=Branch2
|
74 |
+
C@H1
|
75 |
+
C@@H1
|
76 |
+
#Branch2
|
77 |
+
#C
|
78 |
+
Cl
|
79 |
+
/C
|
80 |
+
NH1
|
81 |
+
=Ring1
|
82 |
+
+1
|
83 |
+
-1
|
84 |
+
O-1
|
85 |
+
N+1
|
86 |
+
\C
|
87 |
+
#N
|
88 |
+
/N
|
89 |
+
=Ring2
|
90 |
+
=S
|
91 |
+
=N+1
|
92 |
+
\N
|
93 |
+
Na
|
94 |
+
Na+1
|
95 |
+
/O
|
96 |
+
\O
|
97 |
+
Br-1
|
98 |
+
Branch3
|
99 |
+
\S
|
100 |
+
S+1
|
101 |
+
Cl-1
|
102 |
+
I-1
|
103 |
+
/C@@H1
|
104 |
+
Si
|
105 |
+
/C@H1
|
106 |
+
/S
|
107 |
+
=N-1
|
108 |
+
Se
|
109 |
+
=P
|
110 |
+
N-1
|
111 |
+
Ring3
|
112 |
+
2H
|
113 |
+
P+1
|
114 |
+
K+1
|
115 |
+
\C@@H1
|
116 |
+
\C@H1
|
117 |
+
/N+1
|
118 |
+
@@
|
119 |
+
C-1
|
120 |
+
#N+1
|
121 |
+
B-1
|
122 |
+
+3
|
123 |
+
Cl+3
|
124 |
+
\NH1
|
125 |
+
Li
|
126 |
+
Li+1
|
127 |
+
PH1
|
128 |
+
18
|
129 |
+
18F
|
130 |
+
@+1
|
131 |
+
3H
|
132 |
+
P@@
|
133 |
+
H0
|
134 |
+
OH0
|
135 |
+
12
|
136 |
+
P@
|
137 |
+
+2
|
138 |
+
@@+1
|
139 |
+
S-1
|
140 |
+
/Br
|
141 |
+
-/
|
142 |
+
\Cl
|
143 |
+
-/Ring2
|
144 |
+
\O-1
|
145 |
+
11
|
146 |
+
5I
|
147 |
+
125I
|
148 |
+
11C
|
149 |
+
H3
|
150 |
+
\N+1
|
151 |
+
-\
|
152 |
+
/C@@
|
153 |
+
S@+1
|
154 |
+
As
|
155 |
+
/Cl
|
156 |
+
11CH3
|
157 |
+
=Se
|
158 |
+
S@@+1
|
159 |
+
N@+1
|
160 |
+
14
|
161 |
+
-\Ring2
|
162 |
+
14C
|
163 |
+
\F
|
164 |
+
/C@
|
165 |
+
Te
|
166 |
+
H2
|
167 |
+
H1-1
|
168 |
+
=O+1
|
169 |
+
N@@+1
|
170 |
+
C+1
|
171 |
+
=S+1
|
172 |
+
Zn
|
173 |
+
/P
|
174 |
+
a+2
|
175 |
+
/I
|
176 |
+
OH1-1
|
177 |
+
Ca+2
|
178 |
+
\Br
|
179 |
+
Mg
|
180 |
+
Zn+2
|
181 |
+
Al
|
182 |
+
/F
|
183 |
+
Mg+2
|
184 |
+
123
|
185 |
+
123I
|
186 |
+
13
|
187 |
+
I+1
|
188 |
+
/O-1
|
189 |
+
-\Ring1
|
190 |
+
BH2
|
191 |
+
BH2-1
|
192 |
+
\I
|
193 |
+
/NH1
|
194 |
+
O+1
|
195 |
+
131
|
196 |
+
131I
|
197 |
+
=14C
|
198 |
+
/S+1
|
199 |
+
=Ring3
|
200 |
+
\C@@
|
201 |
+
H2+1
|
202 |
+
\C@
|
203 |
+
Ag
|
204 |
+
=As
|
205 |
+
=Se+1
|
206 |
+
NH2+1
|
207 |
+
SeH1
|
208 |
+
-/Ring1
|
209 |
+
=Te
|
210 |
+
Al+3
|
211 |
+
NaH1
|
212 |
+
=Te+1
|
213 |
+
NH1+1
|
214 |
+
Ag+1
|
215 |
+
H1+1
|
216 |
+
NH1-1
|
217 |
+
\P
|
218 |
+
14CH2
|
219 |
+
13C
|
220 |
+
14CH1
|
221 |
+
=11C
|
222 |
+
S@@
|
223 |
+
=P@@
|
224 |
+
SiH2
|
225 |
+
H3-1
|
226 |
+
14CH3
|
227 |
+
BH3-1
|
228 |
+
S@
|
229 |
+
=14CH1
|
230 |
+
=PH1
|
231 |
+
=P@
|
232 |
+
=NH1+1
|
233 |
+
\S+1
|
234 |
+
124
|
235 |
+
CH1-1
|
236 |
+
Sr
|
237 |
+
=Si
|
238 |
+
124I
|
239 |
+
Sr+2
|
240 |
+
#C-1
|
241 |
+
/C-1
|
242 |
+
N@
|
243 |
+
/N-1
|
244 |
+
13CH1
|
245 |
+
/B
|
246 |
+
19
|
247 |
+
Ba+2
|
248 |
+
H4
|
249 |
+
SH1+1
|
250 |
+
Se+1
|
251 |
+
19F
|
252 |
+
/125I
|
253 |
+
P@+1
|
254 |
+
Rb
|
255 |
+
Cl+1
|
256 |
+
SiH4
|
257 |
+
Rb+1
|
258 |
+
=Branch3
|
259 |
+
N@@
|
260 |
+
As+1
|
261 |
+
/Si
|
262 |
+
BH1-1
|
263 |
+
SH1
|
264 |
+
/123I
|
265 |
+
32
|
266 |
+
=Mg
|
267 |
+
H+1
|
268 |
+
\B
|
269 |
+
SiH1
|
270 |
+
P@@+1
|
271 |
+
-2
|
272 |
+
15
|
273 |
+
17
|
274 |
+
35
|
275 |
+
=13CH1
|
276 |
+
Cs
|
277 |
+
=NH2+1
|
278 |
+
=SH1
|
279 |
+
MgH2
|
280 |
+
32P
|
281 |
+
17F
|
282 |
+
35S
|
283 |
+
Cs+1
|
284 |
+
#11C
|
285 |
+
/131I
|
286 |
+
Bi
|
287 |
+
\125I
|
288 |
+
=S@@
|
289 |
+
\S-1
|
290 |
+
6Br
|
291 |
+
7I
|
292 |
+
76Br
|
293 |
+
=B
|
294 |
+
eH1
|
295 |
+
\N-1
|
296 |
+
18O
|
297 |
+
127I
|
298 |
+
11CH2
|
299 |
+
14C@@H1
|
300 |
+
TeH2
|
301 |
+
15NH1
|
302 |
+
Bi+3
|
303 |
+
/P+1
|
304 |
+
/13C
|
305 |
+
/13CH1
|
306 |
+
0B
|
307 |
+
10B
|
308 |
+
=Al
|
309 |
+
=18O
|
310 |
+
BH0
|
311 |
+
F-1
|
312 |
+
NH3
|
313 |
+
S-2
|
314 |
+
Br+2
|
315 |
+
Cl+2
|
316 |
+
\Si
|
317 |
+
/S-1
|
318 |
+
=PH2
|
319 |
+
14C@H1
|
320 |
+
NH3+1
|
321 |
+
#14C
|
322 |
+
#O+1
|
323 |
+
-3
|
324 |
+
22
|
325 |
+
4H
|
326 |
+
5Se
|
327 |
+
5Sr+2
|
328 |
+
75Se
|
329 |
+
85Sr+2
|
330 |
+
=B-1
|
331 |
+
=13C
|
332 |
+
@-1
|
333 |
+
Be
|
334 |
+
B@@
|
335 |
+
B@-1
|
336 |
+
Ca
|
337 |
+
CH1
|
338 |
+
I+3
|
339 |
+
KH1
|
340 |
+
OH1+1
|
341 |
+
Ra+2
|
342 |
+
SH1-1
|
343 |
+
\PH1
|
344 |
+
\123I
|
345 |
+
=Ca
|
346 |
+
\CH1-1
|
347 |
+
=S@
|
348 |
+
\SeH1
|
349 |
+
/SeH1
|
350 |
+
Se-1
|
351 |
+
LiH1
|
352 |
+
18F-1
|
353 |
+
125IH1
|
354 |
+
11CH1
|
355 |
+
TeH1
|
356 |
+
Zn+1
|
357 |
+
Zn-2
|
358 |
+
Al-3
|
359 |
+
13CH3
|
360 |
+
15N
|
361 |
+
Be+2
|
362 |
+
B@@-1
|
363 |
+
#P
|
364 |
+
#S
|
365 |
+
-4
|
366 |
+
/PH1
|
367 |
+
/P@@
|
368 |
+
/As
|
369 |
+
/14C
|
370 |
+
/14CH1
|
371 |
+
2K+1
|
372 |
+
2Rb+1
|
373 |
+
3Se
|
374 |
+
3Ra+2
|
375 |
+
45
|
376 |
+
47
|
377 |
+
42K+1
|
378 |
+
5I-1
|
379 |
+
73Se
|
380 |
+
89
|
381 |
+
82Rb+1
|
382 |
+
=32
|
383 |
+
=32P
|
384 |
+
CH0
|
385 |
+
CH2
|
386 |
+
I+2
|
387 |
+
NH0
|
388 |
+
NH4
|
389 |
+
OH1
|
390 |
+
PH2+1
|
391 |
+
SH0
|
392 |
+
SH2
|
393 |
+
\3H
|
394 |
+
\11CH3
|
395 |
+
\C-1
|
396 |
+
\Se
|
397 |
+
Si@
|
398 |
+
Si-1
|
399 |
+
SiH1-1
|
400 |
+
SiH3-1
|
401 |
+
/Se
|
402 |
+
Se-2
|
403 |
+
\NH1-1
|
404 |
+
18FH1
|
405 |
+
125I-1
|
406 |
+
11C@@H1
|
407 |
+
11C-1
|
408 |
+
AsH1
|
409 |
+
As-1
|
410 |
+
14C@@
|
411 |
+
Te-1
|
412 |
+
Mg+1
|
413 |
+
123I-1
|
414 |
+
123Te
|
415 |
+
123IH1
|
416 |
+
135I
|
417 |
+
131I-1
|
418 |
+
Ag-4
|
419 |
+
124I-1
|
420 |
+
76BrH1
|
421 |
+
18OH1
|
422 |
+
22Na+1
|
423 |
+
223Ra+2
|
424 |
+
CaH2
|
425 |
+
45Ca+2
|
426 |
+
47Ca+2
|
427 |
+
89Sr+2
|
428 |
+
=32PH1
|
429 |
+
NH4+1
|
utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
class DrugTokenizer:
|
8 |
+
def __init__(self, vocab_path="tokenizer/vocab.json", special_tokens_path="tokenizer/special_tokens_map.json"):
|
9 |
+
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
10 |
+
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
11 |
+
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
12 |
+
self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
|
13 |
+
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
14 |
+
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
15 |
+
|
16 |
+
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
17 |
+
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
18 |
+
vocab = json.load(vocab_file)
|
19 |
+
with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
|
20 |
+
special_tokens_raw = json.load(special_tokens_file)
|
21 |
+
|
22 |
+
special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
|
23 |
+
return vocab, special_tokens
|
24 |
+
|
25 |
+
def encode(self, sequence):
|
26 |
+
tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
|
27 |
+
input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
|
28 |
+
attention_mask = [1] * len(input_ids)
|
29 |
+
return {
|
30 |
+
'input_ids': input_ids,
|
31 |
+
'attention_mask': attention_mask
|
32 |
+
}
|
33 |
+
|
34 |
+
def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
|
35 |
+
input_ids_list = []
|
36 |
+
attention_mask_list = []
|
37 |
+
|
38 |
+
for sequence in sequences:
|
39 |
+
encoded = self.encode(sequence)
|
40 |
+
input_ids = encoded['input_ids']
|
41 |
+
attention_mask = encoded['attention_mask']
|
42 |
+
|
43 |
+
if len(input_ids) > max_length:
|
44 |
+
input_ids = input_ids[:max_length]
|
45 |
+
attention_mask = attention_mask[:max_length]
|
46 |
+
elif len(input_ids) < max_length:
|
47 |
+
pad_length = max_length - len(input_ids)
|
48 |
+
input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
|
49 |
+
attention_mask = attention_mask + [0] * pad_length
|
50 |
+
|
51 |
+
input_ids_list.append(input_ids)
|
52 |
+
attention_mask_list.append(attention_mask)
|
53 |
+
|
54 |
+
return {
|
55 |
+
'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
|
56 |
+
'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
|
57 |
+
}
|
58 |
+
|
59 |
+
def decode(self, input_ids, skip_special_tokens=False):
|
60 |
+
tokens = []
|
61 |
+
for id in input_ids:
|
62 |
+
if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
|
63 |
+
continue
|
64 |
+
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
65 |
+
sequence = ''.join([f'[{token}]' for token in tokens])
|
66 |
+
return sequence
|
utils/__pycache__/drug_tokenizer.cpython-38.pyc
ADDED
Binary file (3.25 kB). View file
|
|
utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc
ADDED
Binary file (10.9 kB). View file
|
|
utils/drug_tokenizer.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
class DrugTokenizer:
|
8 |
+
def __init__(self, vocab_path="tokenizer/vocab.json", special_tokens_path="tokenizer/special_tokens_map.json"):
|
9 |
+
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
10 |
+
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
11 |
+
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
12 |
+
self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
|
13 |
+
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
14 |
+
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
15 |
+
|
16 |
+
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
17 |
+
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
18 |
+
vocab = json.load(vocab_file)
|
19 |
+
with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
|
20 |
+
special_tokens_raw = json.load(special_tokens_file)
|
21 |
+
|
22 |
+
special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
|
23 |
+
return vocab, special_tokens
|
24 |
+
|
25 |
+
def encode(self, sequence):
|
26 |
+
tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
|
27 |
+
input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
|
28 |
+
attention_mask = [1] * len(input_ids)
|
29 |
+
return {
|
30 |
+
'input_ids': input_ids,
|
31 |
+
'attention_mask': attention_mask
|
32 |
+
}
|
33 |
+
|
34 |
+
def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
|
35 |
+
input_ids_list = []
|
36 |
+
attention_mask_list = []
|
37 |
+
|
38 |
+
for sequence in sequences:
|
39 |
+
encoded = self.encode(sequence)
|
40 |
+
input_ids = encoded['input_ids']
|
41 |
+
attention_mask = encoded['attention_mask']
|
42 |
+
|
43 |
+
if len(input_ids) > max_length:
|
44 |
+
input_ids = input_ids[:max_length]
|
45 |
+
attention_mask = attention_mask[:max_length]
|
46 |
+
elif len(input_ids) < max_length:
|
47 |
+
pad_length = max_length - len(input_ids)
|
48 |
+
input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
|
49 |
+
attention_mask = attention_mask + [0] * pad_length
|
50 |
+
|
51 |
+
input_ids_list.append(input_ids)
|
52 |
+
attention_mask_list.append(attention_mask)
|
53 |
+
|
54 |
+
return {
|
55 |
+
'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
|
56 |
+
'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
|
57 |
+
}
|
58 |
+
|
59 |
+
def decode(self, input_ids, skip_special_tokens=False):
|
60 |
+
tokens = []
|
61 |
+
for id in input_ids:
|
62 |
+
if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
|
63 |
+
continue
|
64 |
+
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
65 |
+
sequence = ''.join([f'[{token}]' for token in tokens])
|
66 |
+
return sequence
|
utils/metric_learning_models_att_maps.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
sys.path.append("../")
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
from torch.nn import Module
|
12 |
+
from tqdm import tqdm
|
13 |
+
from torch.nn.utils.weight_norm import weight_norm
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
LOGGER = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
class FusionDTI(nn.Module):
|
19 |
+
def __init__(self, prot_out_dim, disease_out_dim, args):
|
20 |
+
super(FusionDTI, self).__init__()
|
21 |
+
self.fusion = args.fusion
|
22 |
+
self.drug_reg = nn.Linear(disease_out_dim, 512)
|
23 |
+
self.prot_reg = nn.Linear(prot_out_dim, 512)
|
24 |
+
|
25 |
+
if self.fusion == "CAN":
|
26 |
+
self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args)
|
27 |
+
self.mlp_classifier = MlPdecoder_CAN(input_dim=1024)
|
28 |
+
elif self.fusion == "BAN":
|
29 |
+
self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None)
|
30 |
+
self.mlp_classifier = MlPdecoder_CAN(input_dim=256)
|
31 |
+
elif self.fusion == "Nan":
|
32 |
+
self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214)
|
33 |
+
|
34 |
+
def forward(self, prot_embed, drug_embed, prot_mask, drug_mask):
|
35 |
+
# print("drug_embed", drug_embed.shape)
|
36 |
+
if self.fusion == "Nan":
|
37 |
+
prot_embed = prot_embed.mean(1) # query : [batch_size, hidden]
|
38 |
+
drug_embed = drug_embed.mean(1) # query : [batch_size, hidden]
|
39 |
+
joint_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
40 |
+
score = self.mlp_classifier_nan(joint_embed)
|
41 |
+
else:
|
42 |
+
prot_embed = self.prot_reg(prot_embed)
|
43 |
+
drug_embed = self.drug_reg(drug_embed)
|
44 |
+
|
45 |
+
if self.fusion == "CAN":
|
46 |
+
joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask)
|
47 |
+
elif self.fusion == "BAN":
|
48 |
+
joint_embed, att = self.ban_layer(prot_embed, drug_embed)
|
49 |
+
|
50 |
+
score = self.mlp_classifier(joint_embed)
|
51 |
+
|
52 |
+
return score, att
|
53 |
+
|
54 |
+
class Pre_encoded(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self, prot_encoder, drug_encoder, args
|
57 |
+
):
|
58 |
+
"""Constructor for the model.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
prot_encoder (_type_): Protein sturcture-aware sequence encoder.
|
62 |
+
drug_encoder (_type_): Drug SFLFIES encoder.
|
63 |
+
args (_type_): _description_
|
64 |
+
"""
|
65 |
+
super(Pre_encoded, self).__init__()
|
66 |
+
self.prot_encoder = prot_encoder
|
67 |
+
self.drug_encoder = drug_encoder
|
68 |
+
|
69 |
+
def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask):
|
70 |
+
# Process inputs through encoders
|
71 |
+
prot_embed = self.prot_encoder(
|
72 |
+
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
|
73 |
+
).logits
|
74 |
+
# prot_embed = self.prot_reg(prot_embed)
|
75 |
+
|
76 |
+
drug_embed = self.drug_encoder(
|
77 |
+
input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True
|
78 |
+
).last_hidden_state # .last_hidden_state
|
79 |
+
|
80 |
+
# print("drug_embed", drug_embed.shape)
|
81 |
+
|
82 |
+
return prot_embed, drug_embed
|
83 |
+
|
84 |
+
|
85 |
+
class CAN_Layer(nn.Module):
|
86 |
+
def __init__(self, hidden_dim, num_heads, args):
|
87 |
+
super(CAN_Layer, self).__init__()
|
88 |
+
self.agg_mode = args.agg_mode
|
89 |
+
self.group_size = args.group_size # Control Fusion Scale
|
90 |
+
self.hidden_dim = hidden_dim
|
91 |
+
self.num_heads = num_heads
|
92 |
+
self.head_size = hidden_dim // num_heads
|
93 |
+
|
94 |
+
self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
95 |
+
self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
96 |
+
self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
97 |
+
|
98 |
+
self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
99 |
+
self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
100 |
+
self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
101 |
+
|
102 |
+
def alpha_logits(self, logits, mask_row, mask_col, inf=1e6):
|
103 |
+
N, L1, L2, H = logits.shape
|
104 |
+
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
|
105 |
+
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
|
106 |
+
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
|
107 |
+
|
108 |
+
logits = torch.where(mask_pair, logits, logits - inf)
|
109 |
+
alpha = torch.softmax(logits, dim=2)
|
110 |
+
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
|
111 |
+
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
|
112 |
+
return alpha
|
113 |
+
|
114 |
+
def apply_heads(self, x, n_heads, n_ch):
|
115 |
+
s = list(x.size())[:-1] + [n_heads, n_ch]
|
116 |
+
return x.view(*s)
|
117 |
+
|
118 |
+
def group_embeddings(self, x, mask, group_size):
|
119 |
+
N, L, D = x.shape
|
120 |
+
groups = L // group_size
|
121 |
+
x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
|
122 |
+
mask_grouped = mask.view(N, groups, group_size).any(dim=2)
|
123 |
+
return x_grouped, mask_grouped
|
124 |
+
|
125 |
+
def forward(self, protein, drug, mask_prot, mask_drug):
|
126 |
+
# Group embeddings before applying multi-head attention
|
127 |
+
protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size)
|
128 |
+
drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size)
|
129 |
+
|
130 |
+
# print("protein_grouped:", protein_grouped.shape)
|
131 |
+
# print("mask_prot_grouped:", mask_prot_grouped.shape)
|
132 |
+
|
133 |
+
# Compute queries, keys, values for both protein and drug after grouping
|
134 |
+
query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size)
|
135 |
+
key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size)
|
136 |
+
value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size)
|
137 |
+
|
138 |
+
query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size)
|
139 |
+
key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size)
|
140 |
+
value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size)
|
141 |
+
|
142 |
+
# Compute attention scores
|
143 |
+
logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
|
144 |
+
logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug)
|
145 |
+
logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot)
|
146 |
+
logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug)
|
147 |
+
# print("logits_pp:", logits_pp.shape)
|
148 |
+
|
149 |
+
alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
|
150 |
+
alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped)
|
151 |
+
alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped)
|
152 |
+
alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped)
|
153 |
+
|
154 |
+
prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) +
|
155 |
+
torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2
|
156 |
+
drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) +
|
157 |
+
torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2
|
158 |
+
|
159 |
+
# print("prot_embedding:", prot_embedding.shape)
|
160 |
+
|
161 |
+
# Continue as usual with the aggregation mode
|
162 |
+
if self.agg_mode == "cls":
|
163 |
+
prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
|
164 |
+
drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden]
|
165 |
+
elif self.agg_mode == "mean_all_tok":
|
166 |
+
prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
|
167 |
+
drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden]
|
168 |
+
elif self.agg_mode == "mean":
|
169 |
+
prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
|
170 |
+
drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1)
|
171 |
+
else:
|
172 |
+
raise NotImplementedError()
|
173 |
+
|
174 |
+
# print("prot_embed:", prot_embed.shape)
|
175 |
+
|
176 |
+
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
177 |
+
|
178 |
+
|
179 |
+
att = torch.zeros(1, 1, 1024, 1024)
|
180 |
+
att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein
|
181 |
+
att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug
|
182 |
+
att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein
|
183 |
+
att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug
|
184 |
+
|
185 |
+
# print("query_embed:", query_embed.shape)
|
186 |
+
return query_embed, att
|
187 |
+
|
188 |
+
class MlPdecoder_CAN(nn.Module):
|
189 |
+
def __init__(self, input_dim):
|
190 |
+
super(MlPdecoder_CAN, self).__init__()
|
191 |
+
self.fc1 = nn.Linear(input_dim, input_dim)
|
192 |
+
self.bn1 = nn.BatchNorm1d(input_dim)
|
193 |
+
self.fc2 = nn.Linear(input_dim, input_dim // 2)
|
194 |
+
self.bn2 = nn.BatchNorm1d(input_dim // 2)
|
195 |
+
self.fc3 = nn.Linear(input_dim // 2, input_dim // 4)
|
196 |
+
self.bn3 = nn.BatchNorm1d(input_dim // 4)
|
197 |
+
self.output = nn.Linear(input_dim // 4, 1)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
x = self.bn1(torch.relu(self.fc1(x)))
|
201 |
+
x = self.bn2(torch.relu(self.fc2(x)))
|
202 |
+
x = self.bn3(torch.relu(self.fc3(x)))
|
203 |
+
x = torch.sigmoid(self.output(x))
|
204 |
+
return x
|
205 |
+
|
206 |
+
class MLPdecoder_BAN(nn.Module):
|
207 |
+
def __init__(self, in_dim, hidden_dim, out_dim, binary=1):
|
208 |
+
super(MLPdecoder_BAN, self).__init__()
|
209 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
210 |
+
self.bn1 = nn.BatchNorm1d(hidden_dim)
|
211 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
212 |
+
self.bn2 = nn.BatchNorm1d(hidden_dim)
|
213 |
+
self.fc3 = nn.Linear(hidden_dim, out_dim)
|
214 |
+
self.bn3 = nn.BatchNorm1d(out_dim)
|
215 |
+
self.fc4 = nn.Linear(out_dim, binary)
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
x = self.bn1(F.relu(self.fc1(x)))
|
219 |
+
x = self.bn2(F.relu(self.fc2(x)))
|
220 |
+
x = self.bn3(F.relu(self.fc3(x)))
|
221 |
+
# x = self.fc4(x)
|
222 |
+
x = torch.sigmoid(self.fc4(x))
|
223 |
+
return x
|
224 |
+
|
225 |
+
class BANLayer(nn.Module):
|
226 |
+
""" Bilinear attention network
|
227 |
+
Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py
|
228 |
+
"""
|
229 |
+
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
|
230 |
+
super(BANLayer, self).__init__()
|
231 |
+
|
232 |
+
self.c = 32
|
233 |
+
self.k = k
|
234 |
+
self.v_dim = v_dim
|
235 |
+
self.q_dim = q_dim
|
236 |
+
self.h_dim = h_dim
|
237 |
+
self.h_out = h_out
|
238 |
+
|
239 |
+
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
|
240 |
+
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
|
241 |
+
# self.dropout = nn.Dropout(dropout[1])
|
242 |
+
if 1 < k:
|
243 |
+
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
|
244 |
+
|
245 |
+
if h_out <= self.c:
|
246 |
+
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
|
247 |
+
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
|
248 |
+
else:
|
249 |
+
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
|
250 |
+
|
251 |
+
self.bn = nn.BatchNorm1d(h_dim)
|
252 |
+
|
253 |
+
def attention_pooling(self, v, q, att_map):
|
254 |
+
fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
|
255 |
+
if 1 < self.k:
|
256 |
+
fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
|
257 |
+
fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
|
258 |
+
return fusion_logits
|
259 |
+
|
260 |
+
def forward(self, v, q, softmax=False):
|
261 |
+
v_num = v.size(1)
|
262 |
+
q_num = q.size(1)
|
263 |
+
# print("v_num", v_num)
|
264 |
+
# print("v_num ", v_num)
|
265 |
+
if self.h_out <= self.c:
|
266 |
+
v_ = self.v_net(v)
|
267 |
+
q_ = self.q_net(q)
|
268 |
+
# print("v_", v_.shape)
|
269 |
+
# print("q_ ", q_.shape)
|
270 |
+
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
|
271 |
+
# print("Attention map_1",att_maps.shape)
|
272 |
+
else:
|
273 |
+
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
|
274 |
+
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
|
275 |
+
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
|
276 |
+
att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
|
277 |
+
att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
|
278 |
+
# print("Attention map_2",att_maps.shape)
|
279 |
+
if softmax:
|
280 |
+
p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
|
281 |
+
att_maps = p.view(-1, self.h_out, v_num, q_num)
|
282 |
+
# print("Attention map_softmax", att_maps.shape)
|
283 |
+
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
|
284 |
+
for i in range(1, self.h_out):
|
285 |
+
logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
|
286 |
+
logits += logits_i
|
287 |
+
logits = self.bn(logits)
|
288 |
+
return logits, att_maps
|
289 |
+
|
290 |
+
|
291 |
+
class FCNet(nn.Module):
|
292 |
+
"""Simple class for non-linear fully connect network
|
293 |
+
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(self, dims, act='ReLU', dropout=0):
|
297 |
+
super(FCNet, self).__init__()
|
298 |
+
|
299 |
+
layers = []
|
300 |
+
for i in range(len(dims) - 2):
|
301 |
+
in_dim = dims[i]
|
302 |
+
out_dim = dims[i + 1]
|
303 |
+
if 0 < dropout:
|
304 |
+
layers.append(nn.Dropout(dropout))
|
305 |
+
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
|
306 |
+
if '' != act:
|
307 |
+
layers.append(getattr(nn, act)())
|
308 |
+
if 0 < dropout:
|
309 |
+
layers.append(nn.Dropout(dropout))
|
310 |
+
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
|
311 |
+
if '' != act:
|
312 |
+
layers.append(getattr(nn, act)())
|
313 |
+
|
314 |
+
self.main = nn.Sequential(*layers)
|
315 |
+
|
316 |
+
def forward(self, x):
|
317 |
+
return self.main(x)
|
318 |
+
|
319 |
+
|
320 |
+
class BatchFileDataset_Case(Dataset):
|
321 |
+
def __init__(self, file_list):
|
322 |
+
self.file_list = file_list
|
323 |
+
|
324 |
+
def __len__(self):
|
325 |
+
return len(self.file_list)
|
326 |
+
|
327 |
+
def __getitem__(self, idx):
|
328 |
+
batch_file = self.file_list[idx]
|
329 |
+
data = torch.load(batch_file)
|
330 |
+
return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']
|