Spaces:
Runtime error
Runtime error
import random | |
import pandas as pd | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from enumerator import SmilesEnumerator | |
from utils import split | |
PAD = 0 | |
MAX_LEN = 220 | |
class Randomizer(object): | |
def __init__(self): | |
self.sme = SmilesEnumerator() | |
def __call__(self, sm): | |
sm_r = self.sme.randomize_smiles(sm) # Random transoform | |
if sm_r is None: | |
sm_spaced = split(sm) # Spacing | |
else: | |
sm_spaced = split(sm_r) # Spacing | |
sm_split = sm_spaced.split() | |
if len(sm_split)<=MAX_LEN - 2: | |
return sm_split # List | |
else: | |
return split(sm).split() | |
def random_transform(self, sm): | |
''' | |
function: Random transformation for SMILES. It may take some time. | |
input: A SMILES | |
output: A randomized SMILES | |
''' | |
return self.sme.randomize_smiles(sm) | |
class Seq2seqDataset(Dataset): | |
def __init__(self, smiles, vocab, seq_len=220, transform=Randomizer()): | |
self.smiles = smiles | |
self.vocab = vocab | |
self.seq_len = seq_len | |
self.transform = transform | |
def __len__(self): | |
return len(self.smiles) | |
def __getitem__(self, item): | |
sm = self.smiles[item] | |
sm = self.transform(sm) # List | |
content = [self.vocab.stoi.get(token, self.vocab.unk_index) for token in sm] | |
X = [self.vocab.sos_index] + content + [self.vocab.eos_index] | |
padding = [self.vocab.pad_index]*(self.seq_len - len(X)) | |
X.extend(padding) | |
return torch.tensor(X) | |