lev1's picture
Initial commit
8fd2f2f
raw
history blame
6.13 kB
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/flow_estimation
import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp
from .refine import *
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class Head(nn.Module):
def __init__(self, in_planes, scale, c, in_else=17):
super(Head, self).__init__()
self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2))
self.scale = scale
self.conv = nn.Sequential(
conv(in_planes*2 // (4*4) + in_else, c),
conv(c, c),
conv(c, 5),
)
def forward(self, motion_feature, x, flow): # /16 /8 /4
motion_feature = self.upsample(motion_feature) #/4 /2 /1
if self.scale != 4:
x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False)
if flow != None:
if self.scale != 4:
flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale
x = torch.cat((x, flow), 1)
x = self.conv(torch.cat([motion_feature, x], 1))
if self.scale != 4:
x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False)
flow = x[:, :4] * (self.scale // 4)
else:
flow = x[:, :4]
mask = x[:, 4:5]
return flow, mask
class MultiScaleFlow(nn.Module):
def __init__(self, backbone, **kargs):
super(MultiScaleFlow, self).__init__()
self.flow_num_stage = len(kargs['hidden_dims'])
self.feature_bone = backbone
self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i],
kargs['scales'][-1-i],
kargs['hidden_dims'][-1-i],
6 if i==0 else 17)
for i in range(self.flow_num_stage)])
self.unet = Unet(kargs['c'] * 2)
def warp_features(self, xs, flow):
y0 = []
y1 = []
B = xs[0].size(0) // 2
for x in xs:
y0.append(warp(x[:B], flow[:, 0:2]))
y1.append(warp(x[B:], flow[:, 2:4]))
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
return y0, y1
def calculate_flow(self, imgs, timestep, af=None, mf=None):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
B = img0.size(0)
flow, mask = None, None
# appearence_features & motion_features
if (af is None) or (mf is None):
af, mf = self.feature_bone(img0, img1)
for i in range(self.flow_num_stage):
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
if flow != None:
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
flow_, mask_ = self.block[i](
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1),
flow
)
flow = flow + flow_
mask = mask + mask_
else:
flow, mask = self.block[i](
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1), 1),
None
)
return flow, mask
def coraseWarp_and_Refine(self, imgs, af, flow, mask):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
c0, c1 = self.warp_features(af, flow)
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
mask_ = torch.sigmoid(mask)
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
pred = torch.clamp(merged + res, 0, 1)
return pred
# Actually consist of 'calculate_flow' and 'coraseWarp_and_Refine'
def forward(self, x, timestep=0.5):
img0, img1 = x[:, :3], x[:, 3:6]
B = x.size(0)
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
# appearence_features & motion_features
af, mf = self.feature_bone(img0, img1)
for i in range(self.flow_num_stage):
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
if flow != None:
flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-timestep)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow)
flow = flow + flow_d
mask = mask + mask_d
else:
flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1), 1), None)
mask_list.append(torch.sigmoid(mask))
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
c0, c1 = self.warp_features(af, flow)
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
pred = torch.clamp(merged[-1] + res, 0, 1)
return flow_list, mask_list, merged, pred