Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
from __future__ import absolute_import, division, print_function | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from utils.dataset import GraphDataset | |
from utils.lr_scheduler import LR_Scheduler | |
from tensorboardX import SummaryWriter | |
from helper import Trainer, Evaluator, collate | |
from option import Options | |
from models.GraphTransformer import Classifier | |
from models.weight_init import weight_init | |
import pickle | |
args = Options().parse() | |
label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb')) | |
n_class = len(label_map) | |
torch.cuda.synchronize() | |
torch.backends.cudnn.deterministic = True | |
data_path = args.data_path | |
model_path = args.model_path | |
if not os.path.isdir(model_path): os.mkdir(model_path) | |
log_path = args.log_path | |
if not os.path.isdir(log_path): os.mkdir(log_path) | |
task_name = args.task_name | |
print(task_name) | |
################################### | |
train = args.train | |
test = args.test | |
graphcam = args.graphcam | |
print("train:", train, "test:", test, "graphcam:", graphcam) | |
##### Load datasets | |
print("preparing datasets and dataloaders......") | |
batch_size = args.batch_size | |
if train: | |
ids_train = open(args.train_set).readlines() | |
dataset_train = GraphDataset(os.path.join(data_path, ""), ids_train, args.dataset_metadata_path) | |
dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=True, pin_memory=True, drop_last=True) | |
total_train_num = len(dataloader_train) * batch_size | |
ids_val = open(args.val_set).readlines() | |
dataset_val = GraphDataset(os.path.join(data_path, ""), ids_val, args.dataset_metadata_path) | |
dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True) | |
total_val_num = len(dataloader_val) * batch_size | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
##### creating models ############# | |
print("creating models......") | |
num_epochs = args.num_epochs | |
learning_rate = args.lr | |
model = Classifier(n_class) | |
model = nn.DataParallel(model) | |
if args.resume: | |
print('load model{}'.format(args.resume)) | |
model.load_state_dict(torch.load(args.resume)) | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
#model.apply(weight_init) | |
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 5e-4) # best:5e-4, 4e-3 | |
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,100], gamma=0.1) # gamma=0.3 # 30,90,130 # 20,90,130 -> 150 | |
################################## | |
criterion = nn.CrossEntropyLoss() | |
if not test: | |
writer = SummaryWriter(log_dir=log_path + task_name) | |
f_log = open(log_path + task_name + ".log", 'w') | |
trainer = Trainer(n_class) | |
evaluator = Evaluator(n_class) | |
best_pred = 0.0 | |
for epoch in range(num_epochs): | |
# optimizer.zero_grad() | |
model.train() | |
train_loss = 0. | |
total = 0. | |
current_lr = optimizer.param_groups[0]['lr'] | |
print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' % (epoch+1, current_lr, best_pred)) | |
if train: | |
for i_batch, sample_batched in enumerate(dataloader_train): | |
scheduler.step(epoch) | |
preds,labels,loss = trainer.train(sample_batched, model) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
train_loss += loss | |
total += len(labels) | |
trainer.metrics.update(labels, preds) | |
if (i_batch + 1) % args.log_interval_local == 0: | |
print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total, total_train_num, train_loss / total, trainer.get_scores())) | |
trainer.plot_cm() | |
if not test: | |
print("[%d/%d] train loss: %.3f; agg acc: %.3f" % (total_train_num, total_train_num, train_loss / total, trainer.get_scores())) | |
trainer.plot_cm() | |
if epoch % 1 == 0: | |
with torch.no_grad(): | |
model.eval() | |
print("evaluating...") | |
total = 0. | |
batch_idx = 0 | |
for i_batch, sample_batched in enumerate(dataloader_val): | |
preds, labels, _ = evaluator.eval_test(sample_batched, model, graphcam) | |
total += len(labels) | |
evaluator.metrics.update(labels, preds) | |
if (i_batch + 1) % args.log_interval_local == 0: | |
print('[%d/%d] val agg acc: %.3f' % (total, total_val_num, evaluator.get_scores())) | |
evaluator.plot_cm() | |
print('[%d/%d] val agg acc: %.3f' % (total_val_num, total_val_num, evaluator.get_scores())) | |
evaluator.plot_cm() | |
# torch.cuda.empty_cache() | |
val_acc = evaluator.get_scores() | |
if val_acc > best_pred: | |
best_pred = val_acc | |
if not test: | |
print("saving model...") | |
torch.save(model.state_dict(), model_path + task_name + ".pth") | |
log = "" | |
log = log + 'epoch [{}/{}] ------ acc: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, trainer.get_scores(), evaluator.get_scores()) + "\n" | |
log += "================================\n" | |
print(log) | |
if test: break | |
f_log.write(log) | |
f_log.flush() | |
writer.add_scalars('accuracy', {'train acc': trainer.get_scores(), 'val acc': evaluator.get_scores()}, epoch+1) | |
trainer.reset_metrics() | |
evaluator.reset_metrics() | |
if not test: f_log.close() |