MiniDPVO / mini_dpvo /altcorr /correlation.py
pablovela5620's picture
initial commit with working dpvo
899c526
raw
history blame
No virus
2.2 kB
import torch
import cuda_corr
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, ii, jj, radius, dropout):
""" forward correlation """
ctx.save_for_backward(fmap1, fmap2, coords, ii, jj)
ctx.radius = radius
ctx.dropout = dropout
corr, = cuda_corr.forward(fmap1, fmap2, coords, ii, jj, radius)
return corr
@staticmethod
def backward(ctx, grad):
""" backward correlation """
fmap1, fmap2, coords, ii, jj = ctx.saved_tensors
if ctx.dropout < 1:
perm = torch.rand(len(ii), device="cuda") < ctx.dropout
coords = coords[:,perm]
grad = grad[:,perm]
ii = ii[perm]
jj = jj[perm]
fmap1_grad, fmap2_grad = \
cuda_corr.backward(fmap1, fmap2, coords, ii, jj, grad, ctx.radius)
return fmap1_grad, fmap2_grad, None, None, None, None, None
class PatchLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, net, coords, radius):
""" forward patchify """
ctx.radius = radius
ctx.save_for_backward(net, coords)
patches, = cuda_corr.patchify_forward(net, coords, radius)
return patches
@staticmethod
def backward(ctx, grad):
""" backward patchify """
net, coords = ctx.saved_tensors
grad, = cuda_corr.patchify_backward(net, coords, grad, ctx.radius)
return grad, None, None
def patchify(net, coords, radius, mode='bilinear'):
""" extract patches """
patches = PatchLayer.apply(net, coords, radius)
if mode == 'bilinear':
offset = (coords - coords.floor()).to(net.device)
dx, dy = offset[:,:,None,None,None].unbind(dim=-1)
d = 2 * radius + 1
x00 = (1-dy) * (1-dx) * patches[...,:d,:d]
x01 = (1-dy) * ( dx) * patches[...,:d,1:]
x10 = ( dy) * (1-dx) * patches[...,1:,:d]
x11 = ( dy) * ( dx) * patches[...,1:,1:]
return x00 + x01 + x10 + x11
return patches
def corr(fmap1, fmap2, coords, ii, jj, radius=1, dropout=1):
return CorrLayer.apply(fmap1, fmap2, coords, ii, jj, radius, dropout)