Upload testing_utils.py
Browse files- testing_utils.py +210 -0
testing_utils.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from glob import glob
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import math
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
import os.path as osp
|
13 |
+
import random
|
14 |
+
import time
|
15 |
+
import torch
|
16 |
+
from pathlib import Path
|
17 |
+
from torch.utils import data as data
|
18 |
+
|
19 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
20 |
+
from basicsr.utils.img_process_util import filter2D
|
21 |
+
from basicsr.data.transforms import paired_random_crop, triplet_random_crop
|
22 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
|
23 |
+
|
24 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
25 |
+
from basicsr.data.transforms import augment
|
26 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
27 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
28 |
+
|
29 |
+
|
30 |
+
def parse_args_paired_testing(input_args=None):
|
31 |
+
"""
|
32 |
+
Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
|
33 |
+
This function sets up an argument parser to handle various training options.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
argparse.Namespace: The parsed command-line arguments.
|
37 |
+
"""
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument("--ref_path", type=str, default=None,)
|
40 |
+
parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str)
|
41 |
+
parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
|
42 |
+
|
43 |
+
# details about the model architecture
|
44 |
+
parser.add_argument("--sd_path")
|
45 |
+
parser.add_argument("--de_net_path")
|
46 |
+
parser.add_argument("--pretrained_path", type=str, default=None,)
|
47 |
+
parser.add_argument("--revision", type=str, default=None,)
|
48 |
+
parser.add_argument("--variant", type=str, default=None,)
|
49 |
+
parser.add_argument("--tokenizer_name", type=str, default=None)
|
50 |
+
parser.add_argument("--lora_rank_unet", default=32, type=int)
|
51 |
+
parser.add_argument("--lora_rank_vae", default=16, type=int)
|
52 |
+
|
53 |
+
parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.")
|
54 |
+
parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.")
|
55 |
+
parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.")
|
56 |
+
parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.")
|
57 |
+
|
58 |
+
parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
|
59 |
+
parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
|
60 |
+
parser.add_argument("--latent_tiled_size", type=int, default=96)
|
61 |
+
parser.add_argument("--latent_tiled_overlap", type=int, default=32)
|
62 |
+
|
63 |
+
parser.add_argument("--align_method", type=str, default="wavelet")
|
64 |
+
|
65 |
+
parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
|
66 |
+
parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
|
67 |
+
|
68 |
+
# training details
|
69 |
+
parser.add_argument("--output_dir", required=True)
|
70 |
+
parser.add_argument("--cache_dir", default=None,)
|
71 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
72 |
+
parser.add_argument("--resolution", type=int, default=512,)
|
73 |
+
parser.add_argument("--checkpointing_steps", type=int, default=500,)
|
74 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
|
75 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",)
|
76 |
+
|
77 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=0,)
|
78 |
+
parser.add_argument("--allow_tf32", action="store_true",
|
79 |
+
help=(
|
80 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
81 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
82 |
+
),
|
83 |
+
)
|
84 |
+
parser.add_argument("--report_to", type=str, default="wandb",
|
85 |
+
help=(
|
86 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
87 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
88 |
+
),
|
89 |
+
)
|
90 |
+
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
|
91 |
+
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
|
92 |
+
parser.add_argument("--set_grads_to_none", action="store_true",)
|
93 |
+
|
94 |
+
parser.add_argument('--world_size', default=1, type=int,
|
95 |
+
help='number of distributed processes')
|
96 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
97 |
+
parser.add_argument('--dist_url', default='env://',
|
98 |
+
help='url used to set up distributed training')
|
99 |
+
|
100 |
+
if input_args is not None:
|
101 |
+
args = parser.parse_args(input_args)
|
102 |
+
else:
|
103 |
+
args = parser.parse_args()
|
104 |
+
|
105 |
+
return args
|
106 |
+
|
107 |
+
|
108 |
+
class PlainDataset(data.Dataset):
|
109 |
+
"""Modified dataset based on the dataset used for Real-ESRGAN model:
|
110 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
111 |
+
|
112 |
+
It loads gt (Ground-Truth) images, and augments them.
|
113 |
+
It also generates blur kernels and sinc kernels for generating low-quality images.
|
114 |
+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
118 |
+
dataroot_gt (str): Data root path for gt.
|
119 |
+
meta_info (str): Path for meta information file.
|
120 |
+
io_backend (dict): IO backend type and other kwarg.
|
121 |
+
use_hflip (bool): Use horizontal flips.
|
122 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
123 |
+
Please see more options in the codes.
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(self, opt):
|
127 |
+
super(PlainDataset, self).__init__()
|
128 |
+
self.opt = opt
|
129 |
+
self.file_client = None
|
130 |
+
self.io_backend_opt = opt['io_backend']
|
131 |
+
|
132 |
+
if 'image_type' not in opt:
|
133 |
+
opt['image_type'] = 'png'
|
134 |
+
|
135 |
+
# support multiple type of data: file path and meta data, remove support of lmdb
|
136 |
+
self.lr_paths = []
|
137 |
+
if 'lr_path' in opt:
|
138 |
+
if isinstance(opt['lr_path'], str):
|
139 |
+
self.lr_paths.extend(sorted(
|
140 |
+
[str(x) for x in Path(opt['lr_path']).glob('*.png')] +
|
141 |
+
[str(x) for x in Path(opt['lr_path']).glob('*.jpg')] +
|
142 |
+
[str(x) for x in Path(opt['lr_path']).glob('*.jpeg')]
|
143 |
+
))
|
144 |
+
else:
|
145 |
+
self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])]))
|
146 |
+
if len(opt['lr_path']) > 1:
|
147 |
+
for i in range(len(opt['lr_path'])-1):
|
148 |
+
self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])]))
|
149 |
+
|
150 |
+
def __getitem__(self, index):
|
151 |
+
if self.file_client is None:
|
152 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
153 |
+
|
154 |
+
# -------------------------------- Load gt images -------------------------------- #
|
155 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
156 |
+
lr_path = self.lr_paths[index]
|
157 |
+
|
158 |
+
# avoid errors caused by high latency in reading files
|
159 |
+
retry = 3
|
160 |
+
while retry > 0:
|
161 |
+
try:
|
162 |
+
lr_img_bytes = self.file_client.get(lr_path, 'gt')
|
163 |
+
except (IOError, OSError) as e:
|
164 |
+
# logger = get_root_logger()
|
165 |
+
# logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
166 |
+
# change another file to read
|
167 |
+
index = random.randint(0, self.__len__()-1)
|
168 |
+
lr_path = self.lr_paths[index]
|
169 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
170 |
+
else:
|
171 |
+
break
|
172 |
+
finally:
|
173 |
+
retry -= 1
|
174 |
+
|
175 |
+
img_lr = imfrombytes(lr_img_bytes, float32=True)
|
176 |
+
|
177 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
178 |
+
img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0]
|
179 |
+
|
180 |
+
return_d = {'lr': img_lr, 'lr_path': lr_path}
|
181 |
+
return return_d
|
182 |
+
|
183 |
+
def __len__(self):
|
184 |
+
return len(self.lr_paths)
|
185 |
+
|
186 |
+
|
187 |
+
def lr_proc(config, batch, device):
|
188 |
+
im_lr = batch['lr'].cuda()
|
189 |
+
im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
|
190 |
+
|
191 |
+
ori_lr = im_lr
|
192 |
+
|
193 |
+
im_lr = F.interpolate(
|
194 |
+
im_lr,
|
195 |
+
size=(im_lr.size(-2) * config.sf,
|
196 |
+
im_lr.size(-1) * config.sf),
|
197 |
+
mode='bicubic',
|
198 |
+
)
|
199 |
+
|
200 |
+
im_lr = im_lr.contiguous()
|
201 |
+
im_lr = im_lr * 2 - 1.0
|
202 |
+
im_lr = torch.clamp(im_lr, -1.0, 1.0)
|
203 |
+
|
204 |
+
ori_h, ori_w = im_lr.size(-2), im_lr.size(-1)
|
205 |
+
|
206 |
+
pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h
|
207 |
+
pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w
|
208 |
+
im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect')
|
209 |
+
|
210 |
+
return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w)
|