Spaces:
Running
on
T4
Running
on
T4
File size: 6,229 Bytes
561c629 8f3d49d 561c629 d26bbd5 561c629 d26bbd5 561c629 d26bbd5 561c629 8f3d49d 561c629 8f3d49d 561c629 8f3d49d 561c629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
'''
This is file is to execute the inference for a single image or a folder input
'''
import argparse
import os, sys, cv2, shutil, warnings
import torch
import gradio as gr
from torchvision.transforms import ToTensor
from torchvision.utils import save_image
warnings.simplefilter("default")
os.environ["PYTHONWARNINGS"] = "default"
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from test_code.test_utils import load_grl, load_rrdb, load_cunet
@torch.no_grad # You must add these time, else it will have Out of Memory
def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torch.float32, downsample_threshold=720, crop_for_4x=True):
''' Super Resolve a low resolution image
Args:
generator (torch): the generator class that is already loaded
input_path (str): the path to the input lr images
output_path (str): the directory to store the generated images
weight_dtype (bool): the weight type (float32/float16)
downsample_threshold (int): the threshold of height/width (short side) to downsample the input
crop_for_4x (bool): whether we crop the lr images to match 4x scale (needed for some situation)
'''
print("Processing image {}".format(input_path))
# Read the image and do preprocess
img_lr = cv2.imread(input_path)
h, w, c = img_lr.shape
# Downsample if needed
short_side = min(h, w)
if downsample_threshold != -1 and short_side > downsample_threshold:
resize_ratio = short_side / downsample_threshold
img_lr = cv2.resize(img_lr, (int(w/resize_ratio), int(h/resize_ratio)), interpolation = cv2.INTER_LINEAR)
# Crop if needed
if crop_for_4x:
h, w, _ = img_lr.shape
if h % 4 != 0:
img_lr = img_lr[:4*(h//4),:,:]
if w % 4 != 0:
img_lr = img_lr[:,:4*(w//4),:]
# Check if the size is out of the boundary
h, w, c = img_lr.shape
if h*w > 720*1280:
raise gr.Error("The input image size is too large. The largest area we support is 720x1280=921600 pixel!")
# Transform to tensor
img_lr = cv2.cvtColor(img_lr, cv2.COLOR_BGR2RGB)
img_lr = ToTensor()(img_lr).unsqueeze(0).cuda() # Use tensor format
img_lr = img_lr.to(dtype=weight_dtype)
# Model inference
print("lr shape is ", img_lr.shape)
super_resolved_img = generator(img_lr)
# Store the generated result
with torch.cuda.amp.autocast():
if output_path is not None:
save_image(super_resolved_img, output_path)
# Empty the cache every time you finish processing one image
torch.cuda.empty_cache()
return super_resolved_img
if __name__ == "__main__":
# Fundamental setting
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type = str, default = '__assets__/lr_inputs', help="Can be either single image input or a folder input")
parser.add_argument('--model', type = str, default = 'GRL', help=" 'GRL' || 'RRDB' (for ESRNET & ESRGAN) || 'CUNET' (for Real-ESRGAN) ")
parser.add_argument('--scale', type = int, default = 4, help="Up scaler factor")
parser.add_argument('--weight_path', type = str, default = 'pretrained/4x_APISR_GRL_GAN_generator.pth', help="Weight path directory, usually under saved_models folder")
parser.add_argument('--store_dir', type = str, default = 'sample_outputs', help="The folder to store the super-resolved images")
parser.add_argument('--float16_inference', type = bool, default = False, help="Float16 inference, only useful in RRDB now") # Currently, this is only supported in RRDB, there is some bug with GRL model
args = parser.parse_args()
# Sample Command
# 4x GRL (Default): python test_code/inference.py --model GRL --scale 4 --weight_path pretrained/4x_APISR_GRL_GAN_generator.pth
# 2x RRDB: python test_code/inference.py --model RRDB --scale 2 --weight_path pretrained/2x_APISR_RRDB_GAN_generator.pth
# Read argument and prepare the folder needed
input_dir = args.input_dir
model = args.model
weight_path = args.weight_path
store_dir = args.store_dir
scale = args.scale
float16_inference = args.float16_inference
# Check the path of the weight
if not os.path.exists(weight_path):
print("we cannot locate weight path ", weight_path)
# TODO: I am not sure if I should automatically download weight from github release based on the upscale factor and model name.
os._exit(0)
# Prepare the store folder
if os.path.exists(store_dir):
shutil.rmtree(store_dir)
os.makedirs(store_dir)
# Define the weight type
if float16_inference:
torch.backends.cudnn.benchmark = True
weight_dtype = torch.float16
else:
weight_dtype = torch.float32
# Load the model
if model == "GRL":
generator = load_grl(weight_path, scale=scale) # GRL for Real-World SR only support 4x upscaling
elif model == "RRDB":
generator = load_rrdb(weight_path, scale=scale) # Can be any size
generator = generator.to(dtype=weight_dtype)
# Take the input path and do inference
if os.path.isdir(store_dir): # If the input is a directory, we will iterate it
for filename in sorted(os.listdir(input_dir)):
input_path = os.path.join(input_dir, filename)
output_path = os.path.join(store_dir, filename)
# In default, we will automatically use crop to match 4x size
super_resolve_img(generator, input_path, output_path, weight_dtype, crop_for_4x=True)
else: # If the input is a single image, we will process it directly and write on the same folder
filename = os.path.split(input_dir)[-1].split('.')[0]
output_path = os.path.join(store_dir, filename+"_"+str(scale)+"x.png")
# In default, we will automatically use crop to match 4x size
super_resolve_img(generator, input_dir, output_path, weight_dtype, crop_for_4x=True)
|