import torch from collections import OrderedDict from torch.nn import utils, functional as F from torch.optim import Adam, SGD from torch.autograd import Variable from torch.backends import cudnn from model import build_model, weights_init import scipy.misc as sm import numpy as np import os import torchvision.utils as vutils import cv2 import torch.nn.functional as F import math import time import sys import PIL.Image import scipy.io import os import logging EPSILON = 1e-8 p = OrderedDict() from dataset import get_loader base_model_cfg = 'resnet' p['lr_bone'] = 5e-5 # Learning rate resnet:5e-5, vgg:2e-5 p['lr_branch'] = 0.025 # Learning rate p['wd'] = 0.0005 # Weight decay p['momentum'] = 0.90 # Momentum lr_decay_epoch = [15, 24] # [6, 9], now x3 #15 nAveGrad = 10 # Update the weights once in 'nAveGrad' forward passes showEvery = 50 tmp_path = 'tmp_see' class Solver(object): def __init__(self, train_loader, test_loader, config, save_fold=None): self.train_loader = train_loader self.test_loader = test_loader self.config = config self.save_fold = save_fold self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255. # inference: choose the side map (see paper) if config.visdom: self.visual = Viz_visdom("trueUnify", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: print('Loading pre-trained model from %s...' % self.config.model) self.net_bone.load_state_dict(torch.load(self.config.model)) self.net_bone.eval() def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) def get_params(self, base_lr): ml = [] for name, module in self.net_bone.named_children(): print(name) if name == 'loss_weight': ml.append({'params': module.parameters(), 'lr': p['lr_branch']}) else: ml.append({'params': module.parameters()}) return ml # build the network def build_model(self): self.net_bone = build_model(base_model_cfg) if self.config.cuda: self.net_bone = self.net_bone.cuda() self.net_bone.eval() # use_global_stats = True self.net_bone.apply(weights_init) if self.config.mode == 'train': if self.config.load_bone == '': if base_model_cfg == 'vgg': self.net_bone.base.load_pretrained_model(torch.load(self.config.vgg)) elif base_model_cfg == 'resnet': self.net_bone.base.load_state_dict(torch.load(self.config.resnet)) if self.config.load_bone != '': self.net_bone.load_state_dict(torch.load(self.config.load_bone)) self.lr_bone = p['lr_bone'] self.lr_branch = p['lr_branch'] self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd']) self.print_network(self.net_bone, 'trueUnify bone part') # update the learning rate def update_lr(self, rate): for param_group in self.optimizer.param_groups: param_group['lr'] = param_group['lr'] * rate def test(self, test_mode=0): EPSILON = 1e-8 img_num = len(self.test_loader) time_t = 0.0 name_t = 'EGNet_ResNet50/' if not os.path.exists(os.path.join(self.save_fold, name_t)): os.mkdir(os.path.join(self.save_fold, name_t)) for i, data_batch in enumerate(self.test_loader): self.config.test_fold = self.save_fold print(self.config.test_fold) images_, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size']) with torch.no_grad(): images = Variable(images_) if self.config.cuda: images = images.cuda() print(images.size()) time_start = time.time() up_edge, up_sal, up_sal_f = self.net_bone(images) torch.cuda.synchronize() time_end = time.time() print(time_end - time_start) time_t = time_t + time_end - time_start pred = np.squeeze(torch.sigmoid(up_sal_f[-1]).cpu().data.numpy()) multi_fuse = 255 * pred cv2.imwrite(os.path.join(self.config.test_fold,name_t, name[:-4] + '.png'), multi_fuse) print("--- %s seconds ---" % (time_t)) print('Test Done!') # training phase def train(self): iter_num = len(self.train_loader.dataset) // self.config.batch_size aveGrad = 0 F_v = 0 if not os.path.exists(tmp_path): os.mkdir(tmp_path) for epoch in range(self.config.epoch): r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 self.net_bone.zero_grad() for i, data_batch in enumerate(self.train_loader): sal_image, sal_label, sal_edge = data_batch['sal_image'], data_batch['sal_label'], data_batch['sal_edge'] if sal_image.size()[2:] != sal_label.size()[2:]: print("Skip this batch") continue sal_image, sal_label, sal_edge = Variable(sal_image), Variable(sal_label), Variable(sal_edge) if self.config.cuda: sal_image, sal_label, sal_edge = sal_image.cuda(), sal_label.cuda(), sal_edge.cuda() up_edge, up_sal, up_sal_f = self.net_bone(sal_image) # edge part edge_loss = [] for ix in up_edge: edge_loss.append(bce2d_new(ix, sal_edge, reduction='sum')) edge_loss = sum(edge_loss) / (nAveGrad * self.config.batch_size) r_edge_loss += edge_loss.data # sal part sal_loss1= [] sal_loss2 = [] for ix in up_sal: sal_loss1.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum')) for ix in up_sal_f: sal_loss2.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum')) sal_loss = (sum(sal_loss1) + sum(sal_loss2)) / (nAveGrad * self.config.batch_size) r_sal_loss += sal_loss.data loss = sal_loss + edge_loss r_sum_loss += loss.data loss.backward() aveGrad += 1 if aveGrad % nAveGrad == 0: self.optimizer_bone.step() self.optimizer_bone.zero_grad() aveGrad = 0 if i % showEvery == 0: print('epoch: [%2d/%2d], iter: [%5d/%5d] || Edge : %10.4f || Sal : %10.4f || Sum : %10.4f' % ( epoch, self.config.epoch, i, iter_num, r_edge_loss*(nAveGrad * self.config.batch_size)/showEvery, r_sal_loss*(nAveGrad * self.config.batch_size)/showEvery, r_sum_loss*(nAveGrad * self.config.batch_size)/showEvery)) print('Learning rate: ' + str(self.lr_bone)) r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 if i % 200 == 0: vutils.save_image(torch.sigmoid(up_sal_f[-1].data), tmp_path+'/iter%d-sal-0.jpg' % i, normalize=True, padding = 0) vutils.save_image(sal_image.data, tmp_path+'/iter%d-sal-data.jpg' % i, padding = 0) vutils.save_image(sal_label.data, tmp_path+'/iter%d-sal-target.jpg' % i, padding = 0) if (epoch + 1) % self.config.epoch_save == 0: torch.save(self.net_bone.state_dict(), '%s/models/epoch_%d_bone.pth' % (self.config.save_fold, epoch + 1)) if epoch in lr_decay_epoch: self.lr_bone = self.lr_bone * 0.1 self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd']) torch.save(self.net_bone.state_dict(), '%s/models/final_bone.pth' % self.config.save_fold) def bce2d_new(input, target, reduction=None): assert(input.size() == target.size()) pos = torch.eq(target, 1).float() neg = torch.eq(target, 0).float() # ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float() num_pos = torch.sum(pos) num_neg = torch.sum(neg) num_total = num_pos + num_neg alpha = num_neg / num_total beta = 1.1 * num_pos / num_total # target pixel = 1 -> weight beta # target pixel = 0 -> weight 1-beta weights = alpha * pos + beta * neg return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction)