from ugatit.ops import * from utils import * from glob import glob import time from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch import numpy as np class UGATIT(object) : def __init__(self, sess, args): self.light = args.light if self.light : self.model_name = 'UGATIT_light' else : self.model_name = 'UGATIT' self.sess = sess self.phase = args.phase self.checkpoint_dir = args.checkpoint_dir self.result_dir = args.result_dir self.log_dir = args.log_dir self.dataset_name = args.dataset self.augment_flag = args.augment_flag self.epoch = args.epoch self.iteration = args.iteration self.decay_flag = args.decay_flag self.decay_epoch = args.decay_epoch self.gan_type = args.gan_type self.batch_size = args.batch_size self.print_freq = args.print_freq self.save_freq = args.save_freq self.init_lr = args.lr self.ch = args.ch """ Weight """ self.adv_weight = args.adv_weight self.cycle_weight = args.cycle_weight self.identity_weight = args.identity_weight self.cam_weight = args.cam_weight self.ld = args.GP_ld self.smoothing = args.smoothing """ Generator """ self.n_res = args.n_res """ Discriminator """ self.n_dis = args.n_dis self.n_critic = args.n_critic self.sn = args.sn self.img_size = args.img_size self.img_ch = args.img_ch self.sample_dir = os.path.join(args.sample_dir, self.model_dir) check_folder(self.sample_dir) # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA')) self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB')) self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset)) print() print("##### Information #####") print("# light : ", self.light) print("# gan type : ", self.gan_type) print("# dataset : ", self.dataset_name) print("# max dataset number : ", self.dataset_num) print("# batch_size : ", self.batch_size) print("# epoch : ", self.epoch) print("# iteration per epoch : ", self.iteration) print("# smoothing : ", self.smoothing) print() print("##### Generator #####") print("# residual blocks : ", self.n_res) print() print("##### Discriminator #####") print("# discriminator layer : ", self.n_dis) print("# the number of critic : ", self.n_critic) print("# spectral normalization : ", self.sn) print() print("##### Weight #####") print("# adv_weight : ", self.adv_weight) print("# cycle_weight : ", self.cycle_weight) print("# identity_weight : ", self.identity_weight) print("# cam_weight : ", self.cam_weight) ################################################################################## # Generator ################################################################################## def generator(self, x_init, reuse=False, scope="generator"): channel = self.ch with tf.variable_scope(scope, reuse=reuse) : x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv') x = instance_norm(x, scope='ins_norm') x = relu(x) # Down-Sampling for i in range(2) : x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i)) x = instance_norm(x, scope='ins_norm_'+str(i)) x = relu(x) channel = channel * 2 # Down-Sampling Bottleneck for i in range(self.n_res): x = resblock(x, channel, scope='resblock_' + str(i)) # Class Activation Map cam_x = global_avg_pooling(x) cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit') x_gap = tf.multiply(x, cam_x_weight) cam_x = global_max_pooling(x) cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit') x_gmp = tf.multiply(x, cam_x_weight) cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) x = tf.concat([x_gap, x_gmp], axis=-1) x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') x = relu(x) heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) # Gamma, Beta block gamma, beta = self.MLP(x, reuse=reuse) # Up-Sampling Bottleneck for i in range(self.n_res): x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i)) # Up-Sampling for i in range(2) : x = up_sample(x, scale_factor=2) x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i)) x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i)) x = relu(x) channel = channel // 2 x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit') x = tanh(x) return x, cam_logit, heatmap def MLP(self, x, use_bias=True, reuse=False, scope='MLP'): channel = self.ch * self.n_res if self.light : x = global_avg_pooling(x) with tf.variable_scope(scope, reuse=reuse): for i in range(2) : x = fully_connected(x, channel, use_bias, scope='linear_' + str(i)) x = relu(x) gamma = fully_connected(x, channel, use_bias, scope='gamma') beta = fully_connected(x, channel, use_bias, scope='beta') gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel]) beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel]) return gamma, beta ################################################################################## # Discriminator ################################################################################## def discriminator(self, x_init, reuse=False, scope="discriminator"): D_logit = [] D_CAM_logit = [] with tf.variable_scope(scope, reuse=reuse) : local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local') global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global') D_logit.extend([local_x, global_x]) D_CAM_logit.extend([local_cam, global_cam]) return D_logit, D_CAM_logit, local_heatmap, global_heatmap def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'): with tf.variable_scope(scope, reuse=reuse): channel = self.ch x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') x = lrelu(x, 0.2) for i in range(1, self.n_dis - 1): x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) x = lrelu(x, 0.2) channel = channel * 2 x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') x = lrelu(x, 0.2) channel = channel * 2 cam_x = global_avg_pooling(x) cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') x_gap = tf.multiply(x, cam_x_weight) cam_x = global_max_pooling(x) cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') x_gmp = tf.multiply(x, cam_x_weight) cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) x = tf.concat([x_gap, x_gmp], axis=-1) x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') x = lrelu(x, 0.2) heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') return x, cam_logit, heatmap def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'): with tf.variable_scope(scope, reuse=reuse) : channel = self.ch x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0') x = lrelu(x, 0.2) for i in range(1, self.n_dis - 2 - 1): x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i)) x = lrelu(x, 0.2) channel = channel * 2 x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last') x = lrelu(x, 0.2) channel = channel * 2 cam_x = global_avg_pooling(x) cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit') x_gap = tf.multiply(x, cam_x_weight) cam_x = global_max_pooling(x) cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit') x_gmp = tf.multiply(x, cam_x_weight) cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1) x = tf.concat([x_gap, x_gmp], axis=-1) x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1') x = lrelu(x, 0.2) heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1)) x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit') return x, cam_logit, heatmap ################################################################################## # Model ################################################################################## def generate_a2b(self, x_A, reuse=False): out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B") return out, cam def generate_b2a(self, x_B, reuse=False): out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A") return out, cam def discriminate_real(self, x_A, x_B): real_A_logit, real_A_cam_logit, _, _ = self.discriminator(x_A, scope="discriminator_A") real_B_logit, real_B_cam_logit, _, _ = self.discriminator(x_B, scope="discriminator_B") return real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit def discriminate_fake(self, x_ba, x_ab): fake_A_logit, fake_A_cam_logit, _, _ = self.discriminator(x_ba, reuse=True, scope="discriminator_A") fake_B_logit, fake_B_cam_logit, _, _ = self.discriminator(x_ab, reuse=True, scope="discriminator_B") return fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit def gradient_panalty(self, real, fake, scope="discriminator_A"): if self.gan_type.__contains__('dragan'): eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.) _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region fake = real + 0.5 * x_std * eps alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) interpolated = real + alpha * (fake - real) logit, cam_logit, _, _ = self.discriminator(interpolated, reuse=True, scope=scope) GP = [] cam_GP = [] for i in range(2) : grad = tf.gradients(logit[i], interpolated)[0] # gradient of D(interpolated) grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm # WGAN - LP if self.gan_type == 'wgan-lp' : GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))) elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))) for i in range(2) : grad = tf.gradients(cam_logit[i], interpolated)[0] # gradient of D(interpolated) grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm # WGAN - LP if self.gan_type == 'wgan-lp' : cam_GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))) elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': cam_GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))) return sum(GP), sum(cam_GP) def build_model(self): if self.phase == 'train' : self.lr = tf.placeholder(tf.float32, name='learning_rate') """ Input Image""" Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag) trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset) trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset) gpu_device = '/gpu:0' trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None)) trainA_iterator = trainA.make_one_shot_iterator() trainB_iterator = trainB.make_one_shot_iterator() self.domain_A = trainA_iterator.get_next() self.domain_B = trainB_iterator.get_next() """ Define Generator, Discriminator """ x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(self.domain_A, self.domain_B) fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(x_ba, x_ab) """ Define Loss """ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A") GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B") else : GP_A, GP_CAM_A = 0, 0 GP_B, GP_CAM_B = 0, 0 G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit)) G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit)) D_ad_loss_A = (discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A) D_ad_loss_B = (discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B) reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction identity_A = L1_loss(x_aa, self.domain_A) identity_B = L1_loss(x_bb, self.domain_B) cam_A = cam_loss(source=cam_ba, non_source=cam_aa) cam_B = cam_loss(source=cam_ab, non_source=cam_bb) Generator_A_gan = self.adv_weight * G_ad_loss_A Generator_A_cycle = self.cycle_weight * reconstruction_B Generator_A_identity = self.identity_weight * identity_A Generator_A_cam = self.cam_weight * cam_A Generator_B_gan = self.adv_weight * G_ad_loss_B Generator_B_cycle = self.cycle_weight * reconstruction_A Generator_B_identity = self.identity_weight * identity_B Generator_B_cam = self.cam_weight * cam_B Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam Discriminator_A_loss = self.adv_weight * D_ad_loss_A Discriminator_B_loss = self.adv_weight * D_ad_loss_B self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss('generator') self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator') """ Result Image """ self.fake_A = x_ba self.fake_B = x_ab self.real_A = self.domain_A self.real_B = self.domain_B """ Training """ t_vars = tf.trainable_variables() G_vars = [var for var in t_vars if 'generator' in var.name] D_vars = [var for var in t_vars if 'discriminator' in var.name] self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars) self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars) """" Summary """ self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss) self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss) self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss) self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan) self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle) self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity) self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam) self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss) self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan) self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle) self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity) self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam) self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss) self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss) self.rho_var = [] for var in tf.trainable_variables(): if 'rho' in var.name: self.rho_var.append(tf.summary.histogram(var.name, var)) self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var))) self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var))) self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var))) g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam, self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam, self.all_G_loss] g_summary_list.extend(self.rho_var) d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss] self.G_loss = tf.summary.merge(g_summary_list) self.D_loss = tf.summary.merge(d_summary_list) else : """ Test """ self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A') self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B') self.test_fake_B, _ = self.generate_a2b(self.test_domain_A) self.test_fake_A, _ = self.generate_b2a(self.test_domain_B) def train(self): # initialize all variables tf.global_variables_initializer().run() # saver to save model self.saver = tf.train.Saver() # summary writer self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) # restore check-point if it exits could_load, checkpoint_counter = self.load(self.checkpoint_dir) if could_load: start_epoch = (int)(checkpoint_counter / self.iteration) start_batch_id = checkpoint_counter - start_epoch * self.iteration counter = checkpoint_counter print(" [*] Load SUCCESS") else: start_epoch = 0 start_batch_id = 0 counter = 1 print(" [!] Load failed...") # loop for epoch start_time = time.time() past_g_loss = -1. lr = self.init_lr for epoch in range(start_epoch, self.epoch): # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) if self.decay_flag : #lr = self.init_lr * pow(0.5, epoch // self.decay_epoch) lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) for idx in range(start_batch_id, self.iteration): train_feed_dict = { self.lr : lr } # Update D _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict) self.writer.add_summary(summary_str, counter) # Update G g_loss = None if (counter - 1) % self.n_critic == 0 : batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B, self.fake_A, self.fake_B, self.G_optim, self.Generator_loss, self.G_loss], feed_dict = train_feed_dict) self.writer.add_summary(summary_str, counter) past_g_loss = g_loss # display training status counter += 1 if g_loss == None : g_loss = past_g_loss print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) if np.mod(idx+1, self.print_freq) == 0 : save_images(batch_A_images, [self.batch_size, 1], './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) # save_images(batch_B_images, [self.batch_size, 1], # './{}/real_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) # save_images(fake_A, [self.batch_size, 1], # './{}/fake_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) save_images(fake_B, [self.batch_size, 1], './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) if np.mod(idx + 1, self.save_freq) == 0: self.save(self.checkpoint_dir, counter) # After an epoch, start_batch_id is set to zero # non-zero value is only for the first epoch after loading pre-trained model start_batch_id = 0 # save model for final step self.save(self.checkpoint_dir, counter) @property def model_dir(self): n_res = str(self.n_res) + 'resblock' n_dis = str(self.n_dis) + 'dis' if self.smoothing : smoothing = '_smoothing' else : smoothing = '' if self.sn : sn = '_sn' else : sn = '' return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name, self.gan_type, n_res, n_dis, self.n_critic, self.adv_weight, self.cycle_weight, self.identity_weight, self.cam_weight, sn, smoothing) def save(self, checkpoint_dir, step): checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) def load(self, checkpoint_dir): print(" [*] Reading checkpoints...") checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) counter = int(ckpt_name.split('-')[-1]) print(" [*] Success to read {}".format(ckpt_name)) return True, counter else: print(" [*] Failed to find a checkpoint") return False, 0 def test(self): tf.global_variables_initializer().run() test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA')) test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB')) self.saver = tf.train.Saver() could_load, checkpoint_counter = self.load(self.checkpoint_dir) self.result_dir = os.path.join(self.result_dir, self.model_dir) check_folder(self.result_dir) if could_load : print(" [*] Load SUCCESS") else : print(" [!] Load failed...") # write html for visual comparison index_path = os.path.join(self.result_dir, 'index.html') index = open(index_path, 'w') index.write("
name | input | output | %s | " % os.path.basename(image_path)) index.write("" % (sample_file if os.path.isabs(sample_file) else ( '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) index.write(" | " % (image_path if os.path.isabs(image_path) else ( '../..' + os.path.sep + image_path), self.img_size, self.img_size)) index.write("") for sample_file in test_B_files : # B -> A print('Processing B image: ' + sample_file) sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file))) fake_img = self.sess.run(self.test_fake_A, feed_dict = {self.test_domain_B : sample_image}) save_images(fake_img, [1, 1], image_path) index.write(" | %s | " % os.path.basename(image_path)) index.write("" % (sample_file if os.path.isabs(sample_file) else ( '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) index.write(" | " % (image_path if os.path.isabs(image_path) else ( '../..' + os.path.sep + image_path), self.img_size, self.img_size)) index.write("") index.close() |
---|