File size: 2,295 Bytes
3e99418
0c9dedf
 
 
 
 
 
3e99418
 
0c9dedf
 
 
 
 
3e99418
0c9dedf
 
 
 
 
 
3e99418
 
 
 
0c9dedf
3e99418
 
 
0c9dedf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e99418
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Adapted from https://github.com/dajes/frame-interpolation-pytorch
import os
import cv2
import numpy as np
import torch
import bisect
import shutil
import pdb
from tqdm import tqdm

def init_frame_interpolation_model():
    print("Initializing frame interpolation model")
    checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")

    model = torch.jit.load(checkpoint_name, map_location='cpu')
    model.eval()
    model = model.half()
    model = model.to(device="cuda")
    return model


def batch_images_interpolation_tool(input_tensor, model, inter_frames=1):

    video_tensor = []
    frame_num = input_tensor.shape[2]  # bs, channel, frame, height, width
    
    for idx in tqdm(range(frame_num-1)):
        image1 = input_tensor[:,:,idx]
        image2 = input_tensor[:,:,idx+1]

        results = [image1, image2]

        inter_frames = int(inter_frames)
        idxes = [0, inter_frames + 1]
        remains = list(range(1, inter_frames + 1))

        splits = torch.linspace(0, 1, inter_frames + 2)

        for _ in range(len(remains)):
            starts = splits[idxes[:-1]]
            ends = splits[idxes[1:]]
            distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
            matrix = torch.argmin(distances).item()
            start_i, step = np.unravel_index(matrix, distances.shape)
            end_i = start_i + 1

            x0 = results[start_i]
            x1 = results[end_i]

            x0 = x0.half()
            x1 = x1.half()
            x0 = x0.cuda()
            x1 = x1.cuda()

            dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])

            with torch.no_grad():
                prediction = model(x0, x1, dt)
            insert_position = bisect.bisect_left(idxes, remains[step])
            idxes.insert(insert_position, remains[step])
            results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
            del remains[step]

        for sub_idx in range(len(results)-1):
            video_tensor.append(results[sub_idx].unsqueeze(2))

    video_tensor.append(input_tensor[:,:,-1].unsqueeze(2))
    video_tensor = torch.cat(video_tensor, dim=2)
    return video_tensor