|
"""Various utilities used in the film_net frame interpolator model.""" |
|
from typing import List, Optional |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
def pad_batch(batch, align): |
|
height, width = batch.shape[1:3] |
|
height_to_pad = (align - height % align) if height % align != 0 else 0 |
|
width_to_pad = (align - width % align) if width % align != 0 else 0 |
|
|
|
crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)] |
|
batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), |
|
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant') |
|
return batch, crop_region |
|
|
|
|
|
def load_image(path, align=64): |
|
image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) |
|
image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align) |
|
return image_batch, crop_region |
|
|
|
|
|
def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]: |
|
"""Builds an image pyramid from a given image. |
|
|
|
The original image is included in the pyramid and the rest are generated by |
|
successively halving the resolution. |
|
|
|
Args: |
|
image: the input image. |
|
options: film_net options object |
|
|
|
Returns: |
|
A list of images starting from the finest with options.pyramid_levels items |
|
""" |
|
|
|
pyramid = [] |
|
for i in range(pyramid_levels): |
|
pyramid.append(image) |
|
if i < pyramid_levels - 1: |
|
image = F.avg_pool2d(image, 2, 2) |
|
return pyramid |
|
|
|
|
|
def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: |
|
"""Backward warps the image using the given flow. |
|
|
|
Specifically, the output pixel in batch b, at position x, y will be computed |
|
as follows: |
|
(flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0]) |
|
output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x) |
|
|
|
Note that the flow vectors are expected as [x, y], e.g. x in position 0 and |
|
y in position 1. |
|
|
|
Args: |
|
image: An image with shape BxHxWxC. |
|
flow: A flow with shape BxHxWx2, with the two channels denoting the relative |
|
offset in order: (dx, dy). |
|
Returns: |
|
A warped image. |
|
""" |
|
flow = -flow.flip(1) |
|
|
|
dtype = flow.dtype |
|
device = flow.device |
|
|
|
|
|
|
|
ls1 = 1 - 1 / flow.shape[3] |
|
ls2 = 1 - 1 / flow.shape[2] |
|
|
|
normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor( |
|
[flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None] |
|
normalized_flow2 = torch.stack([ |
|
torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1], |
|
torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0], |
|
], dim=3) |
|
|
|
warped = F.grid_sample(image, normalized_flow2, |
|
mode='bilinear', padding_mode='border', align_corners=False) |
|
return warped.reshape(image.shape) |
|
|
|
|
|
def multiply_pyramid(pyramid: List[torch.Tensor], |
|
scalar: torch.Tensor) -> List[torch.Tensor]: |
|
"""Multiplies all image batches in the pyramid by a batch of scalars. |
|
|
|
Args: |
|
pyramid: Pyramid of image batches. |
|
scalar: Batch of scalars. |
|
|
|
Returns: |
|
An image pyramid with all images multiplied by the scalar. |
|
""" |
|
|
|
|
|
|
|
|
|
return [image * scalar[..., None, None] for image in pyramid] |
|
|
|
|
|
def flow_pyramid_synthesis( |
|
residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: |
|
"""Converts a residual flow pyramid into a flow pyramid.""" |
|
flow = residual_pyramid[-1] |
|
flow_pyramid: List[torch.Tensor] = [flow] |
|
for residual_flow in residual_pyramid[:-1][::-1]: |
|
level_size = residual_flow.shape[2:4] |
|
flow = F.interpolate(2 * flow, size=level_size, mode='bilinear') |
|
flow = residual_flow + flow |
|
flow_pyramid.insert(0, flow) |
|
return flow_pyramid |
|
|
|
|
|
def pyramid_warp(feature_pyramid: List[torch.Tensor], |
|
flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: |
|
"""Warps the feature pyramid using the flow pyramid. |
|
|
|
Args: |
|
feature_pyramid: feature pyramid starting from the finest level. |
|
flow_pyramid: flow fields, starting from the finest level. |
|
|
|
Returns: |
|
Reverse warped feature pyramid. |
|
""" |
|
warped_feature_pyramid = [] |
|
for features, flow in zip(feature_pyramid, flow_pyramid): |
|
warped_feature_pyramid.append(warp(features, flow)) |
|
return warped_feature_pyramid |
|
|
|
|
|
def concatenate_pyramids(pyramid1: List[torch.Tensor], |
|
pyramid2: List[torch.Tensor]) -> List[torch.Tensor]: |
|
"""Concatenates each pyramid level together in the channel dimension.""" |
|
result = [] |
|
for features1, features2 in zip(pyramid1, pyramid2): |
|
result.append(torch.cat([features1, features2], dim=1)) |
|
return result |
|
|
|
|
|
class Conv2d(nn.Sequential): |
|
def __init__(self, in_channels, out_channels, size, activation: Optional[str] = 'relu'): |
|
assert activation in (None, 'relu') |
|
super().__init__( |
|
nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=size, |
|
padding='same' if size % 2 else 0) |
|
) |
|
self.size = size |
|
self.activation = nn.LeakyReLU(.2) if activation == 'relu' else None |
|
|
|
def forward(self, x): |
|
if not self.size % 2: |
|
x = F.pad(x, (0, 1, 0, 1)) |
|
y = self[0](x) |
|
if self.activation is not None: |
|
y = self.activation(y) |
|
return y |
|
|