FRESCO / src /EGNet /solver.py
SingleZombie
upload files
ff715ca
raw
history blame contribute delete
No virus
9.46 kB
import torch
from collections import OrderedDict
from torch.nn import utils, functional as F
from torch.optim import Adam, SGD
from torch.autograd import Variable
from torch.backends import cudnn
from model import build_model, weights_init
import scipy.misc as sm
import numpy as np
import os
import torchvision.utils as vutils
import cv2
import torch.nn.functional as F
import math
import time
import sys
import PIL.Image
import scipy.io
import os
import logging
EPSILON = 1e-8
p = OrderedDict()
from dataset import get_loader
base_model_cfg = 'resnet'
p['lr_bone'] = 5e-5 # Learning rate resnet:5e-5, vgg:2e-5
p['lr_branch'] = 0.025 # Learning rate
p['wd'] = 0.0005 # Weight decay
p['momentum'] = 0.90 # Momentum
lr_decay_epoch = [15, 24] # [6, 9], now x3 #15
nAveGrad = 10 # Update the weights once in 'nAveGrad' forward passes
showEvery = 50
tmp_path = 'tmp_see'
class Solver(object):
def __init__(self, train_loader, test_loader, config, save_fold=None):
self.train_loader = train_loader
self.test_loader = test_loader
self.config = config
self.save_fold = save_fold
self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255.
# inference: choose the side map (see paper)
if config.visdom:
self.visual = Viz_visdom("trueUnify", 1)
self.build_model()
if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained))
if config.mode == 'train':
self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
else:
print('Loading pre-trained model from %s...' % self.config.model)
self.net_bone.load_state_dict(torch.load(self.config.model))
self.net_bone.eval()
def print_network(self, model, name):
num_params = 0
for p in model.parameters():
num_params += p.numel()
print(name)
print(model)
print("The number of parameters: {}".format(num_params))
def get_params(self, base_lr):
ml = []
for name, module in self.net_bone.named_children():
print(name)
if name == 'loss_weight':
ml.append({'params': module.parameters(), 'lr': p['lr_branch']})
else:
ml.append({'params': module.parameters()})
return ml
# build the network
def build_model(self):
self.net_bone = build_model(base_model_cfg)
if self.config.cuda:
self.net_bone = self.net_bone.cuda()
self.net_bone.eval() # use_global_stats = True
self.net_bone.apply(weights_init)
if self.config.mode == 'train':
if self.config.load_bone == '':
if base_model_cfg == 'vgg':
self.net_bone.base.load_pretrained_model(torch.load(self.config.vgg))
elif base_model_cfg == 'resnet':
self.net_bone.base.load_state_dict(torch.load(self.config.resnet))
if self.config.load_bone != '': self.net_bone.load_state_dict(torch.load(self.config.load_bone))
self.lr_bone = p['lr_bone']
self.lr_branch = p['lr_branch']
self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd'])
self.print_network(self.net_bone, 'trueUnify bone part')
# update the learning rate
def update_lr(self, rate):
for param_group in self.optimizer.param_groups:
param_group['lr'] = param_group['lr'] * rate
def test(self, test_mode=0):
EPSILON = 1e-8
img_num = len(self.test_loader)
time_t = 0.0
name_t = 'EGNet_ResNet50/'
if not os.path.exists(os.path.join(self.save_fold, name_t)):
os.mkdir(os.path.join(self.save_fold, name_t))
for i, data_batch in enumerate(self.test_loader):
self.config.test_fold = self.save_fold
print(self.config.test_fold)
images_, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size'])
with torch.no_grad():
images = Variable(images_)
if self.config.cuda:
images = images.cuda()
print(images.size())
time_start = time.time()
up_edge, up_sal, up_sal_f = self.net_bone(images)
torch.cuda.synchronize()
time_end = time.time()
print(time_end - time_start)
time_t = time_t + time_end - time_start
pred = np.squeeze(torch.sigmoid(up_sal_f[-1]).cpu().data.numpy())
multi_fuse = 255 * pred
cv2.imwrite(os.path.join(self.config.test_fold,name_t, name[:-4] + '.png'), multi_fuse)
print("--- %s seconds ---" % (time_t))
print('Test Done!')
# training phase
def train(self):
iter_num = len(self.train_loader.dataset) // self.config.batch_size
aveGrad = 0
F_v = 0
if not os.path.exists(tmp_path):
os.mkdir(tmp_path)
for epoch in range(self.config.epoch):
r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0
self.net_bone.zero_grad()
for i, data_batch in enumerate(self.train_loader):
sal_image, sal_label, sal_edge = data_batch['sal_image'], data_batch['sal_label'], data_batch['sal_edge']
if sal_image.size()[2:] != sal_label.size()[2:]:
print("Skip this batch")
continue
sal_image, sal_label, sal_edge = Variable(sal_image), Variable(sal_label), Variable(sal_edge)
if self.config.cuda:
sal_image, sal_label, sal_edge = sal_image.cuda(), sal_label.cuda(), sal_edge.cuda()
up_edge, up_sal, up_sal_f = self.net_bone(sal_image)
# edge part
edge_loss = []
for ix in up_edge:
edge_loss.append(bce2d_new(ix, sal_edge, reduction='sum'))
edge_loss = sum(edge_loss) / (nAveGrad * self.config.batch_size)
r_edge_loss += edge_loss.data
# sal part
sal_loss1= []
sal_loss2 = []
for ix in up_sal:
sal_loss1.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum'))
for ix in up_sal_f:
sal_loss2.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum'))
sal_loss = (sum(sal_loss1) + sum(sal_loss2)) / (nAveGrad * self.config.batch_size)
r_sal_loss += sal_loss.data
loss = sal_loss + edge_loss
r_sum_loss += loss.data
loss.backward()
aveGrad += 1
if aveGrad % nAveGrad == 0:
self.optimizer_bone.step()
self.optimizer_bone.zero_grad()
aveGrad = 0
if i % showEvery == 0:
print('epoch: [%2d/%2d], iter: [%5d/%5d] || Edge : %10.4f || Sal : %10.4f || Sum : %10.4f' % (
epoch, self.config.epoch, i, iter_num, r_edge_loss*(nAveGrad * self.config.batch_size)/showEvery,
r_sal_loss*(nAveGrad * self.config.batch_size)/showEvery,
r_sum_loss*(nAveGrad * self.config.batch_size)/showEvery))
print('Learning rate: ' + str(self.lr_bone))
r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0
if i % 200 == 0:
vutils.save_image(torch.sigmoid(up_sal_f[-1].data), tmp_path+'/iter%d-sal-0.jpg' % i, normalize=True, padding = 0)
vutils.save_image(sal_image.data, tmp_path+'/iter%d-sal-data.jpg' % i, padding = 0)
vutils.save_image(sal_label.data, tmp_path+'/iter%d-sal-target.jpg' % i, padding = 0)
if (epoch + 1) % self.config.epoch_save == 0:
torch.save(self.net_bone.state_dict(), '%s/models/epoch_%d_bone.pth' % (self.config.save_fold, epoch + 1))
if epoch in lr_decay_epoch:
self.lr_bone = self.lr_bone * 0.1
self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd'])
torch.save(self.net_bone.state_dict(), '%s/models/final_bone.pth' % self.config.save_fold)
def bce2d_new(input, target, reduction=None):
assert(input.size() == target.size())
pos = torch.eq(target, 1).float()
neg = torch.eq(target, 0).float()
# ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float()
num_pos = torch.sum(pos)
num_neg = torch.sum(neg)
num_total = num_pos + num_neg
alpha = num_neg / num_total
beta = 1.1 * num_pos / num_total
# target pixel = 1 -> weight beta
# target pixel = 0 -> weight 1-beta
weights = alpha * pos + beta * neg
return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction)