UGATIT / ugatit_test.py
hylee's picture
init
1b9c487
raw
history blame
13.7 kB
from ugatit.ops import *
from ugatit.utils import *
from glob import glob
import time
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
import numpy as np
from ugatit.utils import *
class UgatitTest:
def __init__(self, sess, checkpoint_dir):
self.light = False
if self.light:
self.model_name = 'UGATIT_light'
else:
self.model_name = 'UGATIT'
self.sess = sess
self.phase = 'test'
self.checkpoint_dir = checkpoint_dir
self.result_dir = 'results'
self.log_dir = 'logs'
self.dataset_name = 'selfie2anime'
self.augment_flag = True
self.epoch = 100
self.iteration = 10000
self.decay_flag = True
self.decay_epoch = 50
self.gan_type = 'lsgan'
self.batch_size = 1
self.print_freq = 1000
self.save_freq = 1000
self.init_lr = 0.0001
self.ch = 64
""" Weight """
self.adv_weight = 1
self.cycle_weight = 10
self.identity_weight = 10
self.cam_weight = 1000
self.ld = 10
self.smoothing = True
""" Generator """
self.n_res = 4
""" Discriminator """
self.n_dis = 6
self.n_critic = 1
self.sn = True
self.img_size = 256
self.img_ch = 3
#self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
#check_folder(self.sample_dir)
# self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
print()
print("##### Information #####")
print("# light : ", self.light)
print("# gan type : ", self.gan_type)
print("# dataset : ", self.dataset_name)
print("# max dataset number : ", self.dataset_num)
print("# batch_size : ", self.batch_size)
print("# epoch : ", self.epoch)
print("# iteration per epoch : ", self.iteration)
print("# smoothing : ", self.smoothing)
print()
print("##### Generator #####")
print("# residual blocks : ", self.n_res)
print()
print("##### Discriminator #####")
print("# discriminator layer : ", self.n_dis)
print("# the number of critic : ", self.n_critic)
print("# spectral normalization : ", self.sn)
print()
print("##### Weight #####")
print("# adv_weight : ", self.adv_weight)
print("# cycle_weight : ", self.cycle_weight)
print("# identity_weight : ", self.identity_weight)
print("# cam_weight : ", self.cam_weight)
##################################################################################
# Generator
##################################################################################
def generator(self, x_init, reuse=False, scope="generator"):
channel = self.ch
with tf.variable_scope(scope, reuse=reuse) :
x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
x = instance_norm(x, scope='ins_norm')
x = relu(x)
# Down-Sampling
for i in range(2) :
x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i))
x = instance_norm(x, scope='ins_norm_'+str(i))
x = relu(x)
channel = channel * 2
# Down-Sampling Bottleneck
for i in range(self.n_res):
x = resblock(x, channel, scope='resblock_' + str(i))
# Class Activation Map
cam_x = global_avg_pooling(x)
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit')
x_gap = tf.multiply(x, cam_x_weight)
cam_x = global_max_pooling(x)
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit')
x_gmp = tf.multiply(x, cam_x_weight)
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
x = tf.concat([x_gap, x_gmp], axis=-1)
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
x = relu(x)
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
# Gamma, Beta block
gamma, beta = self.MLP(x, reuse=reuse)
# Up-Sampling Bottleneck
for i in range(self.n_res):
x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i))
# Up-Sampling
for i in range(2) :
x = up_sample(x, scale_factor=2)
x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i))
x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i))
x = relu(x)
channel = channel // 2
x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit')
x = tanh(x)
return x, cam_logit, heatmap
def MLP(self, x, use_bias=True, reuse=False, scope='MLP'):
channel = self.ch * self.n_res
if self.light :
x = global_avg_pooling(x)
with tf.variable_scope(scope, reuse=reuse):
for i in range(2) :
x = fully_connected(x, channel, use_bias, scope='linear_' + str(i))
x = relu(x)
gamma = fully_connected(x, channel, use_bias, scope='gamma')
beta = fully_connected(x, channel, use_bias, scope='beta')
gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel])
beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel])
return gamma, beta
##################################################################################
# Discriminator
##################################################################################
def discriminator(self, x_init, reuse=False, scope="discriminator"):
D_logit = []
D_CAM_logit = []
with tf.variable_scope(scope, reuse=reuse) :
local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local')
global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global')
D_logit.extend([local_x, global_x])
D_CAM_logit.extend([local_cam, global_cam])
return D_logit, D_CAM_logit, local_heatmap, global_heatmap
def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'):
with tf.variable_scope(scope, reuse=reuse):
channel = self.ch
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
x = lrelu(x, 0.2)
for i in range(1, self.n_dis - 1):
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
x = lrelu(x, 0.2)
channel = channel * 2
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
x = lrelu(x, 0.2)
channel = channel * 2
cam_x = global_avg_pooling(x)
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
x_gap = tf.multiply(x, cam_x_weight)
cam_x = global_max_pooling(x)
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
x_gmp = tf.multiply(x, cam_x_weight)
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
x = tf.concat([x_gap, x_gmp], axis=-1)
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
x = lrelu(x, 0.2)
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
return x, cam_logit, heatmap
def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'):
with tf.variable_scope(scope, reuse=reuse) :
channel = self.ch
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
x = lrelu(x, 0.2)
for i in range(1, self.n_dis - 2 - 1):
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
x = lrelu(x, 0.2)
channel = channel * 2
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
x = lrelu(x, 0.2)
channel = channel * 2
cam_x = global_avg_pooling(x)
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
x_gap = tf.multiply(x, cam_x_weight)
cam_x = global_max_pooling(x)
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
x_gmp = tf.multiply(x, cam_x_weight)
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
x = tf.concat([x_gap, x_gmp], axis=-1)
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
x = lrelu(x, 0.2)
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
return x, cam_logit, heatmap
def generate_a2b(self, x_A, reuse=False):
out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B")
return out, cam
def generate_b2a(self, x_B, reuse=False):
out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A")
return out, cam
def build_model(self):
self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A')
self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B')
self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
@property
def model_dir(self):
n_res = str(self.n_res) + 'resblock'
n_dis = str(self.n_dis) + 'dis'
if self.smoothing:
smoothing = '_smoothing'
else:
smoothing = ''
if self.sn:
sn = '_sn'
else:
sn = ''
return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name,
self.gan_type, n_res, n_dis,
self.n_critic,
self.adv_weight, self.cycle_weight, self.identity_weight,
self.cam_weight, sn, smoothing)
def load(self, checkpoint_dir):
print(" [*] Reading checkpoints...")
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(ckpt_name.split('-')[-1])
print(" [*] Success to read {}".format(ckpt_name))
return True, counter
else:
print(" [*] Failed to find a checkpoint")
return False, 0
def loadModel(self):
tf.global_variables_initializer().run(session=self.sess)
self.saver = tf.train.Saver()
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
self.result_dir = os.path.join(self.result_dir, self.model_dir)
check_folder(self.result_dir)
if could_load:
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
def test(self, sample_file):
# A -> B
print('Processing A image: ' + sample_file)
sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image})
save_images(fake_img, [1, 1], image_path)
return image_path
gan = None
def main_test(img_path, checkpoint_dir):
# open session
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
global gan
if gan is None:
gan = UgatitTest(sess, checkpoint_dir)
# build graph
gan.build_model()
# show network architecture
show_all_variables()
gan.loadModel()
result = gan.test(img_path)
print(" [*] Test finished!")
print(result)
return os.path.abspath(result)
if __name__ == '__main__':
main_test('/home/hylee/cartoon/myp2c/imgs/src/im4.jpg')