Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import os | |
import ntpath | |
import time | |
from . import util | |
import scipy.misc | |
try: | |
from StringIO import StringIO # Python 2.7 | |
except ImportError: | |
from io import BytesIO # Python 3.x | |
import torchvision.utils as vutils | |
from tensorboardX import SummaryWriter | |
import torch | |
import numpy as np | |
class Visualizer: | |
def __init__(self, opt): | |
self.opt = opt | |
self.tf_log = opt.isTrain and opt.tf_log | |
self.tensorboard_log = opt.tensorboard_log | |
self.win_size = opt.display_winsize | |
self.name = opt.name | |
if self.tensorboard_log: | |
if self.opt.isTrain: | |
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs") | |
if not os.path.exists(self.log_dir): | |
os.makedirs(self.log_dir) | |
self.writer = SummaryWriter(log_dir=self.log_dir) | |
else: | |
print("hi :)") | |
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir) | |
if not os.path.exists(self.log_dir): | |
os.makedirs(self.log_dir) | |
if opt.isTrain: | |
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt") | |
with open(self.log_name, "a") as log_file: | |
now = time.strftime("%c") | |
log_file.write("================ Training Loss (%s) ================\n" % now) | |
# |visuals|: dictionary of images to display or save | |
def display_current_results(self, visuals, epoch, step): | |
all_tensor = [] | |
if self.tensorboard_log: | |
for key, tensor in visuals.items(): | |
all_tensor.append((tensor.data.cpu() + 1) / 2) | |
output = torch.cat(all_tensor, 0) | |
img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False) | |
if self.opt.isTrain: | |
self.writer.add_image("Face_SPADE/training_samples", img_grid, step) | |
else: | |
vutils.save_image( | |
output, | |
os.path.join(self.log_dir, str(step) + ".png"), | |
nrow=self.opt.batchSize, | |
padding=0, | |
normalize=False, | |
) | |
# errors: dictionary of error labels and values | |
def plot_current_errors(self, errors, step): | |
if self.tf_log: | |
for tag, value in errors.items(): | |
value = value.mean().float() | |
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) | |
self.writer.add_summary(summary, step) | |
if self.tensorboard_log: | |
self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step) | |
self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step) | |
self.writer.add_scalars( | |
"Loss/GAN", | |
{ | |
"G": errors["GAN"].mean().float(), | |
"D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2, | |
}, | |
step, | |
) | |
# errors: same format as |errors| of plotCurrentErrors | |
def print_current_errors(self, epoch, i, errors, t): | |
message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t) | |
for k, v in errors.items(): | |
v = v.mean().float() | |
message += "%s: %.3f " % (k, v) | |
print(message) | |
with open(self.log_name, "a") as log_file: | |
log_file.write("%s\n" % message) | |
def convert_visuals_to_numpy(self, visuals): | |
for key, t in visuals.items(): | |
tile = self.opt.batchSize > 8 | |
if "input_label" == key: | |
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy | |
else: | |
t = util.tensor2im(t, tile=tile) | |
visuals[key] = t | |
return visuals | |
# save image to the disk | |
def save_images(self, webpage, visuals, image_path): | |
visuals = self.convert_visuals_to_numpy(visuals) | |
image_dir = webpage.get_image_dir() | |
short_path = ntpath.basename(image_path[0]) | |
name = os.path.splitext(short_path)[0] | |
webpage.add_header(name) | |
ims = [] | |
txts = [] | |
links = [] | |
for label, image_numpy in visuals.items(): | |
image_name = os.path.join(label, "%s.png" % (name)) | |
save_path = os.path.join(image_dir, image_name) | |
util.save_image(image_numpy, save_path, create_dir=True) | |
ims.append(image_name) | |
txts.append(label) | |
links.append(image_name) | |
webpage.add_images(ims, txts, links, width=self.win_size) | |