Spaces:
Runtime error
Runtime error
""" | |
Here are some use cases: | |
python main.py --config config/all.yaml --experiment experiment_8x1 --signature demo1 --target data/demo1.png | |
""" | |
import pydiffvg | |
import torch | |
import cv2 | |
import matplotlib.pyplot as plt | |
import random | |
import argparse | |
import math | |
import errno | |
from tqdm import tqdm | |
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR | |
from torch.nn.functional import adaptive_avg_pool2d | |
import warnings | |
warnings.filterwarnings("ignore") | |
import PIL | |
import PIL.Image | |
import os | |
import os.path as osp | |
import numpy as np | |
import numpy.random as npr | |
import shutil | |
import copy | |
# import skfmm | |
from xing_loss import xing_loss | |
import yaml | |
from easydict import EasyDict as edict | |
pydiffvg.set_print_timing(False) | |
gamma = 1.0 | |
########## | |
# helper # | |
########## | |
from utils import \ | |
get_experiment_id, \ | |
get_path_schedule, \ | |
edict_2_dict, \ | |
check_and_create_dir | |
def get_bezier_circle(radius=1, segments=4, bias=None): | |
points = [] | |
if bias is None: | |
bias = (random.random(), random.random()) | |
avg_degree = 360 / (segments*3) | |
for i in range(0, segments*3): | |
point = (np.cos(np.deg2rad(i * avg_degree)), | |
np.sin(np.deg2rad(i * avg_degree))) | |
points.append(point) | |
points = torch.tensor(points) | |
points = (points)*radius + torch.tensor(bias).unsqueeze(dim=0) | |
points = points.type(torch.FloatTensor) | |
return points | |
def get_sdf(phi, method='skfmm', **kwargs): | |
if method == 'skfmm': | |
import skfmm | |
phi = (phi-0.5)*2 | |
if (phi.max() <= 0) or (phi.min() >= 0): | |
return np.zeros(phi.shape).astype(np.float32) | |
sd = skfmm.distance(phi, dx=1) | |
flip_negative = kwargs.get('flip_negative', True) | |
if flip_negative: | |
sd = np.abs(sd) | |
truncate = kwargs.get('truncate', 10) | |
sd = np.clip(sd, -truncate, truncate) | |
# print(f"max sd value is: {sd.max()}") | |
zero2max = kwargs.get('zero2max', True) | |
if zero2max and flip_negative: | |
sd = sd.max() - sd | |
elif zero2max: | |
raise ValueError | |
normalize = kwargs.get('normalize', 'sum') | |
if normalize == 'sum': | |
sd /= sd.sum() | |
elif normalize == 'to1': | |
sd /= sd.max() | |
return sd | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--debug', action='store_true', default=False) | |
parser.add_argument("--config", type=str) | |
parser.add_argument("--experiment", type=str) | |
parser.add_argument("--seed", type=int) | |
parser.add_argument("--target", type=str, help="target image path") | |
parser.add_argument('--log_dir', metavar='DIR', default="log/debug") | |
parser.add_argument('--initial', type=str, default="random", choices=['random', 'circle']) | |
parser.add_argument('--signature', nargs='+', type=str) | |
parser.add_argument('--seginit', nargs='+', type=str) | |
parser.add_argument("--num_segments", type=int, default=4) | |
# parser.add_argument("--num_paths", type=str, default="1,1,1") | |
# parser.add_argument("--num_iter", type=int, default=500) | |
# parser.add_argument('--free', action='store_true') | |
# Please ensure that image resolution is divisible by pool_size; otherwise the performance would drop a lot. | |
# parser.add_argument('--pool_size', type=int, default=40, help="the pooled image size for next path initialization") | |
# parser.add_argument('--save_loss', action='store_true') | |
# parser.add_argument('--save_init', action='store_true') | |
# parser.add_argument('--save_image', action='store_true') | |
# parser.add_argument('--save_video', action='store_true') | |
# parser.add_argument('--print_weight', action='store_true') | |
# parser.add_argument('--circle_init_radius', type=float) | |
cfg = edict() | |
args = parser.parse_args() | |
cfg.debug = args.debug | |
cfg.config = args.config | |
cfg.experiment = args.experiment | |
cfg.seed = args.seed | |
cfg.target = args.target | |
cfg.log_dir = args.log_dir | |
cfg.initial = args.initial | |
cfg.signature = args.signature | |
# set cfg num_segments in command | |
cfg.num_segments = args.num_segments | |
if args.seginit is not None: | |
cfg.seginit = edict() | |
cfg.seginit.type = args.seginit[0] | |
if cfg.seginit.type == 'circle': | |
cfg.seginit.radius = float(args.seginit[1]) | |
return cfg | |
def ycrcb_conversion(im, format='[bs x 3 x 2D]', reverse=False): | |
mat = torch.FloatTensor([ | |
[ 65.481/255, 128.553/255, 24.966/255], # ranged_from [0, 219/255] | |
[-37.797/255, -74.203/255, 112.000/255], # ranged_from [-112/255, 112/255] | |
[112.000/255, -93.786/255, -18.214/255], # ranged_from [-112/255, 112/255] | |
]).to(im.device) | |
if reverse: | |
mat = mat.inverse() | |
if format == '[bs x 3 x 2D]': | |
im = im.permute(0, 2, 3, 1) | |
im = torch.matmul(im, mat.T) | |
im = im.permute(0, 3, 1, 2).contiguous() | |
return im | |
elif format == '[2D x 3]': | |
im = torch.matmul(im, mat.T) | |
return im | |
else: | |
raise ValueError | |
class random_coord_init(): | |
def __init__(self, canvas_size): | |
self.canvas_size = canvas_size | |
def __call__(self): | |
h, w = self.canvas_size | |
return [npr.uniform(0, 1)*w, npr.uniform(0, 1)*h] | |
class naive_coord_init(): | |
def __init__(self, pred, gt, format='[bs x c x 2D]', replace_sampling=True): | |
if isinstance(pred, torch.Tensor): | |
pred = pred.detach().cpu().numpy() | |
if isinstance(gt, torch.Tensor): | |
gt = gt.detach().cpu().numpy() | |
if format == '[bs x c x 2D]': | |
self.map = ((pred[0] - gt[0])**2).sum(0) | |
elif format == ['[2D x c]']: | |
self.map = ((pred - gt)**2).sum(-1) | |
else: | |
raise ValueError | |
self.replace_sampling = replace_sampling | |
def __call__(self): | |
coord = np.where(self.map == self.map.max()) | |
coord_h, coord_w = coord[0][0], coord[1][0] | |
if self.replace_sampling: | |
self.map[coord_h, coord_w] = -1 | |
return [coord_w, coord_h] | |
class sparse_coord_init(): | |
def __init__(self, pred, gt, format='[bs x c x 2D]', quantile_interval=200, nodiff_thres=0.1): | |
if isinstance(pred, torch.Tensor): | |
pred = pred.detach().cpu().numpy() | |
if isinstance(gt, torch.Tensor): | |
gt = gt.detach().cpu().numpy() | |
if format == '[bs x c x 2D]': | |
self.map = ((pred[0] - gt[0])**2).sum(0) | |
self.reference_gt = copy.deepcopy( | |
np.transpose(gt[0], (1, 2, 0))) | |
elif format == ['[2D x c]']: | |
self.map = (np.abs(pred - gt)).sum(-1) | |
self.reference_gt = copy.deepcopy(gt[0]) | |
else: | |
raise ValueError | |
# OptionA: Zero too small errors to avoid the error too small deadloop | |
self.map[self.map < nodiff_thres] = 0 | |
quantile_interval = np.linspace(0., 1., quantile_interval) | |
quantized_interval = np.quantile(self.map, quantile_interval) | |
# remove redundant | |
quantized_interval = np.unique(quantized_interval) | |
quantized_interval = sorted(quantized_interval[1:-1]) | |
self.map = np.digitize(self.map, quantized_interval, right=False) | |
self.map = np.clip(self.map, 0, 255).astype(np.uint8) | |
self.idcnt = {} | |
for idi in sorted(np.unique(self.map)): | |
self.idcnt[idi] = (self.map==idi).sum() | |
self.idcnt.pop(min(self.idcnt.keys())) | |
# remove smallest one to remove the correct region | |
def __call__(self): | |
if len(self.idcnt) == 0: | |
h, w = self.map.shape | |
return [npr.uniform(0, 1)*w, npr.uniform(0, 1)*h] | |
target_id = max(self.idcnt, key=self.idcnt.get) | |
_, component, cstats, ccenter = cv2.connectedComponentsWithStats( | |
(self.map==target_id).astype(np.uint8), connectivity=4) | |
# remove cid = 0, it is the invalid area | |
csize = [ci[-1] for ci in cstats[1:]] | |
target_cid = csize.index(max(csize))+1 | |
center = ccenter[target_cid][::-1] | |
coord = np.stack(np.where(component == target_cid)).T | |
dist = np.linalg.norm(coord-center, axis=1) | |
target_coord_id = np.argmin(dist) | |
coord_h, coord_w = coord[target_coord_id] | |
# replace_sampling | |
self.idcnt[target_id] -= max(csize) | |
if self.idcnt[target_id] == 0: | |
self.idcnt.pop(target_id) | |
self.map[component == target_cid] = 0 | |
return [coord_w, coord_h] | |
def init_shapes(num_paths, | |
num_segments, | |
canvas_size, | |
seginit_cfg, | |
shape_cnt, | |
pos_init_method=None, | |
trainable_stroke=False, | |
gt=None, | |
**kwargs): | |
shapes = [] | |
shape_groups = [] | |
h, w = canvas_size | |
# change path init location | |
if pos_init_method is None: | |
pos_init_method = random_coord_init(canvas_size=canvas_size) | |
for i in range(num_paths): | |
num_control_points = [2] * num_segments | |
if seginit_cfg.type=="random": | |
points = [] | |
p0 = pos_init_method() | |
color_ref = copy.deepcopy(p0) | |
points.append(p0) | |
for j in range(num_segments): | |
radius = seginit_cfg.radius | |
p1 = (p0[0] + radius * npr.uniform(-0.5, 0.5), | |
p0[1] + radius * npr.uniform(-0.5, 0.5)) | |
p2 = (p1[0] + radius * npr.uniform(-0.5, 0.5), | |
p1[1] + radius * npr.uniform(-0.5, 0.5)) | |
p3 = (p2[0] + radius * npr.uniform(-0.5, 0.5), | |
p2[1] + radius * npr.uniform(-0.5, 0.5)) | |
points.append(p1) | |
points.append(p2) | |
if j < num_segments - 1: | |
points.append(p3) | |
p0 = p3 | |
points = torch.FloatTensor(points) | |
# circle points initialization | |
elif seginit_cfg.type=="circle": | |
radius = seginit_cfg.radius | |
if radius is None: | |
radius = npr.uniform(0.5, 1) | |
center = pos_init_method() | |
color_ref = copy.deepcopy(center) | |
points = get_bezier_circle( | |
radius=radius, segments=num_segments, | |
bias=center) | |
path = pydiffvg.Path(num_control_points = torch.LongTensor(num_control_points), | |
points = points, | |
stroke_width = torch.tensor(0.0), | |
is_closed = True) | |
shapes.append(path) | |
# !!!!!!problem is here. the shape group shape_ids is wrong | |
if gt is not None: | |
wref, href = color_ref | |
wref = max(0, min(int(wref), w-1)) | |
href = max(0, min(int(href), h-1)) | |
fill_color_init = list(gt[0, :, href, wref]) + [1.] | |
fill_color_init = torch.FloatTensor(fill_color_init) | |
stroke_color_init = torch.FloatTensor(npr.uniform(size=[4])) | |
else: | |
fill_color_init = torch.FloatTensor(npr.uniform(size=[4])) | |
stroke_color_init = torch.FloatTensor(npr.uniform(size=[4])) | |
path_group = pydiffvg.ShapeGroup( | |
shape_ids = torch.LongTensor([shape_cnt+i]), | |
fill_color = fill_color_init, | |
stroke_color = stroke_color_init, | |
) | |
shape_groups.append(path_group) | |
point_var = [] | |
color_var = [] | |
for path in shapes: | |
path.points.requires_grad = True | |
point_var.append(path.points) | |
for group in shape_groups: | |
group.fill_color.requires_grad = True | |
color_var.append(group.fill_color) | |
if trainable_stroke: | |
stroke_width_var = [] | |
stroke_color_var = [] | |
for path in shapes: | |
path.stroke_width.requires_grad = True | |
stroke_width_var.append(path.stroke_width) | |
for group in shape_groups: | |
group.stroke_color.requires_grad = True | |
stroke_color_var.append(group.stroke_color) | |
return shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var | |
else: | |
return shapes, shape_groups, point_var, color_var | |
class linear_decay_lrlambda_f(object): | |
def __init__(self, decay_every, decay_ratio): | |
self.decay_every = decay_every | |
self.decay_ratio = decay_ratio | |
def __call__(self, n): | |
decay_time = n//self.decay_every | |
decay_step = n %self.decay_every | |
lr_s = self.decay_ratio**decay_time | |
lr_e = self.decay_ratio**(decay_time+1) | |
r = decay_step/self.decay_every | |
lr = lr_s * (1-r) + lr_e * r | |
return lr | |
def main_func(target, experiment, num_iter, cfg_arg): | |
with open(cfg_arg.config, 'r') as f: | |
cfg = yaml.load(f, Loader=yaml.FullLoader) | |
cfg_default = edict(cfg['default']) | |
cfg = edict(cfg[cfg_arg.experiment]) | |
cfg.update(cfg_default) | |
cfg.update(cfg_arg) | |
cfg.exid = get_experiment_id(cfg.debug) | |
cfg.experiment_dir = \ | |
osp.join(cfg.log_dir, '{}_{}'.format(cfg.exid, '_'.join(cfg.signature))) | |
cfg.target = target | |
cfg.experiment = experiment | |
cfg.num_iter = num_iter | |
configfile = osp.join(cfg.experiment_dir, 'config.yaml') | |
check_and_create_dir(configfile) | |
with open(osp.join(configfile), 'w') as f: | |
yaml.dump(edict_2_dict(cfg), f) | |
# Use GPU if available | |
pydiffvg.set_use_gpu(torch.cuda.is_available()) | |
device = pydiffvg.get_device() | |
# gt = np.array(PIL.Image.open(cfg.target)) | |
gt = np.array(cfg.target) | |
print(f"Input image shape is: {gt.shape}") | |
if len(gt.shape) == 2: | |
print("Converting the gray-scale image to RGB.") | |
gt = gt.unsqueeze(dim=-1).repeat(1,1,3) | |
if gt.shape[2] == 4: | |
print("Input image includes alpha channel, simply dropout alpha channel.") | |
gt = gt[:, :, :3] | |
gt = (gt/255).astype(np.float32) | |
gt = torch.FloatTensor(gt).permute(2, 0, 1)[None].to(device) | |
if cfg.use_ycrcb: | |
gt = ycrcb_conversion(gt) | |
h, w = gt.shape[2:] | |
path_schedule = get_path_schedule(**cfg.path_schedule) | |
if cfg.seed is not None: | |
random.seed(cfg.seed) | |
npr.seed(cfg.seed) | |
torch.manual_seed(cfg.seed) | |
render = pydiffvg.RenderFunction.apply | |
shapes_record, shape_groups_record = [], [] | |
region_loss = None | |
loss_matrix = [] | |
para_point, para_color = {}, {} | |
if cfg.trainable.stroke: | |
para_stroke_width, para_stroke_color = {}, {} | |
pathn_record = [] | |
# Background | |
if cfg.trainable.bg: | |
# meancolor = gt.mean([2, 3])[0] | |
para_bg = torch.tensor([1., 1., 1.], requires_grad=True, device=device) | |
else: | |
if cfg.use_ycrcb: | |
para_bg = torch.tensor([219/255, 0, 0], requires_grad=False, device=device) | |
else: | |
para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=device) | |
################## | |
# start_training # | |
################## | |
loss_weight = None | |
loss_weight_keep = 0 | |
if cfg.coord_init.type == 'naive': | |
pos_init_method = naive_coord_init( | |
para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) | |
elif cfg.coord_init.type == 'sparse': | |
pos_init_method = sparse_coord_init( | |
para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) | |
elif cfg.coord_init.type == 'random': | |
pos_init_method = random_coord_init([h, w]) | |
else: | |
raise ValueError | |
lrlambda_f = linear_decay_lrlambda_f(cfg.num_iter, 0.4) | |
optim_schedular_dict = {} | |
for path_idx, pathn in enumerate(path_schedule): | |
loss_list = [] | |
print("=> Adding [{}] paths, [{}] ...".format(pathn, cfg.seginit.type)) | |
pathn_record.append(pathn) | |
pathn_record_str = '-'.join([str(i) for i in pathn_record]) | |
# initialize new shapes related stuffs. | |
if cfg.trainable.stroke: | |
shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var = init_shapes( | |
pathn, cfg.num_segments, (h, w), | |
cfg.seginit, len(shapes_record), | |
pos_init_method, | |
trainable_stroke=True, | |
gt=gt, ) | |
para_stroke_width[path_idx] = stroke_width_var | |
para_stroke_color[path_idx] = stroke_color_var | |
else: | |
shapes, shape_groups, point_var, color_var = init_shapes( | |
pathn, cfg.num_segments, (h, w), | |
cfg.seginit, len(shapes_record), | |
pos_init_method, | |
trainable_stroke=False, | |
gt=gt, ) | |
shapes_record += shapes | |
shape_groups_record += shape_groups | |
if cfg.save.init: | |
filename = os.path.join( | |
cfg.experiment_dir, "svg-init", | |
"{}-init.svg".format(pathn_record_str)) | |
check_and_create_dir(filename) | |
pydiffvg.save_svg( | |
filename, w, h, | |
shapes_record, shape_groups_record) | |
para = {} | |
if (cfg.trainable.bg) and (path_idx == 0): | |
para['bg'] = [para_bg] | |
para['point'] = point_var | |
para['color'] = color_var | |
if cfg.trainable.stroke: | |
para['stroke_width'] = stroke_width_var | |
para['stroke_color'] = stroke_color_var | |
pg = [{'params' : para[ki], 'lr' : cfg.lr_base[ki]} for ki in sorted(para.keys())] | |
optim = torch.optim.Adam(pg) | |
if cfg.trainable.record: | |
scheduler = LambdaLR( | |
optim, lr_lambda=lrlambda_f, last_epoch=-1) | |
else: | |
scheduler = LambdaLR( | |
optim, lr_lambda=lrlambda_f, last_epoch=cfg.num_iter) | |
optim_schedular_dict[path_idx] = (optim, scheduler) | |
# Inner loop training | |
t_range = tqdm(range(cfg.num_iter)) | |
for t in t_range: | |
for _, (optim, _) in optim_schedular_dict.items(): | |
optim.zero_grad() | |
# Forward pass: render the image. | |
scene_args = pydiffvg.RenderFunction.serialize_scene( | |
w, h, shapes_record, shape_groups_record) | |
img = render(w, h, 2, 2, t, None, *scene_args) | |
# Compose img with white background | |
img = img[:, :, 3:4] * img[:, :, :3] + \ | |
para_bg * (1 - img[:, :, 3:4]) | |
if cfg.save.video: | |
filename = os.path.join( | |
cfg.experiment_dir, "video-png", | |
"{}-iter{}.png".format(pathn_record_str, t)) | |
check_and_create_dir(filename) | |
if cfg.use_ycrcb: | |
imshow = ycrcb_conversion( | |
img, format='[2D x 3]', reverse=True).detach().cpu() | |
else: | |
imshow = img.detach().cpu() | |
pydiffvg.imwrite(imshow, filename, gamma=gamma) | |
# ### added for app | |
# if t%30==0 and t !=0 : | |
# # print(f"debug: {t}, {filename} {img.size()}") | |
# return img.detach().cpu().numpy(), t | |
x = img.unsqueeze(0).permute(0, 3, 1, 2) # HWC -> NCHW | |
if cfg.use_ycrcb: | |
color_reweight = torch.FloatTensor([255/219, 255/224, 255/255]).to(device) | |
loss = ((x-gt)*(color_reweight.view(1, -1, 1, 1)))**2 | |
else: | |
loss = ((x-gt)**2) | |
if cfg.loss.use_l1_loss: | |
loss = abs(x-gt) | |
if cfg.loss.use_distance_weighted_loss: | |
if cfg.use_ycrcb: | |
raise ValueError | |
shapes_forsdf = copy.deepcopy(shapes) | |
shape_groups_forsdf = copy.deepcopy(shape_groups) | |
for si in shapes_forsdf: | |
si.stroke_width = torch.FloatTensor([0]).to(device) | |
for sg_idx, sgi in enumerate(shape_groups_forsdf): | |
sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(device) | |
sgi.shape_ids = torch.LongTensor([sg_idx]).to(device) | |
sargs_forsdf = pydiffvg.RenderFunction.serialize_scene( | |
w, h, shapes_forsdf, shape_groups_forsdf) | |
with torch.no_grad(): | |
im_forsdf = render(w, h, 2, 2, 0, None, *sargs_forsdf) | |
# use alpha channel is a trick to get 0-1 image | |
im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy() | |
loss_weight = get_sdf(im_forsdf, normalize='to1') | |
loss_weight += loss_weight_keep | |
loss_weight = np.clip(loss_weight, 0, 1) | |
loss_weight = torch.FloatTensor(loss_weight).to(device) | |
if cfg.save.loss: | |
save_loss = loss.squeeze(dim=0).mean(dim=0,keepdim=False).cpu().detach().numpy() | |
save_weight = loss_weight.cpu().detach().numpy() | |
save_weighted_loss = save_loss*save_weight | |
# normalize to [0,1] | |
save_loss = (save_loss - np.min(save_loss))/np.ptp(save_loss) | |
save_weight = (save_weight - np.min(save_weight))/np.ptp(save_weight) | |
save_weighted_loss = (save_weighted_loss - np.min(save_weighted_loss))/np.ptp(save_weighted_loss) | |
# save | |
plt.imshow(save_loss, cmap='Reds') | |
plt.axis('off') | |
# plt.colorbar() | |
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-mseloss.png".format(pathn_record_str, t)) | |
check_and_create_dir(filename) | |
plt.savefig(filename, dpi=800) | |
plt.close() | |
plt.imshow(save_weight, cmap='Greys') | |
plt.axis('off') | |
# plt.colorbar() | |
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-sdfweight.png".format(pathn_record_str, t)) | |
plt.savefig(filename, dpi=800) | |
plt.close() | |
plt.imshow(save_weighted_loss, cmap='Reds') | |
plt.axis('off') | |
# plt.colorbar() | |
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-weightedloss.png".format(pathn_record_str, t)) | |
plt.savefig(filename, dpi=800) | |
plt.close() | |
if loss_weight is None: | |
loss = loss.sum(1).mean() | |
else: | |
loss = (loss.sum(1)*loss_weight).mean() | |
# if (cfg.loss.bis_loss_weight is not None) and (cfg.loss.bis_loss_weight > 0): | |
# loss_bis = bezier_intersection_loss(point_var[0]) * cfg.loss.bis_loss_weight | |
# loss = loss + loss_bis | |
if (cfg.loss.xing_loss_weight is not None) \ | |
and (cfg.loss.xing_loss_weight > 0): | |
loss_xing = xing_loss(point_var) * cfg.loss.xing_loss_weight | |
loss = loss + loss_xing | |
loss_list.append(loss.item()) | |
t_range.set_postfix({'loss': loss.item()}) | |
loss.backward() | |
# step | |
for _, (optim, scheduler) in optim_schedular_dict.items(): | |
optim.step() | |
scheduler.step() | |
for group in shape_groups_record: | |
group.fill_color.data.clamp_(0.0, 1.0) | |
if cfg.loss.use_distance_weighted_loss: | |
loss_weight_keep = loss_weight.detach().cpu().numpy() * 1 | |
if not cfg.trainable.record: | |
for _, pi in pg.items(): | |
for ppi in pi: | |
pi.require_grad = False | |
optim_schedular_dict = {} | |
if cfg.save.image: | |
filename = os.path.join( | |
cfg.experiment_dir, "demo-png", "{}.png".format(pathn_record_str)) | |
check_and_create_dir(filename) | |
if cfg.use_ycrcb: | |
imshow = ycrcb_conversion( | |
img, format='[2D x 3]', reverse=True).detach().cpu() | |
else: | |
imshow = img.detach().cpu() | |
pydiffvg.imwrite(imshow, filename, gamma=gamma) | |
svg_app_file_name = "" | |
if cfg.save.output: | |
filename = os.path.join( | |
cfg.experiment_dir, "output-svg", "{}.svg".format(pathn_record_str)) | |
check_and_create_dir(filename) | |
pydiffvg.save_svg(filename, w, h, shapes_record, shape_groups_record) | |
svg_app_file_name = filename | |
loss_matrix.append(loss_list) | |
# calculate the pixel loss | |
# pixel_loss = ((x-gt)**2).sum(dim=1, keepdim=True).sqrt_() # [N,1,H, W] | |
# region_loss = adaptive_avg_pool2d(pixel_loss, cfg.region_loss_pool_size) | |
# loss_weight = torch.softmax(region_loss.reshape(1, 1, -1), dim=-1)\ | |
# .reshape_as(region_loss) | |
pos_init_method = naive_coord_init(x, gt) | |
if cfg.coord_init.type == 'naive': | |
pos_init_method = naive_coord_init(x, gt) | |
elif cfg.coord_init.type == 'sparse': | |
pos_init_method = sparse_coord_init(x, gt) | |
elif cfg.coord_init.type == 'random': | |
pos_init_method = random_coord_init([h, w]) | |
else: | |
raise ValueError | |
if cfg.save.video: | |
print("saving iteration video...") | |
img_array = [] | |
for ii in range(0, cfg.num_iter): | |
filename = os.path.join( | |
cfg.experiment_dir, "video-png", | |
"{}-iter{}.png".format(pathn_record_str, ii)) | |
img = cv2.imread(filename) | |
# cv2.putText( | |
# img, "Path:{} \nIteration:{}".format(pathn_record_str, ii), | |
# (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) | |
img_array.append(img) | |
videoname = os.path.join( | |
cfg.experiment_dir, "video-avi", | |
"{}.avi".format(pathn_record_str)) | |
check_and_create_dir(videoname) | |
out = cv2.VideoWriter( | |
videoname, | |
# cv2.VideoWriter_fourcc(*'mp4v'), | |
cv2.VideoWriter_fourcc(*'FFV1'), | |
20.0, (w, h)) | |
for iii in range(len(img_array)): | |
out.write(img_array[iii]) | |
out.release() | |
# shutil.rmtree(os.path.join(cfg.experiment_dir, "video-png")) | |
print("The last loss is: {}".format(loss.item())) | |
return img.detach().cpu().numpy(), svg_app_file_name | |
if __name__ == "__main__": | |
############### | |
# make config # | |
############### | |
cfg_arg = parse_args() | |
with open(cfg_arg.config, 'r') as f: | |
cfg = yaml.load(f, Loader=yaml.FullLoader) | |
cfg_default = edict(cfg['default']) | |
cfg = edict(cfg[cfg_arg.experiment]) | |
cfg.update(cfg_default) | |
cfg.update(cfg_arg) | |
cfg.exid = get_experiment_id(cfg.debug) | |
cfg.experiment_dir = \ | |
osp.join(cfg.log_dir, '{}_{}'.format(cfg.exid, '_'.join(cfg.signature))) | |
configfile = osp.join(cfg.experiment_dir, 'config.yaml') | |
check_and_create_dir(configfile) | |
with open(osp.join(configfile), 'w') as f: | |
yaml.dump(edict_2_dict(cfg), f) | |
# Use GPU if available | |
pydiffvg.set_use_gpu(torch.cuda.is_available()) | |
device = pydiffvg.get_device() | |
gt = np.array(PIL.Image.open(cfg.target)) | |
print(f"Input image shape is: {gt.shape}") | |
if len(gt.shape) == 2: | |
print("Converting the gray-scale image to RGB.") | |
gt = gt.unsqueeze(dim=-1).repeat(1,1,3) | |
if gt.shape[2] == 4: | |
print("Input image includes alpha channel, simply dropout alpha channel.") | |
gt = gt[:, :, :3] | |
gt = (gt/255).astype(np.float32) | |
gt = torch.FloatTensor(gt).permute(2, 0, 1)[None].to(device) | |
if cfg.use_ycrcb: | |
gt = ycrcb_conversion(gt) | |
h, w = gt.shape[2:] | |
path_schedule = get_path_schedule(**cfg.path_schedule) | |
if cfg.seed is not None: | |
random.seed(cfg.seed) | |
npr.seed(cfg.seed) | |
torch.manual_seed(cfg.seed) | |
render = pydiffvg.RenderFunction.apply | |
shapes_record, shape_groups_record = [], [] | |
region_loss = None | |
loss_matrix = [] | |
para_point, para_color = {}, {} | |
if cfg.trainable.stroke: | |
para_stroke_width, para_stroke_color = {}, {} | |
pathn_record = [] | |
# Background | |
if cfg.trainable.bg: | |
# meancolor = gt.mean([2, 3])[0] | |
para_bg = torch.tensor([1., 1., 1.], requires_grad=True, device=device) | |
else: | |
if cfg.use_ycrcb: | |
para_bg = torch.tensor([219/255, 0, 0], requires_grad=False, device=device) | |
else: | |
para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=device) | |
################## | |
# start_training # | |
################## | |
loss_weight = None | |
loss_weight_keep = 0 | |
if cfg.coord_init.type == 'naive': | |
pos_init_method = naive_coord_init( | |
para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) | |
elif cfg.coord_init.type == 'sparse': | |
pos_init_method = sparse_coord_init( | |
para_bg.view(1, -1, 1, 1).repeat(1, 1, h, w), gt) | |
elif cfg.coord_init.type == 'random': | |
pos_init_method = random_coord_init([h, w]) | |
else: | |
raise ValueError | |
lrlambda_f = linear_decay_lrlambda_f(cfg.num_iter, 0.4) | |
optim_schedular_dict = {} | |
for path_idx, pathn in enumerate(path_schedule): | |
loss_list = [] | |
print("=> Adding [{}] paths, [{}] ...".format(pathn, cfg.seginit.type)) | |
pathn_record.append(pathn) | |
pathn_record_str = '-'.join([str(i) for i in pathn_record]) | |
# initialize new shapes related stuffs. | |
if cfg.trainable.stroke: | |
shapes, shape_groups, point_var, color_var, stroke_width_var, stroke_color_var = init_shapes( | |
pathn, cfg.num_segments, (h, w), | |
cfg.seginit, len(shapes_record), | |
pos_init_method, | |
trainable_stroke=True, | |
gt=gt, ) | |
para_stroke_width[path_idx] = stroke_width_var | |
para_stroke_color[path_idx] = stroke_color_var | |
else: | |
shapes, shape_groups, point_var, color_var = init_shapes( | |
pathn, cfg.num_segments, (h, w), | |
cfg.seginit, len(shapes_record), | |
pos_init_method, | |
trainable_stroke=False, | |
gt=gt, ) | |
shapes_record += shapes | |
shape_groups_record += shape_groups | |
if cfg.save.init: | |
filename = os.path.join( | |
cfg.experiment_dir, "svg-init", | |
"{}-init.svg".format(pathn_record_str)) | |
check_and_create_dir(filename) | |
pydiffvg.save_svg( | |
filename, w, h, | |
shapes_record, shape_groups_record) | |
para = {} | |
if (cfg.trainable.bg) and (path_idx == 0): | |
para['bg'] = [para_bg] | |
para['point'] = point_var | |
para['color'] = color_var | |
if cfg.trainable.stroke: | |
para['stroke_width'] = stroke_width_var | |
para['stroke_color'] = stroke_color_var | |
pg = [{'params' : para[ki], 'lr' : cfg.lr_base[ki]} for ki in sorted(para.keys())] | |
optim = torch.optim.Adam(pg) | |
if cfg.trainable.record: | |
scheduler = LambdaLR( | |
optim, lr_lambda=lrlambda_f, last_epoch=-1) | |
else: | |
scheduler = LambdaLR( | |
optim, lr_lambda=lrlambda_f, last_epoch=cfg.num_iter) | |
optim_schedular_dict[path_idx] = (optim, scheduler) | |
# Inner loop training | |
t_range = tqdm(range(cfg.num_iter)) | |
for t in t_range: | |
for _, (optim, _) in optim_schedular_dict.items(): | |
optim.zero_grad() | |
# Forward pass: render the image. | |
scene_args = pydiffvg.RenderFunction.serialize_scene( | |
w, h, shapes_record, shape_groups_record) | |
img = render(w, h, 2, 2, t, None, *scene_args) | |
# Compose img with white background | |
img = img[:, :, 3:4] * img[:, :, :3] + \ | |
para_bg * (1 - img[:, :, 3:4]) | |
if cfg.save.video: | |
filename = os.path.join( | |
cfg.experiment_dir, "video-png", | |
"{}-iter{}.png".format(pathn_record_str, t)) | |
check_and_create_dir(filename) | |
if cfg.use_ycrcb: | |
imshow = ycrcb_conversion( | |
img, format='[2D x 3]', reverse=True).detach().cpu() | |
else: | |
imshow = img.detach().cpu() | |
pydiffvg.imwrite(imshow, filename, gamma=gamma) | |
x = img.unsqueeze(0).permute(0, 3, 1, 2) # HWC -> NCHW | |
if cfg.use_ycrcb: | |
color_reweight = torch.FloatTensor([255/219, 255/224, 255/255]).to(device) | |
loss = ((x-gt)*(color_reweight.view(1, -1, 1, 1)))**2 | |
else: | |
loss = ((x-gt)**2) | |
if cfg.loss.use_l1_loss: | |
loss = abs(x-gt) | |
if cfg.loss.use_distance_weighted_loss: | |
if cfg.use_ycrcb: | |
raise ValueError | |
shapes_forsdf = copy.deepcopy(shapes) | |
shape_groups_forsdf = copy.deepcopy(shape_groups) | |
for si in shapes_forsdf: | |
si.stroke_width = torch.FloatTensor([0]).to(device) | |
for sg_idx, sgi in enumerate(shape_groups_forsdf): | |
sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(device) | |
sgi.shape_ids = torch.LongTensor([sg_idx]).to(device) | |
sargs_forsdf = pydiffvg.RenderFunction.serialize_scene( | |
w, h, shapes_forsdf, shape_groups_forsdf) | |
with torch.no_grad(): | |
im_forsdf = render(w, h, 2, 2, 0, None, *sargs_forsdf) | |
# use alpha channel is a trick to get 0-1 image | |
im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy() | |
loss_weight = get_sdf(im_forsdf, normalize='to1') | |
loss_weight += loss_weight_keep | |
loss_weight = np.clip(loss_weight, 0, 1) | |
loss_weight = torch.FloatTensor(loss_weight).to(device) | |
if cfg.save.loss: | |
save_loss = loss.squeeze(dim=0).mean(dim=0,keepdim=False).cpu().detach().numpy() | |
save_weight = loss_weight.cpu().detach().numpy() | |
save_weighted_loss = save_loss*save_weight | |
# normalize to [0,1] | |
save_loss = (save_loss - np.min(save_loss))/np.ptp(save_loss) | |
save_weight = (save_weight - np.min(save_weight))/np.ptp(save_weight) | |
save_weighted_loss = (save_weighted_loss - np.min(save_weighted_loss))/np.ptp(save_weighted_loss) | |
# save | |
plt.imshow(save_loss, cmap='Reds') | |
plt.axis('off') | |
# plt.colorbar() | |
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-mseloss.png".format(pathn_record_str, t)) | |
check_and_create_dir(filename) | |
plt.savefig(filename, dpi=800) | |
plt.close() | |
plt.imshow(save_weight, cmap='Greys') | |
plt.axis('off') | |
# plt.colorbar() | |
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-sdfweight.png".format(pathn_record_str, t)) | |
plt.savefig(filename, dpi=800) | |
plt.close() | |
plt.imshow(save_weighted_loss, cmap='Reds') | |
plt.axis('off') | |
# plt.colorbar() | |
filename = os.path.join(cfg.experiment_dir, "loss", "{}-iter{}-weightedloss.png".format(pathn_record_str, t)) | |
plt.savefig(filename, dpi=800) | |
plt.close() | |
if loss_weight is None: | |
loss = loss.sum(1).mean() | |
else: | |
loss = (loss.sum(1)*loss_weight).mean() | |
# if (cfg.loss.bis_loss_weight is not None) and (cfg.loss.bis_loss_weight > 0): | |
# loss_bis = bezier_intersection_loss(point_var[0]) * cfg.loss.bis_loss_weight | |
# loss = loss + loss_bis | |
if (cfg.loss.xing_loss_weight is not None) \ | |
and (cfg.loss.xing_loss_weight > 0): | |
loss_xing = xing_loss(point_var) * cfg.loss.xing_loss_weight | |
loss = loss + loss_xing | |
loss_list.append(loss.item()) | |
t_range.set_postfix({'loss': loss.item()}) | |
loss.backward() | |
# step | |
for _, (optim, scheduler) in optim_schedular_dict.items(): | |
optim.step() | |
scheduler.step() | |
for group in shape_groups_record: | |
group.fill_color.data.clamp_(0.0, 1.0) | |
if cfg.loss.use_distance_weighted_loss: | |
loss_weight_keep = loss_weight.detach().cpu().numpy() * 1 | |
if not cfg.trainable.record: | |
for _, pi in pg.items(): | |
for ppi in pi: | |
pi.require_grad = False | |
optim_schedular_dict = {} | |
if cfg.save.image: | |
filename = os.path.join( | |
cfg.experiment_dir, "demo-png", "{}.png".format(pathn_record_str)) | |
check_and_create_dir(filename) | |
if cfg.use_ycrcb: | |
imshow = ycrcb_conversion( | |
img, format='[2D x 3]', reverse=True).detach().cpu() | |
else: | |
imshow = img.detach().cpu() | |
pydiffvg.imwrite(imshow, filename, gamma=gamma) | |
if cfg.save.output: | |
filename = os.path.join( | |
cfg.experiment_dir, "output-svg", "{}.svg".format(pathn_record_str)) | |
check_and_create_dir(filename) | |
pydiffvg.save_svg(filename, w, h, shapes_record, shape_groups_record) | |
loss_matrix.append(loss_list) | |
# calculate the pixel loss | |
# pixel_loss = ((x-gt)**2).sum(dim=1, keepdim=True).sqrt_() # [N,1,H, W] | |
# region_loss = adaptive_avg_pool2d(pixel_loss, cfg.region_loss_pool_size) | |
# loss_weight = torch.softmax(region_loss.reshape(1, 1, -1), dim=-1)\ | |
# .reshape_as(region_loss) | |
pos_init_method = naive_coord_init(x, gt) | |
if cfg.coord_init.type == 'naive': | |
pos_init_method = naive_coord_init(x, gt) | |
elif cfg.coord_init.type == 'sparse': | |
pos_init_method = sparse_coord_init(x, gt) | |
elif cfg.coord_init.type == 'random': | |
pos_init_method = random_coord_init([h, w]) | |
else: | |
raise ValueError | |
if cfg.save.video: | |
print("saving iteration video...") | |
img_array = [] | |
for ii in range(0, cfg.num_iter): | |
filename = os.path.join( | |
cfg.experiment_dir, "video-png", | |
"{}-iter{}.png".format(pathn_record_str, ii)) | |
img = cv2.imread(filename) | |
# cv2.putText( | |
# img, "Path:{} \nIteration:{}".format(pathn_record_str, ii), | |
# (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) | |
img_array.append(img) | |
videoname = os.path.join( | |
cfg.experiment_dir, "video-avi", | |
"{}.avi".format(pathn_record_str)) | |
check_and_create_dir(videoname) | |
out = cv2.VideoWriter( | |
videoname, | |
# cv2.VideoWriter_fourcc(*'mp4v'), | |
cv2.VideoWriter_fourcc(*'FFV1'), | |
20.0, (w, h)) | |
for iii in range(len(img_array)): | |
out.write(img_array[iii]) | |
out.release() | |
# shutil.rmtree(os.path.join(cfg.experiment_dir, "video-png")) | |
print("The last loss is: {}".format(loss.item())) | |