Spaces:
Runtime error
Runtime error
File size: 1,573 Bytes
c7272f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import random
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from enumerator import SmilesEnumerator
from utils import split
PAD = 0
MAX_LEN = 220
class Randomizer(object):
def __init__(self):
self.sme = SmilesEnumerator()
def __call__(self, sm):
sm_r = self.sme.randomize_smiles(sm) # Random transoform
if sm_r is None:
sm_spaced = split(sm) # Spacing
else:
sm_spaced = split(sm_r) # Spacing
sm_split = sm_spaced.split()
if len(sm_split)<=MAX_LEN - 2:
return sm_split # List
else:
return split(sm).split()
def random_transform(self, sm):
'''
function: Random transformation for SMILES. It may take some time.
input: A SMILES
output: A randomized SMILES
'''
return self.sme.randomize_smiles(sm)
class Seq2seqDataset(Dataset):
def __init__(self, smiles, vocab, seq_len=220, transform=Randomizer()):
self.smiles = smiles
self.vocab = vocab
self.seq_len = seq_len
self.transform = transform
def __len__(self):
return len(self.smiles)
def __getitem__(self, item):
sm = self.smiles[item]
sm = self.transform(sm) # List
content = [self.vocab.stoi.get(token, self.vocab.unk_index) for token in sm]
X = [self.vocab.sos_index] + content + [self.vocab.eos_index]
padding = [self.vocab.pad_index]*(self.seq_len - len(X))
X.extend(padding)
return torch.tensor(X)
|