|
"""The film_net frame interpolator main model code. |
|
|
|
Basics |
|
====== |
|
The film_net is an end-to-end learned neural frame interpolator implemented as |
|
a PyTorch model. It has the following inputs and outputs: |
|
|
|
Inputs: |
|
x0: image A. |
|
x1: image B. |
|
time: desired sub-frame time. |
|
|
|
Outputs: |
|
image: the predicted in-between image at the chosen time in range [0, 1]. |
|
|
|
Additional outputs include forward and backward warped image pyramids, flow |
|
pyramids, etc., that can be visualized for debugging and analysis. |
|
|
|
Note that many training sets only contain triplets with ground truth at |
|
time=0.5. If a model has been trained with such training set, it will only work |
|
well for synthesizing frames at time=0.5. Such models can only generate more |
|
in-between frames using recursion. |
|
|
|
Architecture |
|
============ |
|
The inference consists of three main stages: 1) feature extraction 2) warping |
|
3) fusion. On high-level, the architecture has similarities to Context-aware |
|
Synthesis for Video Frame Interpolation [1], but the exact architecture is |
|
closer to Multi-view Image Fusion [2] with some modifications for the frame |
|
interpolation use-case. |
|
|
|
Feature extraction stage employs the cascaded multi-scale architecture described |
|
in [2]. The advantage of this architecture is that coarse level flow prediction |
|
can be learned from finer resolution image samples. This is especially useful |
|
to avoid overfitting with moderately sized datasets. |
|
|
|
The warping stage uses a residual flow prediction idea that is similar to |
|
PWC-Net [3], Multi-view Image Fusion [2] and many others. |
|
|
|
The fusion stage is similar to U-Net's decoder where the skip connections are |
|
connected to warped image and feature pyramids. This is described in [2]. |
|
|
|
Implementation Conventions |
|
==================== |
|
Pyramids |
|
-------- |
|
Throughtout the model, all image and feature pyramids are stored as python lists |
|
with finest level first followed by downscaled versions obtained by successively |
|
halving the resolution. The depths of all pyramids are determined by |
|
options.pyramid_levels. The only exception to this is internal to the feature |
|
extractor, where smaller feature pyramids are temporarily constructed with depth |
|
options.sub_levels. |
|
|
|
Color ranges & gamma |
|
-------------------- |
|
The model code makes no assumptions on whether the images are in gamma or |
|
linearized space or what is the range of RGB color values. So a model can be |
|
trained with different choices. This does not mean that all the choices lead to |
|
similar results. In practice the model has been proven to work well with RGB |
|
scale = [0,1] with gamma-space images (i.e. not linearized). |
|
|
|
[1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018 |
|
[2] Multi-view Image Fusion, Trinidad et al, 2019 |
|
[3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume |
|
""" |
|
from typing import Dict, List |
|
|
|
import torch |
|
from torch import nn |
|
|
|
import util |
|
from feature_extractor import FeatureExtractor |
|
from fusion import Fusion |
|
from pyramid_flow_estimator import PyramidFlowEstimator |
|
|
|
|
|
class Interpolator(nn.Module): |
|
def __init__( |
|
self, |
|
pyramid_levels=7, |
|
fusion_pyramid_levels=5, |
|
specialized_levels=3, |
|
sub_levels=4, |
|
filters=64, |
|
flow_convs=(3, 3, 3, 3), |
|
flow_filters=(32, 64, 128, 256), |
|
): |
|
super().__init__() |
|
self.pyramid_levels = pyramid_levels |
|
self.fusion_pyramid_levels = fusion_pyramid_levels |
|
|
|
self.extract = FeatureExtractor(3, filters, sub_levels) |
|
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters) |
|
self.fuse = Fusion(sub_levels, specialized_levels, filters) |
|
|
|
def shuffle_images(self, x0, x1): |
|
return [ |
|
util.build_image_pyramid(x0, self.pyramid_levels), |
|
util.build_image_pyramid(x1, self.pyramid_levels) |
|
] |
|
|
|
def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]: |
|
image_pyramids = self.shuffle_images(x0, x1) |
|
|
|
|
|
feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])] |
|
|
|
|
|
forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1]) |
|
|
|
|
|
backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
forward_flow_pyramid = util.flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels] |
|
|
|
backward_flow_pyramid = util.flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels] |
|
|
|
|
|
|
|
|
|
|
|
|
|
backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt) |
|
forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt) |
|
|
|
pyramids_to_warp = [ |
|
util.concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels], |
|
feature_pyramids[0][:self.fusion_pyramid_levels]), |
|
util.concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels], |
|
feature_pyramids[1][:self.fusion_pyramid_levels]) |
|
] |
|
|
|
|
|
|
|
|
|
forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow) |
|
backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow) |
|
|
|
aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid, |
|
backward_warped_pyramid) |
|
aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow) |
|
aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow) |
|
|
|
return { |
|
'image': [self.fuse(aligned_pyramid)], |
|
'forward_residual_flow_pyramid': forward_residual_flow_pyramid, |
|
'backward_residual_flow_pyramid': backward_residual_flow_pyramid, |
|
'forward_flow_pyramid': forward_flow_pyramid, |
|
'backward_flow_pyramid': backward_flow_pyramid, |
|
} |
|
|
|
def forward(self, x0, x1, batch_dt) -> torch.Tensor: |
|
return self.debug_forward(x0, x1, batch_dt)['image'][0] |
|
|