|
"""PyTorch layer for extracting image features for the film_net interpolator. |
|
|
|
The feature extractor implemented here converts an image pyramid into a pyramid |
|
of deep features. The feature pyramid serves a similar purpose as U-Net |
|
architecture's encoder, but we use a special cascaded architecture described in |
|
Multi-view Image Fusion [1]. |
|
|
|
For comprehensiveness, below is a short description of the idea. While the |
|
description is a bit involved, the cascaded feature pyramid can be used just |
|
like any image feature pyramid. |
|
|
|
Why cascaded architeture? |
|
========================= |
|
To understand the concept it is worth reviewing a traditional feature pyramid |
|
first: *A traditional feature pyramid* as in U-net or in many optical flow |
|
networks is built by alternating between convolutions and pooling, starting |
|
from the input image. |
|
|
|
It is well known that early features of such architecture correspond to low |
|
level concepts such as edges in the image whereas later layers extract |
|
semantically higher level concepts such as object classes etc. In other words, |
|
the meaning of the filters in each resolution level is different. For problems |
|
such as semantic segmentation and many others this is a desirable property. |
|
|
|
However, the asymmetric features preclude sharing weights across resolution |
|
levels in the feature extractor itself and in any subsequent neural networks |
|
that follow. This can be a downside, since optical flow prediction, for |
|
instance is symmetric across resolution levels. The cascaded feature |
|
architecture addresses this shortcoming. |
|
|
|
How is it built? |
|
================ |
|
The *cascaded* feature pyramid contains feature vectors that have constant |
|
length and meaning on each resolution level, except few of the finest ones. The |
|
advantage of this is that the subsequent optical flow layer can learn |
|
synergically from many resolutions. This means that coarse level prediction can |
|
benefit from finer resolution training examples, which can be useful with |
|
moderately sized datasets to avoid overfitting. |
|
|
|
The cascaded feature pyramid is built by extracting shallower subtree pyramids, |
|
each one of them similar to the traditional architecture. Each subtree |
|
pyramid S_i is extracted starting from each resolution level: |
|
|
|
image resolution 0 -> S_0 |
|
image resolution 1 -> S_1 |
|
image resolution 2 -> S_2 |
|
... |
|
|
|
If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid |
|
is constructed by concatenating features as follows (assuming subtree depth=3): |
|
|
|
lvl |
|
feat_0 = concat( S_0_0 ) |
|
feat_1 = concat( S_1_0 S_0_1 ) |
|
feat_2 = concat( S_2_0 S_1_1 S_0_2 ) |
|
feat_3 = concat( S_3_0 S_2_1 S_1_2 ) |
|
feat_4 = concat( S_4_0 S_3_1 S_2_2 ) |
|
feat_5 = concat( S_5_0 S_4_1 S_3_2 ) |
|
.... |
|
|
|
In above, all levels except feat_0 and feat_1 have the same number of features |
|
with similar semantic meaning. This enables training a single optical flow |
|
predictor module shared by levels 2,3,4,5... . For more details and evaluation |
|
see [1]. |
|
|
|
[1] Multi-view Image Fusion, Trinidad et al. 2019 |
|
""" |
|
from typing import List |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from util import Conv2d |
|
|
|
|
|
class SubTreeExtractor(nn.Module): |
|
"""Extracts a hierarchical set of features from an image. |
|
|
|
This is a conventional, hierarchical image feature extractor, that extracts |
|
[k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels. |
|
Each level is followed by average pooling. |
|
""" |
|
|
|
def __init__(self, in_channels=3, channels=64, n_layers=4): |
|
super().__init__() |
|
convs = [] |
|
for i in range(n_layers): |
|
convs.append(nn.Sequential( |
|
Conv2d(in_channels, (channels << i), 3), |
|
Conv2d((channels << i), (channels << i), 3) |
|
)) |
|
in_channels = channels << i |
|
self.convs = nn.ModuleList(convs) |
|
|
|
def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]: |
|
"""Extracts a pyramid of features from the image. |
|
|
|
Args: |
|
image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS. |
|
n: number of pyramid levels to extract. This can be less or equal to |
|
options.sub_levels given in the __init__. |
|
Returns: |
|
The pyramid of features, starting from the finest level. Each element |
|
contains the output after the last convolution on the corresponding |
|
pyramid level. |
|
""" |
|
head = image |
|
pyramid = [] |
|
for i, layer in enumerate(self.convs): |
|
head = layer(head) |
|
pyramid.append(head) |
|
if i < n - 1: |
|
head = F.avg_pool2d(head, kernel_size=2, stride=2) |
|
return pyramid |
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
"""Extracts features from an image pyramid using a cascaded architecture. |
|
""" |
|
|
|
def __init__(self, in_channels=3, channels=64, sub_levels=4): |
|
super().__init__() |
|
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels) |
|
self.sub_levels = sub_levels |
|
|
|
def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: |
|
"""Extracts a cascaded feature pyramid. |
|
|
|
Args: |
|
image_pyramid: Image pyramid as a list, starting from the finest level. |
|
Returns: |
|
A pyramid of cascaded features. |
|
""" |
|
sub_pyramids: List[List[torch.Tensor]] = [] |
|
for i in range(len(image_pyramid)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels) |
|
sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels)) |
|
|
|
|
|
|
|
feature_pyramid: List[torch.Tensor] = [] |
|
for i in range(len(image_pyramid)): |
|
features = sub_pyramids[i][0] |
|
for j in range(1, self.sub_levels): |
|
if j <= i: |
|
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) |
|
feature_pyramid.append(features) |
|
return feature_pyramid |
|
|