BOPBTL / Face_Enhancement /util /visualizer.py
manhkhanhUIT's picture
Add code
7fab858
raw
history blame contribute delete
No virus
4.79 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import ntpath
import time
from . import util
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch
import numpy as np
class Visualizer:
def __init__(self, opt):
self.opt = opt
self.tf_log = opt.isTrain and opt.tf_log
self.tensorboard_log = opt.tensorboard_log
self.win_size = opt.display_winsize
self.name = opt.name
if self.tensorboard_log:
if self.opt.isTrain:
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.writer = SummaryWriter(log_dir=self.log_dir)
else:
print("hi :)")
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
if opt.isTrain:
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write("================ Training Loss (%s) ================\n" % now)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, step):
all_tensor = []
if self.tensorboard_log:
for key, tensor in visuals.items():
all_tensor.append((tensor.data.cpu() + 1) / 2)
output = torch.cat(all_tensor, 0)
img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)
if self.opt.isTrain:
self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
else:
vutils.save_image(
output,
os.path.join(self.log_dir, str(step) + ".png"),
nrow=self.opt.batchSize,
padding=0,
normalize=False,
)
# errors: dictionary of error labels and values
def plot_current_errors(self, errors, step):
if self.tf_log:
for tag, value in errors.items():
value = value.mean().float()
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
if self.tensorboard_log:
self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
self.writer.add_scalars(
"Loss/GAN",
{
"G": errors["GAN"].mean().float(),
"D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
},
step,
)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t):
message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
for k, v in errors.items():
v = v.mean().float()
message += "%s: %.3f " % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write("%s\n" % message)
def convert_visuals_to_numpy(self, visuals):
for key, t in visuals.items():
tile = self.opt.batchSize > 8
if "input_label" == key:
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy
else:
t = util.tensor2im(t, tile=tile)
visuals[key] = t
return visuals
# save image to the disk
def save_images(self, webpage, visuals, image_path):
visuals = self.convert_visuals_to_numpy(visuals)
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = os.path.join(label, "%s.png" % (name))
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path, create_dir=True)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)