demo-painttransformer / inference.py
jaekookang
first upload
9ff1108
raw
history blame
2.78 kB
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import network
import os
import math
import render_utils
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import cv2
import render_parallel
import render_serial
def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
if not os.path.exists(output_dir):
os.mkdir(output_dir)
input_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, input_name)
frame_dir = None
if need_animation:
if not serial:
print('It must be under serial mode if animation results are required, so serial flag is set to True!')
serial = True
frame_dir = os.path.join(output_dir, input_name[:input_name.find('.')])
if not os.path.exists(frame_dir):
os.mkdir(frame_dir)
stroke_num = 8
#* ----- load model ----- *#
# paddle.set_device('gpu')
paddle.set_device('cpu') # 2021-12-21 jkang edited to "cpu"
net_g = network.Painter(5, stroke_num, 256, 8, 3, 3)
net_g.set_state_dict(paddle.load(model_path))
net_g.eval()
for param in net_g.parameters():
param.stop_gradient = True
#* ----- load brush ----- *#
brush_large_vertical = render_utils.read_img('brush/brush_large_vertical.png', 'L')
brush_large_horizontal = render_utils.read_img('brush/brush_large_horizontal.png', 'L')
meta_brushes = paddle.concat([brush_large_vertical, brush_large_horizontal], axis=0)
import time
t0 = time.time()
original_img = render_utils.read_img(input_path, 'RGB', resize_h, resize_w)
if serial:
final_result_list = render_serial.render_serial(original_img, net_g, meta_brushes)
if need_animation:
print("total frame:", len(final_result_list))
for idx, frame in enumerate(final_result_list):
cv2.imwrite(os.path.join(frame_dir, '%03d.png' %idx), frame)
else:
cv2.imwrite(output_path, final_result_list[-1])
else:
final_result = render_parallel.render_parallel(original_img, net_g, meta_brushes)
cv2.imwrite(output_path, final_result)
print("total infer time:", time.time() - t0)
if __name__ == '__main__':
main(input_path='input/chicago.jpg',
model_path='paint_best.pdparams',
output_dir='output/',
need_animation=True, # whether need intermediate results for animation.
resize_h=512, # resize original input to this size. None means do not resize.
resize_w=512, # resize original input to this size. None means do not resize.
serial=True) # if need animation, serial must be True.