ML6-UniKP / dataset.py
Topallaj Denis
copied the unikp model into this endpoint
c7272f2
raw
history blame
No virus
1.57 kB
import random
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from enumerator import SmilesEnumerator
from utils import split
PAD = 0
MAX_LEN = 220
class Randomizer(object):
def __init__(self):
self.sme = SmilesEnumerator()
def __call__(self, sm):
sm_r = self.sme.randomize_smiles(sm) # Random transoform
if sm_r is None:
sm_spaced = split(sm) # Spacing
else:
sm_spaced = split(sm_r) # Spacing
sm_split = sm_spaced.split()
if len(sm_split)<=MAX_LEN - 2:
return sm_split # List
else:
return split(sm).split()
def random_transform(self, sm):
'''
function: Random transformation for SMILES. It may take some time.
input: A SMILES
output: A randomized SMILES
'''
return self.sme.randomize_smiles(sm)
class Seq2seqDataset(Dataset):
def __init__(self, smiles, vocab, seq_len=220, transform=Randomizer()):
self.smiles = smiles
self.vocab = vocab
self.seq_len = seq_len
self.transform = transform
def __len__(self):
return len(self.smiles)
def __getitem__(self, item):
sm = self.smiles[item]
sm = self.transform(sm) # List
content = [self.vocab.stoi.get(token, self.vocab.unk_index) for token in sm]
X = [self.vocab.sos_index] + content + [self.vocab.eos_index]
padding = [self.vocab.pad_index]*(self.seq_len - len(X))
X.extend(padding)
return torch.tensor(X)