File size: 5,232 Bytes
b3f324b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import cv2
import argparse
from basicsr.test_img import image_sr  
from os import path as osp
import os
import shutil
from PIL import Image
import re
import imageio.v2 as imageio
import threading
from concurrent.futures import ThreadPoolExecutor
import time

def replace_filename(original_path, suffix):

    directory = os.path.dirname(original_path)
    old_filename = os.path.basename(original_path)
    name_part, file_extension = os.path.splitext(old_filename)
    new_filename = f"{name_part}{suffix}{file_extension}"
    new_path = os.path.join(directory, new_filename)

    return new_path

def create_temp_folder(folder_path):
    
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)
    os.makedirs(folder_path)

def delete_temp_folder(folder_path):
    shutil.rmtree(folder_path)

def extract_number(filename):
    s = re.findall(r'\d+', filename)
    return int(s[0]) if s else -1

def bicubic_upsample_opencv(input_image_path, output_image_path, scale_factor):
    
    img = cv2.imread(input_image_path)
    
    original_height, original_width = img.shape[:2]
    
    new_width = int(original_width * scale_factor)
    new_height = int(original_height * scale_factor)
    
    upsampled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
    cv2.imwrite(output_image_path, upsampled_img)


def process_frame(frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, SR):
    frame_path = os.path.join(temp_LR_folder_path, f"frame_{frame_count}{SR}.png")
    cv2.imwrite(frame_path, frame)
    HR_frame_path = os.path.join(temp_HR_folder_path, f"frame_{frame_count}.png")

    if SR == 'x4': 
        bicubic_upsample_opencv(frame_path, HR_frame_path, 4)
    elif SR == 'x2': 
        bicubic_upsample_opencv(frame_path, HR_frame_path, 2)

def video_sr(args):
    file_name = os.path.basename(args.input_dir)
    video_output_path = os.path.join(args.output_dir,file_name)

    if args.SR == 'x4': 
        temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X4')
        video_output_path = replace_filename(video_output_path, '_x4')
        result_temp = osp.join(args.root_path, f'results/test_RGT_x4/visualization/Set5')
    if args.SR == 'x2': 
        temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X2')
        video_output_path = replace_filename(video_output_path, '_x2')
        result_temp = osp.join(args.root_path, f'results/test_RGT_x2/visualization/Set5')
    
    temp_HR_folder_path = os.path.join(args.output_dir, f'temp_HR')
    
    # create_temp_folder(result_temp)
    create_temp_folder(temp_LR_folder_path)
    create_temp_folder(temp_HR_folder_path) 

    cap = cv2.VideoCapture(args.input_dir)
    if not cap.isOpened():
        print("Error opening video file.")
        return

    t1 = time.time()
    frame_count = 0
    frames_to_process = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames_to_process.append((frame_count, frame))
        frame_count += 1

    with ThreadPoolExecutor(max_workers = args.mul_numwork) as executor:
        for frame_count, frame in frames_to_process:
            executor.submit(process_frame, frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, args.SR)
    
    print("total frames:",frame_count)
    print("fps :",cap.get(cv2.CAP_PROP_FPS))
    
    t2 = time.time()
    print('mul threads: ',t2 - t1,'s')
    # progress all frames in video 
    image_sr(args)

    t3 = time.time()
    print('image super resolution: ',t3 - t2,'s')
    # recover video form all frames
    frame_files = sorted(os.listdir(result_temp), key=extract_number)
    video_frames = [imageio.imread(os.path.join(result_temp, frame_file)) for frame_file in frame_files]
    fps = cap.get(cv2.CAP_PROP_FPS)
    imageio.mimwrite(video_output_path, video_frames, fps=fps, quality=9)  
    
    t4 = time.time()
    print('tranformer frames to  video: ',t4 - t3,'s')
    # release all resources
    cap.release()
    delete_temp_folder(os.path.dirname(temp_LR_folder_path))
    delete_temp_folder(temp_HR_folder_path)
    delete_temp_folder(os.path.join(args.root_path, f'results'))
    
    t5 = time.time()
    print('delete time: ',t5 - t4,'s')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="RGT for Video Super-Resolution")
    # make sure you SR is match with the ckpt_path
    parser.add_argument("--SR", type=str, choices=['x2', 'x4'], default='x4', help='image resolution')
    parser.add_argument("--ckpt_path", type=str, default = "/remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth")
    
    parser.add_argument("--root_path", type=str, default = "/remote-home/lzy/RGT")
    parser.add_argument("--input_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video/video_test1.mp4")
    parser.add_argument("--output_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video_output")
    
    parser.add_argument("--mul_numwork", type=int, default = 16, help ='max_workers to execute Multi')
    parser.add_argument("--use_chop", type= bool, default = True,  help ='use_chop: True  # True to save memory, if img too large')
    args = parser.parse_args()
    video_sr(args)