sketch2pose / src /spin /utils.py
kbrodt's picture
Upload utils.py
df84fee
raw
history blame
5.08 kB
import json
import cv2
import numpy as np
import torch
from skimage.transform import resize, rotate
from torchvision.transforms import Normalize
from .constants import IMG_NORM_MEAN, IMG_NORM_STD, IMG_RES
def get_transform(center, scale, res, rot=0):
"""Generate transformation matrix."""
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1] / 2
t_mat[1, 2] = -res[0] / 2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
"""Transform pixel location to different reference."""
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int) + 1
def crop(img, center, scale, res, rot=0):
"""Crop image according to the supplied bounding box."""
# Upper left point
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
# Bottom right point
br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[
old_y[0] : old_y[1], old_x[0] : old_x[1]
]
if not rot == 0:
# Remove padding
new_img = rotate(new_img, rot)
new_img = new_img[pad:-pad, pad:-pad]
new_img = resize(new_img, res)
return new_img
def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):
"""Get center and scale for bounding box from openpose detections."""
with open(openpose_file, "r") as f:
keypoints = json.load(f)["people"][0]["pose_keypoints_2d"]
keypoints = np.reshape(np.array(keypoints), (-1, 3))
valid = keypoints[:, -1] > detection_thresh
valid_keypoints = keypoints[valid][:, :-1]
center = valid_keypoints.mean(axis=0)
bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)).max()
# adjust bounding box tightness
scale = bbox_size / 200.0
scale *= rescale
return center, scale
def bbox_from_json(bbox_file):
"""Get center and scale of bounding box from bounding box annotations.
The expected format is [top_left(x), top_left(y), width, height].
"""
with open(bbox_file, "r") as f:
bbox = np.array(json.load(f)["bbox"]).astype(np.float32)
ul_corner = bbox[:2]
center = ul_corner + 0.5 * bbox[2:]
width = max(bbox[2], bbox[3])
scale = width / 200.0
# make sure the bounding box is rectangular
return center, scale
def process_image(img_file, bbox_file=None, openpose_file=None, input_res=IMG_RES):
"""Read image, do preprocessing and possibly crop it according to the bounding box.
If there are bounding box annotations, use them to crop the image.
If no bounding box is specified but openpose detections are available, use them to get the bounding box.
"""
img_file = str(img_file)
normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD)
img = cv2.imread(img_file)[
:, :, ::-1
].copy() # PyTorch does not support negative stride at the moment
if bbox_file is None and openpose_file is None:
# Assume that the person is centerered in the image
height = img.shape[0]
width = img.shape[1]
center = np.array([width // 2, height // 2])
scale = max(height, width) / 200
else:
if bbox_file is not None:
center, scale = bbox_from_json(bbox_file)
elif openpose_file is not None:
center, scale = bbox_from_openpose(openpose_file)
img = crop(img, center, scale, (input_res, input_res))
img = img.astype(np.float32) / 255.0
img = torch.from_numpy(img).permute(2, 0, 1)
norm_img = normalize_img(img.clone())
return img, norm_img