from ugatit.ops import * from ugatit.utils import * from glob import glob import time from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch import numpy as np from ugatit.utils import * class UgatitTest: def __init__(self, sess, checkpoint_dir): self.light = False if self.light: self.model_name = 'UGATIT_light' else: self.model_name = 'UGATIT' self.sess = sess self.phase = 'test' self.checkpoint_dir = checkpoint_dir self.result_dir = 'results' self.log_dir = 'logs' self.dataset_name = 'selfie2anime' self.augment_flag = True self.epoch = 100 self.iteration = 10000 self.decay_flag = True self.decay_epoch = 50 self.gan_type = 'lsgan' self.batch_size = 1 self.print_freq = 1000 self.save_freq = 1000 self.init_lr = 0.0001 self.ch = 64 """ Weight """ self.adv_weight = 1 self.cycle_weight = 10 self.identity_weight = 10 self.cam_weight = 1000 self.ld = 10 self.smoothing = True """ Generator """ self.n_res = 4 """ Discriminator """ self.n_dis = 6 self.n_critic = 1 self.sn = True self.img_size = 256 self.img_ch = 3 #self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir) #check_folder(self.sample_dir) # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) print() print("##### Information #####") print("# light : ", self.light) print("# gan type : ", self.gan_type) print("# dataset : ", self.dataset_name) print("# max dataset number : ", self.dataset_num) print("# batch_size : ", self.batch_size) print("# epoch : ", self.epoch) print("# iteration per epoch : ", self.iteration) print("# smoothing : ", self.smoothing) print() print("##### Generator #####") print("# residual blocks : ", self.n_res) print() print("##### Discriminator #####") print("# discriminator layer : ", self.n_dis) print("# the number of critic : ", self.n_critic) print("# spectral normalization : ", self.sn) print() print("##### Weight #####") print("# adv_weight : ", self.adv_weight) print("# cycle_weight : ", self.cycle_weight) print("# identity_weight : ", self.identity_weight) print("# cam_weight : ", self.cam_weight) ################################################################################## # Generator ################################################################################## def generator(self, x_init, reuse=False, scope="generator"): channel = self.ch with tf.variable_scope(scope, reuse=reuse) : x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') x = instance_norm(x, scope='ins_norm') x = relu(x) # Down-Sampling for i in range(2) : x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i)) x = instance_norm(x, scope='ins_norm_'+str(i)) x = relu(x) channel = channel * 2 # Down-Sampling Bottleneck for i in range(self.n_res): x = resblock(x, channel, scope='resblock_' + str(i)) # Class Activation Map cam_x = global_avg_pooling(x) cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit') x_gap = tf.multiply(x, cam_x_weight) cam_x = global_max_pooling(x) cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit') x_gmp = tf.multiply(x, cam_x_weight) cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) x = tf.concat([x_gap, x_gmp], axis=-1) x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') x = relu(x) heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) # Gamma, Beta block gamma, beta = self.MLP(x, reuse=reuse) # Up-Sampling Bottleneck for i in range(self.n_res): x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i)) # Up-Sampling for i in range(2) : x = up_sample(x, scale_factor=2) x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i)) x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i)) x = relu(x) channel = channel // 2 x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit') x = tanh(x) return x, cam_logit, heatmap def MLP(self, x, use_bias=True, reuse=False, scope='MLP'): channel = self.ch * self.n_res if self.light : x = global_avg_pooling(x) with tf.variable_scope(scope, reuse=reuse): for i in range(2) : x = fully_connected(x, channel, use_bias, scope='linear_' + str(i)) x = relu(x) gamma = fully_connected(x, channel, use_bias, scope='gamma') beta = fully_connected(x, channel, use_bias, scope='beta') gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel]) beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel]) return gamma, beta ################################################################################## # Discriminator ################################################################################## def discriminator(self, x_init, reuse=False, scope="discriminator"): D_logit = [] D_CAM_logit = [] with tf.variable_scope(scope, reuse=reuse) : local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local') global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global') D_logit.extend([local_x, global_x]) D_CAM_logit.extend([local_cam, global_cam]) return D_logit, D_CAM_logit, local_heatmap, global_heatmap def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'): with tf.variable_scope(scope, reuse=reuse): channel = self.ch x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') x = lrelu(x, 0.2) for i in range(1, self.n_dis - 1): x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) x = lrelu(x, 0.2) channel = channel * 2 x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') x = lrelu(x, 0.2) channel = channel * 2 cam_x = global_avg_pooling(x) cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') x_gap = tf.multiply(x, cam_x_weight) cam_x = global_max_pooling(x) cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') x_gmp = tf.multiply(x, cam_x_weight) cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) x = tf.concat([x_gap, x_gmp], axis=-1) x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') x = lrelu(x, 0.2) heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') return x, cam_logit, heatmap def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'): with tf.variable_scope(scope, reuse=reuse) : channel = self.ch x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') x = lrelu(x, 0.2) for i in range(1, self.n_dis - 2 - 1): x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) x = lrelu(x, 0.2) channel = channel * 2 x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') x = lrelu(x, 0.2) channel = channel * 2 cam_x = global_avg_pooling(x) cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') x_gap = tf.multiply(x, cam_x_weight) cam_x = global_max_pooling(x) cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') x_gmp = tf.multiply(x, cam_x_weight) cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) x = tf.concat([x_gap, x_gmp], axis=-1) x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') x = lrelu(x, 0.2) heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') return x, cam_logit, heatmap def generate_a2b(self, x_A, reuse=False): out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B") return out, cam def generate_b2a(self, x_B, reuse=False): out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A") return out, cam def build_model(self): self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) self.test_fake_A, _ = self.generate_b2a(self.test_domain_B) @property def model_dir(self): n_res = str(self.n_res) + 'resblock' n_dis = str(self.n_dis) + 'dis' if self.smoothing: smoothing = '_smoothing' else: smoothing = '' if self.sn: sn = '_sn' else: sn = '' return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name, self.gan_type, n_res, n_dis, self.n_critic, self.adv_weight, self.cycle_weight, self.identity_weight, self.cam_weight, sn, smoothing) def load(self, checkpoint_dir): print(" [*] Reading checkpoints...") checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) counter = int(ckpt_name.split('-')[-1]) print(" [*] Success to read {}".format(ckpt_name)) return True, counter else: print(" [*] Failed to find a checkpoint") return False, 0 def loadModel(self): tf.global_variables_initializer().run(session=self.sess) self.saver = tf.train.Saver() could_load, checkpoint_counter = self.load(self.checkpoint_dir) self.result_dir = os.path.join(self.result_dir, self.model_dir) check_folder(self.result_dir) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") def test(self, sample_file): # A -> B print('Processing A image: ' + sample_file) sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image}) save_images(fake_img, [1, 1], image_path) return image_path gan = None def main_test(img_path, checkpoint_dir): # open session sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) global gan if gan is None: gan = UgatitTest(sess, checkpoint_dir) # build graph gan.build_model() # show network architecture show_all_variables() gan.loadModel() result = gan.test(img_path) print(" [*] Test finished!") print(result) return os.path.abspath(result) if __name__ == '__main__': main_test('/home/hylee/cartoon/myp2c/imgs/src/im4.jpg')