File size: 6,517 Bytes
8da8f47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import math
import argparse
import random
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from data.data_sampler import DistIterSampler
import options.options as option
from utils import util
from data import create_dataloader, create_dataset
from models import create_model
import numpy as np
def init_dist(backend='nccl', **kwargs):
''' initialization for distributed training'''
# if mp.get_start_method(allow_none=True) is None:
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def cal_pnsr(sr_img, gt_img):
# calculate PSNR
gt_img = gt_img / 255.
sr_img = sr_img / 255.
psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
return psnr
def get_min_avg_and_indices(nums):
# Get the indices of the smallest 1000 elements
indices = sorted(range(len(nums)), key=lambda i: nums[i])[:900]
# Calculate the average of these elements
avg = sum(nums[i] for i in indices) / 900
# Write the indices to a txt file
with open("indices.txt", "w") as file:
for index in indices:
file.write(str(index) + "\n")
return avg
def main():
# options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--ckpt', type=str, default='/userhome/NewIBSN/EditGuard_open/checkpoints/clean.pth', help='Path to pre-trained model.')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)
# distributed training settings
if args.launcher == 'none': # disabled distributed training
opt['dist'] = False
rank = -1
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
# loading resume state if exists
if opt['path'].get('resume_state', None):
# distributed resuming: all load into default GPU
device_id = torch.cuda.current_device()
resume_state = torch.load(opt['path']['resume_state'],
map_location=lambda storage, loc: storage.cuda(device_id))
option.check_resume(opt, resume_state['iter']) # check resume options
else:
resume_state = None
# convert to NoneDict, which returns None for missing keys
opt = option.dict_to_nonedict(opt)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
#### create train and val dataloader
dataset_ratio = 200 # enlarge the size of each epoch
for phase, dataset_opt in opt['datasets'].items():
print("phase", phase)
if phase == 'TD':
val_set = create_dataset(dataset_opt)
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
elif phase == 'val':
val_set = create_dataset(dataset_opt)
val_loader = create_dataloader(val_set, dataset_opt, opt, None)
else:
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
# create model
model = create_model(opt)
model.load_test(args.ckpt)
# validation
avg_psnr = 0.0
avg_psnr_h = [0.0]*opt['num_image']
avg_psnr_lr = 0.0
biterr = []
idx = 0
for image_id, val_data in enumerate(val_loader):
img_dir = os.path.join('results',opt['name'])
util.mkdir(img_dir)
model.feed_data(val_data)
model.test(image_id)
visuals = model.get_current_visuals()
t_step = visuals['SR'].shape[0]
idx += t_step
n = len(visuals['SR_h'])
a = visuals['recmessage'][0]
b = visuals['message'][0]
bitrecord = util.decoded_message_error_rate_batch(a, b)
print(bitrecord)
biterr.append(bitrecord)
for i in range(t_step):
sr_img = util.tensor2img(visuals['SR'][i]) # uint8
sr_img_h = []
for j in range(n):
sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8
gt_img = util.tensor2img(visuals['GT'][i]) # uint8
lr_img = util.tensor2img(visuals['LR'][i])
lrgt_img = []
for j in range(n):
lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i]))
# Save SR images for reference
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'SR'))
util.save_img(sr_img, save_img_path)
for j in range(n):
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'SR_h'))
util.save_img(sr_img_h[j], save_img_path)
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'GT'))
util.save_img(gt_img, save_img_path)
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(image_id, i, 'LR'))
util.save_img(lr_img, save_img_path)
for j in range(n):
save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(image_id, i, j, 'LRGT'))
util.save_img(lrgt_img[j], save_img_path)
psnr = cal_pnsr(sr_img, gt_img)
psnr_h = []
for j in range(n):
psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j]))
psnr_lr = cal_pnsr(lr_img, gt_img)
avg_psnr += psnr
for j in range(n):
avg_psnr_h[j] += psnr_h[j]
avg_psnr_lr += psnr_lr
avg_psnr = avg_psnr / idx
avg_biterr = sum(biterr) / len(biterr)
print(get_min_avg_and_indices(biterr))
avg_psnr_h = [psnr / idx for psnr in avg_psnr_h]
avg_psnr_lr = avg_psnr_lr / idx
res_psnr_h = ''
for p in avg_psnr_h:
res_psnr_h+=('_{:.4e}'.format(p))
print('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}, Bit_Error: {:.4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr, avg_biterr))
if __name__ == '__main__':
main() |