Spaces:
Runtime error
Runtime error
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/flow_estimation | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .warplayer import warp | |
from .refine import * | |
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): | |
return nn.Sequential( | |
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, | |
padding=padding, dilation=dilation, bias=True), | |
nn.PReLU(out_planes) | |
) | |
class Head(nn.Module): | |
def __init__(self, in_planes, scale, c, in_else=17): | |
super(Head, self).__init__() | |
self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2)) | |
self.scale = scale | |
self.conv = nn.Sequential( | |
conv(in_planes*2 // (4*4) + in_else, c), | |
conv(c, c), | |
conv(c, 5), | |
) | |
def forward(self, motion_feature, x, flow): # /16 /8 /4 | |
motion_feature = self.upsample(motion_feature) #/4 /2 /1 | |
if self.scale != 4: | |
x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) | |
if flow != None: | |
if self.scale != 4: | |
flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale | |
x = torch.cat((x, flow), 1) | |
x = self.conv(torch.cat([motion_feature, x], 1)) | |
if self.scale != 4: | |
x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False) | |
flow = x[:, :4] * (self.scale // 4) | |
else: | |
flow = x[:, :4] | |
mask = x[:, 4:5] | |
return flow, mask | |
class MultiScaleFlow(nn.Module): | |
def __init__(self, backbone, **kargs): | |
super(MultiScaleFlow, self).__init__() | |
self.flow_num_stage = len(kargs['hidden_dims']) | |
self.feature_bone = backbone | |
self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i], | |
kargs['scales'][-1-i], | |
kargs['hidden_dims'][-1-i], | |
6 if i==0 else 17) | |
for i in range(self.flow_num_stage)]) | |
self.unet = Unet(kargs['c'] * 2) | |
def warp_features(self, xs, flow): | |
y0 = [] | |
y1 = [] | |
B = xs[0].size(0) // 2 | |
for x in xs: | |
y0.append(warp(x[:B], flow[:, 0:2])) | |
y1.append(warp(x[B:], flow[:, 2:4])) | |
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 | |
return y0, y1 | |
def calculate_flow(self, imgs, timestep, af=None, mf=None): | |
img0, img1 = imgs[:, :3], imgs[:, 3:6] | |
B = img0.size(0) | |
flow, mask = None, None | |
# appearence_features & motion_features | |
if (af is None) or (mf is None): | |
af, mf = self.feature_bone(img0, img1) | |
for i in range(self.flow_num_stage): | |
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() | |
if flow != None: | |
warped_img0 = warp(img0, flow[:, :2]) | |
warped_img1 = warp(img1, flow[:, 2:4]) | |
flow_, mask_ = self.block[i]( | |
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), | |
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), | |
flow | |
) | |
flow = flow + flow_ | |
mask = mask + mask_ | |
else: | |
flow, mask = self.block[i]( | |
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), | |
torch.cat((img0, img1), 1), | |
None | |
) | |
return flow, mask | |
def coraseWarp_and_Refine(self, imgs, af, flow, mask): | |
img0, img1 = imgs[:, :3], imgs[:, 3:6] | |
warped_img0 = warp(img0, flow[:, :2]) | |
warped_img1 = warp(img1, flow[:, 2:4]) | |
c0, c1 = self.warp_features(af, flow) | |
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) | |
res = tmp[:, :3] * 2 - 1 | |
mask_ = torch.sigmoid(mask) | |
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_) | |
pred = torch.clamp(merged + res, 0, 1) | |
return pred | |
# Actually consist of 'calculate_flow' and 'coraseWarp_and_Refine' | |
def forward(self, x, timestep=0.5): | |
img0, img1 = x[:, :3], x[:, 3:6] | |
B = x.size(0) | |
flow_list = [] | |
merged = [] | |
mask_list = [] | |
warped_img0 = img0 | |
warped_img1 = img1 | |
flow = None | |
# appearence_features & motion_features | |
af, mf = self.feature_bone(img0, img1) | |
for i in range(self.flow_num_stage): | |
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() | |
if flow != None: | |
flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-timestep)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), | |
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow) | |
flow = flow + flow_d | |
mask = mask + mask_d | |
else: | |
flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), | |
torch.cat((img0, img1), 1), None) | |
mask_list.append(torch.sigmoid(mask)) | |
flow_list.append(flow) | |
warped_img0 = warp(img0, flow[:, :2]) | |
warped_img1 = warp(img1, flow[:, 2:4]) | |
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i])) | |
c0, c1 = self.warp_features(af, flow) | |
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) | |
res = tmp[:, :3] * 2 - 1 | |
pred = torch.clamp(merged[-1] + res, 0, 1) | |
return flow_list, mask_list, merged, pred |