Spaces:
Runtime error
Runtime error
import argparse | |
import math | |
import os | |
import numpy as np | |
import pandas as pd | |
import torch | |
from torch import nn | |
from torch import optim | |
from torch.autograd import Variable | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from build_vocab import WordVocab | |
from dataset import Seq2seqDataset | |
PAD = 0 | |
UNK = 1 | |
EOS = 2 | |
SOS = 3 | |
MASK = 4 | |
class PositionalEncoding(nn.Module): | |
"Implement the PE function. No batch support?" | |
def __init__(self, d_model, dropout, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
# Compute the positional encodings once in log space. | |
pe = torch.zeros(max_len, d_model) # (T,H) | |
position = torch.arange(0., max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + Variable(self.pe[:, :x.size(1)], | |
requires_grad=False) | |
return self.dropout(x) | |
class TrfmSeq2seq(nn.Module): | |
def __init__(self, in_size, hidden_size, out_size, n_layers, dropout=0.1): | |
super(TrfmSeq2seq, self).__init__() | |
self.in_size = in_size | |
self.hidden_size = hidden_size | |
self.embed = nn.Embedding(in_size, hidden_size) | |
self.pe = PositionalEncoding(hidden_size, dropout) | |
self.trfm = nn.Transformer(d_model=hidden_size, nhead=4, | |
num_encoder_layers=n_layers, num_decoder_layers=n_layers, dim_feedforward=hidden_size) | |
self.out = nn.Linear(hidden_size, out_size) | |
def forward(self, src): | |
# src: (T,B) | |
embedded = self.embed(src) # (T,B,H) | |
embedded = self.pe(embedded) # (T,B,H) | |
hidden = self.trfm(embedded, embedded) # (T,B,H) | |
out = self.out(hidden) # (T,B,V) | |
out = F.log_softmax(out, dim=2) # (T,B,V) | |
return out # (T,B,V) | |
def _encode(self, src): | |
# src: (T,B) | |
embedded = self.embed(src) # (T,B,H) | |
embedded = self.pe(embedded) # (T,B,H) | |
output = embedded | |
for i in range(self.trfm.encoder.num_layers - 1): | |
output = self.trfm.encoder.layers[i](output, None) # (T,B,H) | |
penul = output.detach().numpy() | |
output = self.trfm.encoder.layers[-1](output, None) # (T,B,H) | |
if self.trfm.encoder.norm: | |
output = self.trfm.encoder.norm(output) # (T,B,H) | |
output = output.detach().numpy() | |
# mean, max, first*2 | |
return np.hstack([np.mean(output, axis=0), np.max(output, axis=0), output[0,:,:], penul[0,:,:] ]) # (B,4H) | |
def encode(self, src): | |
# src: (T,B) | |
batch_size = src.shape[1] | |
if batch_size<=100: | |
return self._encode(src) | |
else: # Batch is too large to load | |
print('There are {:d} molecules. It will take a little time.'.format(batch_size)) | |
st,ed = 0,100 | |
out = self._encode(src[:,st:ed]) # (B,4H) | |
while ed<batch_size: | |
st += 100 | |
ed += 100 | |
out = np.concatenate([out, self._encode(src[:,st:ed])], axis=0) | |
return out | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description='Hyperparams') | |
parser.add_argument('--n_epoch', '-e', type=int, default=5, help='number of epochs') | |
parser.add_argument('--vocab', '-v', type=str, default='data/vocab.pkl', help='vocabulary (.pkl)') | |
parser.add_argument('--data', '-d', type=str, default='data/chembl_25.csv', help='train corpus (.csv)') | |
parser.add_argument('--out-dir', '-o', type=str, default='../result', help='output directory') | |
parser.add_argument('--name', '-n', type=str, default='ST', help='model name') | |
parser.add_argument('--seq_len', type=int, default=220, help='maximum length of the paired seqence') | |
parser.add_argument('--batch_size', '-b', type=int, default=8, help='batch size') | |
parser.add_argument('--n_worker', '-w', type=int, default=16, help='number of workers') | |
parser.add_argument('--hidden', type=int, default=256, help='length of hidden vector') | |
parser.add_argument('--n_layer', '-l', type=int, default=4, help='number of layers') | |
parser.add_argument('--n_head', type=int, default=4, help='number of attention heads') | |
parser.add_argument('--lr', type=float, default=1e-4, help='Adam learning rate') | |
parser.add_argument('--gpu', metavar='N', type=int, nargs='+', help='list of GPU IDs to use') | |
return parser.parse_args() | |
def evaluate(model, test_loader, vocab): | |
model.eval() | |
total_loss = 0 | |
for b, sm in enumerate(test_loader): | |
sm = torch.t(sm.cuda()) # (T,B) | |
with torch.no_grad(): | |
output = model(sm) # (T,B,V) | |
loss = F.nll_loss(output.view(-1, len(vocab)), | |
sm.contiguous().view(-1), | |
ignore_index=PAD) | |
total_loss += loss.item() | |
return total_loss / len(test_loader) | |
def main(): | |
args = parse_arguments() | |
assert torch.cuda.is_available() | |
print('Loading dataset...') | |
vocab = WordVocab.load_vocab(args.vocab) | |
dataset = Seq2seqDataset(pd.read_csv(args.data)['canonical_smiles'].values, vocab) | |
test_size = 10000 | |
train, test = torch.utils.data.random_split(dataset, [len(dataset)-test_size, test_size]) | |
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=args.n_worker) | |
test_loader = DataLoader(test, batch_size=args.batch_size, shuffle=False, num_workers=args.n_worker) | |
print('Train size:', len(train)) | |
print('Test size:', len(test)) | |
del dataset, train, test | |
model = TrfmSeq2seq(len(vocab), args.hidden, len(vocab), args.n_layer).cuda() | |
optimizer = optim.Adam(model.parameters(), lr=args.lr) | |
print(model) | |
print('Total parameters:', sum(p.numel() for p in model.parameters())) | |
best_loss = None | |
for e in range(1, args.n_epoch): | |
for b, sm in tqdm(enumerate(train_loader)): | |
sm = torch.t(sm.cuda()) # (T,B) | |
optimizer.zero_grad() | |
output = model(sm) # (T,B,V) | |
loss = F.nll_loss(output.view(-1, len(vocab)), | |
sm.contiguous().view(-1), ignore_index=PAD) | |
loss.backward() | |
optimizer.step() | |
if b%1000==0: | |
print('Train {:3d}: iter {:5d} | loss {:.3f} | ppl {:.3f}'.format(e, b, loss.item(), math.exp(loss.item()))) | |
if b%10000==0: | |
loss = evaluate(model, test_loader, vocab) | |
print('Val {:3d}: iter {:5d} | loss {:.3f} | ppl {:.3f}'.format(e, b, loss, math.exp(loss))) | |
# Save the model if the validation loss is the best we've seen so far. | |
if not best_loss or loss < best_loss: | |
print("[!] saving model...") | |
if not os.path.isdir(".save"): | |
os.makedirs(".save") | |
torch.save(model.state_dict(), './.save/trfm_new_%d_%d.pkl' % (e,b)) | |
best_loss = loss | |
if __name__ == "__main__": | |
try: | |
main() | |
except KeyboardInterrupt as e: | |
print("[STOP]", e) | |