ID-Pose / src /pose_estimation.py
tokenid
fix plot stuck
146da98
raw
history blame
8.95 kB
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from datetime import datetime
from .ldm.util import load_and_preprocess, instantiate_from_config
from .pose_funcs import probe_pose, find_optimal_poses, get_inv_pose, add_pose, pairwise_loss
from .oee.utils.elev_est_api import elev_est_api, ElevEstHelper
from .sampling import sample_images
def load_image(img_path, mask_path=None, preprocessor=None, threshold=0.9):
img = Image.open(img_path)
if preprocessor is not None:
img = load_and_preprocess(preprocessor, img)
else:
if img.mode == 'RGBA':
img = np.asarray(img, dtype=np.float32) / 255.
img[img[:, :, -1] <= threshold] = [1., 1., 1., 1.] # thresholding background
img = img[:, :, :3]
elif img.mode == 'RGB':
if mask_path is not None:
mask = Image.open(mask_path)
bkg = Image.new('RGB', (img.width, img.height), color=(255, 255, 255))
img = Image.composite(img, bkg, mask)
img = np.asarray(img, dtype=np.float32) / 255.
else:
print('Wrong format:', img_path)
return img
def load_model_from_config(config, ckpt, device, verbose=False):
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location=device)
if 'global_step' in pl_sd:
step = pl_sd['global_step']
print(f'Global Step: {step}')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('missing keys:')
print(m)
if len(u) > 0 and verbose:
print('unexpected keys:')
print(u)
model.to(device)
model.eval()
return model
def estimate_elevs(model, images, est_type=None, matcher_ckpt_path=None):
num = len(images)
elevs = {i: None for i in range(num)}
elev_ranges = {i: None for i in range(num)}
if est_type == 'all':
matcher = ElevEstHelper.get_feature_matcher(matcher_ckpt_path, model.device)
for i in range(num):
simgs = sample_surrounding_images(model, images[i])
elev = elev_est_api(matcher, simgs, min_elev=20, max_elev=160)
elevs[i] = elev
for i in range(num):
if elevs[i] is not None:
elevs[i] = np.deg2rad(elevs[i])
for i in range(1, num):
if elevs[i] is not None and elevs[0] is not None:
elev_ranges[i] = np.array([ elevs[i] - elevs[0] ])
elif elevs[i] is not None:
elev_ranges[i] = -make_elev_probe_range(elevs[i])
elif elevs[0] is not None:
elev_ranges[i] = make_elev_probe_range(elevs[0])
elif est_type == 'simple':
matcher = ElevEstHelper.get_feature_matcher(matcher_ckpt_path, model.device)
simgs = sample_surrounding_images(model, images[0])
elev = elev_est_api(matcher, simgs, min_elev=20, max_elev=160)
elevs[0] = np.deg2rad(elev) if elev is not None else None
ae = elevs[0] if elevs[0] is not None else np.pi/2
for i in range(1, num):
elev_ranges[i] = np.array([np.pi/2 - ae])
return elevs, elev_ranges
def estimate_poses(
model, images,
seed_cand_num=8,
explore_type='pairwise',
refine_type='pairwise',
probe_ts_range=[0.02, 0.98], ts_range=[0.02, 0.98],
probe_bsz=16,
adjust_factor=10.,
adjust_iters=10,
adjust_bsz=1,
refine_factor=1.,
refine_iters=600,
refine_bsz=1,
noise=None,
elevs=None,
elev_ranges=None
):
num = len(images)
if elevs is None:
elevs = {i: None for i in range(num)}
if elev_ranges is None:
elev_ranges = {i: None for i in range(num)}
if num <= 2:
explore_type = 'pairwise'
cands = {}
losses = {}
ep_poses = {i: None for i in range(num)}
pairwise_ep_poses = {i: None for i in range(num)}
print('Start', datetime.now())
images = [ img.permute(0, 2, 3, 1) for img in images ]
for i in range(1, num):
print('PAIR', 0, i, datetime.now())
azimuth_range = np.arange(start=0.0, stop=np.pi*2, step=np.pi*2 / seed_cand_num)
all_cands = probe_pose(model, images[0], images[i], probe_ts_range, probe_bsz, theta_range=elev_ranges[i], azimuth_range=azimuth_range, noise=noise)
all_cands = sorted(all_cands)
print('Exploration', len(all_cands), datetime.now())
adjusted_cands = all_cands[:5]
if adjust_iters > 0:
adjusted_cands = []
'''only adjust the first half'''
for cand in all_cands[:len(all_cands)//2]:
out_poses, _, _ = find_optimal_poses(
model, [images[0], images[i]],
adjust_factor,
bsz=adjust_bsz,
n_iter=adjust_iters,
init_poses={1: cand[1]},
ts_range=ts_range,
print_n=100,
avg_last_n=1
)
loss = pairwise_loss(out_poses[0], model, images[0], images[i], probe_ts_range, probe_bsz, noise=noise)
adjusted_cands.append((loss, out_poses[0], cand[0], cand[1]))
adjusted_cands = sorted(adjusted_cands)[:5]
for cand in adjusted_cands:
print(cand)
cands[i] = [ cand[:2] for cand in adjusted_cands ]
losses[i] = [loss if (explore_type == 'pairwise') else 0.0 for loss, _ in cands[i]]
pairwise_ep_poses[i] = min(cands[i])[1]
print('Selection', datetime.now())
if explore_type == 'triangular':
for i in range(1, num):
for j in range(i+1, num):
iloss = [ [None for v in range(0, len(cands[j]))] for u in range(0, len(cands[i])) ]
jloss = [ [None for u in range(0, len(cands[i]))] for v in range(0, len(cands[j])) ]
for u in range(0, len(cands[i])):
la, pa = cands[i][u]
# pose i -> 0
pa = get_inv_pose(pa)
for v in range(0, len(cands[j])):
# pose 0 -> j
lb, pb = cands[j][v]
theta, azimuth, radius = add_pose(pa, pb)
lp = pairwise_loss([theta, azimuth, radius], model, images[i], images[j], probe_ts_range, probe_bsz, noise=noise)
iloss[u][v] = la + lb + lp
jloss[v][u] = la + lb + lp
for u in range(0, len(cands[i])):
losses[i][u] += min(min(iloss[u]), cands[i][u][0]*3)
for v in range(0, len(cands[j])):
losses[j][v] += min(min(jloss[v]), cands[j][v][0]*3)
for i in range(1, num):
ranks = sorted([x for x in range(0, len(losses[i]))], key=lambda x: losses[i][x])
min_rank = ranks[0]
for u in range(0, len(cands[i])):
print(cands[i][u], losses[i][u])
print(i, 'SELECT', min_rank, losses[i][min_rank])
ep_poses[i] = cands[i][min_rank][1]
print('Refinement', datetime.now())
combinations = None
if refine_type == 'pairwise':
combinations = [ (0, i) for i in range(1, num) ] + [ (i, 0) for i in range(1, num) ]
elif refine_type == 'triangular':
combinations = []
for i in range(0, num):
for j in range(i+1, num):
combinations.append((i, j))
combinations.append((j, i))
print('Combinations', len(combinations), combinations)
'''Refinement'''
out_poses, _, loss = find_optimal_poses(
model, images,
refine_factor,
bsz=refine_bsz,
n_iter=(num-1)*refine_iters,
init_poses=ep_poses,
ts_range=ts_range,
combinations=combinations,
avg_last_n=20,
print_n=100
)
print('Done', datetime.now())
aux_data = {
'tri_ep_sph': ep_poses,
'pw_ep_sph': pairwise_ep_poses,
'elev': elevs
}
return out_poses, aux_data
def make_elev_probe_range(elev, interval=np.pi/4):
up_range = np.arange(elev, 0, -interval)
down_range = np.arange(elev+interval, np.pi, interval)
probe_range = np.concatenate([up_range, down_range])
probe_range -= elev
return probe_range
def sample_surrounding_images(model, image):
s0 = sample_images(model, image, float(np.deg2rad(-10)), 0, 0, n_samples=1)
s1 = sample_images(model, image, float(np.deg2rad(+10)), 0, 0, n_samples=1)
s2 = sample_images(model, image, 0, float(np.deg2rad(-10)), 0, n_samples=1)
s3 = sample_images(model, image, 0, float(np.deg2rad(+10)), 0, n_samples=1)
return s0 + s1 + s2 + s3