Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
class PoseT(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, pose): | |
p1 = pose[..., 0:1] | |
p2 = torch.sin(pose[..., 1:2]) | |
p3 = torch.cos(pose[..., 1:2]) | |
p4 = pose[..., 2:] | |
return torch.cat([p1, p2, p3, p4], dim=-1) | |
def noise_loss(model, cond_image, target_image, pose, ts_range, bsz, noise=None): | |
mx = ts_range[1] | |
mn = ts_range[0] | |
pose_layer = PoseT() | |
batch = {} | |
batch['image_target'] = target_image.repeat(bsz, 1, 1, 1) | |
batch['image_cond'] = cond_image.repeat(bsz, 1, 1, 1) | |
batch['T'] = pose_layer(pose.detach()).repeat(bsz, 1) | |
if noise is not None: | |
noise = torch.tensor(noise, dtype=model.dtype, device=model.device) | |
loss, _ = model.shared_step(batch, ts=np.arange(mn, mx, (mx-mn) / bsz), noise=noise[:bsz]) | |
return loss.item() | |
def pairwise_loss(pose, model, cond_image, target_image, ts_range, probe_bsz, noise=None): | |
theta, azimuth, radius = pose | |
pose1 = torch.tensor([[theta, azimuth, radius]], device=model.device, dtype=torch.float32) | |
pose2 = torch.tensor([[-theta, np.pi*2-azimuth, -radius]], device=model.device, dtype=torch.float32) | |
loss1 = noise_loss(model, cond_image, target_image, pose1, ts_range, probe_bsz, noise=noise) | |
loss2 = noise_loss(model, target_image, cond_image, pose2, ts_range, probe_bsz, noise=noise) | |
return loss1 + loss2 | |
def probe_pose(model, cond_image, target_image, ts_range, probe_bsz, theta_range=None, azimuth_range=None, radius_range=None, noise=None): | |
eps = 1e-5 | |
if theta_range is None: | |
theta_range = np.arange(start=-np.pi*2/3, stop=np.pi*2/3+eps, step=np.pi/3) | |
if azimuth_range is None: | |
azimuth_range = np.arange(start=0.0, stop=np.pi*2, step=np.pi/4) | |
if radius_range is None: | |
radius_range = np.arange(start=0.0, stop=0.0+eps, step=0.1) | |
cands = [] | |
for radius in radius_range: | |
for azimuth in azimuth_range: | |
for theta in theta_range: | |
loss = pairwise_loss([theta, azimuth, radius], model, cond_image, target_image, ts_range, probe_bsz, noise=noise) | |
'''convert numpy.float to float''' | |
cands.append((loss, [float(theta), float(azimuth), float(radius)])) | |
return cands | |
def create_random_pose(): | |
theta = np.random.rand() * np.pi - np.pi / 2 | |
azimuth = np.random.rand() * np.pi * 2 | |
radius = np.random.rand() - 0.5 | |
return [theta, azimuth, radius] | |
def get_inv_pose(pose): | |
return [-pose[0], np.pi*2 - pose[1], -pose[2]] | |
def add_pose(pose1, pose2): | |
theta = pose1[0] + pose2[0] | |
azimuth = pose1[1] + pose2[1] | |
azimuth = azimuth % (np.pi*2) | |
return [ theta, azimuth, (pose1[2] + pose2[2]) ] | |
def create_pose_params(pose, device): | |
theta = torch.tensor([pose[0]], requires_grad=True, device=device) | |
azimuth = torch.tensor([pose[1]], requires_grad=True, device=device) | |
radius = torch.tensor([pose[2]], requires_grad=True, device=device) | |
return [theta, azimuth, radius] | |
def find_optimal_poses(model, images, learning_rate, bsz=1, n_iter=1000, init_poses={}, ts_range=[0.02, 0.92], combinations=None, print_n=50, avg_last_n=1): | |
layer = PoseT() | |
num = len(images) | |
batch = {} | |
pose_params = { i:None for i in range(1, num)} | |
pose_trajs = { i:[] for i in range(1, num) } | |
for i in range(1, num): | |
if i in init_poses: | |
init_pose = init_poses[i] | |
else: | |
init_pose = create_random_pose() | |
pose = create_pose_params(init_pose, model.device) | |
pose_params[i] = pose | |
if combinations is None: | |
combinations = [] | |
for i in range(0, num): | |
for j in range(i+1, num): | |
combinations.append((i, j)) | |
combinations.append((j, i)) | |
param_list = [] | |
for i in pose_params: | |
param_list += pose_params[i] | |
optimizer = torch.optim.SGD(param_list, lr = learning_rate) | |
loss_traj = [] | |
select_indces = set([]) | |
for iter in range(0, n_iter): | |
if print_n > 0 and iter % print_n == 0 and iter > 0: | |
print(iter, np.mean(loss_traj[-avg_last_n:]), flush=True) | |
for i in range(1, num): | |
print(0, i, np.mean(pose_trajs[i][-avg_last_n:], axis=0).tolist()) | |
'''record poses''' | |
for i in select_indces: | |
pose = pose_params[i] | |
pose_trajs[i].append([pose[0].item(), pose[1].item(), pose[2].item()]) | |
select_indces = set([]) | |
conds = [] | |
targets = [] | |
rts = [] | |
choices = [ iter % len(combinations) ] | |
if bsz > 1: | |
choices = np.random.choice(len(combinations), size=bsz, replace=True) | |
for cho in choices: | |
i, j = combinations[cho] | |
conds.append(images[i]) | |
targets.append(images[j]) | |
if i == 0: | |
pose = pose_params[j] | |
select_indces.add(j) | |
elif j == 0: | |
pose = get_inv_pose(pose_params[i]) | |
select_indces.add(i) | |
else: | |
pose0j = pose_params[j] | |
posei0 = get_inv_pose(pose_params[i]) | |
if np.random.rand() < 0.5: | |
posei0 = [a.item() for a in posei0] | |
select_indces.add(j) | |
else: | |
pose0j = [b.item() for b in pose0j] | |
select_indces.add(i) | |
#pose = [ torch.remainder(a+b+2*np.pi, 2*np.pi) - np.pi for a, b in zip(posei0, pose0j) ] | |
pose = [ a+b for a, b in zip(posei0, pose0j) ] | |
rts.append(torch.cat(pose)[None, ...]) | |
batch['image_cond'] = torch.cat(conds, dim=0) | |
batch['image_target'] = torch.cat(targets, dim=0) | |
batch['T'] = layer(torch.cat(rts, dim=0)) | |
ts = np.arange(ts_range[0], ts_range[1], (ts_range[1]-ts_range[0]) / len(conds)) | |
optimizer.zero_grad() | |
loss, loss_dict = model.shared_step(batch, ts=ts) | |
loss.backward() | |
optimizer.step() | |
loss_traj.append(loss.item()) | |
if n_iter > 0: | |
result_poses = [ np.mean(pose_trajs[i][-avg_last_n:], axis=0).tolist() for i in range(1, num) ] | |
result_loss = np.mean(loss_traj[-avg_last_n:]) | |
else: | |
result_poses = [ init_poses[i] for i in range(1, num) ] | |
result_loss = None | |
return result_poses, [ init_poses[i] for i in range(1, num) ], result_loss | |