File size: 6,351 Bytes
ff715ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import numpy as np
import cv2
import torch.nn.functional as F
from src.utils import *
import sys
sys.path.append("./src/ebsynth/deps/gmflow/")
from gmflow.geometry import flow_warp

"""
==========================================================================
* warp_tensor(): warp and fuse tensors based on optical flow and mask
* get_single_mapping_ind(): get pixel index correspondence between two frames
* get_mapping_ind(): get pixel index correspondence between consecutive frames within a batch
==========================================================================
"""

@torch.no_grad()
def warp_tensor(sample, flows, occs, saliency, unet_chunk_size):
    """
    Warp images or features based on optical flow
    Fuse the warped imges or features based on occusion masks and saliency map
    """
    scale = sample.shape[2] * 1.0 / flows[0].shape[2]
    kernel = int(1 / scale)
    bwd_flow_ = F.interpolate(flows[1] * scale, scale_factor=scale, mode='bilinear')
    bwd_occ_ = F.max_pool2d(occs[1].unsqueeze(1), kernel_size=kernel) # (N-1)*1*H1*W1
    if scale == 1:
        bwd_occ_ = Dilate(kernel_size=13, device=sample.device)(bwd_occ_)
    fwd_flow_ = F.interpolate(flows[0] * scale, scale_factor=scale, mode='bilinear')
    fwd_occ_ = F.max_pool2d(occs[0].unsqueeze(1), kernel_size=kernel) # (N-1)*1*H1*W1 
    if scale == 1:
        fwd_occ_ = Dilate(kernel_size=13, device=sample.device)(fwd_occ_)    
    scale2 = sample.shape[2] * 1.0 / saliency.shape[2]
    saliency = F.interpolate(saliency, scale_factor=scale2, mode='bilinear')
    latent = sample.to(torch.float32)
    video_length = sample.shape[0] // unet_chunk_size
    warp_saliency = flow_warp(saliency, bwd_flow_)
    warp_saliency_ = flow_warp(saliency[0:1], fwd_flow_[video_length-1:video_length])
    
    for j in range(unet_chunk_size):
        for ii in range(video_length-1):
            i = video_length * j + ii
            warped_image = flow_warp(latent[i:i+1], bwd_flow_[ii:ii+1])
            mask = (1 - bwd_occ_[ii:ii+1]) * saliency[ii+1:ii+2] * warp_saliency[ii:ii+1]
            latent[i+1:i+2] = latent[i+1:i+2] * (1-mask) + warped_image * mask
        i = video_length * j
        ii = video_length - 1
        warped_image = flow_warp(latent[i:i+1], fwd_flow_[ii:ii+1])
        mask = (1 - fwd_occ_[ii:ii+1]) * saliency[ii:ii+1] * warp_saliency_
        latent[ii+i:ii+i+1] = latent[ii+i:ii+i+1] * (1-mask) + warped_image * mask
        
    return latent.to(sample.dtype)


