wanghaofan's picture
Upload 10 files
35ed688 verified
import cv2
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter
from skimage.measure import label
from . import util
from .model import handpose_model
class Hand(object):
def __init__(self, model_path):
self.model = handpose_model()
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
def to(self, device):
self.model.to(device)
return self
def __call__(self, oriImgRaw):
device = next(iter(self.model.parameters())).device
scale_search = [0.5, 1.0, 1.5, 2.0]
# scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre = 0.05
multiplier = [x * boxsize for x in scale_search]
wsize = 128
heatmap_avg = np.zeros((wsize, wsize, 22))
Hr, Wr, Cr = oriImgRaw.shape
oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8)
for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = util.smart_resize(oriImg, (scale, scale))
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)
data = torch.from_numpy(im).float()
data = data.to(device)
with torch.no_grad():
output = self.model(data).cpu().numpy()
# extract outputs, resize, and remove padding
heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride)
heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
heatmap = util.smart_resize(heatmap, (wsize, wsize))
heatmap_avg += heatmap / len(multiplier)
all_peaks = []
for part in range(21):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)
binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
if np.sum(binary) == 0:
all_peaks.append([0, 0])
continue
label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
label_img[label_img != max_index] = 0
map_ori[label_img == 0] = 0
y, x = util.npmax(map_ori)
y = int(float(y) * float(Hr) / float(wsize))
x = int(float(x) * float(Wr) / float(wsize))
all_peaks.append([x, y])
return np.array(all_peaks)
if __name__ == "__main__":
hand_estimation = Hand('../model/hand_pose_model.pth')
# test_image = '../images/hand.jpg'
test_image = '../images/hand.jpg'
oriImg = cv2.imread(test_image) # B,G,R order
peaks = hand_estimation(oriImg)
canvas = util.draw_handpose(oriImg, peaks, True)
cv2.imshow('', canvas)
cv2.waitKey(0)