|
from ugatit.ops import * |
|
from ugatit.utils import * |
|
from glob import glob |
|
import time |
|
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch |
|
import numpy as np |
|
from ugatit.utils import * |
|
|
|
class UgatitTest: |
|
|
|
def __init__(self, sess, checkpoint_dir): |
|
self.light = False |
|
|
|
if self.light: |
|
self.model_name = 'UGATIT_light' |
|
else: |
|
self.model_name = 'UGATIT' |
|
|
|
self.sess = sess |
|
self.phase = 'test' |
|
self.checkpoint_dir = checkpoint_dir |
|
self.result_dir = 'results' |
|
self.log_dir = 'logs' |
|
self.dataset_name = 'selfie2anime' |
|
self.augment_flag = True |
|
|
|
self.epoch = 100 |
|
self.iteration = 10000 |
|
self.decay_flag = True |
|
self.decay_epoch = 50 |
|
|
|
self.gan_type = 'lsgan' |
|
|
|
self.batch_size = 1 |
|
self.print_freq = 1000 |
|
self.save_freq = 1000 |
|
|
|
self.init_lr = 0.0001 |
|
self.ch = 64 |
|
|
|
""" Weight """ |
|
self.adv_weight = 1 |
|
self.cycle_weight = 10 |
|
self.identity_weight = 10 |
|
self.cam_weight = 1000 |
|
self.ld = 10 |
|
self.smoothing = True |
|
|
|
""" Generator """ |
|
self.n_res = 4 |
|
|
|
""" Discriminator """ |
|
self.n_dis = 6 |
|
self.n_critic = 1 |
|
self.sn = True |
|
|
|
self.img_size = 256 |
|
self.img_ch = 3 |
|
|
|
|
|
|
|
|
|
|
|
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) |
|
self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) |
|
self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) |
|
|
|
print() |
|
|
|
print("##### Information #####") |
|
print("# light : ", self.light) |
|
print("# gan type : ", self.gan_type) |
|
print("# dataset : ", self.dataset_name) |
|
print("# max dataset number : ", self.dataset_num) |
|
print("# batch_size : ", self.batch_size) |
|
print("# epoch : ", self.epoch) |
|
print("# iteration per epoch : ", self.iteration) |
|
print("# smoothing : ", self.smoothing) |
|
|
|
print() |
|
|
|
print("##### Generator #####") |
|
print("# residual blocks : ", self.n_res) |
|
|
|
print() |
|
|
|
print("##### Discriminator #####") |
|
print("# discriminator layer : ", self.n_dis) |
|
print("# the number of critic : ", self.n_critic) |
|
print("# spectral normalization : ", self.sn) |
|
|
|
print() |
|
|
|
print("##### Weight #####") |
|
print("# adv_weight : ", self.adv_weight) |
|
print("# cycle_weight : ", self.cycle_weight) |
|
print("# identity_weight : ", self.identity_weight) |
|
print("# cam_weight : ", self.cam_weight) |
|
|
|
|
|
|
|
|
|
|
|
def generator(self, x_init, reuse=False, scope="generator"): |
|
channel = self.ch |
|
with tf.variable_scope(scope, reuse=reuse) : |
|
x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') |
|
x = instance_norm(x, scope='ins_norm') |
|
x = relu(x) |
|
|
|
|
|
for i in range(2) : |
|
x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i)) |
|
x = instance_norm(x, scope='ins_norm_'+str(i)) |
|
x = relu(x) |
|
|
|
channel = channel * 2 |
|
|
|
|
|
for i in range(self.n_res): |
|
x = resblock(x, channel, scope='resblock_' + str(i)) |
|
|
|
|
|
|
|
cam_x = global_avg_pooling(x) |
|
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit') |
|
x_gap = tf.multiply(x, cam_x_weight) |
|
|
|
cam_x = global_max_pooling(x) |
|
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit') |
|
x_gmp = tf.multiply(x, cam_x_weight) |
|
|
|
|
|
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) |
|
x = tf.concat([x_gap, x_gmp], axis=-1) |
|
|
|
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') |
|
x = relu(x) |
|
|
|
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) |
|
|
|
|
|
gamma, beta = self.MLP(x, reuse=reuse) |
|
|
|
|
|
for i in range(self.n_res): |
|
x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i)) |
|
|
|
|
|
for i in range(2) : |
|
x = up_sample(x, scale_factor=2) |
|
x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i)) |
|
x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i)) |
|
x = relu(x) |
|
|
|
channel = channel // 2 |
|
|
|
|
|
x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit') |
|
x = tanh(x) |
|
|
|
return x, cam_logit, heatmap |
|
|
|
def MLP(self, x, use_bias=True, reuse=False, scope='MLP'): |
|
channel = self.ch * self.n_res |
|
|
|
if self.light : |
|
x = global_avg_pooling(x) |
|
|
|
with tf.variable_scope(scope, reuse=reuse): |
|
for i in range(2) : |
|
x = fully_connected(x, channel, use_bias, scope='linear_' + str(i)) |
|
x = relu(x) |
|
|
|
|
|
gamma = fully_connected(x, channel, use_bias, scope='gamma') |
|
beta = fully_connected(x, channel, use_bias, scope='beta') |
|
|
|
gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel]) |
|
beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel]) |
|
|
|
return gamma, beta |
|
|
|
|
|
|
|
|
|
|
|
def discriminator(self, x_init, reuse=False, scope="discriminator"): |
|
D_logit = [] |
|
D_CAM_logit = [] |
|
with tf.variable_scope(scope, reuse=reuse) : |
|
local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local') |
|
global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global') |
|
|
|
D_logit.extend([local_x, global_x]) |
|
D_CAM_logit.extend([local_cam, global_cam]) |
|
|
|
return D_logit, D_CAM_logit, local_heatmap, global_heatmap |
|
|
|
def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'): |
|
with tf.variable_scope(scope, reuse=reuse): |
|
channel = self.ch |
|
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') |
|
x = lrelu(x, 0.2) |
|
|
|
for i in range(1, self.n_dis - 1): |
|
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) |
|
x = lrelu(x, 0.2) |
|
|
|
channel = channel * 2 |
|
|
|
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') |
|
x = lrelu(x, 0.2) |
|
|
|
channel = channel * 2 |
|
|
|
cam_x = global_avg_pooling(x) |
|
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') |
|
x_gap = tf.multiply(x, cam_x_weight) |
|
|
|
cam_x = global_max_pooling(x) |
|
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') |
|
x_gmp = tf.multiply(x, cam_x_weight) |
|
|
|
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) |
|
x = tf.concat([x_gap, x_gmp], axis=-1) |
|
|
|
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') |
|
x = lrelu(x, 0.2) |
|
|
|
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) |
|
|
|
|
|
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') |
|
|
|
return x, cam_logit, heatmap |
|
|
|
def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'): |
|
with tf.variable_scope(scope, reuse=reuse) : |
|
channel = self.ch |
|
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') |
|
x = lrelu(x, 0.2) |
|
|
|
for i in range(1, self.n_dis - 2 - 1): |
|
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) |
|
x = lrelu(x, 0.2) |
|
|
|
channel = channel * 2 |
|
|
|
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') |
|
x = lrelu(x, 0.2) |
|
|
|
channel = channel * 2 |
|
|
|
cam_x = global_avg_pooling(x) |
|
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') |
|
x_gap = tf.multiply(x, cam_x_weight) |
|
|
|
cam_x = global_max_pooling(x) |
|
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') |
|
x_gmp = tf.multiply(x, cam_x_weight) |
|
|
|
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) |
|
x = tf.concat([x_gap, x_gmp], axis=-1) |
|
|
|
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') |
|
x = lrelu(x, 0.2) |
|
|
|
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) |
|
|
|
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') |
|
|
|
return x, cam_logit, heatmap |
|
|
|
def generate_a2b(self, x_A, reuse=False): |
|
out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B") |
|
|
|
return out, cam |
|
|
|
def generate_b2a(self, x_B, reuse=False): |
|
out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A") |
|
|
|
return out, cam |
|
def build_model(self): |
|
self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') |
|
self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') |
|
|
|
self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) |
|
self.test_fake_A, _ = self.generate_b2a(self.test_domain_B) |
|
|
|
@property |
|
def model_dir(self): |
|
n_res = str(self.n_res) + 'resblock' |
|
n_dis = str(self.n_dis) + 'dis' |
|
|
|
if self.smoothing: |
|
smoothing = '_smoothing' |
|
else: |
|
smoothing = '' |
|
|
|
if self.sn: |
|
sn = '_sn' |
|
else: |
|
sn = '' |
|
|
|
return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name, |
|
self.gan_type, n_res, n_dis, |
|
self.n_critic, |
|
self.adv_weight, self.cycle_weight, self.identity_weight, |
|
self.cam_weight, sn, smoothing) |
|
|
|
def load(self, checkpoint_dir): |
|
print(" [*] Reading checkpoints...") |
|
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) |
|
|
|
ckpt = tf.train.get_checkpoint_state(checkpoint_dir) |
|
if ckpt and ckpt.model_checkpoint_path: |
|
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) |
|
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) |
|
counter = int(ckpt_name.split('-')[-1]) |
|
print(" [*] Success to read {}".format(ckpt_name)) |
|
return True, counter |
|
else: |
|
print(" [*] Failed to find a checkpoint") |
|
return False, 0 |
|
|
|
def loadModel(self): |
|
tf.global_variables_initializer().run(session=self.sess) |
|
|
|
self.saver = tf.train.Saver() |
|
could_load, checkpoint_counter = self.load(self.checkpoint_dir) |
|
self.result_dir = os.path.join(self.result_dir, self.model_dir) |
|
check_folder(self.result_dir) |
|
|
|
if could_load: |
|
print(" [*] Load SUCCESS") |
|
else: |
|
print(" [!] Load failed...") |
|
|
|
def test(self, sample_file): |
|
|
|
print('Processing A image: ' + sample_file) |
|
sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) |
|
image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) |
|
|
|
fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image}) |
|
save_images(fake_img, [1, 1], image_path) |
|
|
|
return image_path |
|
|
|
|
|
gan = None |
|
def main_test(img_path, checkpoint_dir): |
|
|
|
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) |
|
global gan |
|
if gan is None: |
|
gan = UgatitTest(sess, checkpoint_dir) |
|
|
|
gan.build_model() |
|
|
|
show_all_variables() |
|
|
|
gan.loadModel() |
|
|
|
result = gan.test(img_path) |
|
print(" [*] Test finished!") |
|
print(result) |
|
return os.path.abspath(result) |
|
|
|
if __name__ == '__main__': |
|
main_test('/home/hylee/cartoon/myp2c/imgs/src/im4.jpg') |