import logging |
from collections import OrderedDict |
import torch |
import torch.nn as nn |
from torch.nn.parallel import DataParallel, DistributedDataParallel |
import models.networks as networks |
import models.lr_scheduler as lr_scheduler |
from .base_model import BaseModel |
from models.modules.loss import ReconstructionLoss, ReconstructionMsgLoss |
from models.modules.Quantization import Quantization |
from .modules.common import DWT,IWT |
from utils.jpegtest import JpegTest |
from utils.JPEG import DiffJPEG |
import utils.util as util |
import numpy as np |
import random |
import cv2 |
import time |
logger = logging.getLogger('base') |
dwt=DWT() |
iwt=IWT() |
from diffusers import StableDiffusionInpaintPipeline |
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler |
from diffusers import StableDiffusionXLInpaintPipeline |
from diffusers.utils import load_image |
from diffusers import RePaintPipeline, RePaintScheduler |
class Model_VSN(BaseModel): |
def __init__(self, opt): |
super(Model_VSN, self).__init__(opt) |
if opt['dist']: |
self.rank = torch.distributed.get_rank() |
else: |
self.rank = -1 |
self.gop = opt['gop'] |
train_opt = opt['train'] |
test_opt = opt['test'] |
self.opt = opt |
self.train_opt = train_opt |
self.test_opt = test_opt |
self.opt_net = opt['network_G'] |
self.center = self.gop // 2 |
self.num_image = opt['num_image'] |
self.mode = opt["mode"] |
self.idxx = 0 |
self.netG = networks.define_G_v2(opt).to(self.device) |
if opt['dist']: |
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) |
else: |
self.netG = DataParallel(self.netG) |
self.print_network() |
self.load() |
self.Quantization = Quantization() |
if not self.opt['hide']: |
file_path = "bit_sequence.txt" |
data_list = [] |
with open(file_path, "r") as file: |
for line in file: |
data = [int(bit) for bit in line.strip()] |
data_list.append(data) |
self.msg_list = data_list |
if self.opt['sdinpaint']: |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( |
"stabilityai/stable-diffusion-2-inpainting", |
torch_dtype=torch.float16, |
).to("cuda") |
if self.opt['controlnetinpaint']: |
controlnet = ControlNetModel.from_pretrained( |
"lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float32 |
).to("cuda") |
self.pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32 |
).to("cuda") |
if self.opt['sdxl']: |
self.pipe_sdxl = StableDiffusionXLInpaintPipeline.from_pretrained( |
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
torch_dtype=torch.float16, |
variant="fp16", |
use_safetensors=True, |
).to("cuda") |
if self.opt['repaint']: |
self.scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") |
self.pipe_repaint = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=self.scheduler) |
self.pipe_repaint = self.pipe_repaint.to("cuda") |
if self.is_train: |
self.netG.train() |
self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw']) |
self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back']) |
self.Reconstruction_center = ReconstructionLoss(losstype="center") |
self.Reconstruction_msg = ReconstructionMsgLoss(losstype=self.opt['losstype']) |
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 |
optim_params = [] |
if self.mode == "image": |
for k, v in self.netG.named_parameters(): |
if (k.startswith('module.irn') or k.startswith('module.pm')) and v.requires_grad: |
optim_params.append(v) |
else: |
if self.rank <= 0: |
logger.warning('Params [{:s}] will not optimize.'.format(k)) |
elif self.mode == "bit": |
for k, v in self.netG.named_parameters(): |
if (k.startswith('module.bitencoder') or k.startswith('module.bitdecoder')) and v.requires_grad: |
optim_params.append(v) |
else: |
if self.rank <= 0: |
logger.warning('Params [{:s}] will not optimize.'.format(k)) |
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], |
weight_decay=wd_G, |
betas=(train_opt['beta1'], train_opt['beta2'])) |
self.optimizers.append(self.optimizer_G) |
if train_opt['lr_scheme'] == 'MultiStepLR': |
for optimizer in self.optimizers: |
self.schedulers.append( |
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], |
restarts=train_opt['restarts'], |
weights=train_opt['restart_weights'], |
gamma=train_opt['lr_gamma'], |
clear_state=train_opt['clear_state'])) |
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': |
for optimizer in self.optimizers: |
self.schedulers.append( |
lr_scheduler.CosineAnnealingLR_Restart( |
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], |
restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) |
else: |
raise NotImplementedError('MultiStepLR learning rate scheme is enough.') |
self.log_dict = OrderedDict() |
def feed_data(self, data): |
self.ref_L = data['LQ'].to(self.device) |
self.real_H = data['GT'].to(self.device) |
self.mes = data['MES'] |
def init_hidden_state(self, z): |
b, c, h, w = z.shape |
h_t = [] |
c_t = [] |
for _ in range(self.opt_net['block_num_rbm']): |
h_t.append(torch.zeros([b, c, h, w]).cuda()) |
c_t.append(torch.zeros([b, c, h, w]).cuda()) |
memory = torch.zeros([b, c, h, w]).cuda() |
return h_t, c_t, memory |
def loss_forward(self, out, y): |
l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y) |
return l_forw_fit |
def loss_back_rec(self, out, x): |
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x) |
return l_back_rec |
def loss_back_rec_mul(self, out, x): |
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x) |
return l_back_rec |
def optimize_parameters(self, current_step): |
self.optimizer_G.zero_grad() |
b, n, t, c, h, w = self.ref_L.shape |
center = t // 2 |
intval = self.gop // 2 |
message = torch.Tensor(np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))).to(self.device) |
add_noise = self.opt['addnoise'] |
add_jpeg = self.opt['addjpeg'] |
add_possion = self.opt['addpossion'] |
add_sdinpaint = self.opt['sdinpaint'] |
degrade_shuffle = self.opt['degrade_shuffle'] |
self.host = self.real_H[:, center - intval:center + intval + 1] |
self.secret = self.ref_L[:, :, center - intval:center + intval + 1] |
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=dwt(self.secret[:,0].reshape(b, -1, h, w)), message=message) |
Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach() |
y_forw = container |
l_forw_fit = self.loss_forward(y_forw, self.host[:,0]) |
if degrade_shuffle: |
import random |
choice = random.randint(0, 2) |
if choice == 0: |
NL = float((np.random.randint(1, 16))/255) |
noise = np.random.normal(0, NL, y_forw.shape) |
torchnoise = torch.from_numpy(noise).cuda().float() |
y_forw = y_forw + torchnoise |
elif choice == 1: |
NL = int(np.random.randint(70,95)) |
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() |
y_forw = self.DiffJPEG(y_forw) |
elif choice == 2: |
vals = 10**4 |
if random.random() < 0.5: |
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
else: |
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
else: |
if add_noise: |
NL = float((np.random.randint(1,16))/255) |
noise = np.random.normal(0, NL, y_forw.shape) |
torchnoise = torch.from_numpy(noise).cuda().float() |
y_forw = y_forw + torchnoise |
elif add_jpeg: |
NL = int(np.random.randint(70,95)) |
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() |
y_forw = self.DiffJPEG(y_forw) |
elif add_possion: |
vals = 10**4 |
if random.random() < 0.5: |
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
else: |
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
y = self.Quantization(y_forw) |
all_zero = torch.zeros(message.shape).to(self.device) |
if self.mode == "image": |
out_x, out_x_h, out_z, recmessage = self.netG(x=y, message=all_zero, rev=True) |
out_x = iwt(out_x) |
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] |
l_back_rec = self.loss_back_rec(out_x, self.host[:,0]) |
out_x_h = torch.stack(out_x_h, dim=1) |
l_center_x = self.loss_back_rec(out_x_h[:, 0], self.secret[:,0].reshape(b, -1, h, w)) |
recmessage = torch.clamp(recmessage, -0.5, 0.5) |
l_msg = self.Reconstruction_msg(message, recmessage) |
loss = l_forw_fit*2 + l_back_rec + l_center_x*4 |
loss.backward() |
if self.train_opt['lambda_center'] != 0: |
self.log_dict['l_center_x'] = l_center_x.item() |
self.log_dict['l_back_rec'] = l_back_rec.item() |
self.log_dict['l_forw_fit'] = l_forw_fit.item() |
self.log_dict['l_msg'] = l_msg.item() |
self.log_dict['l_h'] = (l_center_x*10).item() |
if self.train_opt['gradient_clipping']: |
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) |
self.optimizer_G.step() |
elif self.mode == "bit": |
recmessage = self.netG(x=y, message=all_zero, rev=True) |
recmessage = torch.clamp(recmessage, -0.5, 0.5) |
l_msg = self.Reconstruction_msg(message, recmessage) |
lambda_msg = self.train_opt['lambda_msg'] |
loss = l_msg * lambda_msg + l_forw_fit |
loss.backward() |
self.log_dict['l_forw_fit'] = l_forw_fit.item() |
self.log_dict['l_msg'] = l_msg.item() |
if self.train_opt['gradient_clipping']: |
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) |
self.optimizer_G.step() |
def test(self, image_id): |
self.netG.eval() |
add_noise = self.opt['addnoise'] |
add_jpeg = self.opt['addjpeg'] |
add_possion = self.opt['addpossion'] |
add_sdinpaint = self.opt['sdinpaint'] |
add_controlnet = self.opt['controlnetinpaint'] |
add_sdxl = self.opt['sdxl'] |
add_repaint = self.opt['repaint'] |
degrade_shuffle = self.opt['degrade_shuffle'] |
with torch.no_grad(): |
forw_L = [] |
forw_L_h = [] |
fake_H = [] |
fake_H_h = [] |
pred_z = [] |
recmsglist = [] |
msglist = [] |
b, t, c, h, w = self.real_H.shape |
center = t // 2 |
intval = self.gop // 2 |
b, n, t, c, h, w = self.ref_L.shape |
id=0 |
self.host = self.real_H[:, center - intval+id:center + intval + 1+id] |
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] |
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] |
messagenp = np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length'])) |
message = torch.Tensor(messagenp).to(self.device) |
if self.opt['bitrecord']: |
mymsg = message.clone() |
mymsg[mymsg>0] = 1 |
mymsg[mymsg<0] = 0 |
mymsg = mymsg.squeeze(0).to(torch.int) |
bit_list = mymsg.tolist() |
bit_string = ''.join(map(str, bit_list)) |
file_name = "bit_sequence.txt" |
with open(file_name, "a") as file: |
file.write(bit_string + "\n") |
if self.opt['hide']: |
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message) |
y_forw = container |
else: |
message = torch.tensor(self.msg_list[image_id]).unsqueeze(0).cuda() |
self.output = self.host |
y_forw = self.output.squeeze(1) |
if add_sdinpaint: |
import random |
from PIL import Image |
prompt = "" |
b, _, _, _ = y_forw.shape |
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
forw_list = [] |
for j in range(b): |
i = image_id + 1 |
masksrc = "../dataset/valAGE-Set-Mask/" |
mask_image = Image.open(masksrc + str(i).zfill(4) + ".png").convert("L") |
mask_image = mask_image.resize((512, 512)) |
h, w = mask_image.size |
image = image_batch[j, :, :, :] |
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") |
image_inpaint = self.pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0] |
image_inpaint = np.array(image_inpaint) / 255. |
mask_image = np.array(mask_image) |
mask_image = np.stack([mask_image] * 3, axis=-1) / 255. |
mask_image = mask_image.astype(np.uint8) |
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image |
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
if add_controlnet: |
from diffusers.utils import load_image |
from PIL import Image |
b, _, _, _ = y_forw.shape |
forw_list = [] |
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
generator = torch.Generator(device="cuda").manual_seed(1) |
for j in range(b): |
i = image_id + 1 |
mask_path = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png" |
mask_image = load_image(mask_path) |
mask_image = mask_image.resize((512, 512)) |
image_init = image_batch[j, :, :, :] |
image_init1 = Image.fromarray((image_init * 255).astype(np.uint8), mode = "RGB") |
image_mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 |
assert image_init.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" |
image_init[image_mask > 0.5] = -1.0 |
image = np.expand_dims(image_init, 0).transpose(0, 3, 1, 2) |
control_image = torch.from_numpy(image) |
image_inpaint = self.pipe_control( |
"", |
num_inference_steps=20, |
generator=generator, |
eta=1.0, |
image=image_init1, |
mask_image=image_mask, |
control_image=control_image, |
).images[0] |
image_inpaint = np.array(image_inpaint) / 255. |
image_mask = np.stack([image_mask] * 3, axis=-1) |
image_mask = image_mask.astype(np.uint8) |
image_fuse = image_init * (1 - image_mask) + image_inpaint * image_mask |
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
if add_sdxl: |
import random |
from PIL import Image |
from diffusers.utils import load_image |
prompt = "" |
b, _, _, _ = y_forw.shape |
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
forw_list = [] |
for j in range(b): |
i = image_id + 1 |
masksrc = "../dataset/valAGE-Set-Mask/" |
mask_image = load_image(masksrc + str(i).zfill(4) + ".png").convert("RGB") |
mask_image = mask_image.resize((512, 512)) |
h, w = mask_image.size |
image = image_batch[j, :, :, :] |
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") |
image_inpaint = self.pipe_sdxl( |
prompt=prompt, image=image_init, mask_image=mask_image, num_inference_steps=50, strength=0.80, target_size=(512, 512) |
).images[0] |
image_inpaint = image_inpaint.resize((512, 512)) |
image_inpaint = np.array(image_inpaint) / 255. |
mask_image = np.array(mask_image) / 255. |
mask_image = mask_image.astype(np.uint8) |
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image |
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
if add_repaint: |
from PIL import Image |
b, _, _, _ = y_forw.shape |
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() |
forw_list = [] |
generator = torch.Generator(device="cuda").manual_seed(0) |
for j in range(b): |
i = image_id + 1 |
masksrc = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png" |
mask_image = Image.open(masksrc).convert("RGB") |
mask_image = mask_image.resize((256, 256)) |
mask_image = Image.fromarray(255 - np.array(mask_image)) |
image = image_batch[j, :, :, :] |
original_image = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") |
original_image = original_image.resize((256, 256)) |
output = self.pipe_repaint( |
image=original_image, |
mask_image=mask_image, |
num_inference_steps=150, |
eta=0.0, |
jump_length=10, |
jump_n_sample=10, |
generator=generator, |
) |
image_inpaint = output.images[0] |
image_inpaint = image_inpaint.resize((512, 512)) |
image_inpaint = np.array(image_inpaint) / 255. |
mask_image = mask_image.resize((512, 512)) |
mask_image = np.array(mask_image) / 255. |
mask_image = mask_image.astype(np.uint8) |
image_fuse = image * mask_image + image_inpaint * (1 - mask_image) |
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) |
y_forw = torch.stack(forw_list, dim=0).float().cuda() |
if degrade_shuffle: |
import random |
choice = random.randint(0, 2) |
if choice == 0: |
NL = float((np.random.randint(1,5))/255) |
noise = np.random.normal(0, NL, y_forw.shape) |
torchnoise = torch.from_numpy(noise).cuda().float() |
y_forw = y_forw + torchnoise |
elif choice == 1: |
NL = 90 |
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() |
y_forw = self.DiffJPEG(y_forw) |
elif choice == 2: |
vals = 10**4 |
if random.random() < 0.5: |
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
else: |
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
else: |
if add_noise: |
NL = self.opt['noisesigma'] / 255.0 |
noise = np.random.normal(0, NL, y_forw.shape) |
torchnoise = torch.from_numpy(noise).cuda().float() |
y_forw = y_forw + torchnoise |
elif add_jpeg: |
Q = self.opt['jpegfactor'] |
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(Q)).cuda() |
y_forw = self.DiffJPEG(y_forw) |
elif add_possion: |
vals = 10**4 |
if random.random() < 0.5: |
noisy_img_tensor = torch.poisson(y_forw * vals) / vals |
else: |
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) |
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals |
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) |
y_forw = torch.clamp(noisy_img_tensor, 0, 1) |
if self.opt['hide']: |
y = self.Quantization(y_forw) |
else: |
y = y_forw |
if self.mode == "image": |
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True) |
out_x = iwt(out_x) |
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] |
out_x = out_x.reshape(-1, self.gop, 3, h, w) |
out_x_h = torch.stack(out_x_h, dim=1) |
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w) |
forw_L.append(y_forw) |
fake_H.append(out_x[:, self.gop//2]) |
fake_H_h.append(out_x_h[:,:, self.gop//2]) |
recmsglist.append(recmessage) |
msglist.append(message) |
elif self.mode == "bit": |
recmessage = self.netG(x=y, rev=True) |
forw_L.append(y_forw) |
recmsglist.append(recmessage) |
msglist.append(message) |
if self.mode == "image": |
self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1) |
self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1) |
self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1) |
remesg = torch.clamp(torch.stack(recmsglist, dim=0),-0.5,0.5) |
if self.opt['hide']: |
mesg = torch.clamp(torch.stack(msglist, dim=0),-0.5,0.5) |
else: |
mesg = torch.stack(msglist, dim=0) |
self.recmessage = remesg.clone() |
self.recmessage[remesg > 0] = 1 |
self.recmessage[remesg <= 0] = 0 |
self.message = mesg.clone() |
self.message[mesg > 0] = 1 |
self.message[mesg <= 0] = 0 |
self.netG.train() |
def image_hiding(self, ): |
self.netG.eval() |
with torch.no_grad(): |
b, t, c, h, w = self.real_H.shape |
center = t // 2 |
intval = self.gop // 2 |
b, n, t, c, h, w = self.ref_L.shape |
id=0 |
self.host = self.real_H[:, center - intval+id:center + intval + 1+id] |
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] |
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] |
message = torch.Tensor(self.mes).to(self.device) |
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message) |
y_forw = container |
result = torch.clamp(y_forw,0,1) |
lr_img = util.tensor2img(result) |
return lr_img |
def image_recovery(self, number): |
self.netG.eval() |
with torch.no_grad(): |
b, t, c, h, w = self.real_H.shape |
center = t // 2 |
intval = self.gop // 2 |
b, n, t, c, h, w = self.ref_L.shape |
id=0 |
self.host = self.real_H[:, center - intval+id:center + intval + 1+id] |
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] |
template = self.secret.reshape(b, -1, h, w) |
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] |
self.output = self.host |
y_forw = self.output.squeeze(1) |
y = self.Quantization(y_forw) |
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True) |
out_x = iwt(out_x) |
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] |
out_x = out_x.reshape(-1, self.gop, 3, h, w) |
out_x_h = torch.stack(out_x_h, dim=1) |
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w) |
rec_loc = out_x_h[:,:, self.gop//2] |
residual = torch.abs(template - rec_loc) |
binary_residual = (residual > number).float() |
residual = util.tensor2img(binary_residual) |
mask = np.sum(residual, axis=2) |
remesg = torch.clamp(recmessage,-0.5,0.5) |
remesg[remesg > 0] = 1 |
remesg[remesg <= 0] = 0 |
return mask, remesg |
def get_current_log(self): |
return self.log_dict |
def get_current_visuals(self): |
b, n, t, c, h, w = self.ref_L.shape |
center = t // 2 |
intval = self.gop // 2 |
out_dict = OrderedDict() |
LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu() |
LR_ref = torch.chunk(LR_ref, self.num_image, dim=0) |
out_dict['LR_ref'] = [image.squeeze(0) for image in LR_ref] |
if self.mode == "image": |
out_dict['SR'] = self.fake_H.detach()[0].float().cpu() |
SR_h = self.fake_H_h.detach()[0].float().cpu() |
SR_h = torch.chunk(SR_h, self.num_image, dim=0) |
out_dict['SR_h'] = [image.squeeze(0) for image in SR_h] |
out_dict['LR'] = self.forw_L.detach()[0].float().cpu() |
out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu() |
out_dict['message'] = self.message |
out_dict['recmessage'] = self.recmessage |
return out_dict |
def print_network(self): |
s, n = self.get_network_description(self.netG) |
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): |
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, |
self.netG.module.__class__.__name__) |
else: |
net_struc_str = '{}'.format(self.netG.__class__.__name__) |
if self.rank <= 0: |
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) |
logger.info(s) |
def load(self): |
load_path_G = self.opt['path']['pretrain_model_G'] |
if load_path_G is not None: |
logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) |
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) |
def load_test(self,load_path_G): |
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) |
def save(self, iter_label): |
self.save_network(self.netG, 'G', iter_label) |