|
import logging |
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parallel import DataParallel, DistributedDataParallel |
|
|
|
import models.networks as networks |
|
import models.lr_scheduler as lr_scheduler |
|
from .base_model import BaseModel |
|
from models.modules.loss import ReconstructionLoss, ReconstructionMsgLoss |
|
from models.modules.Quantization import Quantization |
|
from .modules.common import DWT,IWT |
|
from utils.jpegtest import JpegTest |
|
from utils.JPEG import DiffJPEG |
|
import utils.util as util |
|
|
|
|
|
import numpy as np |
|
import random |
|
import cv2 |
|
import time |
|
|
|
logger = logging.getLogger('base') |
|
dwt=DWT() |
|
iwt=IWT() |
|
|
|
from diffusers import StableDiffusionInpaintPipeline |
|
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler |
|
from diffusers import StableDiffusionXLInpaintPipeline |
|
from diffusers.utils import load_image |
|
from diffusers import RePaintPipeline, RePaintScheduler |
|
|
|
class Model_VSN(BaseModel): |
|
def __init__(self, opt): |
|
super(Model_VSN, self).__init__(opt) |
|
|
|
if opt['dist']: |
|
self.rank = torch.distributed.get_rank() |
|
else: |
|
self.rank = -1 |
|
|
|
self.gop = opt['gop'] |
|
train_opt = opt['train'] |
|
test_opt = opt['test'] |
|
self.opt = opt |
|
self.train_opt = train_opt |
|
self.test_opt = test_opt |
|
self.opt_net = opt['network_G'] |
|
self.center = self.gop // 2 |
|
self.num_image = opt['num_image'] |
|
self.mode = opt["mode"] |
|
self.idxx = 0 |
|
|
|
self.netG = networks.define_G_v2(opt).to(self.device) |
|
if opt['dist']: |
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) |
|
else: |
|
self.netG = DataParallel(self.netG) |
|
|
|
self.print_network() |
|
self.load() |
|
|
|
self.Quantization = Quantization() |
|
|
|
if not self.opt['hide']: |
|
file_path = "bit_sequence.txt" |
|
|
|
data_list = [] |
|
|
|
with open(file_path, "r") as file: |
|
for line in file: |
|
data = [int(bit) for bit in line.strip()] |
|
data_list.append(data) |
|
|
|
self.msg_list = data_list |
|
|
|
if self.opt['sdinpaint']: |
|
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
|
|
if self.opt['controlnetinpaint']: |
|
controlnet = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float32 |
|
).to("cuda") |
|
self.pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32 |
|
).to("cuda") |
|
|
|
if self.opt['sdxl']: |
|
self.pipe_sdxl = StableDiffusionXLInpaintPipeline.from_pretrained( |
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
use_safetensors=True, |
|
).to("cuda") |
|
|
|
if self.opt['repaint']: |
|
self.scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") |
|
self.pipe_repaint = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=self.scheduler) |
|
self.pipe_repaint = self.pipe_repaint.to("cuda") |
|
|
|
if self.is_train: |
|
self.netG.train() |
|
|
|
|
|
self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw']) |
|
self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back']) |
|
self.Reconstruction_center = ReconstructionLoss(losstype="center") |
|
self.Reconstruction_msg = ReconstructionMsgLoss(losstype=self.opt['losstype']) |
|
|
|
|
|
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 |
|
optim_params = [] |
|
|
|
if self.mode == "image": |
|
for k, v in self.netG.named_parameters(): |
|
if (k.startswith('module.irn') or k.startswith('module.pm')) and v.requires_grad: |
|
optim_params.append(v) |
|
else: |
|
if self.rank <= 0: |
|
logger.warning('Params [{:s}] will not optimize.'.format(k)) |
|
|
|
elif self.mode == "bit": |
|
for k, v in self.netG.named_parameters(): |
|
if (k.startswith('module.bitencoder') or k.startswith('module.bitdecoder')) and v.requires_grad: |
|
optim_params.append(v) |
|
else: |
|
if self.rank <= 0: |
|
logger.warning('Params [{:s}] will not optimize.'.format(k)) |
|
|
|
|
|
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], |
|
weight_decay=wd_G, |
|
betas=(train_opt['beta1'], train_opt['beta2'])) |
|
self.optimizers.append(self.optimizer_G) |
|
|
|
|
|
if train_opt['lr_scheme'] == 'MultiStepLR': |
|
for optimizer in self.optimizers: |
|
self.schedulers.append( |
|
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], |
|
restarts=train_opt['restarts'], |
|
weights=train_opt['restart_weights'], |
|
gamma=train_opt['lr_gamma'], |
|
clear_state=train_opt['clear_state'])) |
|
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': |
|
for optimizer in self.optimizers: |
|
self.schedulers.append( |
|
lr_scheduler.CosineAnnealingLR_Restart( |
|
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], |
|
restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) |
|
else: |
|
raise NotImplementedError('MultiStepLR learning rate scheme is enough.') |
|
|
|
self.log_dict = OrderedDict() |
|
|
|
def feed_data(self, data): |
|
self.ref_L = data['LQ'].to(self.device) |
|
self.real_H = data['GT'].to(self.device) |
|
self.mes = data['MES'] |
|
|
|
def init_hidden_state(self, z): |
|
b, c, h, w = z.shape |
|
h_t = [] |
|
c_t = [] |
|
for _ in range(self.opt_net['block_num_rbm']): |
|
h_t.append(torch.zeros([b, c, h, w]).cuda()) |
|
c_t.append(torch.zeros([b, c, h, w]).cuda()) |
|
memory = torch.zeros([b, c, h, w]).cuda() |
|
|
|
return h_t, c_t, memory |
|
|
|
def loss_forward(self, out, y): |
|
l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y) |
|
return l_forw_fit |
|
|
|
def loss_back_rec(self, out, x): |
|
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x) |
|
return l_back_rec |
|
|
|
def loss_back_rec_mul(self, out, x): |
|
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x) |
|
return l_back_rec |
|
|
|
def optimize_parameters(self, current_step): |
|
self.optimizer_G.zero_grad() |
|
|
|
b, n, t, c, h, w = self.ref_L.shape |
|
center = t // 2 |
|
intval = self.gop // 2 |
|
|
|
message = torch.Tensor(np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))).to(self.device) |
|
|
|
add_noise = self.opt['addnoise'] |
|
add_jpeg = self.opt['addjpeg'] |
|
add_possion = self.opt['addpossion'] |
|
add_sdinpaint = self.opt['sdinpaint'] |
|
degrade_shuffle = self.opt['degrade_shuffle'] |
|
|
|
self.host = self.real_H[:, center - intval:center + intval + 1] |
|
self.secret = self.ref_L[:, :, center - intval:center + intval + 1] |
|
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=dwt(self.secret[:,0].reshape(b, -1, h, w)), message=message) |
|
|
|
Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach() |
|
|
|
y_forw = container |
|
|
|
l_forw_fit = self.loss_forward(y_forw, self.host[:,0]) |
|
|
|
|
|
if degrade_shuffle: |
|
import random |
|
choice = random.randint(0, 2) |
|
|
|
if choice == 0: |
|
NL = float((np.random.randint(1, 16))/255) |
|
noise = np.random.normal(0, NL, y_forw.shape) |
|
torchnoise = torch.from_numpy(noise).cuda().float() |
|
y_forw = y_forw + torchnoise |
|
|
|
elif choice == 1: |
|
NL = int(np.random.randint(70,95)) |
|
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() |
|
y_forw = self.DiffJPEG(y_forw) |
|
|
|
elif choice == 2: |
|
vals = 10**4 |
|
if random.random() < 0.5: |
|
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
|
else: |
|
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
|
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
|
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
|
|
|
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
|
|
|
else: |
|
|
|
if add_noise: |
|
NL = float((np.random.randint(1,16))/255) |
|
noise = np.random.normal(0, NL, y_forw.shape) |
|
torchnoise = torch.from_numpy(noise).cuda().float() |
|
y_forw = y_forw + torchnoise |
|
|
|
elif add_jpeg: |
|
NL = int(np.random.randint(70,95)) |
|
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() |
|
y_forw = self.DiffJPEG(y_forw) |
|
|
|
elif add_possion: |
|
vals = 10**4 |
|
if random.random() < 0.5: |
|
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
|
else: |
|
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
|
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
|
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
|
|
|
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
|
|
|
y = self.Quantization(y_forw) |
|
all_zero = torch.zeros(message.shape).to(self.device) |
|
|
|
if self.mode == "image": |
|
out_x, out_x_h, out_z, recmessage = self.netG(x=y, message=all_zero, rev=True) |
|
out_x = iwt(out_x) |
|
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] |
|
|
|
l_back_rec = self.loss_back_rec(out_x, self.host[:,0]) |
|
out_x_h = torch.stack(out_x_h, dim=1) |
|
|
|
l_center_x = self.loss_back_rec(out_x_h[:, 0], self.secret[:,0].reshape(b, -1, h, w)) |
|
|
|
recmessage = torch.clamp(recmessage, -0.5, 0.5) |
|
|
|
l_msg = self.Reconstruction_msg(message, recmessage) |
|
|
|
loss = l_forw_fit*2 + l_back_rec + l_center_x*4 |
|
|
|
loss.backward() |
|
|
|
if self.train_opt['lambda_center'] != 0: |
|
self.log_dict['l_center_x'] = l_center_x.item() |
|
|
|
|
|
self.log_dict['l_back_rec'] = l_back_rec.item() |
|
self.log_dict['l_forw_fit'] = l_forw_fit.item() |
|
self.log_dict['l_msg'] = l_msg.item() |
|
|
|
self.log_dict['l_h'] = (l_center_x*10).item() |
|
|
|
|
|
if self.train_opt['gradient_clipping']: |
|
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) |
|
|
|
self.optimizer_G.step() |
|
|
|
elif self.mode == "bit": |
|
recmessage = self.netG(x=y, message=all_zero, rev=True) |
|
|
|
recmessage = torch.clamp(recmessage, -0.5, 0.5) |
|
|
|
l_msg = self.Reconstruction_msg(message, recmessage) |
|
|
|
lambda_msg = self.train_opt['lambda_msg'] |
|
|
|
loss = l_msg * lambda_msg + l_forw_fit |
|
|
|
loss.backward() |
|
|
|
|
|
self.log_dict['l_forw_fit'] = l_forw_fit.item() |
|
self.log_dict['l_msg'] = l_msg.item() |
|
|
|
|
|
if self.train_opt['gradient_clipping']: |
|
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) |
|
|
|
self.optimizer_G.step() |
|
|
|
def test(self, image_id): |
|
self.netG.eval() |
|
add_noise = self.opt['addnoise'] |
|
add_jpeg = self.opt['addjpeg'] |
|
add_possion = self.opt['addpossion'] |
|
add_sdinpaint = self.opt['sdinpaint'] |
|
add_controlnet = self.opt['controlnetinpaint'] |
|
add_sdxl = self.opt['sdxl'] |
|
add_repaint = self.opt['repaint'] |
|
degrade_shuffle = self.opt['degrade_shuffle'] |
|
|
|
with torch.no_grad(): |
|
forw_L = [] |
|
forw_L_h = [] |
|
fake_H = [] |
|
fake_H_h = [] |
|
pred_z = [] |
|
recmsglist = [] |
|
msglist = [] |
|
b, t, c, h, w = self.real_H.shape |
|
center = t // 2 |
|
intval = self.gop // 2 |
|
b, n, t, c, h, w = self.ref_L.shape |
|
id=0 |
|
|
|
self.host = self.real_H[:, center - intval+id:center + intval + 1+id] |
|
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] |
|
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] |
|
|
|
messagenp = np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length'])) |
|
|
|
message = torch.Tensor(messagenp).to(self.device) |
|
|
|
if self.opt['bitrecord']: |
|
mymsg = message.clone() |
|
|
|
mymsg[mymsg>0] = 1 |
|
mymsg[mymsg<0] = 0 |
|
mymsg = mymsg.squeeze(0).to(torch.int) |
|
|
|
bit_list = mymsg.tolist() |
|
|
|
bit_string = ''.join(map(str, bit_list)) |
|
|
|
file_name = "bit_sequence.txt" |
|
|
|
with open(file_name, "a") as file: |
|
file.write(bit_string + "\n") |
|
|
|
if self.opt['hide']: |
|
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message) |
|
y_forw = container |
|
else: |
|
|
|
message = torch.tensor(self.msg_list[image_id]).unsqueeze(0).cuda() |
|
self.output = self.host |
|
y_forw = self.output.squeeze(1) |
|
|
|
if add_sdinpaint: |
|
import random |
|
from PIL import Image |
|
prompt = "" |
|
|
|
b, _, _, _ = y_forw.shape |
|
|
|
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
|
forw_list = [] |
|
|
|
for j in range(b): |
|
i = image_id + 1 |
|
masksrc = "../dataset/valAGE-Set-Mask/" |
|
mask_image = Image.open(masksrc + str(i).zfill(4) + ".png").convert("L") |
|
mask_image = mask_image.resize((512, 512)) |
|
h, w = mask_image.size |
|
|
|
image = image_batch[j, :, :, :] |
|
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") |
|
image_inpaint = self.pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0] |
|
image_inpaint = np.array(image_inpaint) / 255. |
|
mask_image = np.array(mask_image) |
|
mask_image = np.stack([mask_image] * 3, axis=-1) / 255. |
|
mask_image = mask_image.astype(np.uint8) |
|
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image |
|
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
|
|
|
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
|
|
|
if add_controlnet: |
|
from diffusers.utils import load_image |
|
from PIL import Image |
|
|
|
b, _, _, _ = y_forw.shape |
|
forw_list = [] |
|
|
|
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
|
generator = torch.Generator(device="cuda").manual_seed(1) |
|
|
|
for j in range(b): |
|
i = image_id + 1 |
|
mask_path = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png" |
|
mask_image = load_image(mask_path) |
|
mask_image = mask_image.resize((512, 512)) |
|
image_init = image_batch[j, :, :, :] |
|
image_init1 = Image.fromarray((image_init * 255).astype(np.uint8), mode = "RGB") |
|
image_mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 |
|
|
|
assert image_init.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" |
|
image_init[image_mask > 0.5] = -1.0 |
|
image = np.expand_dims(image_init, 0).transpose(0, 3, 1, 2) |
|
control_image = torch.from_numpy(image) |
|
|
|
|
|
image_inpaint = self.pipe_control( |
|
"", |
|
num_inference_steps=20, |
|
generator=generator, |
|
eta=1.0, |
|
image=image_init1, |
|
mask_image=image_mask, |
|
control_image=control_image, |
|
).images[0] |
|
|
|
image_inpaint = np.array(image_inpaint) / 255. |
|
image_mask = np.stack([image_mask] * 3, axis=-1) |
|
image_mask = image_mask.astype(np.uint8) |
|
image_fuse = image_init * (1 - image_mask) + image_inpaint * image_mask |
|
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
|
|
|
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
|
|
|
if add_sdxl: |
|
import random |
|
from PIL import Image |
|
from diffusers.utils import load_image |
|
prompt = "" |
|
|
|
b, _, _, _ = y_forw.shape |
|
|
|
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
|
forw_list = [] |
|
|
|
for j in range(b): |
|
i = image_id + 1 |
|
masksrc = "../dataset/valAGE-Set-Mask/" |
|
mask_image = load_image(masksrc + str(i).zfill(4) + ".png").convert("RGB") |
|
mask_image = mask_image.resize((512, 512)) |
|
h, w = mask_image.size |
|
|
|
image = image_batch[j, :, :, :] |
|
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") |
|
image_inpaint = self.pipe_sdxl( |
|
prompt=prompt, image=image_init, mask_image=mask_image, num_inference_steps=50, strength=0.80, target_size=(512, 512) |
|
).images[0] |
|
image_inpaint = image_inpaint.resize((512, 512)) |
|
image_inpaint = np.array(image_inpaint) / 255. |
|
mask_image = np.array(mask_image) / 255. |
|
mask_image = mask_image.astype(np.uint8) |
|
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image |
|
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
|
|
|
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
|
|
|
|
|
if add_repaint: |
|
from PIL import Image |
|
|
|
b, _, _, _ = y_forw.shape |
|
|
|
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
|
forw_list = [] |
|
|
|
generator = torch.Generator(device="cuda").manual_seed(0) |
|
for j in range(b): |
|
i = image_id + 1 |
|
masksrc = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png" |
|
mask_image = Image.open(masksrc).convert("RGB") |
|
mask_image = mask_image.resize((256, 256)) |
|
mask_image = Image.fromarray(255 - np.array(mask_image)) |
|
image = image_batch[j, :, :, :] |
|
original_image = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") |
|
original_image = original_image.resize((256, 256)) |
|
output = self.pipe_repaint( |
|
image=original_image, |
|
mask_image=mask_image, |
|
num_inference_steps=150, |
|
eta=0.0, |
|
jump_length=10, |
|
jump_n_sample=10, |
|
generator=generator, |
|
) |
|
image_inpaint = output.images[0] |
|
image_inpaint = image_inpaint.resize((512, 512)) |
|
image_inpaint = np.array(image_inpaint) / 255. |
|
mask_image = mask_image.resize((512, 512)) |
|
mask_image = np.array(mask_image) / 255. |
|
mask_image = mask_image.astype(np.uint8) |
|
image_fuse = image * mask_image + image_inpaint * (1 - mask_image) |
|
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
|
|
|
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
|
|
|
if degrade_shuffle: |
|
import random |
|
choice = random.randint(0, 2) |
|
|
|
if choice == 0: |
|
NL = float((np.random.randint(1,5))/255) |
|
noise = np.random.normal(0, NL, y_forw.shape) |
|
torchnoise = torch.from_numpy(noise).cuda().float() |
|
y_forw = y_forw + torchnoise |
|
|
|
elif choice == 1: |
|
NL = 90 |
|
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() |
|
y_forw = self.DiffJPEG(y_forw) |
|
|
|
elif choice == 2: |
|
vals = 10**4 |
|
if random.random() < 0.5: |
|
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
|
else: |
|
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
|
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
|
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
|
|
|
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
|
|
|
else: |
|
|
|
if add_noise: |
|
NL = self.opt['noisesigma'] / 255.0 |
|
noise = np.random.normal(0, NL, y_forw.shape) |
|
torchnoise = torch.from_numpy(noise).cuda().float() |
|
y_forw = y_forw + torchnoise |
|
|
|
elif add_jpeg: |
|
Q = self.opt['jpegfactor'] |
|
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(Q)).cuda() |
|
y_forw = self.DiffJPEG(y_forw) |
|
|
|
elif add_possion: |
|
vals = 10**4 |
|
if random.random() < 0.5: |
|
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
|
else: |
|
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
|
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
|
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
|
|
|
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
|
|
|
|
|
if self.opt['hide']: |
|
y = self.Quantization(y_forw) |
|
else: |
|
y = y_forw |
|
|
|
if self.mode == "image": |
|
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True) |
|
out_x = iwt(out_x) |
|
|
|
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] |
|
out_x = out_x.reshape(-1, self.gop, 3, h, w) |
|
out_x_h = torch.stack(out_x_h, dim=1) |
|
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w) |
|
|
|
forw_L.append(y_forw) |
|
fake_H.append(out_x[:, self.gop//2]) |
|
fake_H_h.append(out_x_h[:,:, self.gop//2]) |
|
recmsglist.append(recmessage) |
|
msglist.append(message) |
|
|
|
elif self.mode == "bit": |
|
recmessage = self.netG(x=y, rev=True) |
|
forw_L.append(y_forw) |
|
recmsglist.append(recmessage) |
|
msglist.append(message) |
|
|
|
if self.mode == "image": |
|
self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1) |
|
self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1) |
|
|
|
self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1) |
|
remesg = torch.clamp(torch.stack(recmsglist, dim=0),-0.5,0.5) |
|
|
|
if self.opt['hide']: |
|
mesg = torch.clamp(torch.stack(msglist, dim=0),-0.5,0.5) |
|
else: |
|
mesg = torch.stack(msglist, dim=0) |
|
|
|
self.recmessage = remesg.clone() |
|
self.recmessage[remesg > 0] = 1 |
|
self.recmessage[remesg <= 0] = 0 |
|
|
|
self.message = mesg.clone() |
|
self.message[mesg > 0] = 1 |
|
self.message[mesg <= 0] = 0 |
|
|
|
self.netG.train() |
|
|
|
|
|
def image_hiding(self, ): |
|
self.netG.eval() |
|
with torch.no_grad(): |
|
b, t, c, h, w = self.real_H.shape |
|
center = t // 2 |
|
intval = self.gop // 2 |
|
b, n, t, c, h, w = self.ref_L.shape |
|
id=0 |
|
|
|
self.host = self.real_H[:, center - intval+id:center + intval + 1+id] |
|
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] |
|
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] |
|
|
|
message = torch.Tensor(self.mes).to(self.device) |
|
|
|
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message) |
|
y_forw = container |
|
|
|
result = torch.clamp(y_forw,0,1) |
|
|
|
lr_img = util.tensor2img(result) |
|
|
|
return lr_img |
|
|
|
def image_recovery(self, number): |
|
self.netG.eval() |
|
with torch.no_grad(): |
|
b, t, c, h, w = self.real_H.shape |
|
center = t // 2 |
|
intval = self.gop // 2 |
|
b, n, t, c, h, w = self.ref_L.shape |
|
id=0 |
|
|
|
self.host = self.real_H[:, center - intval+id:center + intval + 1+id] |
|
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] |
|
template = self.secret.reshape(b, -1, h, w) |
|
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] |
|
|
|
self.output = self.host |
|
y_forw = self.output.squeeze(1) |
|
|
|
y = self.Quantization(y_forw) |
|
|
|
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True) |
|
out_x = iwt(out_x) |
|
|
|
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] |
|
out_x = out_x.reshape(-1, self.gop, 3, h, w) |
|
out_x_h = torch.stack(out_x_h, dim=1) |
|
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w) |
|
|
|
rec_loc = out_x_h[:,:, self.gop//2] |
|
|
|
|
|
|
|
residual = torch.abs(template - rec_loc) |
|
binary_residual = (residual > number).float() |
|
residual = util.tensor2img(binary_residual) |
|
mask = np.sum(residual, axis=2) |
|
|
|
|
|
remesg = torch.clamp(recmessage,-0.5,0.5) |
|
remesg[remesg > 0] = 1 |
|
remesg[remesg <= 0] = 0 |
|
|
|
return mask, remesg |
|
|
|
def get_current_log(self): |
|
return self.log_dict |
|
|
|
def get_current_visuals(self): |
|
b, n, t, c, h, w = self.ref_L.shape |
|
center = t // 2 |
|
intval = self.gop // 2 |
|
out_dict = OrderedDict() |
|
LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu() |
|
LR_ref = torch.chunk(LR_ref, self.num_image, dim=0) |
|
out_dict['LR_ref'] = [image.squeeze(0) for image in LR_ref] |
|
|
|
if self.mode == "image": |
|
out_dict['SR'] = self.fake_H.detach()[0].float().cpu() |
|
SR_h = self.fake_H_h.detach()[0].float().cpu() |
|
SR_h = torch.chunk(SR_h, self.num_image, dim=0) |
|
out_dict['SR_h'] = [image.squeeze(0) for image in SR_h] |
|
|
|
out_dict['LR'] = self.forw_L.detach()[0].float().cpu() |
|
out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu() |
|
out_dict['message'] = self.message |
|
out_dict['recmessage'] = self.recmessage |
|
|
|
return out_dict |
|
|
|
def print_network(self): |
|
s, n = self.get_network_description(self.netG) |
|
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): |
|
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, |
|
self.netG.module.__class__.__name__) |
|
else: |
|
net_struc_str = '{}'.format(self.netG.__class__.__name__) |
|
if self.rank <= 0: |
|
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) |
|
logger.info(s) |
|
|
|
def load(self): |
|
load_path_G = self.opt['path']['pretrain_model_G'] |
|
if load_path_G is not None: |
|
logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) |
|
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) |
|
|
|
def load_test(self,load_path_G): |
|
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) |
|
|
|
def save(self, iter_label): |
|
self.save_network(self.netG, 'G', iter_label) |
|
|