umuthopeyildirim's picture
here we go
bd86ed9
raw
history blame
10.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Sampler
from torchvision import transforms
import matplotlib.pyplot as plt
import os, sys
import numpy as np
import math
import torch
def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield arg
def block_print():
sys.stdout = open(os.devnull, 'w')
def enable_print():
sys.stdout = sys.__stdout__
def get_num_lines(file_path):
f = open(file_path, 'r')
lines = f.readlines()
f.close()
return len(lines)
def colorize(value, vmin=None, vmax=None, cmap='Greys'):
value = value.cpu().numpy()[:, :, :]
value = np.log10(value)
vmin = value.min() if vmin is None else vmin
vmax = value.max() if vmax is None else vmax
if vmin != vmax:
value = (value - vmin) / (vmax - vmin)
else:
value = value*0.
cmapper = matplotlib.cm.get_cmap(cmap)
value = cmapper(value, bytes=True)
img = value[:, :, :3]
return img.transpose((2, 0, 1))
def normalize_result(value, vmin=None, vmax=None):
value = value.cpu().numpy()[0, :, :]
vmin = value.min() if vmin is None else vmin
vmax = value.max() if vmax is None else vmax
if vmin != vmax:
value = (value - vmin) / (vmax - vmin)
else:
value = value * 0.
return np.expand_dims(value, 0)
inv_normalize = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
std=[1/0.229, 1/0.224, 1/0.225]
)
eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3']
def compute_errors(gt, pred):
thresh = np.maximum((gt / pred), (pred / gt))
d1 = (thresh < 1.25).mean()
d2 = (thresh < 1.25 ** 2).mean()
d3 = (thresh < 1.25 ** 3).mean()
rms = (gt - pred) ** 2
rms = np.sqrt(rms.mean())
log_rms = (np.log(gt) - np.log(pred)) ** 2
log_rms = np.sqrt(log_rms.mean())
abs_rel = np.mean(np.abs(gt - pred) / gt)
sq_rel = np.mean(((gt - pred) ** 2) / gt)
err = np.log(pred) - np.log(gt)
silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
err = np.abs(np.log10(pred) - np.log10(gt))
log10 = np.mean(err)
return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3]
class silog_loss(nn.Module):
def __init__(self, variance_focus):
super(silog_loss, self).__init__()
self.variance_focus = variance_focus
def forward(self, depth_est, depth_gt, mask):
d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0
def entropy_loss(preds, gt_label, mask):
# preds: B, C, H, W
# gt_label: B, H, W
# mask: B, H, W
mask = mask > 0.0 # B, H, W
preds = preds.permute(0, 2, 3, 1) # B, H, W, C
preds_mask = preds[mask] # N, C
gt_label_mask = gt_label[mask] # N
loss = F.cross_entropy(preds_mask, gt_label_mask, reduction='mean')
return loss
def colormap(inputs, normalize=True, torch_transpose=True):
if isinstance(inputs, torch.Tensor):
inputs = inputs.detach().cpu().numpy()
_DEPTH_COLORMAP = plt.get_cmap('jet', 256) # for plotting
vis = inputs
if normalize:
ma = float(vis.max())
mi = float(vis.min())
d = ma - mi if ma != mi else 1e5
vis = (vis - mi) / d
if vis.ndim == 4:
vis = vis.transpose([0, 2, 3, 1])
vis = _DEPTH_COLORMAP(vis)
vis = vis[:, :, :, 0, :3]
if torch_transpose:
vis = vis.transpose(0, 3, 1, 2)
elif vis.ndim == 3:
vis = _DEPTH_COLORMAP(vis)
vis = vis[:, :, :, :3]
if torch_transpose:
vis = vis.transpose(0, 3, 1, 2)
elif vis.ndim == 2:
vis = _DEPTH_COLORMAP(vis)
vis = vis[..., :3]
if torch_transpose:
vis = vis.transpose(2, 0, 1)
return vis[0,:,:,:]
def colormap_magma(inputs, normalize=True, torch_transpose=True):
if isinstance(inputs, torch.Tensor):
inputs = inputs.detach().cpu().numpy()
_DEPTH_COLORMAP = plt.get_cmap('magma', 256) # for plotting
vis = inputs
if normalize:
ma = float(vis.max())
mi = float(vis.min())
d = ma - mi if ma != mi else 1e5
vis = (vis - mi) / d
if vis.ndim == 4:
vis = vis.transpose([0, 2, 3, 1])
vis = _DEPTH_COLORMAP(vis)
vis = vis[:, :, :, 0, :3]
if torch_transpose:
vis = vis.transpose(0, 3, 1, 2)
elif vis.ndim == 3:
vis = _DEPTH_COLORMAP(vis)
vis = vis[:, :, :, :3]
if torch_transpose:
vis = vis.transpose(0, 3, 1, 2)
elif vis.ndim == 2:
vis = _DEPTH_COLORMAP(vis)
vis = vis[..., :3]
if torch_transpose:
vis = vis.transpose(2, 0, 1)
return vis[0,:,:,:]
def flip_lr(image):
"""
Flip image horizontally
Parameters
----------
image : torch.Tensor [B,3,H,W]
Image to be flipped
Returns
-------
image_flipped : torch.Tensor [B,3,H,W]
Flipped image
"""
assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip'
return torch.flip(image, [3])
def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'):
"""
Fuse inverse depth and flipped inverse depth maps
Parameters
----------
inv_depth : torch.Tensor [B,1,H,W]
Inverse depth map
inv_depth_hat : torch.Tensor [B,1,H,W]
Flipped inverse depth map produced from a flipped image
method : str
Method that will be used to fuse the inverse depth maps
Returns
-------
fused_inv_depth : torch.Tensor [B,1,H,W]
Fused inverse depth map
"""
if method == 'mean':
return 0.5 * (inv_depth + inv_depth_hat)
elif method == 'max':
return torch.max(inv_depth, inv_depth_hat)
elif method == 'min':
return torch.min(inv_depth, inv_depth_hat)
else:
raise ValueError('Unknown post-process method {}'.format(method))
def post_process_depth(depth, depth_flipped, method='mean'):
"""
Post-process an inverse and flipped inverse depth map
Parameters
----------
inv_depth : torch.Tensor [B,1,H,W]
Inverse depth map
inv_depth_flipped : torch.Tensor [B,1,H,W]
Inverse depth map produced from a flipped image
method : str
Method that will be used to fuse the inverse depth maps
Returns
-------
inv_depth_pp : torch.Tensor [B,1,H,W]
Post-processed inverse depth map
"""
B, C, H, W = depth.shape
inv_depth_hat = flip_lr(depth_flipped)
inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method)
xs = torch.linspace(0., 1., W, device=depth.device,
dtype=depth.dtype).repeat(B, C, H, 1)
mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.)
mask_hat = flip_lr(mask)
return mask_hat * depth + mask * inv_depth_hat + \
(1.0 - mask - mask_hat) * inv_depth_fused
class DistributedSamplerNoEvenlyDivisible(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
shuffle (optional): If true (default), sampler will shuffle the indices
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
rest = len(self.dataset) - num_samples * self.num_replicas
if self.rank < rest:
num_samples += 1
self.num_samples = num_samples
self.total_size = len(dataset)
# self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
# indices += indices[:(self.total_size - len(indices))]
# assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
self.num_samples = len(indices)
# assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class D_to_cloud(nn.Module):
"""Layer to transform depth into point cloud
"""
def __init__(self, batch_size, height, width):
super(D_to_cloud, self).__init__()
self.batch_size = batch_size
self.height = height
self.width = width
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) # 2, H, W
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) # 2, H, W
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
requires_grad=False) # B, 1, H, W
self.pix_coords = torch.unsqueeze(torch.stack(
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) # 1, 2, L
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) # B, 2, L
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) # B, 3, L
def forward(self, depth, inv_K):
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
return cam_points.permute(0, 2, 1)