ML6-UniKP / utils.py
Topallaj Denis
copied the unikp model into this endpoint
c7272f2
raw
history blame
No virus
5.91 kB
import torch
import math
import torch.nn as nn
from rdkit import Chem
from rdkit import rdBase
rdBase.DisableLog('rdApp.*')
# Split SMILES into words
def split(sm):
'''
function: Split SMILES into words. Care for Cl, Br, Si, Se, Na etc.
input: A SMILES
output: A string with space between words
'''
arr = []
i = 0
while i < len(sm)-1:
if not sm[i] in ['%', 'C', 'B', 'S', 'N', 'R', 'X', 'L', 'A', 'M', \
'T', 'Z', 's', 't', 'H', '+', '-', 'K', 'F']:
arr.append(sm[i])
i += 1
elif sm[i]=='%':
arr.append(sm[i:i+3])
i += 3
elif sm[i]=='C' and sm[i+1]=='l':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='C' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='C' and sm[i+1]=='u':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='r':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='B' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='S' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='S' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='S' and sm[i+1]=='r':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='N' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='N' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='R' and sm[i+1]=='b':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='R' and sm[i+1]=='a':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='X' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='L' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='l':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='s':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='g':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='A' and sm[i+1]=='u':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='M' and sm[i+1]=='g':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='M' and sm[i+1]=='n':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='T' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='Z' and sm[i+1]=='n':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='s' and sm[i+1]=='i':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='s' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='t' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='H' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='+' and sm[i+1]=='2':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='+' and sm[i+1]=='3':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='+' and sm[i+1]=='4':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='-' and sm[i+1]=='2':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='-' and sm[i+1]=='3':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='-' and sm[i+1]=='4':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='K' and sm[i+1]=='r':
arr.append(sm[i:i+2])
i += 2
elif sm[i]=='F' and sm[i+1]=='e':
arr.append(sm[i:i+2])
i += 2
else:
arr.append(sm[i])
i += 1
if i == len(sm)-1:
arr.append(sm[i])
return ' '.join(arr)
# 活性化関数
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
# 位置情報を考慮したFFN
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = GELU()
def forward(self, x):
return self.w_2(self.dropout(self.activation(self.w_1(x))))
# 正規化層
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class SublayerConnection(nn.Module):
def __init__(self, size, dropout):
super(SublayerConnection, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
# Sample SMILES from probablistic distribution
def sample(msms):
ret = []
for msm in msms:
ret.append(torch.multinomial(msm.exp(), 1).squeeze())
return torch.stack(ret)
def validity(smiles):
loss = 0
for sm in smiles:
mol = Chem.MolFromSmiles(sm)
if mol is None:
loss += 1
return 1-loss/len(smiles)