Spaces:
Build error
Build error
zejunyang
commited on
Commit
•
0c9dedf
1
Parent(s):
202b7b1
update frame interpolation model
Browse files
src/utils/frame_interpolation.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import bisect
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
def init_frame_interpolation_model():
|
9 |
+
print("Initializing frame interpolation model")
|
10 |
+
checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
|
11 |
+
|
12 |
+
model = torch.load(checkpoint_name, map_location='cpu')
|
13 |
+
model.eval()
|
14 |
+
model = model.half()
|
15 |
+
model = model.to(device="cuda")
|
16 |
+
return model
|
17 |
+
|
18 |
+
|
19 |
+
def batch_images_interpolation_tool(input_file, model, fps, inter_frames=1):
|
20 |
+
|
21 |
+
image_save_dir = input_file + '_tmp'
|
22 |
+
os.makedirs(image_save_dir, exist_ok=True)
|
23 |
+
|
24 |
+
input_img_list = os.listdir(input_file)
|
25 |
+
input_img_list.sort()
|
26 |
+
|
27 |
+
for idx in range(len(input_img_list)-1):
|
28 |
+
img1 = cv2.imread(os.path.join(input_file, input_img_list[idx]))
|
29 |
+
img2 = cv2.imread(os.path.join(input_file, input_img_list[idx+1]))
|
30 |
+
|
31 |
+
image1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
|
32 |
+
image2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
|
33 |
+
image1 = torch.from_numpy(image1).unsqueeze(0).permute(0, 3, 1, 2)
|
34 |
+
image2 = torch.from_numpy(image2).unsqueeze(0).permute(0, 3, 1, 2)
|
35 |
+
|
36 |
+
results = [image1, image2]
|
37 |
+
|
38 |
+
inter_frames = int(inter_frames)
|
39 |
+
idxes = [0, inter_frames + 1]
|
40 |
+
remains = list(range(1, inter_frames + 1))
|
41 |
+
|
42 |
+
splits = torch.linspace(0, 1, inter_frames + 2)
|
43 |
+
|
44 |
+
for _ in range(len(remains)):
|
45 |
+
starts = splits[idxes[:-1]]
|
46 |
+
ends = splits[idxes[1:]]
|
47 |
+
distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
|
48 |
+
matrix = torch.argmin(distances).item()
|
49 |
+
start_i, step = np.unravel_index(matrix, distances.shape)
|
50 |
+
end_i = start_i + 1
|
51 |
+
|
52 |
+
x0 = results[start_i]
|
53 |
+
x1 = results[end_i]
|
54 |
+
|
55 |
+
x0 = x0.half()
|
56 |
+
x1 = x1.half()
|
57 |
+
x0 = x0.cuda()
|
58 |
+
x1 = x1.cuda()
|
59 |
+
|
60 |
+
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
prediction = model(x0, x1, dt)
|
64 |
+
insert_position = bisect.bisect_left(idxes, remains[step])
|
65 |
+
idxes.insert(insert_position, remains[step])
|
66 |
+
results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
|
67 |
+
del remains[step]
|
68 |
+
|
69 |
+
frames = [(tensor[0] * 255).byte().flip(0).permute(1, 2, 0).numpy().copy() for tensor in results]
|
70 |
+
|
71 |
+
for sub_idx in range(len(frames)):
|
72 |
+
img_path = os.path.join(image_save_dir, f'{sub_idx+idx*(inter_frames+1):06d}.png')
|
73 |
+
cv2.imwrite(img_path, frames[sub_idx])
|
74 |
+
|
75 |
+
final_frames = []
|
76 |
+
final_img_list = os.listdir(image_save_dir)
|
77 |
+
final_img_list.sort()
|
78 |
+
for item in final_img_list:
|
79 |
+
final_frames.append(cv2.imread(os.path.join(image_save_dir, item)))
|
80 |
+
w, h = final_frames[0].shape[1::-1]
|
81 |
+
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
82 |
+
video_save_dir = input_file + '.mp4'
|
83 |
+
writer = cv2.VideoWriter(video_save_dir, fourcc, fps, (w, h))
|
84 |
+
for frame in final_frames:
|
85 |
+
writer.write(frame)
|
86 |
+
writer.release()
|
87 |
+
|
88 |
+
shutil.rmtree(image_save_dir)
|
89 |
+
|
90 |
+
return video_save_dir
|