@torch.no_grad()
def get_single_mapping_ind(bwd_flow, bwd_occ, imgs, scale=1.0):
    """
    FLATTEN: Optical fLow-guided attention (Temoporal-guided attention)
    Find the correspondence between every pixels in a pair of frames
    
    [input]
    bwd_flow: 1*2*H*W   
    bwd_occ: 1*H*W      i.e., f2 = warp(f1, bwd_flow) * bwd_occ
    imgs: 2*3*H*W       i.e., [f1,f2]
    
    [output]
    mapping_ind: pixel index correspondence
    unlinkedmask: indicate whether a pixel has no correspondence
    i.e., f2 = f1[mapping_ind] * unlinkedmask
    """
    flows = F.interpolate(bwd_flow, scale_factor=1./scale, mode='bilinear')[0][[1,0]] / scale # 2*H*W
    _, H, W = flows.shape
    masks = torch.logical_not(F.interpolate(bwd_occ[None], scale_factor=1./scale, mode='bilinear') > 0.5)[0] # 1*H*W
    frames = F.interpolate(imgs, scale_factor=1./scale, mode='bilinear').view(2, 3, -1) # 2*3*HW
    grid = torch.stack(torch.meshgrid([torch.arange(H), torch.arange(W)]), dim=0).to(flows.device) # 2*H*W
    warp_grid = torch.round(grid + flows)
    mask = torch.logical_and(torch.logical_and(torch.logical_and(torch.logical_and(warp_grid[0] >= 0, warp_grid[0] < H),
                         warp_grid[1] >= 0), warp_grid[1] < W), masks[0]).view(-1) # HW
    warp_grid = warp_grid.view(2, -1) # 2*HW
    warp_ind = (warp_grid[0] * W + warp_grid[1]).to(torch.long)  # HW
    mapping_ind = torch.zeros_like(warp_ind) - 1 # HW
    
    for f0ind, f1ind in enumerate(warp_ind):
        if mask[f0ind]:
            if mapping_ind[f1ind] == -1:
                mapping_ind[f1ind] = f0ind
            else:
                targetv = frames[0,:,f1ind]
                pref0ind = mapping_ind[f1ind]
                prev = frames[1,:,pref0ind]
                v = frames[1,:,f0ind]
                if ((prev - targetv)**2).mean() > ((v - targetv)**2).mean():
                    mask[pref0ind] = False 
                    mapping_ind[f1ind] = f0ind
                else:
                    mask[f0ind] = False
                    
    unusedind = torch.arange(len(mask)).to(mask.device)[~mask]
    unlinkedmask = mapping_ind == -1
    mapping_ind[unlinkedmask] = unusedind
    return mapping_ind, unlinkedmask


@torch.no_grad()
def get_mapping_ind(bwd_flows, bwd_occs, imgs, scale=1.0):
    """
    FLATTEN: Optical fLow-guided attention (Temoporal-guided attention)
    Find pixel correspondence between every consecutive frames in a batch
    
    [input]
    bwd_flow: (N-1)*2*H*W   
    bwd_occ: (N-1)*H*W        
    imgs: N*3*H*W             
    
    [output]
    fwd_mappings: N*1*HW 
    bwd_mappings: N*1*HW 
    flattn_mask: HW*1*N*N
    i.e., imgs[i,:,fwd_mappings[i]] corresponds to imgs[0]
    i.e., imgs[i,:,fwd_mappings[i]][:,bwd_mappings[i]] restore the original imgs[i]
    """
    N, H, W = imgs.shape[0], int(imgs.shape[2] // scale), int(imgs.shape[3] // scale)
    iterattn_mask = torch.ones(H*W, N, N, dtype=torch.bool).to(imgs.device) 
    for i in range(len(imgs)-1):
        one_mask = torch.ones(N, N, dtype=torch.bool).to(imgs.device)
        one_mask[:i+1,i+1:] = False
        one_mask[i+1:,:i+1] = False
        mapping_ind, unlinkedmask = get_single_mapping_ind(bwd_flows[i:i+1], bwd_occs[i:i+1], imgs[i:i+2], scale)
        if i == 0:
            fwd_mapping = [torch.arange(len(mapping_ind)).to(mapping_ind.device)]
            bwd_mapping = [torch.arange(len(mapping_ind)).to(mapping_ind.device)]
        iterattn_mask[unlinkedmask[fwd_mapping[-1]]] = torch.logical_and(iterattn_mask[unlinkedmask[fwd_mapping[-1]]], one_mask)
        fwd_mapping += [mapping_ind[fwd_mapping[-1]]]
        bwd_mapping += [torch.sort(fwd_mapping[-1])[1]]
    fwd_mappings = torch.stack(fwd_mapping, dim=0).unsqueeze(1)
    bwd_mappings = torch.stack(bwd_mapping, dim=0).unsqueeze(1)
    return fwd_mappings, bwd_mappings, iterattn_mask.unsqueeze(1)