'folder'
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __pycache__/test_gradio.cpython-310.pyc +0 -0
- __pycache__/test_gradio.cpython-38.pyc +0 -0
- data/__init__.py +43 -0
- data/__pycache__/__init__.cpython-310.pyc +0 -0
- data/__pycache__/__init__.cpython-38.pyc +0 -0
- data/__pycache__/coco_dataset.cpython-38.pyc +0 -0
- data/__pycache__/coco_test_dataset.cpython-38.pyc +0 -0
- data/__pycache__/data_sampler.cpython-310.pyc +0 -0
- data/__pycache__/data_sampler.cpython-38.pyc +0 -0
- data/__pycache__/test_dataset_td.cpython-310.pyc +0 -0
- data/__pycache__/test_dataset_td.cpython-38.pyc +0 -0
- data/__pycache__/util.cpython-310.pyc +0 -0
- data/__pycache__/util.cpython-38.pyc +0 -0
- data/__pycache__/video_test_dataset.cpython-38.pyc +0 -0
- data/coco_dataset.py +90 -0
- data/coco_test_dataset.py +61 -0
- data/data_sampler.py +65 -0
- data/test_dataset_td.py +63 -0
- data/util.py +551 -0
- models/IBSN.py +738 -0
- models/__init__.py +11 -0
- models/__pycache__/IBSN.cpython-310.pyc +0 -0
- models/__pycache__/IBSN.cpython-38.pyc +0 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/base_model.cpython-38.pyc +0 -0
- models/__pycache__/lr_scheduler.cpython-38.pyc +0 -0
- models/__pycache__/networks.cpython-310.pyc +0 -0
- models/__pycache__/networks.cpython-38.pyc +0 -0
- models/base_model.py +119 -0
- models/bitnetwork/ConvBlock.py +38 -0
- models/bitnetwork/DW_EncoderDecoder.py +28 -0
- models/bitnetwork/Decoder_U.py +87 -0
- models/bitnetwork/Dual_Mark.py +249 -0
- models/bitnetwork/Encoder_U.py +125 -0
- models/bitnetwork/Random_Noise.py +59 -0
- models/bitnetwork/ResBlock.py +222 -0
- models/bitnetwork/__init__.py +9 -0
- models/bitnetwork/__pycache__/ConvBlock.cpython-38.pyc +0 -0
- models/bitnetwork/__pycache__/Decoder_U.cpython-38.pyc +0 -0
- models/bitnetwork/__pycache__/Encoder_U.cpython-38.pyc +0 -0
- models/bitnetwork/__pycache__/ResBlock.cpython-38.pyc +0 -0
- models/bitnetwork/__pycache__/__init__.cpython-38.pyc +0 -0
- models/discrim.py +169 -0
- models/lr_scheduler.py +142 -0
- models/modules/Inv_arch.py +584 -0
- models/modules/Quantization.py +21 -0
- models/modules/Subnet_constructor.py +79 -0
- models/modules/__init__.py +0 -0
- models/modules/__pycache__/Conv1x1.cpython-38.pyc +0 -0
__pycache__/test_gradio.cpython-310.pyc
ADDED
Binary file (2.85 kB). View file
|
|
__pycache__/test_gradio.cpython-38.pyc
ADDED
Binary file (2.83 kB). View file
|
|
data/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''create dataset and dataloader'''
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import torch.utils.data
|
5 |
+
|
6 |
+
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
7 |
+
phase = dataset_opt['phase']
|
8 |
+
if phase == 'train':
|
9 |
+
if opt['dist']:
|
10 |
+
world_size = torch.distributed.get_world_size()
|
11 |
+
num_workers = dataset_opt['n_workers']
|
12 |
+
assert dataset_opt['batch_size'] % world_size == 0
|
13 |
+
batch_size = dataset_opt['batch_size'] // world_size
|
14 |
+
shuffle = False
|
15 |
+
else:
|
16 |
+
num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
|
17 |
+
batch_size = dataset_opt['batch_size']
|
18 |
+
shuffle = True
|
19 |
+
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
20 |
+
num_workers=num_workers, sampler=sampler, drop_last=True,
|
21 |
+
pin_memory=False)
|
22 |
+
else:
|
23 |
+
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
|
24 |
+
pin_memory=True)
|
25 |
+
|
26 |
+
|
27 |
+
def create_dataset(dataset_opt):
|
28 |
+
mode = dataset_opt['mode']
|
29 |
+
if mode == 'test':
|
30 |
+
from data.coco_test_dataset import imageTestDataset as D
|
31 |
+
elif mode == 'train':
|
32 |
+
from data.coco_dataset import CoCoDataset as D
|
33 |
+
elif mode == 'td':
|
34 |
+
from data.test_dataset_td import imageTestDataset as D
|
35 |
+
else:
|
36 |
+
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
37 |
+
print(mode)
|
38 |
+
dataset = D(dataset_opt)
|
39 |
+
|
40 |
+
logger = logging.getLogger('base')
|
41 |
+
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
|
42 |
+
dataset_opt['name']))
|
43 |
+
return dataset
|
data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.42 kB). View file
|
|
data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.45 kB). View file
|
|
data/__pycache__/coco_dataset.cpython-38.pyc
ADDED
Binary file (3.3 kB). View file
|
|
data/__pycache__/coco_test_dataset.cpython-38.pyc
ADDED
Binary file (2.3 kB). View file
|
|
data/__pycache__/data_sampler.cpython-310.pyc
ADDED
Binary file (2.6 kB). View file
|
|
data/__pycache__/data_sampler.cpython-38.pyc
ADDED
Binary file (2.63 kB). View file
|
|
data/__pycache__/test_dataset_td.cpython-310.pyc
ADDED
Binary file (2.32 kB). View file
|
|
data/__pycache__/test_dataset_td.cpython-38.pyc
ADDED
Binary file (2.29 kB). View file
|
|
data/__pycache__/util.cpython-310.pyc
ADDED
Binary file (13.7 kB). View file
|
|
data/__pycache__/util.cpython-38.pyc
ADDED
Binary file (14.2 kB). View file
|
|
data/__pycache__/video_test_dataset.cpython-38.pyc
ADDED
Binary file (2.48 kB). View file
|
|
data/coco_dataset.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Vimeo90K dataset
|
3 |
+
support reading images from lmdb, image folder and memcached
|
4 |
+
'''
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import os.path as osp
|
8 |
+
import pickle
|
9 |
+
import random
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import lmdb
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.utils.data as data
|
16 |
+
|
17 |
+
import data.util as util
|
18 |
+
|
19 |
+
try:
|
20 |
+
import mc
|
21 |
+
except ImportError:
|
22 |
+
pass
|
23 |
+
logger = logging.getLogger('base')
|
24 |
+
|
25 |
+
class CoCoDataset(data.Dataset):
|
26 |
+
def __init__(self, opt):
|
27 |
+
super(CoCoDataset, self).__init__()
|
28 |
+
self.opt = opt
|
29 |
+
# get train indexes
|
30 |
+
self.data_path = self.opt['data_path']
|
31 |
+
self.txt_path = self.opt['txt_path']
|
32 |
+
with open(self.txt_path) as f:
|
33 |
+
self.list_image = f.readlines()
|
34 |
+
self.list_image = [line.strip('\n') for line in self.list_image]
|
35 |
+
# temporal augmentation
|
36 |
+
self.interval_list = opt['interval_list']
|
37 |
+
self.random_reverse = opt['random_reverse']
|
38 |
+
logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
|
39 |
+
','.join(str(x) for x in opt['interval_list']), self.random_reverse))
|
40 |
+
self.data_type = self.opt['data_type']
|
41 |
+
random.shuffle(self.list_image)
|
42 |
+
self.LR_input = True
|
43 |
+
self.num_image = self.opt['num_image']
|
44 |
+
|
45 |
+
def _ensure_memcached(self):
|
46 |
+
if self.mclient is None:
|
47 |
+
# specify the config files
|
48 |
+
server_list_config_file = None
|
49 |
+
client_config_file = None
|
50 |
+
self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
|
51 |
+
client_config_file)
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
GT_size = self.opt['GT_size']
|
55 |
+
image_name = self.list_image[index]
|
56 |
+
path_frame = os.path.join(self.data_path, image_name)
|
57 |
+
img_GT = util.read_img(None, osp.join(path_frame, path_frame))
|
58 |
+
index_h = random.randint(0, len(self.list_image) - 1)
|
59 |
+
|
60 |
+
# random crop
|
61 |
+
H, W, C = img_GT.shape
|
62 |
+
rnd_h = random.randint(0, max(0, H - GT_size))
|
63 |
+
rnd_w = random.randint(0, max(0, W - GT_size))
|
64 |
+
img_frames = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
|
65 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
66 |
+
img_frames = img_frames[:, :, [2, 1, 0]]
|
67 |
+
img_frames = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames, (2, 0, 1)))).float().unsqueeze(0)
|
68 |
+
|
69 |
+
# process h_list
|
70 |
+
if index_h % 100 == 0:
|
71 |
+
path_frame_h = "../dataset/locwatermark/blue.png"
|
72 |
+
else:
|
73 |
+
image_name_h = self.list_image[index_h]
|
74 |
+
path_frame_h = os.path.join(self.data_path, image_name_h)
|
75 |
+
|
76 |
+
frame_h = util.read_img(None, osp.join(path_frame_h, path_frame_h))
|
77 |
+
H1, W1, C1 = frame_h.shape
|
78 |
+
rnd_h = random.randint(0, max(0, H1 - GT_size))
|
79 |
+
rnd_w = random.randint(0, max(0, W1 - GT_size))
|
80 |
+
img_frames_h = frame_h[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
|
81 |
+
img_frames_h = img_frames_h[:, :, [2, 1, 0]]
|
82 |
+
img_frames_h = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames_h, (2, 0, 1)))).float().unsqueeze(0)
|
83 |
+
|
84 |
+
img_frames_h = torch.nn.functional.interpolate(img_frames_h, size=(512, 512), mode='nearest', align_corners=None).unsqueeze(0)
|
85 |
+
img_frames = torch.nn.functional.interpolate(img_frames, size=(512, 512), mode='nearest', align_corners=None)
|
86 |
+
|
87 |
+
return {'GT': img_frames, 'LQ': img_frames_h}
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.list_image)
|
data/coco_test_dataset.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import torch
|
4 |
+
import torch.utils.data as data
|
5 |
+
import data.util as util
|
6 |
+
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
class imageTestDataset(data.Dataset):
|
12 |
+
|
13 |
+
def __init__(self, opt):
|
14 |
+
super(imageTestDataset, self).__init__()
|
15 |
+
self.opt = opt
|
16 |
+
self.half_N_frames = opt['N_frames'] // 2
|
17 |
+
self.data_path = opt['data_path']
|
18 |
+
self.bit_path = opt['bit_path']
|
19 |
+
self.txt_path = self.opt['txt_path']
|
20 |
+
self.num_image = self.opt['num_image']
|
21 |
+
with open(self.txt_path) as f:
|
22 |
+
self.list_image = f.readlines()
|
23 |
+
self.list_image = [line.strip('\n') for line in self.list_image]
|
24 |
+
self.list_image.sort()
|
25 |
+
self.list_image = self.list_image
|
26 |
+
l = len(self.list_image) // (self.num_image + 1)
|
27 |
+
self.image_list_gt = self.list_image
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
path_GT = self.image_list_gt[index]
|
31 |
+
|
32 |
+
img_GT = util.read_img(None, osp.join(self.data_path, path_GT))
|
33 |
+
img_GT = img_GT[:, :, [2, 1, 0]]
|
34 |
+
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
|
35 |
+
img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
|
36 |
+
|
37 |
+
T, C, W, H = img_GT.shape
|
38 |
+
list_h = []
|
39 |
+
R = 0
|
40 |
+
G = 0
|
41 |
+
B = 255
|
42 |
+
image = Image.new('RGB', (W, H), (R, G, B))
|
43 |
+
result = np.array(image) / 255.
|
44 |
+
expanded_matrix = np.expand_dims(result, axis=0)
|
45 |
+
expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
|
46 |
+
imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
|
47 |
+
imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
|
48 |
+
|
49 |
+
imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
|
50 |
+
|
51 |
+
list_h.append(imgs_LQ)
|
52 |
+
|
53 |
+
list_h = torch.stack(list_h, dim=0)
|
54 |
+
|
55 |
+
return {
|
56 |
+
'LQ': list_h,
|
57 |
+
'GT': img_GT
|
58 |
+
}
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.image_list_gt)
|
data/data_sampler.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
3 |
+
Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
|
4 |
+
dataloader after each epoch
|
5 |
+
"""
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
from torch.utils.data.sampler import Sampler
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
|
12 |
+
class DistIterSampler(Sampler):
|
13 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
14 |
+
|
15 |
+
It is especially useful in conjunction with
|
16 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
17 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
18 |
+
and load a subset of the original dataset that is exclusive to it.
|
19 |
+
|
20 |
+
.. note::
|
21 |
+
Dataset is assumed to be of constant size.
|
22 |
+
|
23 |
+
Arguments:
|
24 |
+
dataset: Dataset used for sampling.
|
25 |
+
num_replicas (optional): Number of processes participating in
|
26 |
+
distributed training.
|
27 |
+
rank (optional): Rank of the current process within num_replicas.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
|
31 |
+
if num_replicas is None:
|
32 |
+
if not dist.is_available():
|
33 |
+
raise RuntimeError("Requires distributed package to be available")
|
34 |
+
num_replicas = dist.get_world_size()
|
35 |
+
if rank is None:
|
36 |
+
if not dist.is_available():
|
37 |
+
raise RuntimeError("Requires distributed package to be available")
|
38 |
+
rank = dist.get_rank()
|
39 |
+
self.dataset = dataset
|
40 |
+
self.num_replicas = num_replicas
|
41 |
+
self.rank = rank
|
42 |
+
self.epoch = 0
|
43 |
+
self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
|
44 |
+
self.total_size = self.num_samples * self.num_replicas
|
45 |
+
|
46 |
+
def __iter__(self):
|
47 |
+
# deterministically shuffle based on epoch
|
48 |
+
g = torch.Generator()
|
49 |
+
g.manual_seed(self.epoch)
|
50 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
51 |
+
|
52 |
+
dsize = len(self.dataset)
|
53 |
+
indices = [v % dsize for v in indices]
|
54 |
+
|
55 |
+
# subsample
|
56 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
57 |
+
assert len(indices) == self.num_samples
|
58 |
+
|
59 |
+
return iter(indices)
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return self.num_samples
|
63 |
+
|
64 |
+
def set_epoch(self, epoch):
|
65 |
+
self.epoch = epoch
|
data/test_dataset_td.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import torch
|
4 |
+
import torch.utils.data as data
|
5 |
+
import data.util as util
|
6 |
+
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
class imageTestDataset(data.Dataset):
|
12 |
+
|
13 |
+
def __init__(self, opt):
|
14 |
+
super(imageTestDataset, self).__init__()
|
15 |
+
self.opt = opt
|
16 |
+
self.half_N_frames = opt['N_frames'] // 2
|
17 |
+
self.data_path = opt['data_path']
|
18 |
+
self.bit_path = opt['bit_path']
|
19 |
+
self.txt_path = self.opt['txt_path']
|
20 |
+
self.num_image = self.opt['num_image']
|
21 |
+
with open(self.txt_path) as f:
|
22 |
+
self.list_image = f.readlines()
|
23 |
+
self.list_image = [line.strip('\n') for line in self.list_image]
|
24 |
+
# self.list_image = sorted(self.list_image)
|
25 |
+
l = len(self.list_image) // (self.num_image + 1)
|
26 |
+
self.image_list_gt = self.list_image
|
27 |
+
self.image_list_bit = self.list_image
|
28 |
+
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
path_GT = self.image_list_gt[index]
|
32 |
+
|
33 |
+
img_GT = util.read_img(None, osp.join(self.data_path, path_GT))
|
34 |
+
img_GT = img_GT[:, :, [2, 1, 0]]
|
35 |
+
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
|
36 |
+
img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
|
37 |
+
|
38 |
+
T, C, W, H = img_GT.shape
|
39 |
+
list_h = []
|
40 |
+
R = 0
|
41 |
+
G = 0
|
42 |
+
B = 255
|
43 |
+
image = Image.new('RGB', (W, H), (R, G, B))
|
44 |
+
result = np.array(image) / 255.
|
45 |
+
expanded_matrix = np.expand_dims(result, axis=0)
|
46 |
+
expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
|
47 |
+
imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
|
48 |
+
imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
|
49 |
+
|
50 |
+
|
51 |
+
imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
|
52 |
+
|
53 |
+
list_h.append(imgs_LQ)
|
54 |
+
|
55 |
+
list_h = torch.stack(list_h, dim=0)
|
56 |
+
|
57 |
+
return {
|
58 |
+
'LQ': list_h,
|
59 |
+
'GT': img_GT
|
60 |
+
}
|
61 |
+
|
62 |
+
def __len__(self):
|
63 |
+
return len(self.image_list_gt)
|
data/util.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import pickle
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
import glob
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
####################
|
11 |
+
# Files & IO
|
12 |
+
####################
|
13 |
+
|
14 |
+
###################### get image path list ######################
|
15 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
|
16 |
+
|
17 |
+
|
18 |
+
def is_image_file(filename):
|
19 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
20 |
+
|
21 |
+
|
22 |
+
def _get_paths_from_images(path):
|
23 |
+
'''get image path list from image folder'''
|
24 |
+
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
25 |
+
images = []
|
26 |
+
for dirpath, _, fnames in sorted(os.walk(path)):
|
27 |
+
for fname in sorted(fnames):
|
28 |
+
if is_image_file(fname):
|
29 |
+
img_path = os.path.join(dirpath, fname)
|
30 |
+
images.append(img_path)
|
31 |
+
assert images, '{:s} has no valid image file'.format(path)
|
32 |
+
return images
|
33 |
+
|
34 |
+
|
35 |
+
def _get_paths_from_lmdb(dataroot):
|
36 |
+
'''get image path list from lmdb meta info'''
|
37 |
+
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
|
38 |
+
paths = meta_info['keys']
|
39 |
+
sizes = meta_info['resolution']
|
40 |
+
if len(sizes) == 1:
|
41 |
+
sizes = sizes * len(paths)
|
42 |
+
return paths, sizes
|
43 |
+
|
44 |
+
|
45 |
+
def get_image_paths(data_type, dataroot):
|
46 |
+
'''get image path list
|
47 |
+
support lmdb or image files'''
|
48 |
+
paths, sizes = None, None
|
49 |
+
if dataroot is not None:
|
50 |
+
if data_type == 'lmdb':
|
51 |
+
paths, sizes = _get_paths_from_lmdb(dataroot)
|
52 |
+
elif data_type == 'img':
|
53 |
+
paths = sorted(_get_paths_from_images(dataroot))
|
54 |
+
else:
|
55 |
+
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
|
56 |
+
return paths, sizes
|
57 |
+
|
58 |
+
|
59 |
+
def glob_file_list(root):
|
60 |
+
return sorted(glob.glob(os.path.join(root, '*')))
|
61 |
+
|
62 |
+
|
63 |
+
###################### read images ######################
|
64 |
+
def _read_img_lmdb(env, key, size):
|
65 |
+
'''read image from lmdb with key (w/ and w/o fixed size)
|
66 |
+
size: (C, H, W) tuple'''
|
67 |
+
with env.begin(write=False) as txn:
|
68 |
+
buf = txn.get(key.encode('ascii'))
|
69 |
+
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
70 |
+
C, H, W = size
|
71 |
+
img = img_flat.reshape(H, W, C)
|
72 |
+
return img
|
73 |
+
|
74 |
+
|
75 |
+
def read_img(env, path, size=None):
|
76 |
+
'''read image by cv2 or from lmdb
|
77 |
+
return: Numpy float32, HWC, BGR, [0,1]'''
|
78 |
+
if env is None: # img
|
79 |
+
# print(path)
|
80 |
+
#img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
81 |
+
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
82 |
+
else:
|
83 |
+
img = _read_img_lmdb(env, path, size)
|
84 |
+
# print(img.shape)
|
85 |
+
# if img is None:
|
86 |
+
# print(path)
|
87 |
+
# print(img.shape)
|
88 |
+
img = img.astype(np.float32) / 255.
|
89 |
+
if img.ndim == 2:
|
90 |
+
img = np.expand_dims(img, axis=2)
|
91 |
+
# some images have 4 channels
|
92 |
+
if img.shape[2] > 3:
|
93 |
+
img = img[:, :, :3]
|
94 |
+
return img
|
95 |
+
|
96 |
+
|
97 |
+
def read_img_seq(path):
|
98 |
+
"""Read a sequence of images from a given folder path
|
99 |
+
Args:
|
100 |
+
path (list/str): list of image paths/image folder path
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
|
104 |
+
"""
|
105 |
+
if type(path) is list:
|
106 |
+
img_path_l = path
|
107 |
+
else:
|
108 |
+
img_path_l = sorted(glob.glob(os.path.join(path, '*.png')))
|
109 |
+
# print(path)
|
110 |
+
# print(path,img_path_l)
|
111 |
+
img_l = [read_img(None, v) for v in img_path_l]
|
112 |
+
# stack to Torch tensor
|
113 |
+
imgs = np.stack(img_l, axis=0)
|
114 |
+
imgs = imgs[:, :, :, [2, 1, 0]]
|
115 |
+
imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
|
116 |
+
return imgs
|
117 |
+
|
118 |
+
|
119 |
+
def index_generation(crt_i, max_n, N, padding='reflection'):
|
120 |
+
"""Generate an index list for reading N frames from a sequence of images
|
121 |
+
Args:
|
122 |
+
crt_i (int): current center index
|
123 |
+
max_n (int): max number of the sequence of images (calculated from 1)
|
124 |
+
N (int): reading N frames
|
125 |
+
padding (str): padding mode, one of replicate | reflection | new_info | circle
|
126 |
+
Example: crt_i = 0, N = 5
|
127 |
+
replicate: [0, 0, 0, 1, 2]
|
128 |
+
reflection: [2, 1, 0, 1, 2]
|
129 |
+
new_info: [4, 3, 0, 1, 2]
|
130 |
+
circle: [3, 4, 0, 1, 2]
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
return_l (list [int]): a list of indexes
|
134 |
+
"""
|
135 |
+
max_n = max_n - 1
|
136 |
+
n_pad = N // 2
|
137 |
+
return_l = []
|
138 |
+
|
139 |
+
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
|
140 |
+
if i < 0:
|
141 |
+
if padding == 'replicate':
|
142 |
+
add_idx = 0
|
143 |
+
elif padding == 'reflection':
|
144 |
+
add_idx = -i
|
145 |
+
elif padding == 'new_info':
|
146 |
+
add_idx = (crt_i + n_pad) + (-i)
|
147 |
+
elif padding == 'circle':
|
148 |
+
add_idx = N + i
|
149 |
+
else:
|
150 |
+
raise ValueError('Wrong padding mode')
|
151 |
+
elif i > max_n:
|
152 |
+
if padding == 'replicate':
|
153 |
+
add_idx = max_n
|
154 |
+
elif padding == 'reflection':
|
155 |
+
add_idx = max_n * 2 - i
|
156 |
+
elif padding == 'new_info':
|
157 |
+
add_idx = (crt_i - n_pad) - (i - max_n)
|
158 |
+
elif padding == 'circle':
|
159 |
+
add_idx = i - N
|
160 |
+
else:
|
161 |
+
raise ValueError('Wrong padding mode')
|
162 |
+
else:
|
163 |
+
add_idx = i
|
164 |
+
return_l.append(add_idx)
|
165 |
+
return return_l
|
166 |
+
|
167 |
+
|
168 |
+
####################
|
169 |
+
# image processing
|
170 |
+
# process on numpy image
|
171 |
+
####################
|
172 |
+
|
173 |
+
|
174 |
+
def augment(img_list, hflip=True, rot=True):
|
175 |
+
# horizontal flip OR rotate
|
176 |
+
hflip = hflip and random.random() < 0.5
|
177 |
+
vflip = rot and random.random() < 0.5
|
178 |
+
rot90 = rot and random.random() < 0.5
|
179 |
+
|
180 |
+
def _augment(img):
|
181 |
+
if hflip:
|
182 |
+
img = img[:, ::-1, :]
|
183 |
+
if vflip:
|
184 |
+
img = img[::-1, :, :]
|
185 |
+
if rot90:
|
186 |
+
img = img.transpose(1, 0, 2)
|
187 |
+
return img
|
188 |
+
|
189 |
+
return [_augment(img) for img in img_list]
|
190 |
+
|
191 |
+
|
192 |
+
def augment_flow(img_list, flow_list, hflip=True, rot=True):
|
193 |
+
# horizontal flip OR rotate
|
194 |
+
hflip = hflip and random.random() < 0.5
|
195 |
+
vflip = rot and random.random() < 0.5
|
196 |
+
rot90 = rot and random.random() < 0.5
|
197 |
+
|
198 |
+
def _augment(img):
|
199 |
+
if hflip:
|
200 |
+
img = img[:, ::-1, :]
|
201 |
+
if vflip:
|
202 |
+
img = img[::-1, :, :]
|
203 |
+
if rot90:
|
204 |
+
img = img.transpose(1, 0, 2)
|
205 |
+
return img
|
206 |
+
|
207 |
+
def _augment_flow(flow):
|
208 |
+
if hflip:
|
209 |
+
flow = flow[:, ::-1, :]
|
210 |
+
flow[:, :, 0] *= -1
|
211 |
+
if vflip:
|
212 |
+
flow = flow[::-1, :, :]
|
213 |
+
flow[:, :, 1] *= -1
|
214 |
+
if rot90:
|
215 |
+
flow = flow.transpose(1, 0, 2)
|
216 |
+
flow = flow[:, :, [1, 0]]
|
217 |
+
return flow
|
218 |
+
|
219 |
+
rlt_img_list = [_augment(img) for img in img_list]
|
220 |
+
rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
|
221 |
+
|
222 |
+
return rlt_img_list, rlt_flow_list
|
223 |
+
|
224 |
+
|
225 |
+
def channel_convert(in_c, tar_type, img_list):
|
226 |
+
# conversion among BGR, gray and y
|
227 |
+
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
228 |
+
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
229 |
+
return [np.expand_dims(img, axis=2) for img in gray_list]
|
230 |
+
elif in_c == 3 and tar_type == 'y': # BGR to y
|
231 |
+
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
232 |
+
return [np.expand_dims(img, axis=2) for img in y_list]
|
233 |
+
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
234 |
+
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
235 |
+
else:
|
236 |
+
return img_list
|
237 |
+
|
238 |
+
|
239 |
+
def rgb2ycbcr(img, only_y=True):
|
240 |
+
'''same as matlab rgb2ycbcr
|
241 |
+
only_y: only return Y channel
|
242 |
+
Input:
|
243 |
+
uint8, [0, 255]
|
244 |
+
float, [0, 1]
|
245 |
+
'''
|
246 |
+
in_img_type = img.dtype
|
247 |
+
img.astype(np.float32)
|
248 |
+
if in_img_type != np.uint8:
|
249 |
+
img *= 255.
|
250 |
+
# convert
|
251 |
+
if only_y:
|
252 |
+
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
253 |
+
else:
|
254 |
+
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
255 |
+
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
256 |
+
if in_img_type == np.uint8:
|
257 |
+
rlt = rlt.round()
|
258 |
+
else:
|
259 |
+
rlt /= 255.
|
260 |
+
return rlt.astype(in_img_type)
|
261 |
+
|
262 |
+
|
263 |
+
def bgr2ycbcr(img, only_y=True):
|
264 |
+
'''bgr version of rgb2ycbcr
|
265 |
+
only_y: only return Y channel
|
266 |
+
Input:
|
267 |
+
uint8, [0, 255]
|
268 |
+
float, [0, 1]
|
269 |
+
'''
|
270 |
+
in_img_type = img.dtype
|
271 |
+
img.astype(np.float32)
|
272 |
+
if in_img_type != np.uint8:
|
273 |
+
img *= 255.
|
274 |
+
# convert
|
275 |
+
if only_y:
|
276 |
+
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
277 |
+
else:
|
278 |
+
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
279 |
+
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
280 |
+
if in_img_type == np.uint8:
|
281 |
+
rlt = rlt.round()
|
282 |
+
else:
|
283 |
+
rlt /= 255.
|
284 |
+
return rlt.astype(in_img_type)
|
285 |
+
|
286 |
+
|
287 |
+
def ycbcr2rgb(img):
|
288 |
+
'''same as matlab ycbcr2rgb
|
289 |
+
Input:
|
290 |
+
uint8, [0, 255]
|
291 |
+
float, [0, 1]
|
292 |
+
'''
|
293 |
+
in_img_type = img.dtype
|
294 |
+
img.astype(np.float32)
|
295 |
+
if in_img_type != np.uint8:
|
296 |
+
img *= 255.
|
297 |
+
# convert
|
298 |
+
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
299 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
300 |
+
if in_img_type == np.uint8:
|
301 |
+
rlt = rlt.round()
|
302 |
+
else:
|
303 |
+
rlt /= 255.
|
304 |
+
return rlt.astype(in_img_type)
|
305 |
+
|
306 |
+
|
307 |
+
def modcrop(img_in, scale):
|
308 |
+
# img_in: Numpy, HWC or HW
|
309 |
+
img = np.copy(img_in)
|
310 |
+
if img.ndim == 2:
|
311 |
+
H, W = img.shape
|
312 |
+
H_r, W_r = H % scale, W % scale
|
313 |
+
img = img[:H - H_r, :W - W_r]
|
314 |
+
elif img.ndim == 3:
|
315 |
+
H, W, C = img.shape
|
316 |
+
H_r, W_r = H % scale, W % scale
|
317 |
+
img = img[:H - H_r, :W - W_r, :]
|
318 |
+
else:
|
319 |
+
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
320 |
+
return img
|
321 |
+
|
322 |
+
|
323 |
+
####################
|
324 |
+
# Functions
|
325 |
+
####################
|
326 |
+
|
327 |
+
|
328 |
+
# matlab 'imresize' function, now only support 'bicubic'
|
329 |
+
def cubic(x):
|
330 |
+
absx = torch.abs(x)
|
331 |
+
absx2 = absx**2
|
332 |
+
absx3 = absx**3
|
333 |
+
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
334 |
+
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
|
335 |
+
(absx > 1) * (absx <= 2)).type_as(absx))
|
336 |
+
|
337 |
+
|
338 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
339 |
+
if (scale < 1) and (antialiasing):
|
340 |
+
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
341 |
+
kernel_width = kernel_width / scale
|
342 |
+
|
343 |
+
# Output-space coordinates
|
344 |
+
x = torch.linspace(1, out_length, out_length)
|
345 |
+
|
346 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
347 |
+
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
348 |
+
# space maps to 1.5 in input space.
|
349 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
350 |
+
|
351 |
+
# What is the left-most pixel that can be involved in the computation?
|
352 |
+
left = torch.floor(u - kernel_width / 2)
|
353 |
+
|
354 |
+
# What is the maximum number of pixels that can be involved in the
|
355 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
356 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
357 |
+
# of this function.
|
358 |
+
P = math.ceil(kernel_width) + 2
|
359 |
+
|
360 |
+
# The indices of the input pixels involved in computing the k-th output
|
361 |
+
# pixel are in row k of the indices matrix.
|
362 |
+
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
363 |
+
1, P).expand(out_length, P)
|
364 |
+
|
365 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
366 |
+
# weights matrix.
|
367 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
368 |
+
# apply cubic kernel
|
369 |
+
if (scale < 1) and (antialiasing):
|
370 |
+
weights = scale * cubic(distance_to_center * scale)
|
371 |
+
else:
|
372 |
+
weights = cubic(distance_to_center)
|
373 |
+
# Normalize the weights matrix so that each row sums to 1.
|
374 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
375 |
+
weights = weights / weights_sum.expand(out_length, P)
|
376 |
+
|
377 |
+
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
378 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
379 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
380 |
+
indices = indices.narrow(1, 1, P - 2)
|
381 |
+
weights = weights.narrow(1, 1, P - 2)
|
382 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
383 |
+
indices = indices.narrow(1, 0, P - 2)
|
384 |
+
weights = weights.narrow(1, 0, P - 2)
|
385 |
+
weights = weights.contiguous()
|
386 |
+
indices = indices.contiguous()
|
387 |
+
sym_len_s = -indices.min() + 1
|
388 |
+
sym_len_e = indices.max() - in_length
|
389 |
+
indices = indices + sym_len_s - 1
|
390 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
391 |
+
|
392 |
+
|
393 |
+
def imresize(img, scale, antialiasing=True):
|
394 |
+
# Now the scale should be the same for H and W
|
395 |
+
# input: img: CHW RGB [0,1]
|
396 |
+
# output: CHW RGB [0,1] w/o round
|
397 |
+
|
398 |
+
in_C, in_H, in_W = img.size()
|
399 |
+
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
400 |
+
kernel_width = 4
|
401 |
+
kernel = 'cubic'
|
402 |
+
|
403 |
+
# Return the desired dimension order for performing the resize. The
|
404 |
+
# strategy is to perform the resize first along the dimension with the
|
405 |
+
# smallest scale factor.
|
406 |
+
# Now we do not support this.
|
407 |
+
|
408 |
+
# get weights and indices
|
409 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
410 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
411 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
412 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
413 |
+
# process H dimension
|
414 |
+
# symmetric copying
|
415 |
+
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
416 |
+
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
417 |
+
|
418 |
+
sym_patch = img[:, :sym_len_Hs, :]
|
419 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
420 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
421 |
+
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
422 |
+
|
423 |
+
sym_patch = img[:, -sym_len_He:, :]
|
424 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
425 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
426 |
+
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
427 |
+
|
428 |
+
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
429 |
+
kernel_width = weights_H.size(1)
|
430 |
+
for i in range(out_H):
|
431 |
+
idx = int(indices_H[i][0])
|
432 |
+
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
433 |
+
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
434 |
+
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
435 |
+
|
436 |
+
# process W dimension
|
437 |
+
# symmetric copying
|
438 |
+
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
439 |
+
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
440 |
+
|
441 |
+
sym_patch = out_1[:, :, :sym_len_Ws]
|
442 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
443 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
444 |
+
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
445 |
+
|
446 |
+
sym_patch = out_1[:, :, -sym_len_We:]
|
447 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
448 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
449 |
+
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
450 |
+
|
451 |
+
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
452 |
+
kernel_width = weights_W.size(1)
|
453 |
+
for i in range(out_W):
|
454 |
+
idx = int(indices_W[i][0])
|
455 |
+
out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
|
456 |
+
out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
|
457 |
+
out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
|
458 |
+
|
459 |
+
return out_2
|
460 |
+
|
461 |
+
|
462 |
+
def imresize_np(img, scale, antialiasing=True):
|
463 |
+
# Now the scale should be the same for H and W
|
464 |
+
# input: img: Numpy, HWC BGR [0,1]
|
465 |
+
# output: HWC BGR [0,1] w/o round
|
466 |
+
img = torch.from_numpy(img)
|
467 |
+
|
468 |
+
in_H, in_W, in_C = img.size()
|
469 |
+
_, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
470 |
+
kernel_width = 4
|
471 |
+
kernel = 'cubic'
|
472 |
+
|
473 |
+
# Return the desired dimension order for performing the resize. The
|
474 |
+
# strategy is to perform the resize first along the dimension with the
|
475 |
+
# smallest scale factor.
|
476 |
+
# Now we do not support this.
|
477 |
+
|
478 |
+
# get weights and indices
|
479 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
480 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
481 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
482 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
483 |
+
# process H dimension
|
484 |
+
# symmetric copying
|
485 |
+
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
486 |
+
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
487 |
+
|
488 |
+
sym_patch = img[:sym_len_Hs, :, :]
|
489 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
490 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
491 |
+
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
492 |
+
|
493 |
+
sym_patch = img[-sym_len_He:, :, :]
|
494 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
495 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
496 |
+
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
497 |
+
|
498 |
+
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
499 |
+
kernel_width = weights_H.size(1)
|
500 |
+
for i in range(out_H):
|
501 |
+
idx = int(indices_H[i][0])
|
502 |
+
out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
|
503 |
+
out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
|
504 |
+
out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
|
505 |
+
|
506 |
+
# process W dimension
|
507 |
+
# symmetric copying
|
508 |
+
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
509 |
+
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
510 |
+
|
511 |
+
sym_patch = out_1[:, :sym_len_Ws, :]
|
512 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
513 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
514 |
+
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
515 |
+
|
516 |
+
sym_patch = out_1[:, -sym_len_We:, :]
|
517 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
518 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
519 |
+
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
520 |
+
|
521 |
+
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
522 |
+
kernel_width = weights_W.size(1)
|
523 |
+
for i in range(out_W):
|
524 |
+
idx = int(indices_W[i][0])
|
525 |
+
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
|
526 |
+
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
|
527 |
+
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
|
528 |
+
|
529 |
+
return out_2.numpy()
|
530 |
+
|
531 |
+
|
532 |
+
if __name__ == '__main__':
|
533 |
+
# test imresize function
|
534 |
+
# read images
|
535 |
+
img = cv2.imread('test.png')
|
536 |
+
img = img * 1.0 / 255
|
537 |
+
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
538 |
+
# imresize
|
539 |
+
scale = 1 / 4
|
540 |
+
import time
|
541 |
+
total_time = 0
|
542 |
+
for i in range(10):
|
543 |
+
start_time = time.time()
|
544 |
+
rlt = imresize(img, scale, antialiasing=True)
|
545 |
+
use_time = time.time() - start_time
|
546 |
+
total_time += use_time
|
547 |
+
print('average time: {}'.format(total_time / 10))
|
548 |
+
|
549 |
+
import torchvision.utils
|
550 |
+
torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
|
551 |
+
normalize=False)
|
models/IBSN.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
7 |
+
|
8 |
+
import models.networks as networks
|
9 |
+
import models.lr_scheduler as lr_scheduler
|
10 |
+
from .base_model import BaseModel
|
11 |
+
from models.modules.loss import ReconstructionLoss, ReconstructionMsgLoss
|
12 |
+
from models.modules.Quantization import Quantization
|
13 |
+
from .modules.common import DWT,IWT
|
14 |
+
from utils.jpegtest import JpegTest
|
15 |
+
from utils.JPEG import DiffJPEG
|
16 |
+
import utils.util as util
|
17 |
+
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import random
|
21 |
+
import cv2
|
22 |
+
import time
|
23 |
+
|
24 |
+
logger = logging.getLogger('base')
|
25 |
+
dwt=DWT()
|
26 |
+
iwt=IWT()
|
27 |
+
|
28 |
+
from diffusers import StableDiffusionInpaintPipeline
|
29 |
+
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
|
30 |
+
from diffusers import StableDiffusionXLInpaintPipeline
|
31 |
+
from diffusers.utils import load_image
|
32 |
+
from diffusers import RePaintPipeline, RePaintScheduler
|
33 |
+
|
34 |
+
class Model_VSN(BaseModel):
|
35 |
+
def __init__(self, opt):
|
36 |
+
super(Model_VSN, self).__init__(opt)
|
37 |
+
|
38 |
+
if opt['dist']:
|
39 |
+
self.rank = torch.distributed.get_rank()
|
40 |
+
else:
|
41 |
+
self.rank = -1 # non dist training
|
42 |
+
|
43 |
+
self.gop = opt['gop']
|
44 |
+
train_opt = opt['train']
|
45 |
+
test_opt = opt['test']
|
46 |
+
self.opt = opt
|
47 |
+
self.train_opt = train_opt
|
48 |
+
self.test_opt = test_opt
|
49 |
+
self.opt_net = opt['network_G']
|
50 |
+
self.center = self.gop // 2
|
51 |
+
self.num_image = opt['num_image']
|
52 |
+
self.mode = opt["mode"]
|
53 |
+
self.idxx = 0
|
54 |
+
|
55 |
+
self.netG = networks.define_G_v2(opt).to(self.device)
|
56 |
+
if opt['dist']:
|
57 |
+
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
58 |
+
else:
|
59 |
+
self.netG = DataParallel(self.netG)
|
60 |
+
# print network
|
61 |
+
self.print_network()
|
62 |
+
self.load()
|
63 |
+
|
64 |
+
self.Quantization = Quantization()
|
65 |
+
|
66 |
+
if not self.opt['hide']:
|
67 |
+
file_path = "bit_sequence.txt"
|
68 |
+
|
69 |
+
data_list = []
|
70 |
+
|
71 |
+
with open(file_path, "r") as file:
|
72 |
+
for line in file:
|
73 |
+
data = [int(bit) for bit in line.strip()]
|
74 |
+
data_list.append(data)
|
75 |
+
|
76 |
+
self.msg_list = data_list
|
77 |
+
|
78 |
+
if self.opt['sdinpaint']:
|
79 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
80 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
81 |
+
torch_dtype=torch.float16,
|
82 |
+
).to("cuda")
|
83 |
+
|
84 |
+
if self.opt['controlnetinpaint']:
|
85 |
+
controlnet = ControlNetModel.from_pretrained(
|
86 |
+
"lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float32
|
87 |
+
).to("cuda")
|
88 |
+
self.pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
89 |
+
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32
|
90 |
+
).to("cuda")
|
91 |
+
|
92 |
+
if self.opt['sdxl']:
|
93 |
+
self.pipe_sdxl = StableDiffusionXLInpaintPipeline.from_pretrained(
|
94 |
+
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
95 |
+
torch_dtype=torch.float16,
|
96 |
+
variant="fp16",
|
97 |
+
use_safetensors=True,
|
98 |
+
).to("cuda")
|
99 |
+
|
100 |
+
if self.opt['repaint']:
|
101 |
+
self.scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
|
102 |
+
self.pipe_repaint = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=self.scheduler)
|
103 |
+
self.pipe_repaint = self.pipe_repaint.to("cuda")
|
104 |
+
|
105 |
+
if self.is_train:
|
106 |
+
self.netG.train()
|
107 |
+
|
108 |
+
# loss
|
109 |
+
self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw'])
|
110 |
+
self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back'])
|
111 |
+
self.Reconstruction_center = ReconstructionLoss(losstype="center")
|
112 |
+
self.Reconstruction_msg = ReconstructionMsgLoss(losstype=self.opt['losstype'])
|
113 |
+
|
114 |
+
# optimizers
|
115 |
+
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
116 |
+
optim_params = []
|
117 |
+
|
118 |
+
if self.mode == "image":
|
119 |
+
for k, v in self.netG.named_parameters():
|
120 |
+
if (k.startswith('module.irn') or k.startswith('module.pm')) and v.requires_grad:
|
121 |
+
optim_params.append(v)
|
122 |
+
else:
|
123 |
+
if self.rank <= 0:
|
124 |
+
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
125 |
+
|
126 |
+
elif self.mode == "bit":
|
127 |
+
for k, v in self.netG.named_parameters():
|
128 |
+
if (k.startswith('module.bitencoder') or k.startswith('module.bitdecoder')) and v.requires_grad:
|
129 |
+
optim_params.append(v)
|
130 |
+
else:
|
131 |
+
if self.rank <= 0:
|
132 |
+
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
133 |
+
|
134 |
+
|
135 |
+
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
|
136 |
+
weight_decay=wd_G,
|
137 |
+
betas=(train_opt['beta1'], train_opt['beta2']))
|
138 |
+
self.optimizers.append(self.optimizer_G)
|
139 |
+
|
140 |
+
# schedulers
|
141 |
+
if train_opt['lr_scheme'] == 'MultiStepLR':
|
142 |
+
for optimizer in self.optimizers:
|
143 |
+
self.schedulers.append(
|
144 |
+
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
|
145 |
+
restarts=train_opt['restarts'],
|
146 |
+
weights=train_opt['restart_weights'],
|
147 |
+
gamma=train_opt['lr_gamma'],
|
148 |
+
clear_state=train_opt['clear_state']))
|
149 |
+
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
|
150 |
+
for optimizer in self.optimizers:
|
151 |
+
self.schedulers.append(
|
152 |
+
lr_scheduler.CosineAnnealingLR_Restart(
|
153 |
+
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
|
154 |
+
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
|
155 |
+
else:
|
156 |
+
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
157 |
+
|
158 |
+
self.log_dict = OrderedDict()
|
159 |
+
|
160 |
+
def feed_data(self, data):
|
161 |
+
self.ref_L = data['LQ'].to(self.device)
|
162 |
+
self.real_H = data['GT'].to(self.device)
|
163 |
+
self.mes = data['MES']
|
164 |
+
|
165 |
+
def init_hidden_state(self, z):
|
166 |
+
b, c, h, w = z.shape
|
167 |
+
h_t = []
|
168 |
+
c_t = []
|
169 |
+
for _ in range(self.opt_net['block_num_rbm']):
|
170 |
+
h_t.append(torch.zeros([b, c, h, w]).cuda())
|
171 |
+
c_t.append(torch.zeros([b, c, h, w]).cuda())
|
172 |
+
memory = torch.zeros([b, c, h, w]).cuda()
|
173 |
+
|
174 |
+
return h_t, c_t, memory
|
175 |
+
|
176 |
+
def loss_forward(self, out, y):
|
177 |
+
l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y)
|
178 |
+
return l_forw_fit
|
179 |
+
|
180 |
+
def loss_back_rec(self, out, x):
|
181 |
+
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x)
|
182 |
+
return l_back_rec
|
183 |
+
|
184 |
+
def loss_back_rec_mul(self, out, x):
|
185 |
+
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x)
|
186 |
+
return l_back_rec
|
187 |
+
|
188 |
+
def optimize_parameters(self, current_step):
|
189 |
+
self.optimizer_G.zero_grad()
|
190 |
+
|
191 |
+
b, n, t, c, h, w = self.ref_L.shape
|
192 |
+
center = t // 2
|
193 |
+
intval = self.gop // 2
|
194 |
+
|
195 |
+
message = torch.Tensor(np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))).to(self.device)
|
196 |
+
|
197 |
+
add_noise = self.opt['addnoise']
|
198 |
+
add_jpeg = self.opt['addjpeg']
|
199 |
+
add_possion = self.opt['addpossion']
|
200 |
+
add_sdinpaint = self.opt['sdinpaint']
|
201 |
+
degrade_shuffle = self.opt['degrade_shuffle']
|
202 |
+
|
203 |
+
self.host = self.real_H[:, center - intval:center + intval + 1]
|
204 |
+
self.secret = self.ref_L[:, :, center - intval:center + intval + 1]
|
205 |
+
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=dwt(self.secret[:,0].reshape(b, -1, h, w)), message=message)
|
206 |
+
|
207 |
+
Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach()
|
208 |
+
|
209 |
+
y_forw = container
|
210 |
+
|
211 |
+
l_forw_fit = self.loss_forward(y_forw, self.host[:,0])
|
212 |
+
|
213 |
+
|
214 |
+
if degrade_shuffle:
|
215 |
+
import random
|
216 |
+
choice = random.randint(0, 2)
|
217 |
+
|
218 |
+
if choice == 0:
|
219 |
+
NL = float((np.random.randint(1, 16))/255)
|
220 |
+
noise = np.random.normal(0, NL, y_forw.shape)
|
221 |
+
torchnoise = torch.from_numpy(noise).cuda().float()
|
222 |
+
y_forw = y_forw + torchnoise
|
223 |
+
|
224 |
+
elif choice == 1:
|
225 |
+
NL = int(np.random.randint(70,95))
|
226 |
+
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
|
227 |
+
y_forw = self.DiffJPEG(y_forw)
|
228 |
+
|
229 |
+
elif choice == 2:
|
230 |
+
vals = 10**4
|
231 |
+
if random.random() < 0.5:
|
232 |
+
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
|
233 |
+
else:
|
234 |
+
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
|
235 |
+
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
|
236 |
+
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
|
237 |
+
|
238 |
+
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
|
239 |
+
|
240 |
+
else:
|
241 |
+
|
242 |
+
if add_noise:
|
243 |
+
NL = float((np.random.randint(1,16))/255)
|
244 |
+
noise = np.random.normal(0, NL, y_forw.shape)
|
245 |
+
torchnoise = torch.from_numpy(noise).cuda().float()
|
246 |
+
y_forw = y_forw + torchnoise
|
247 |
+
|
248 |
+
elif add_jpeg:
|
249 |
+
NL = int(np.random.randint(70,95))
|
250 |
+
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
|
251 |
+
y_forw = self.DiffJPEG(y_forw)
|
252 |
+
|
253 |
+
elif add_possion:
|
254 |
+
vals = 10**4
|
255 |
+
if random.random() < 0.5:
|
256 |
+
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
|
257 |
+
else:
|
258 |
+
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
|
259 |
+
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
|
260 |
+
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
|
261 |
+
|
262 |
+
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
|
263 |
+
|
264 |
+
y = self.Quantization(y_forw)
|
265 |
+
all_zero = torch.zeros(message.shape).to(self.device)
|
266 |
+
|
267 |
+
if self.mode == "image":
|
268 |
+
out_x, out_x_h, out_z, recmessage = self.netG(x=y, message=all_zero, rev=True)
|
269 |
+
out_x = iwt(out_x)
|
270 |
+
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
|
271 |
+
|
272 |
+
l_back_rec = self.loss_back_rec(out_x, self.host[:,0])
|
273 |
+
out_x_h = torch.stack(out_x_h, dim=1)
|
274 |
+
|
275 |
+
l_center_x = self.loss_back_rec(out_x_h[:, 0], self.secret[:,0].reshape(b, -1, h, w))
|
276 |
+
|
277 |
+
recmessage = torch.clamp(recmessage, -0.5, 0.5)
|
278 |
+
|
279 |
+
l_msg = self.Reconstruction_msg(message, recmessage)
|
280 |
+
|
281 |
+
loss = l_forw_fit*2 + l_back_rec + l_center_x*4
|
282 |
+
|
283 |
+
loss.backward()
|
284 |
+
|
285 |
+
if self.train_opt['lambda_center'] != 0:
|
286 |
+
self.log_dict['l_center_x'] = l_center_x.item()
|
287 |
+
|
288 |
+
# set log
|
289 |
+
self.log_dict['l_back_rec'] = l_back_rec.item()
|
290 |
+
self.log_dict['l_forw_fit'] = l_forw_fit.item()
|
291 |
+
self.log_dict['l_msg'] = l_msg.item()
|
292 |
+
|
293 |
+
self.log_dict['l_h'] = (l_center_x*10).item()
|
294 |
+
|
295 |
+
# gradient clipping
|
296 |
+
if self.train_opt['gradient_clipping']:
|
297 |
+
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
|
298 |
+
|
299 |
+
self.optimizer_G.step()
|
300 |
+
|
301 |
+
elif self.mode == "bit":
|
302 |
+
recmessage = self.netG(x=y, message=all_zero, rev=True)
|
303 |
+
|
304 |
+
recmessage = torch.clamp(recmessage, -0.5, 0.5)
|
305 |
+
|
306 |
+
l_msg = self.Reconstruction_msg(message, recmessage)
|
307 |
+
|
308 |
+
lambda_msg = self.train_opt['lambda_msg']
|
309 |
+
|
310 |
+
loss = l_msg * lambda_msg + l_forw_fit
|
311 |
+
|
312 |
+
loss.backward()
|
313 |
+
|
314 |
+
# set log
|
315 |
+
self.log_dict['l_forw_fit'] = l_forw_fit.item()
|
316 |
+
self.log_dict['l_msg'] = l_msg.item()
|
317 |
+
|
318 |
+
# gradient clipping
|
319 |
+
if self.train_opt['gradient_clipping']:
|
320 |
+
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
|
321 |
+
|
322 |
+
self.optimizer_G.step()
|
323 |
+
|
324 |
+
def test(self, image_id):
|
325 |
+
self.netG.eval()
|
326 |
+
add_noise = self.opt['addnoise']
|
327 |
+
add_jpeg = self.opt['addjpeg']
|
328 |
+
add_possion = self.opt['addpossion']
|
329 |
+
add_sdinpaint = self.opt['sdinpaint']
|
330 |
+
add_controlnet = self.opt['controlnetinpaint']
|
331 |
+
add_sdxl = self.opt['sdxl']
|
332 |
+
add_repaint = self.opt['repaint']
|
333 |
+
degrade_shuffle = self.opt['degrade_shuffle']
|
334 |
+
|
335 |
+
with torch.no_grad():
|
336 |
+
forw_L = []
|
337 |
+
forw_L_h = []
|
338 |
+
fake_H = []
|
339 |
+
fake_H_h = []
|
340 |
+
pred_z = []
|
341 |
+
recmsglist = []
|
342 |
+
msglist = []
|
343 |
+
b, t, c, h, w = self.real_H.shape
|
344 |
+
center = t // 2
|
345 |
+
intval = self.gop // 2
|
346 |
+
b, n, t, c, h, w = self.ref_L.shape
|
347 |
+
id=0
|
348 |
+
# forward downscaling
|
349 |
+
self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
|
350 |
+
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
|
351 |
+
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
|
352 |
+
|
353 |
+
messagenp = np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))
|
354 |
+
|
355 |
+
message = torch.Tensor(messagenp).to(self.device)
|
356 |
+
|
357 |
+
if self.opt['bitrecord']:
|
358 |
+
mymsg = message.clone()
|
359 |
+
|
360 |
+
mymsg[mymsg>0] = 1
|
361 |
+
mymsg[mymsg<0] = 0
|
362 |
+
mymsg = mymsg.squeeze(0).to(torch.int)
|
363 |
+
|
364 |
+
bit_list = mymsg.tolist()
|
365 |
+
|
366 |
+
bit_string = ''.join(map(str, bit_list))
|
367 |
+
|
368 |
+
file_name = "bit_sequence.txt"
|
369 |
+
|
370 |
+
with open(file_name, "a") as file:
|
371 |
+
file.write(bit_string + "\n")
|
372 |
+
|
373 |
+
if self.opt['hide']:
|
374 |
+
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message)
|
375 |
+
y_forw = container
|
376 |
+
else:
|
377 |
+
|
378 |
+
message = torch.tensor(self.msg_list[image_id]).unsqueeze(0).cuda()
|
379 |
+
self.output = self.host
|
380 |
+
y_forw = self.output.squeeze(1)
|
381 |
+
|
382 |
+
if add_sdinpaint:
|
383 |
+
import random
|
384 |
+
from PIL import Image
|
385 |
+
prompt = ""
|
386 |
+
|
387 |
+
b, _, _, _ = y_forw.shape
|
388 |
+
|
389 |
+
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
|
390 |
+
forw_list = []
|
391 |
+
|
392 |
+
for j in range(b):
|
393 |
+
i = image_id + 1
|
394 |
+
masksrc = "../dataset/valAGE-Set-Mask/"
|
395 |
+
mask_image = Image.open(masksrc + str(i).zfill(4) + ".png").convert("L")
|
396 |
+
mask_image = mask_image.resize((512, 512))
|
397 |
+
h, w = mask_image.size
|
398 |
+
|
399 |
+
image = image_batch[j, :, :, :]
|
400 |
+
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
|
401 |
+
image_inpaint = self.pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0]
|
402 |
+
image_inpaint = np.array(image_inpaint) / 255.
|
403 |
+
mask_image = np.array(mask_image)
|
404 |
+
mask_image = np.stack([mask_image] * 3, axis=-1) / 255.
|
405 |
+
mask_image = mask_image.astype(np.uint8)
|
406 |
+
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
|
407 |
+
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
|
408 |
+
|
409 |
+
y_forw = torch.stack(forw_list, dim=0).float().cuda()
|
410 |
+
|
411 |
+
if add_controlnet:
|
412 |
+
from diffusers.utils import load_image
|
413 |
+
from PIL import Image
|
414 |
+
|
415 |
+
b, _, _, _ = y_forw.shape
|
416 |
+
forw_list = []
|
417 |
+
|
418 |
+
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
|
419 |
+
generator = torch.Generator(device="cuda").manual_seed(1)
|
420 |
+
|
421 |
+
for j in range(b):
|
422 |
+
i = image_id + 1
|
423 |
+
mask_path = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png"
|
424 |
+
mask_image = load_image(mask_path)
|
425 |
+
mask_image = mask_image.resize((512, 512))
|
426 |
+
image_init = image_batch[j, :, :, :]
|
427 |
+
image_init1 = Image.fromarray((image_init * 255).astype(np.uint8), mode = "RGB")
|
428 |
+
image_mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
|
429 |
+
|
430 |
+
assert image_init.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
|
431 |
+
image_init[image_mask > 0.5] = -1.0 # set as masked pixel
|
432 |
+
image = np.expand_dims(image_init, 0).transpose(0, 3, 1, 2)
|
433 |
+
control_image = torch.from_numpy(image)
|
434 |
+
|
435 |
+
# generate image
|
436 |
+
image_inpaint = self.pipe_control(
|
437 |
+
"",
|
438 |
+
num_inference_steps=20,
|
439 |
+
generator=generator,
|
440 |
+
eta=1.0,
|
441 |
+
image=image_init1,
|
442 |
+
mask_image=image_mask,
|
443 |
+
control_image=control_image,
|
444 |
+
).images[0]
|
445 |
+
|
446 |
+
image_inpaint = np.array(image_inpaint) / 255.
|
447 |
+
image_mask = np.stack([image_mask] * 3, axis=-1)
|
448 |
+
image_mask = image_mask.astype(np.uint8)
|
449 |
+
image_fuse = image_init * (1 - image_mask) + image_inpaint * image_mask
|
450 |
+
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
|
451 |
+
|
452 |
+
y_forw = torch.stack(forw_list, dim=0).float().cuda()
|
453 |
+
|
454 |
+
if add_sdxl:
|
455 |
+
import random
|
456 |
+
from PIL import Image
|
457 |
+
from diffusers.utils import load_image
|
458 |
+
prompt = ""
|
459 |
+
|
460 |
+
b, _, _, _ = y_forw.shape
|
461 |
+
|
462 |
+
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
|
463 |
+
forw_list = []
|
464 |
+
|
465 |
+
for j in range(b):
|
466 |
+
i = image_id + 1
|
467 |
+
masksrc = "../dataset/valAGE-Set-Mask/"
|
468 |
+
mask_image = load_image(masksrc + str(i).zfill(4) + ".png").convert("RGB")
|
469 |
+
mask_image = mask_image.resize((512, 512))
|
470 |
+
h, w = mask_image.size
|
471 |
+
|
472 |
+
image = image_batch[j, :, :, :]
|
473 |
+
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
|
474 |
+
image_inpaint = self.pipe_sdxl(
|
475 |
+
prompt=prompt, image=image_init, mask_image=mask_image, num_inference_steps=50, strength=0.80, target_size=(512, 512)
|
476 |
+
).images[0]
|
477 |
+
image_inpaint = image_inpaint.resize((512, 512))
|
478 |
+
image_inpaint = np.array(image_inpaint) / 255.
|
479 |
+
mask_image = np.array(mask_image) / 255.
|
480 |
+
mask_image = mask_image.astype(np.uint8)
|
481 |
+
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
|
482 |
+
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
|
483 |
+
|
484 |
+
y_forw = torch.stack(forw_list, dim=0).float().cuda()
|
485 |
+
|
486 |
+
|
487 |
+
if add_repaint:
|
488 |
+
from PIL import Image
|
489 |
+
|
490 |
+
b, _, _, _ = y_forw.shape
|
491 |
+
|
492 |
+
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
|
493 |
+
forw_list = []
|
494 |
+
|
495 |
+
generator = torch.Generator(device="cuda").manual_seed(0)
|
496 |
+
for j in range(b):
|
497 |
+
i = image_id + 1
|
498 |
+
masksrc = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png"
|
499 |
+
mask_image = Image.open(masksrc).convert("RGB")
|
500 |
+
mask_image = mask_image.resize((256, 256))
|
501 |
+
mask_image = Image.fromarray(255 - np.array(mask_image))
|
502 |
+
image = image_batch[j, :, :, :]
|
503 |
+
original_image = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
|
504 |
+
original_image = original_image.resize((256, 256))
|
505 |
+
output = self.pipe_repaint(
|
506 |
+
image=original_image,
|
507 |
+
mask_image=mask_image,
|
508 |
+
num_inference_steps=150,
|
509 |
+
eta=0.0,
|
510 |
+
jump_length=10,
|
511 |
+
jump_n_sample=10,
|
512 |
+
generator=generator,
|
513 |
+
)
|
514 |
+
image_inpaint = output.images[0]
|
515 |
+
image_inpaint = image_inpaint.resize((512, 512))
|
516 |
+
image_inpaint = np.array(image_inpaint) / 255.
|
517 |
+
mask_image = mask_image.resize((512, 512))
|
518 |
+
mask_image = np.array(mask_image) / 255.
|
519 |
+
mask_image = mask_image.astype(np.uint8)
|
520 |
+
image_fuse = image * mask_image + image_inpaint * (1 - mask_image)
|
521 |
+
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
|
522 |
+
|
523 |
+
y_forw = torch.stack(forw_list, dim=0).float().cuda()
|
524 |
+
|
525 |
+
if degrade_shuffle:
|
526 |
+
import random
|
527 |
+
choice = random.randint(0, 2)
|
528 |
+
|
529 |
+
if choice == 0:
|
530 |
+
NL = float((np.random.randint(1,5))/255)
|
531 |
+
noise = np.random.normal(0, NL, y_forw.shape)
|
532 |
+
torchnoise = torch.from_numpy(noise).cuda().float()
|
533 |
+
y_forw = y_forw + torchnoise
|
534 |
+
|
535 |
+
elif choice == 1:
|
536 |
+
NL = 90
|
537 |
+
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
|
538 |
+
y_forw = self.DiffJPEG(y_forw)
|
539 |
+
|
540 |
+
elif choice == 2:
|
541 |
+
vals = 10**4
|
542 |
+
if random.random() < 0.5:
|
543 |
+
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
|
544 |
+
else:
|
545 |
+
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
|
546 |
+
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
|
547 |
+
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
|
548 |
+
|
549 |
+
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
|
550 |
+
|
551 |
+
else:
|
552 |
+
|
553 |
+
if add_noise:
|
554 |
+
NL = self.opt['noisesigma'] / 255.0
|
555 |
+
noise = np.random.normal(0, NL, y_forw.shape)
|
556 |
+
torchnoise = torch.from_numpy(noise).cuda().float()
|
557 |
+
y_forw = y_forw + torchnoise
|
558 |
+
|
559 |
+
elif add_jpeg:
|
560 |
+
Q = self.opt['jpegfactor']
|
561 |
+
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(Q)).cuda()
|
562 |
+
y_forw = self.DiffJPEG(y_forw)
|
563 |
+
|
564 |
+
elif add_possion:
|
565 |
+
vals = 10**4
|
566 |
+
if random.random() < 0.5:
|
567 |
+
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
|
568 |
+
else:
|
569 |
+
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
|
570 |
+
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
|
571 |
+
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
|
572 |
+
|
573 |
+
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
|
574 |
+
|
575 |
+
# backward upscaling
|
576 |
+
if self.opt['hide']:
|
577 |
+
y = self.Quantization(y_forw)
|
578 |
+
else:
|
579 |
+
y = y_forw
|
580 |
+
|
581 |
+
if self.mode == "image":
|
582 |
+
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True)
|
583 |
+
out_x = iwt(out_x)
|
584 |
+
|
585 |
+
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
|
586 |
+
out_x = out_x.reshape(-1, self.gop, 3, h, w)
|
587 |
+
out_x_h = torch.stack(out_x_h, dim=1)
|
588 |
+
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w)
|
589 |
+
|
590 |
+
forw_L.append(y_forw)
|
591 |
+
fake_H.append(out_x[:, self.gop//2])
|
592 |
+
fake_H_h.append(out_x_h[:,:, self.gop//2])
|
593 |
+
recmsglist.append(recmessage)
|
594 |
+
msglist.append(message)
|
595 |
+
|
596 |
+
elif self.mode == "bit":
|
597 |
+
recmessage = self.netG(x=y, rev=True)
|
598 |
+
forw_L.append(y_forw)
|
599 |
+
recmsglist.append(recmessage)
|
600 |
+
msglist.append(message)
|
601 |
+
|
602 |
+
if self.mode == "image":
|
603 |
+
self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1)
|
604 |
+
self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1)
|
605 |
+
|
606 |
+
self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1)
|
607 |
+
remesg = torch.clamp(torch.stack(recmsglist, dim=0),-0.5,0.5)
|
608 |
+
|
609 |
+
if self.opt['hide']:
|
610 |
+
mesg = torch.clamp(torch.stack(msglist, dim=0),-0.5,0.5)
|
611 |
+
else:
|
612 |
+
mesg = torch.stack(msglist, dim=0)
|
613 |
+
|
614 |
+
self.recmessage = remesg.clone()
|
615 |
+
self.recmessage[remesg > 0] = 1
|
616 |
+
self.recmessage[remesg <= 0] = 0
|
617 |
+
|
618 |
+
self.message = mesg.clone()
|
619 |
+
self.message[mesg > 0] = 1
|
620 |
+
self.message[mesg <= 0] = 0
|
621 |
+
|
622 |
+
self.netG.train()
|
623 |
+
|
624 |
+
|
625 |
+
def image_hiding(self, ):
|
626 |
+
self.netG.eval()
|
627 |
+
with torch.no_grad():
|
628 |
+
b, t, c, h, w = self.real_H.shape
|
629 |
+
center = t // 2
|
630 |
+
intval = self.gop // 2
|
631 |
+
b, n, t, c, h, w = self.ref_L.shape
|
632 |
+
id=0
|
633 |
+
# forward downscaling
|
634 |
+
self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
|
635 |
+
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
|
636 |
+
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
|
637 |
+
|
638 |
+
message = torch.Tensor(self.mes).to(self.device)
|
639 |
+
|
640 |
+
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message)
|
641 |
+
y_forw = container
|
642 |
+
|
643 |
+
result = torch.clamp(y_forw,0,1)
|
644 |
+
|
645 |
+
lr_img = util.tensor2img(result)
|
646 |
+
|
647 |
+
return lr_img
|
648 |
+
|
649 |
+
def image_recovery(self, number):
|
650 |
+
self.netG.eval()
|
651 |
+
with torch.no_grad():
|
652 |
+
b, t, c, h, w = self.real_H.shape
|
653 |
+
center = t // 2
|
654 |
+
intval = self.gop // 2
|
655 |
+
b, n, t, c, h, w = self.ref_L.shape
|
656 |
+
id=0
|
657 |
+
# forward downscaling
|
658 |
+
self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
|
659 |
+
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
|
660 |
+
template = self.secret.reshape(b, -1, h, w)
|
661 |
+
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
|
662 |
+
|
663 |
+
self.output = self.host
|
664 |
+
y_forw = self.output.squeeze(1)
|
665 |
+
|
666 |
+
y = self.Quantization(y_forw)
|
667 |
+
|
668 |
+
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True)
|
669 |
+
out_x = iwt(out_x)
|
670 |
+
|
671 |
+
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
|
672 |
+
out_x = out_x.reshape(-1, self.gop, 3, h, w)
|
673 |
+
out_x_h = torch.stack(out_x_h, dim=1)
|
674 |
+
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w)
|
675 |
+
|
676 |
+
rec_loc = out_x_h[:,:, self.gop//2]
|
677 |
+
# from PIL import Image
|
678 |
+
# tmp = util.tensor2img(rec_loc)
|
679 |
+
# save
|
680 |
+
residual = torch.abs(template - rec_loc)
|
681 |
+
binary_residual = (residual > number).float()
|
682 |
+
residual = util.tensor2img(binary_residual)
|
683 |
+
mask = np.sum(residual, axis=2)
|
684 |
+
# print(mask)
|
685 |
+
|
686 |
+
remesg = torch.clamp(recmessage,-0.5,0.5)
|
687 |
+
remesg[remesg > 0] = 1
|
688 |
+
remesg[remesg <= 0] = 0
|
689 |
+
|
690 |
+
return mask, remesg
|
691 |
+
|
692 |
+
def get_current_log(self):
|
693 |
+
return self.log_dict
|
694 |
+
|
695 |
+
def get_current_visuals(self):
|
696 |
+
b, n, t, c, h, w = self.ref_L.shape
|
697 |
+
center = t // 2
|
698 |
+
intval = self.gop // 2
|
699 |
+
out_dict = OrderedDict()
|
700 |
+
LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu()
|
701 |
+
LR_ref = torch.chunk(LR_ref, self.num_image, dim=0)
|
702 |
+
out_dict['LR_ref'] = [image.squeeze(0) for image in LR_ref]
|
703 |
+
|
704 |
+
if self.mode == "image":
|
705 |
+
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
706 |
+
SR_h = self.fake_H_h.detach()[0].float().cpu()
|
707 |
+
SR_h = torch.chunk(SR_h, self.num_image, dim=0)
|
708 |
+
out_dict['SR_h'] = [image.squeeze(0) for image in SR_h]
|
709 |
+
|
710 |
+
out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
|
711 |
+
out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu()
|
712 |
+
out_dict['message'] = self.message
|
713 |
+
out_dict['recmessage'] = self.recmessage
|
714 |
+
|
715 |
+
return out_dict
|
716 |
+
|
717 |
+
def print_network(self):
|
718 |
+
s, n = self.get_network_description(self.netG)
|
719 |
+
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
|
720 |
+
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
721 |
+
self.netG.module.__class__.__name__)
|
722 |
+
else:
|
723 |
+
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
724 |
+
if self.rank <= 0:
|
725 |
+
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
726 |
+
logger.info(s)
|
727 |
+
|
728 |
+
def load(self):
|
729 |
+
load_path_G = self.opt['path']['pretrain_model_G']
|
730 |
+
if load_path_G is not None:
|
731 |
+
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
|
732 |
+
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
733 |
+
|
734 |
+
def load_test(self,load_path_G):
|
735 |
+
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
|
736 |
+
|
737 |
+
def save(self, iter_label):
|
738 |
+
self.save_network(self.netG, 'G', iter_label)
|
models/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
logger = logging.getLogger('base')
|
3 |
+
|
4 |
+
def create_model(opt):
|
5 |
+
model = opt['model']
|
6 |
+
frame_num = opt['gop']
|
7 |
+
from .IBSN import Model_VSN as M
|
8 |
+
|
9 |
+
m = M(opt)
|
10 |
+
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
|
11 |
+
return m
|
models/__pycache__/IBSN.cpython-310.pyc
ADDED
Binary file (18.5 kB). View file
|
|
models/__pycache__/IBSN.cpython-38.pyc
ADDED
Binary file (18.7 kB). View file
|
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (491 Bytes). View file
|
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (504 Bytes). View file
|
|
models/__pycache__/base_model.cpython-38.pyc
ADDED
Binary file (5.44 kB). View file
|
|
models/__pycache__/lr_scheduler.cpython-38.pyc
ADDED
Binary file (4.98 kB). View file
|
|
models/__pycache__/networks.cpython-310.pyc
ADDED
Binary file (731 Bytes). View file
|
|
models/__pycache__/networks.cpython-38.pyc
ADDED
Binary file (740 Bytes). View file
|
|
models/base_model.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import OrderedDict
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn.parallel import DistributedDataParallel
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel():
|
9 |
+
def __init__(self, opt):
|
10 |
+
self.opt = opt
|
11 |
+
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
12 |
+
self.is_train = opt['is_train']
|
13 |
+
self.schedulers = []
|
14 |
+
self.optimizers = []
|
15 |
+
|
16 |
+
def feed_data(self, data):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def optimize_parameters(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_current_visuals(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def get_current_losses(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def print_network(self):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def save(self, label):
|
32 |
+
pass
|
33 |
+
|
34 |
+
def load(self):
|
35 |
+
pass
|
36 |
+
|
37 |
+
def _set_lr(self, lr_groups_l):
|
38 |
+
''' set learning rate for warmup,
|
39 |
+
lr_groups_l: list for lr_groups. each for a optimizer'''
|
40 |
+
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
41 |
+
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
42 |
+
param_group['lr'] = lr
|
43 |
+
|
44 |
+
def _get_init_lr(self):
|
45 |
+
# get the initial lr, which is set by the scheduler
|
46 |
+
init_lr_groups_l = []
|
47 |
+
for optimizer in self.optimizers:
|
48 |
+
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
|
49 |
+
return init_lr_groups_l
|
50 |
+
|
51 |
+
def update_learning_rate(self, cur_iter, warmup_iter=-1):
|
52 |
+
for scheduler in self.schedulers:
|
53 |
+
scheduler.step()
|
54 |
+
#### set up warm up learning rate
|
55 |
+
if cur_iter < warmup_iter:
|
56 |
+
# get initial lr for each group
|
57 |
+
init_lr_g_l = self._get_init_lr()
|
58 |
+
# modify warming-up learning rates
|
59 |
+
warm_up_lr_l = []
|
60 |
+
for init_lr_g in init_lr_g_l:
|
61 |
+
warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
|
62 |
+
# set learning rate
|
63 |
+
self._set_lr(warm_up_lr_l)
|
64 |
+
|
65 |
+
def get_current_learning_rate(self):
|
66 |
+
# return self.schedulers[0].get_lr()[0]
|
67 |
+
return self.optimizers[0].param_groups[0]['lr']
|
68 |
+
|
69 |
+
def get_network_description(self, network):
|
70 |
+
'''Get the string and total parameters of the network'''
|
71 |
+
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
72 |
+
network = network.module
|
73 |
+
s = str(network)
|
74 |
+
n = sum(map(lambda x: x.numel(), network.parameters()))
|
75 |
+
return s, n
|
76 |
+
|
77 |
+
def save_network(self, network, network_label, iter_label):
|
78 |
+
save_filename = '{}_{}.pth'.format(iter_label, network_label)
|
79 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
80 |
+
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
81 |
+
network = network.module
|
82 |
+
state_dict = network.state_dict()
|
83 |
+
for key, param in state_dict.items():
|
84 |
+
state_dict[key] = param.cpu()
|
85 |
+
torch.save(state_dict, save_path)
|
86 |
+
|
87 |
+
def load_network(self, load_path, network, strict=True):
|
88 |
+
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
89 |
+
network = network.module
|
90 |
+
load_net = torch.load(load_path)
|
91 |
+
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
92 |
+
for k, v in load_net.items():
|
93 |
+
if k.startswith('module.'):
|
94 |
+
load_net_clean[k[7:]] = v
|
95 |
+
else:
|
96 |
+
load_net_clean[k] = v
|
97 |
+
network.load_state_dict(load_net_clean, strict=strict)
|
98 |
+
|
99 |
+
def save_training_state(self, epoch, iter_step):
|
100 |
+
'''Saves training state during training, which will be used for resuming'''
|
101 |
+
state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
|
102 |
+
for s in self.schedulers:
|
103 |
+
state['schedulers'].append(s.state_dict())
|
104 |
+
for o in self.optimizers:
|
105 |
+
state['optimizers'].append(o.state_dict())
|
106 |
+
save_filename = '{}.state'.format(iter_step)
|
107 |
+
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
108 |
+
torch.save(state, save_path)
|
109 |
+
|
110 |
+
def resume_training(self, resume_state):
|
111 |
+
'''Resume the optimizers and schedulers for training'''
|
112 |
+
resume_optimizers = resume_state['optimizers']
|
113 |
+
resume_schedulers = resume_state['schedulers']
|
114 |
+
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
|
115 |
+
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
|
116 |
+
for i, o in enumerate(resume_optimizers):
|
117 |
+
self.optimizers[i].load_state_dict(o)
|
118 |
+
for i, s in enumerate(resume_schedulers):
|
119 |
+
self.schedulers[i].load_state_dict(s)
|
models/bitnetwork/ConvBlock.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class ConvINRelu(nn.Module):
|
5 |
+
"""
|
6 |
+
A sequence of Convolution, Instance Normalization, and ReLU activation
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(self, channels_in, channels_out, stride):
|
10 |
+
super(ConvINRelu, self).__init__()
|
11 |
+
|
12 |
+
self.layers = nn.Sequential(
|
13 |
+
nn.Conv2d(channels_in, channels_out, 3, stride, padding=1),
|
14 |
+
nn.InstanceNorm2d(channels_out),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return self.layers(x)
|
20 |
+
|
21 |
+
|
22 |
+
class ConvBlock(nn.Module):
|
23 |
+
'''
|
24 |
+
Network that composed by layers of ConvINRelu
|
25 |
+
'''
|
26 |
+
|
27 |
+
def __init__(self, in_channels, out_channels, blocks=1, stride=1):
|
28 |
+
super(ConvBlock, self).__init__()
|
29 |
+
|
30 |
+
layers = [ConvINRelu(in_channels, out_channels, stride)] if blocks != 0 else []
|
31 |
+
for _ in range(blocks - 1):
|
32 |
+
layer = ConvINRelu(out_channels, out_channels, 1)
|
33 |
+
layers.append(layer)
|
34 |
+
|
35 |
+
self.layers = nn.Sequential(*layers)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layers(x)
|
models/bitnetwork/DW_EncoderDecoder.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import *
|
2 |
+
from .Encoder_U import DW_Encoder
|
3 |
+
from .Decoder_U import DW_Decoder
|
4 |
+
from .Noise import Noise
|
5 |
+
from .Random_Noise import Random_Noise
|
6 |
+
|
7 |
+
|
8 |
+
class DW_EncoderDecoder(nn.Module):
|
9 |
+
'''
|
10 |
+
A Sequential of Encoder_MP-Noise-Decoder
|
11 |
+
'''
|
12 |
+
|
13 |
+
def __init__(self, message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder):
|
14 |
+
super(DW_EncoderDecoder, self).__init__()
|
15 |
+
self.encoder = DW_Encoder(message_length, attention = attention_encoder)
|
16 |
+
self.noise = Random_Noise(noise_layers_R + noise_layers_F, len(noise_layers_R), len(noise_layers_F))
|
17 |
+
self.decoder_C = DW_Decoder(message_length, attention = attention_decoder)
|
18 |
+
self.decoder_RF = DW_Decoder(message_length, attention = attention_decoder)
|
19 |
+
|
20 |
+
|
21 |
+
def forward(self, image, message, mask):
|
22 |
+
encoded_image = self.encoder(image, message)
|
23 |
+
noised_image_C, noised_image_R, noised_image_F = self.noise([encoded_image, image, mask])
|
24 |
+
decoded_message_C = self.decoder_C(noised_image_C)
|
25 |
+
decoded_message_R = self.decoder_RF(noised_image_R)
|
26 |
+
decoded_message_F = self.decoder_RF(noised_image_F)
|
27 |
+
return encoded_image, noised_image_C, decoded_message_C, decoded_message_R, decoded_message_F
|
28 |
+
|
models/bitnetwork/Decoder_U.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import *
|
2 |
+
|
3 |
+
|
4 |
+
class DW_Decoder(nn.Module):
|
5 |
+
|
6 |
+
def __init__(self, message_length, blocks=2, channels=64, attention=None):
|
7 |
+
super(DW_Decoder, self).__init__()
|
8 |
+
|
9 |
+
self.conv1 = ConvBlock(3, 16, blocks=blocks)
|
10 |
+
self.down1 = Down(16, 32, blocks=blocks)
|
11 |
+
self.down2 = Down(32, 64, blocks=blocks)
|
12 |
+
self.down3 = Down(64, 128, blocks=blocks)
|
13 |
+
|
14 |
+
self.down4 = Down(128, 256, blocks=blocks)
|
15 |
+
|
16 |
+
self.up3 = UP(256, 128)
|
17 |
+
self.att3 = ResBlock(128 * 2, 128, blocks=blocks, attention=attention)
|
18 |
+
|
19 |
+
self.up2 = UP(128, 64)
|
20 |
+
self.att2 = ResBlock(64 * 2, 64, blocks=blocks, attention=attention)
|
21 |
+
|
22 |
+
self.up1 = UP(64, 32)
|
23 |
+
self.att1 = ResBlock(32 * 2, 32, blocks=blocks, attention=attention)
|
24 |
+
|
25 |
+
self.up0 = UP(32, 16)
|
26 |
+
self.att0 = ResBlock(16 * 2, 16, blocks=blocks, attention=attention)
|
27 |
+
|
28 |
+
self.Conv_1x1 = nn.Conv2d(16, 1, kernel_size=1, stride=1, padding=0, bias=False)
|
29 |
+
|
30 |
+
self.message_layer = nn.Linear(message_length * message_length, message_length)
|
31 |
+
self.message_length = message_length
|
32 |
+
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
d0 = self.conv1(x)
|
36 |
+
d1 = self.down1(d0)
|
37 |
+
d2 = self.down2(d1)
|
38 |
+
d3 = self.down3(d2)
|
39 |
+
|
40 |
+
d4 = self.down4(d3)
|
41 |
+
|
42 |
+
u3 = self.up3(d4)
|
43 |
+
u3 = torch.cat((d3, u3), dim=1)
|
44 |
+
u3 = self.att3(u3)
|
45 |
+
|
46 |
+
u2 = self.up2(u3)
|
47 |
+
u2 = torch.cat((d2, u2), dim=1)
|
48 |
+
u2 = self.att2(u2)
|
49 |
+
|
50 |
+
u1 = self.up1(u2)
|
51 |
+
u1 = torch.cat((d1, u1), dim=1)
|
52 |
+
u1 = self.att1(u1)
|
53 |
+
|
54 |
+
u0 = self.up0(u1)
|
55 |
+
u0 = torch.cat((d0, u0), dim=1)
|
56 |
+
u0 = self.att0(u0)
|
57 |
+
|
58 |
+
residual = self.Conv_1x1(u0)
|
59 |
+
|
60 |
+
message = F.interpolate(residual, size=(self.message_length, self.message_length),
|
61 |
+
mode='nearest')
|
62 |
+
message = message.view(message.shape[0], -1)
|
63 |
+
message = self.message_layer(message)
|
64 |
+
|
65 |
+
return message
|
66 |
+
|
67 |
+
|
68 |
+
class Down(nn.Module):
|
69 |
+
def __init__(self, in_channels, out_channels, blocks):
|
70 |
+
super(Down, self).__init__()
|
71 |
+
self.layer = torch.nn.Sequential(
|
72 |
+
ConvBlock(in_channels, in_channels, stride=2),
|
73 |
+
ConvBlock(in_channels, out_channels, blocks=blocks)
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
return self.layer(x)
|
78 |
+
|
79 |
+
|
80 |
+
class UP(nn.Module):
|
81 |
+
def __init__(self, in_channels, out_channels):
|
82 |
+
super(UP, self).__init__()
|
83 |
+
self.conv = ConvBlock(in_channels, out_channels)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
87 |
+
return self.conv(x)
|
models/bitnetwork/Dual_Mark.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .DW_EncoderDecoder import *
|
2 |
+
from .Patch_Discriminator import Patch_Discriminator
|
3 |
+
import torch
|
4 |
+
import kornia.losses
|
5 |
+
import lpips
|
6 |
+
|
7 |
+
|
8 |
+
class Network:
|
9 |
+
|
10 |
+
def __init__(self, message_length, noise_layers_R, noise_layers_F, device, batch_size, lr, beta1, attention_encoder, attention_decoder, weight):
|
11 |
+
# device
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
# loss function
|
15 |
+
self.criterion_MSE = nn.MSELoss().to(device)
|
16 |
+
self.criterion_LPIPS = lpips.LPIPS().to(device)
|
17 |
+
|
18 |
+
# weight of encoder-decoder loss
|
19 |
+
self.encoder_weight = weight[0]
|
20 |
+
self.decoder_weight_C = weight[1]
|
21 |
+
self.decoder_weight_R = weight[2]
|
22 |
+
self.decoder_weight_F = weight[3]
|
23 |
+
self.discriminator_weight = weight[4]
|
24 |
+
|
25 |
+
# network
|
26 |
+
self.encoder_decoder = DW_EncoderDecoder(message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder).to(device)
|
27 |
+
self.discriminator = Patch_Discriminator().to(device)
|
28 |
+
|
29 |
+
self.encoder_decoder = torch.nn.DataParallel(self.encoder_decoder)
|
30 |
+
self.discriminator = torch.nn.DataParallel(self.discriminator)
|
31 |
+
|
32 |
+
# mark "cover" as 1, "encoded" as -1
|
33 |
+
self.label_cover = 1.0
|
34 |
+
self.label_encoded = - 1.0
|
35 |
+
|
36 |
+
for p in self.encoder_decoder.module.noise.parameters():
|
37 |
+
p.requires_grad = False
|
38 |
+
|
39 |
+
# optimizer
|
40 |
+
self.opt_encoder_decoder = torch.optim.Adam(
|
41 |
+
filter(lambda p: p.requires_grad, self.encoder_decoder.parameters()), lr=lr, betas=(beta1, 0.999))
|
42 |
+
self.opt_discriminator = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
|
43 |
+
|
44 |
+
|
45 |
+
def train(self, images: torch.Tensor, messages: torch.Tensor, masks: torch.Tensor):
|
46 |
+
self.encoder_decoder.train()
|
47 |
+
self.discriminator.train()
|
48 |
+
|
49 |
+
with torch.enable_grad():
|
50 |
+
# use device to compute
|
51 |
+
images, messages, masks = images.to(self.device), messages.to(self.device), masks.to(self.device)
|
52 |
+
encoded_images, noised_images, decoded_messages_C, decoded_messages_R, decoded_messages_F = self.encoder_decoder(images, messages, masks)
|
53 |
+
|
54 |
+
'''
|
55 |
+
train discriminator
|
56 |
+
'''
|
57 |
+
for p in self.discriminator.parameters():
|
58 |
+
p.requires_grad = True
|
59 |
+
|
60 |
+
self.opt_discriminator.zero_grad()
|
61 |
+
|
62 |
+
# RAW : target label for image should be "cover"(1)
|
63 |
+
d_label_cover = self.discriminator(images)
|
64 |
+
#d_cover_loss = self.criterion_MSE(d_label_cover, torch.ones_like(d_label_cover))
|
65 |
+
#d_cover_loss.backward()
|
66 |
+
|
67 |
+
# GAN : target label for encoded image should be "encoded"(0)
|
68 |
+
d_label_encoded = self.discriminator(encoded_images.detach())
|
69 |
+
#d_encoded_loss = self.criterion_MSE(d_label_encoded, torch.zeros_like(d_label_encoded))
|
70 |
+
#d_encoded_loss.backward()
|
71 |
+
|
72 |
+
d_loss = self.criterion_MSE(d_label_cover - torch.mean(d_label_encoded), self.label_cover * torch.ones_like(d_label_cover)) +\
|
73 |
+
self.criterion_MSE(d_label_encoded - torch.mean(d_label_cover), self.label_encoded * torch.ones_like(d_label_encoded))
|
74 |
+
d_loss.backward()
|
75 |
+
|
76 |
+
self.opt_discriminator.step()
|
77 |
+
|
78 |
+
'''
|
79 |
+
train encoder and decoder
|
80 |
+
'''
|
81 |
+
# Make it a tiny bit faster
|
82 |
+
for p in self.discriminator.parameters():
|
83 |
+
p.requires_grad = False
|
84 |
+
|
85 |
+
self.opt_encoder_decoder.zero_grad()
|
86 |
+
|
87 |
+
# GAN : target label for encoded image should be "cover"(0)
|
88 |
+
g_label_cover = self.discriminator(images)
|
89 |
+
g_label_encoded = self.discriminator(encoded_images)
|
90 |
+
g_loss_on_discriminator = self.criterion_MSE(g_label_cover - torch.mean(g_label_encoded), self.label_encoded * torch.ones_like(g_label_cover)) +\
|
91 |
+
self.criterion_MSE(g_label_encoded - torch.mean(g_label_cover), self.label_cover * torch.ones_like(g_label_encoded))
|
92 |
+
|
93 |
+
# RAW : the encoded image should be similar to cover image
|
94 |
+
g_loss_on_encoder_MSE = self.criterion_MSE(encoded_images, images)
|
95 |
+
g_loss_on_encoder_LPIPS = torch.mean(self.criterion_LPIPS(encoded_images, images))
|
96 |
+
|
97 |
+
# RESULT : the decoded message should be similar to the raw message /Dual
|
98 |
+
g_loss_on_decoder_C = self.criterion_MSE(decoded_messages_C, messages)
|
99 |
+
g_loss_on_decoder_R = self.criterion_MSE(decoded_messages_R, messages)
|
100 |
+
g_loss_on_decoder_F = self.criterion_MSE(decoded_messages_F, torch.zeros_like(messages))
|
101 |
+
|
102 |
+
# full loss
|
103 |
+
g_loss = self.discriminator_weight * g_loss_on_discriminator + self.encoder_weight * g_loss_on_encoder_MSE +\
|
104 |
+
self.decoder_weight_C * g_loss_on_decoder_C + self.decoder_weight_R * g_loss_on_decoder_R + self.decoder_weight_F * g_loss_on_decoder_F
|
105 |
+
|
106 |
+
g_loss.backward()
|
107 |
+
self.opt_encoder_decoder.step()
|
108 |
+
|
109 |
+
# psnr
|
110 |
+
psnr = - kornia.losses.psnr_loss(encoded_images.detach(), images, 2)
|
111 |
+
|
112 |
+
# ssim
|
113 |
+
ssim = 1 - 2 * kornia.losses.ssim_loss(encoded_images.detach(), images, window_size=11, reduction="mean")
|
114 |
+
|
115 |
+
'''
|
116 |
+
decoded message error rate /Dual
|
117 |
+
'''
|
118 |
+
error_rate_C = self.decoded_message_error_rate_batch(messages, decoded_messages_C)
|
119 |
+
error_rate_R = self.decoded_message_error_rate_batch(messages, decoded_messages_R)
|
120 |
+
error_rate_F = self.decoded_message_error_rate_batch(messages, decoded_messages_F)
|
121 |
+
|
122 |
+
result = {
|
123 |
+
"g_loss": g_loss,
|
124 |
+
"error_rate_C": error_rate_C,
|
125 |
+
"error_rate_R": error_rate_R,
|
126 |
+
"error_rate_F": error_rate_F,
|
127 |
+
"psnr": psnr,
|
128 |
+
"ssim": ssim,
|
129 |
+
"g_loss_on_discriminator": g_loss_on_discriminator,
|
130 |
+
"g_loss_on_encoder_MSE": g_loss_on_encoder_MSE,
|
131 |
+
"g_loss_on_encoder_LPIPS": g_loss_on_encoder_LPIPS,
|
132 |
+
"g_loss_on_decoder_C": g_loss_on_decoder_C,
|
133 |
+
"g_loss_on_decoder_R": g_loss_on_decoder_R,
|
134 |
+
"g_loss_on_decoder_F": g_loss_on_decoder_F,
|
135 |
+
"d_loss": d_loss
|
136 |
+
}
|
137 |
+
return result
|
138 |
+
|
139 |
+
|
140 |
+
def validation(self, images: torch.Tensor, messages: torch.Tensor, masks: torch.Tensor):
|
141 |
+
self.encoder_decoder.eval()
|
142 |
+
self.encoder_decoder.module.noise.train()
|
143 |
+
self.discriminator.eval()
|
144 |
+
|
145 |
+
with torch.no_grad():
|
146 |
+
# use device to compute
|
147 |
+
images, messages, masks = images.to(self.device), messages.to(self.device), masks.to(self.device)
|
148 |
+
encoded_images, noised_images, decoded_messages_C, decoded_messages_R, decoded_messages_F = self.encoder_decoder(images, messages, masks)
|
149 |
+
|
150 |
+
'''
|
151 |
+
validate discriminator
|
152 |
+
'''
|
153 |
+
# RAW : target label for image should be "cover"(1)
|
154 |
+
d_label_cover = self.discriminator(images)
|
155 |
+
#d_cover_loss = self.criterion_MSE(d_label_cover, torch.ones_like(d_label_cover))
|
156 |
+
|
157 |
+
# GAN : target label for encoded image should be "encoded"(0)
|
158 |
+
d_label_encoded = self.discriminator(encoded_images.detach())
|
159 |
+
#d_encoded_loss = self.criterion_MSE(d_label_encoded, torch.zeros_like(d_label_encoded))
|
160 |
+
|
161 |
+
d_loss = self.criterion_MSE(d_label_cover - torch.mean(d_label_encoded), self.label_cover * torch.ones_like(d_label_cover)) +\
|
162 |
+
self.criterion_MSE(d_label_encoded - torch.mean(d_label_cover), self.label_encoded * torch.ones_like(d_label_encoded))
|
163 |
+
|
164 |
+
'''
|
165 |
+
validate encoder and decoder
|
166 |
+
'''
|
167 |
+
|
168 |
+
# GAN : target label for encoded image should be "cover"(0)
|
169 |
+
g_label_cover = self.discriminator(images)
|
170 |
+
g_label_encoded = self.discriminator(encoded_images)
|
171 |
+
g_loss_on_discriminator = self.criterion_MSE(g_label_cover - torch.mean(g_label_encoded), self.label_encoded * torch.ones_like(g_label_cover)) +\
|
172 |
+
self.criterion_MSE(g_label_encoded - torch.mean(g_label_cover), self.label_cover * torch.ones_like(g_label_encoded))
|
173 |
+
|
174 |
+
# RAW : the encoded image should be similar to cover image
|
175 |
+
g_loss_on_encoder_MSE = self.criterion_MSE(encoded_images, images)
|
176 |
+
g_loss_on_encoder_LPIPS = torch.mean(self.criterion_LPIPS(encoded_images, images))
|
177 |
+
|
178 |
+
# RESULT : the decoded message should be similar to the raw message /Dual
|
179 |
+
g_loss_on_decoder_C = self.criterion_MSE(decoded_messages_C, messages)
|
180 |
+
g_loss_on_decoder_R = self.criterion_MSE(decoded_messages_R, messages)
|
181 |
+
g_loss_on_decoder_F = self.criterion_MSE(decoded_messages_F, torch.zeros_like(messages))
|
182 |
+
|
183 |
+
# full loss
|
184 |
+
# unstable g_loss_on_discriminator is not used during validation
|
185 |
+
|
186 |
+
g_loss = 0 * g_loss_on_discriminator + self.encoder_weight * g_loss_on_encoder_LPIPS +\
|
187 |
+
self.decoder_weight_C * g_loss_on_decoder_C + self.decoder_weight_R * g_loss_on_decoder_R + self.decoder_weight_F * g_loss_on_decoder_F
|
188 |
+
|
189 |
+
|
190 |
+
# psnr
|
191 |
+
psnr = - kornia.losses.psnr_loss(encoded_images.detach(), images, 2)
|
192 |
+
|
193 |
+
# ssim
|
194 |
+
ssim = 1 - 2 * kornia.losses.ssim_loss(encoded_images.detach(), images, window_size=11, reduction="mean")
|
195 |
+
|
196 |
+
'''
|
197 |
+
decoded message error rate /Dual
|
198 |
+
'''
|
199 |
+
error_rate_C = self.decoded_message_error_rate_batch(messages, decoded_messages_C)
|
200 |
+
error_rate_R = self.decoded_message_error_rate_batch(messages, decoded_messages_R)
|
201 |
+
error_rate_F = self.decoded_message_error_rate_batch(messages, decoded_messages_F)
|
202 |
+
|
203 |
+
result = {
|
204 |
+
"g_loss": g_loss,
|
205 |
+
"error_rate_C": error_rate_C,
|
206 |
+
"error_rate_R": error_rate_R,
|
207 |
+
"error_rate_F": error_rate_F,
|
208 |
+
"psnr": psnr,
|
209 |
+
"ssim": ssim,
|
210 |
+
"g_loss_on_discriminator": g_loss_on_discriminator,
|
211 |
+
"g_loss_on_encoder_MSE": g_loss_on_encoder_MSE,
|
212 |
+
"g_loss_on_encoder_LPIPS": g_loss_on_encoder_LPIPS,
|
213 |
+
"g_loss_on_decoder_C": g_loss_on_decoder_C,
|
214 |
+
"g_loss_on_decoder_R": g_loss_on_decoder_R,
|
215 |
+
"g_loss_on_decoder_F": g_loss_on_decoder_F,
|
216 |
+
"d_loss": d_loss
|
217 |
+
}
|
218 |
+
|
219 |
+
return result, (images, encoded_images, noised_images)
|
220 |
+
|
221 |
+
def decoded_message_error_rate(self, message, decoded_message):
|
222 |
+
length = message.shape[0]
|
223 |
+
|
224 |
+
message = message.gt(0)
|
225 |
+
decoded_message = decoded_message.gt(0)
|
226 |
+
error_rate = float(sum(message != decoded_message)) / length
|
227 |
+
return error_rate
|
228 |
+
|
229 |
+
def decoded_message_error_rate_batch(self, messages, decoded_messages):
|
230 |
+
error_rate = 0.0
|
231 |
+
batch_size = len(messages)
|
232 |
+
for i in range(batch_size):
|
233 |
+
error_rate += self.decoded_message_error_rate(messages[i], decoded_messages[i])
|
234 |
+
error_rate /= batch_size
|
235 |
+
return error_rate
|
236 |
+
|
237 |
+
def save_model(self, path_encoder_decoder: str, path_discriminator: str):
|
238 |
+
torch.save(self.encoder_decoder.module.state_dict(), path_encoder_decoder)
|
239 |
+
torch.save(self.discriminator.module.state_dict(), path_discriminator)
|
240 |
+
|
241 |
+
def load_model(self, path_encoder_decoder: str, path_discriminator: str):
|
242 |
+
self.load_model_ed(path_encoder_decoder)
|
243 |
+
self.load_model_dis(path_discriminator)
|
244 |
+
|
245 |
+
def load_model_ed(self, path_encoder_decoder: str):
|
246 |
+
self.encoder_decoder.module.load_state_dict(torch.load(path_encoder_decoder), strict=False)
|
247 |
+
|
248 |
+
def load_model_dis(self, path_discriminator: str):
|
249 |
+
self.discriminator.module.load_state_dict(torch.load(path_discriminator))
|
models/bitnetwork/Encoder_U.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import *
|
2 |
+
|
3 |
+
class DW_Encoder(nn.Module):
|
4 |
+
|
5 |
+
def __init__(self, message_length, blocks=2, channels=64, attention=None):
|
6 |
+
super(DW_Encoder, self).__init__()
|
7 |
+
|
8 |
+
self.conv1 = ConvBlock(3, 16, blocks=blocks)
|
9 |
+
self.down1 = Down(16, 32, blocks=blocks)
|
10 |
+
self.down2 = Down(32, 64, blocks=blocks)
|
11 |
+
self.down3 = Down(64, 128, blocks=blocks)
|
12 |
+
|
13 |
+
self.down4 = Down(128, 256, blocks=blocks)
|
14 |
+
|
15 |
+
self.up3 = UP(256, 128)
|
16 |
+
self.linear3 = nn.Linear(message_length, message_length * message_length)
|
17 |
+
self.Conv_message3 = ConvBlock(1, channels, blocks=blocks)
|
18 |
+
self.att3 = ResBlock(128 * 2 + channels, 128, blocks=blocks, attention=attention)
|
19 |
+
|
20 |
+
self.up2 = UP(128, 64)
|
21 |
+
self.linear2 = nn.Linear(message_length, message_length * message_length)
|
22 |
+
self.Conv_message2 = ConvBlock(1, channels, blocks=blocks)
|
23 |
+
self.att2 = ResBlock(64 * 2 + channels, 64, blocks=blocks, attention=attention)
|
24 |
+
|
25 |
+
self.up1 = UP(64, 32)
|
26 |
+
self.linear1 = nn.Linear(message_length, message_length * message_length)
|
27 |
+
self.Conv_message1 = ConvBlock(1, channels, blocks=blocks)
|
28 |
+
self.att1 = ResBlock(32 * 2 + channels, 32, blocks=blocks, attention=attention)
|
29 |
+
|
30 |
+
self.up0 = UP(32, 16)
|
31 |
+
self.linear0 = nn.Linear(message_length, message_length * message_length)
|
32 |
+
self.Conv_message0 = ConvBlock(1, channels, blocks=blocks)
|
33 |
+
self.att0 = ResBlock(16 * 2 + channels, 16, blocks=blocks, attention=attention)
|
34 |
+
|
35 |
+
self.Conv_1x1 = nn.Conv2d(16 + 3, 3, kernel_size=1, stride=1, padding=0)
|
36 |
+
|
37 |
+
self.message_length = message_length
|
38 |
+
|
39 |
+
self.transform = transforms.Compose([
|
40 |
+
transforms.ToTensor(),
|
41 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
42 |
+
])
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, x, watermark):
|
46 |
+
d0 = self.conv1(x)
|
47 |
+
d1 = self.down1(d0)
|
48 |
+
d2 = self.down2(d1)
|
49 |
+
d3 = self.down3(d2)
|
50 |
+
|
51 |
+
d4 = self.down4(d3)
|
52 |
+
|
53 |
+
u3 = self.up3(d4)
|
54 |
+
expanded_message = self.linear3(watermark)
|
55 |
+
expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
|
56 |
+
expanded_message = F.interpolate(expanded_message, size=(d3.shape[2], d3.shape[3]),
|
57 |
+
mode='nearest')
|
58 |
+
expanded_message = self.Conv_message3(expanded_message)
|
59 |
+
u3 = torch.cat((d3, u3, expanded_message), dim=1)
|
60 |
+
u3 = self.att3(u3)
|
61 |
+
|
62 |
+
u2 = self.up2(u3)
|
63 |
+
expanded_message = self.linear2(watermark)
|
64 |
+
expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
|
65 |
+
expanded_message = F.interpolate(expanded_message, size=(d2.shape[2], d2.shape[3]),
|
66 |
+
mode='nearest')
|
67 |
+
expanded_message = self.Conv_message2(expanded_message)
|
68 |
+
u2 = torch.cat((d2, u2, expanded_message), dim=1)
|
69 |
+
u2 = self.att2(u2)
|
70 |
+
|
71 |
+
u1 = self.up1(u2)
|
72 |
+
expanded_message = self.linear1(watermark)
|
73 |
+
expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
|
74 |
+
expanded_message = F.interpolate(expanded_message, size=(d1.shape[2], d1.shape[3]),
|
75 |
+
mode='nearest')
|
76 |
+
expanded_message = self.Conv_message1(expanded_message)
|
77 |
+
u1 = torch.cat((d1, u1, expanded_message), dim=1)
|
78 |
+
u1 = self.att1(u1)
|
79 |
+
|
80 |
+
u0 = self.up0(u1)
|
81 |
+
expanded_message = self.linear0(watermark)
|
82 |
+
expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
|
83 |
+
expanded_message = F.interpolate(expanded_message, size=(d0.shape[2], d0.shape[3]),
|
84 |
+
mode='nearest')
|
85 |
+
expanded_message = self.Conv_message0(expanded_message)
|
86 |
+
u0 = torch.cat((d0, u0, expanded_message), dim=1)
|
87 |
+
u0 = self.att0(u0)
|
88 |
+
|
89 |
+
image = self.Conv_1x1(torch.cat((x, u0), dim=1))
|
90 |
+
|
91 |
+
forward_image = image.clone().detach()
|
92 |
+
'''read_image = torch.zeros_like(forward_image)
|
93 |
+
|
94 |
+
for index in range(forward_image.shape[0]):
|
95 |
+
single_image = ((forward_image[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
|
96 |
+
im = Image.fromarray(single_image)
|
97 |
+
read = np.array(im, dtype=np.uint8)
|
98 |
+
read_image[index] = self.transform(read).unsqueeze(0).to(image.device)
|
99 |
+
|
100 |
+
gap = read_image - forward_image'''
|
101 |
+
gap = forward_image.clamp(-1, 1) - forward_image
|
102 |
+
|
103 |
+
return image + gap
|
104 |
+
|
105 |
+
|
106 |
+
class Down(nn.Module):
|
107 |
+
def __init__(self, in_channels, out_channels, blocks):
|
108 |
+
super(Down, self).__init__()
|
109 |
+
self.layer = torch.nn.Sequential(
|
110 |
+
ConvBlock(in_channels, in_channels, stride=2),
|
111 |
+
ConvBlock(in_channels, out_channels, blocks=blocks)
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
return self.layer(x)
|
116 |
+
|
117 |
+
|
118 |
+
class UP(nn.Module):
|
119 |
+
def __init__(self, in_channels, out_channels):
|
120 |
+
super(UP, self).__init__()
|
121 |
+
self.conv = ConvBlock(in_channels, out_channels)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
125 |
+
return self.conv(x)
|
models/bitnetwork/Random_Noise.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import *
|
2 |
+
from .noise_layers import *
|
3 |
+
|
4 |
+
|
5 |
+
class Random_Noise(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, layers, len_layers_R, len_layers_F):
|
8 |
+
super(Random_Noise, self).__init__()
|
9 |
+
for i in range(len(layers)):
|
10 |
+
layers[i] = eval(layers[i])
|
11 |
+
self.noise = nn.Sequential(*layers)
|
12 |
+
self.len_layers_R = len_layers_R
|
13 |
+
self.len_layers_F = len_layers_F
|
14 |
+
print(self.noise)
|
15 |
+
self.transform = transforms.Compose([
|
16 |
+
transforms.ToTensor(),
|
17 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
18 |
+
])
|
19 |
+
|
20 |
+
def forward(self, image_cover_mask):
|
21 |
+
image, cover_image, mask = image_cover_mask[0], image_cover_mask[1], image_cover_mask[2]
|
22 |
+
forward_image = image.clone().detach()
|
23 |
+
forward_cover_image = cover_image.clone().detach()
|
24 |
+
forward_mask = mask.clone().detach()
|
25 |
+
noised_image_C = torch.zeros_like(forward_image)
|
26 |
+
noised_image_R = torch.zeros_like(forward_image)
|
27 |
+
noised_image_F = torch.zeros_like(forward_image)
|
28 |
+
|
29 |
+
for index in range(forward_image.shape[0]):
|
30 |
+
random_noise_layer_C = np.random.choice(self.noise, 1)[0]
|
31 |
+
random_noise_layer_R = np.random.choice(self.noise[0:self.len_layers_R], 1)[0]
|
32 |
+
random_noise_layer_F = np.random.choice(self.noise[self.len_layers_R:self.len_layers_R + self.len_layers_F], 1)[0]
|
33 |
+
noised_image_C[index] = random_noise_layer_C([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)])
|
34 |
+
noised_image_R[index] = random_noise_layer_R([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)])
|
35 |
+
noised_image_F[index] = random_noise_layer_F([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)])
|
36 |
+
|
37 |
+
'''single_image = ((noised_image_C[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
|
38 |
+
im = Image.fromarray(single_image)
|
39 |
+
read = np.array(im, dtype=np.uint8)
|
40 |
+
noised_image_C[index] = self.transform(read).unsqueeze(0).to(image.device)
|
41 |
+
|
42 |
+
single_image = ((noised_image_R[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
|
43 |
+
im = Image.fromarray(single_image)
|
44 |
+
read = np.array(im, dtype=np.uint8)
|
45 |
+
noised_image_R[index] = self.transform(read).unsqueeze(0).to(image.device)
|
46 |
+
|
47 |
+
single_image = ((noised_image_F[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
|
48 |
+
im = Image.fromarray(single_image)
|
49 |
+
read = np.array(im, dtype=np.uint8)
|
50 |
+
noised_image_F[index] = self.transform(read).unsqueeze(0).to(image.device)
|
51 |
+
|
52 |
+
noised_image_gap_C = noised_image_C - forward_image
|
53 |
+
noised_image_gap_R = noised_image_R - forward_image
|
54 |
+
noised_image_gap_F = noised_image_F - forward_image'''
|
55 |
+
noised_image_gap_C = noised_image_C.clamp(-1, 1) - forward_image
|
56 |
+
noised_image_gap_R = noised_image_R.clamp(-1, 1) - forward_image
|
57 |
+
noised_image_gap_F = noised_image_F.clamp(-1, 1) - forward_image
|
58 |
+
|
59 |
+
return image + noised_image_gap_C, image + noised_image_gap_R, image + noised_image_gap_F
|
models/bitnetwork/ResBlock.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class SEAttention(nn.Module):
|
7 |
+
def __init__(self, in_channels, out_channels, reduction=8):
|
8 |
+
super(SEAttention, self).__init__()
|
9 |
+
self.se = nn.Sequential(
|
10 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
11 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False),
|
12 |
+
nn.ReLU(inplace=True),
|
13 |
+
nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False),
|
14 |
+
nn.Sigmoid()
|
15 |
+
)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x = self.se(x) * x
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class ChannelAttention(nn.Module):
|
23 |
+
def __init__(self, in_channels, out_channels, reduction=8):
|
24 |
+
super(ChannelAttention, self).__init__()
|
25 |
+
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
26 |
+
self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
|
27 |
+
|
28 |
+
self.fc = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False))
|
31 |
+
self.sigmoid = nn.Sigmoid()
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
avg_out = self.fc(self.avg_pool(x))
|
35 |
+
max_out = self.fc(self.max_pool(x))
|
36 |
+
out = avg_out + max_out
|
37 |
+
return self.sigmoid(out)
|
38 |
+
|
39 |
+
|
40 |
+
class SpatialAttention(nn.Module):
|
41 |
+
def __init__(self, kernel_size=7):
|
42 |
+
super(SpatialAttention, self).__init__()
|
43 |
+
|
44 |
+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
45 |
+
self.sigmoid = nn.Sigmoid()
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
49 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
50 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
51 |
+
x = self.conv1(x)
|
52 |
+
return self.sigmoid(x)
|
53 |
+
|
54 |
+
|
55 |
+
class CBAMAttention(nn.Module):
|
56 |
+
def __init__(self, in_channels, out_channels, reduction=8):
|
57 |
+
super(CBAMAttention, self).__init__()
|
58 |
+
self.ca = ChannelAttention(in_channels=in_channels, out_channels=out_channels, reduction=reduction)
|
59 |
+
self.sa = SpatialAttention()
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = self.ca(x) * x
|
63 |
+
x = self.sa(x) * x
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class h_sigmoid(nn.Module):
|
68 |
+
def __init__(self, inplace=True):
|
69 |
+
super(h_sigmoid, self).__init__()
|
70 |
+
self.relu = nn.ReLU6(inplace=inplace)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
return self.relu(x + 3) / 6
|
74 |
+
|
75 |
+
|
76 |
+
class h_swish(nn.Module):
|
77 |
+
def __init__(self, inplace=True):
|
78 |
+
super(h_swish, self).__init__()
|
79 |
+
self.sigmoid = h_sigmoid(inplace=inplace)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
return x * self.sigmoid(x)
|
83 |
+
|
84 |
+
|
85 |
+
class CoordAttention(nn.Module):
|
86 |
+
def __init__(self, in_channels, out_channels, reduction=8):
|
87 |
+
super(CoordAttention, self).__init__()
|
88 |
+
self.pool_w, self.pool_h = nn.AdaptiveAvgPool2d((1, None)), nn.AdaptiveAvgPool2d((None, 1))
|
89 |
+
temp_c = max(8, in_channels // reduction)
|
90 |
+
self.conv1 = nn.Conv2d(in_channels, temp_c, kernel_size=1, stride=1, padding=0)
|
91 |
+
|
92 |
+
self.bn1 = nn.InstanceNorm2d(temp_c)
|
93 |
+
self.act1 = h_swish() # nn.SiLU() # nn.Hardswish() # nn.SiLU()
|
94 |
+
|
95 |
+
self.conv2 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
|
96 |
+
self.conv3 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
short = x
|
100 |
+
n, c, H, W = x.shape
|
101 |
+
x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2)
|
102 |
+
x_cat = torch.cat([x_h, x_w], dim=2)
|
103 |
+
out = self.act1(self.bn1(self.conv1(x_cat)))
|
104 |
+
x_h, x_w = torch.split(out, [H, W], dim=2)
|
105 |
+
x_w = x_w.permute(0, 1, 3, 2)
|
106 |
+
out_h = torch.sigmoid(self.conv2(x_h))
|
107 |
+
out_w = torch.sigmoid(self.conv3(x_w))
|
108 |
+
return short * out_w * out_h
|
109 |
+
|
110 |
+
|
111 |
+
class BasicBlock(nn.Module):
|
112 |
+
def __init__(self, in_channels, out_channels, reduction, stride, attention=None):
|
113 |
+
super(BasicBlock, self).__init__()
|
114 |
+
|
115 |
+
self.change = None
|
116 |
+
if (in_channels != out_channels or stride != 1):
|
117 |
+
self.change = nn.Sequential(
|
118 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0,
|
119 |
+
stride=stride, bias=False),
|
120 |
+
nn.InstanceNorm2d(out_channels)
|
121 |
+
)
|
122 |
+
|
123 |
+
self.left = nn.Sequential(
|
124 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1,
|
125 |
+
stride=stride, bias=False),
|
126 |
+
nn.InstanceNorm2d(out_channels),
|
127 |
+
nn.ReLU(inplace=True),
|
128 |
+
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
|
129 |
+
nn.InstanceNorm2d(out_channels)
|
130 |
+
)
|
131 |
+
|
132 |
+
if attention == 'se':
|
133 |
+
print('SEAttention')
|
134 |
+
self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
|
135 |
+
elif attention == 'cbam':
|
136 |
+
print('CBAMAttention')
|
137 |
+
self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
|
138 |
+
elif attention == 'coord':
|
139 |
+
print('CoordAttention')
|
140 |
+
self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
|
141 |
+
else:
|
142 |
+
print('None Attention')
|
143 |
+
self.attention = nn.Identity()
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
identity = x
|
147 |
+
x = self.left(x)
|
148 |
+
x = self.attention(x)
|
149 |
+
|
150 |
+
if self.change is not None:
|
151 |
+
identity = self.change(identity)
|
152 |
+
|
153 |
+
x += identity
|
154 |
+
x = F.relu(x)
|
155 |
+
return x
|
156 |
+
|
157 |
+
|
158 |
+
class BottleneckBlock(nn.Module):
|
159 |
+
def __init__(self, in_channels, out_channels, reduction, stride, attention=None):
|
160 |
+
super(BottleneckBlock, self).__init__()
|
161 |
+
|
162 |
+
self.change = None
|
163 |
+
if (in_channels != out_channels or stride != 1):
|
164 |
+
self.change = nn.Sequential(
|
165 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0,
|
166 |
+
stride=stride, bias=False),
|
167 |
+
nn.InstanceNorm2d(out_channels)
|
168 |
+
)
|
169 |
+
|
170 |
+
self.left = nn.Sequential(
|
171 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
|
172 |
+
stride=stride, padding=0, bias=False),
|
173 |
+
nn.InstanceNorm2d(out_channels),
|
174 |
+
nn.ReLU(inplace=True),
|
175 |
+
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
|
176 |
+
nn.InstanceNorm2d(out_channels),
|
177 |
+
nn.ReLU(inplace=True),
|
178 |
+
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding=0, bias=False),
|
179 |
+
nn.InstanceNorm2d(out_channels)
|
180 |
+
)
|
181 |
+
|
182 |
+
if attention == 'se':
|
183 |
+
print('SEAttention')
|
184 |
+
self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
|
185 |
+
elif attention == 'cbam':
|
186 |
+
print('CBAMAttention')
|
187 |
+
self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
|
188 |
+
elif attention == 'coord':
|
189 |
+
print('CoordAttention')
|
190 |
+
self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
|
191 |
+
else:
|
192 |
+
print('None Attention')
|
193 |
+
self.attention = nn.Identity()
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
identity = x
|
197 |
+
x = self.left(x)
|
198 |
+
x = self.attention(x)
|
199 |
+
|
200 |
+
if self.change is not None:
|
201 |
+
identity = self.change(identity)
|
202 |
+
|
203 |
+
x += identity
|
204 |
+
x = F.relu(x)
|
205 |
+
return x
|
206 |
+
|
207 |
+
|
208 |
+
class ResBlock(nn.Module):
|
209 |
+
|
210 |
+
def __init__(self, in_channels, out_channels, blocks=1, block_type="BottleneckBlock", reduction=8, stride=1, attention=None):
|
211 |
+
super(ResBlock, self).__init__()
|
212 |
+
|
213 |
+
layers = [eval(block_type)(in_channels, out_channels, reduction, stride, attention=attention)] if blocks != 0 else []
|
214 |
+
for _ in range(blocks - 1):
|
215 |
+
layer = eval(block_type)(out_channels, out_channels, reduction, 1, attention=attention)
|
216 |
+
layers.append(layer)
|
217 |
+
|
218 |
+
self.layers = nn.Sequential(*layers)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
return self.layers(x)
|
222 |
+
|
models/bitnetwork/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
# import kornia.losses
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
from .ResBlock import *
|
9 |
+
from .ConvBlock import *
|
models/bitnetwork/__pycache__/ConvBlock.cpython-38.pyc
ADDED
Binary file (1.58 kB). View file
|
|
models/bitnetwork/__pycache__/Decoder_U.cpython-38.pyc
ADDED
Binary file (2.75 kB). View file
|
|
models/bitnetwork/__pycache__/Encoder_U.cpython-38.pyc
ADDED
Binary file (3.51 kB). View file
|
|
models/bitnetwork/__pycache__/ResBlock.cpython-38.pyc
ADDED
Binary file (7.51 kB). View file
|
|
models/bitnetwork/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (410 Bytes). View file
|
|
models/discrim.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from torch.nn.utils import spectral_norm
|
4 |
+
|
5 |
+
|
6 |
+
class UNetDiscriminatorSN(nn.Module):
|
7 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
8 |
+
|
9 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
10 |
+
|
11 |
+
Arg:
|
12 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
13 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
14 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
18 |
+
super(UNetDiscriminatorSN, self).__init__()
|
19 |
+
self.skip_connection = skip_connection
|
20 |
+
norm = spectral_norm
|
21 |
+
# the first convolution
|
22 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
23 |
+
# downsample
|
24 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
25 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
26 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
27 |
+
# upsample
|
28 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
29 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
30 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
31 |
+
# extra convolutions
|
32 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
33 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
34 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
# downsample
|
38 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
39 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
40 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
41 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
42 |
+
|
43 |
+
# upsample
|
44 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
45 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
46 |
+
|
47 |
+
if self.skip_connection:
|
48 |
+
x4 = x4 + x2
|
49 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
50 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
51 |
+
|
52 |
+
if self.skip_connection:
|
53 |
+
x5 = x5 + x1
|
54 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
55 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
56 |
+
|
57 |
+
if self.skip_connection:
|
58 |
+
x6 = x6 + x0
|
59 |
+
|
60 |
+
# extra convolutions
|
61 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
62 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
63 |
+
out = self.conv9(out)
|
64 |
+
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class GANLoss(nn.Module):
|
69 |
+
"""Define GAN loss.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
73 |
+
real_label_val (float): The value for real label. Default: 1.0.
|
74 |
+
fake_label_val (float): The value for fake label. Default: 0.0.
|
75 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
76 |
+
Note that loss_weight is only for generators; and it is always 1.0
|
77 |
+
for discriminators.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
81 |
+
super(GANLoss, self).__init__()
|
82 |
+
self.gan_type = gan_type
|
83 |
+
self.loss_weight = loss_weight
|
84 |
+
self.real_label_val = real_label_val
|
85 |
+
self.fake_label_val = fake_label_val
|
86 |
+
|
87 |
+
if self.gan_type == 'vanilla':
|
88 |
+
self.loss = nn.BCEWithLogitsLoss()
|
89 |
+
elif self.gan_type == 'lsgan':
|
90 |
+
self.loss = nn.MSELoss()
|
91 |
+
elif self.gan_type == 'wgan':
|
92 |
+
self.loss = self._wgan_loss
|
93 |
+
elif self.gan_type == 'wgan_softplus':
|
94 |
+
self.loss = self._wgan_softplus_loss
|
95 |
+
elif self.gan_type == 'hinge':
|
96 |
+
self.loss = nn.ReLU()
|
97 |
+
else:
|
98 |
+
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
|
99 |
+
|
100 |
+
def _wgan_loss(self, input, target):
|
101 |
+
"""wgan loss.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
input (Tensor): Input tensor.
|
105 |
+
target (bool): Target label.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
Tensor: wgan loss.
|
109 |
+
"""
|
110 |
+
return -input.mean() if target else input.mean()
|
111 |
+
|
112 |
+
def _wgan_softplus_loss(self, input, target):
|
113 |
+
"""wgan loss with soft plus. softplus is a smooth approximation to the
|
114 |
+
ReLU function.
|
115 |
+
|
116 |
+
In StyleGAN2, it is called:
|
117 |
+
Logistic loss for discriminator;
|
118 |
+
Non-saturating loss for generator.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
input (Tensor): Input tensor.
|
122 |
+
target (bool): Target label.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Tensor: wgan loss.
|
126 |
+
"""
|
127 |
+
return F.softplus(-input).mean() if target else F.softplus(input).mean()
|
128 |
+
|
129 |
+
def get_target_label(self, input, target_is_real):
|
130 |
+
"""Get target label.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
input (Tensor): Input tensor.
|
134 |
+
target_is_real (bool): Whether the target is real or fake.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
|
138 |
+
return Tensor.
|
139 |
+
"""
|
140 |
+
|
141 |
+
if self.gan_type in ['wgan', 'wgan_softplus']:
|
142 |
+
return target_is_real
|
143 |
+
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
|
144 |
+
return input.new_ones(input.size()) * target_val
|
145 |
+
|
146 |
+
def forward(self, input, target_is_real, is_disc=False):
|
147 |
+
"""
|
148 |
+
Args:
|
149 |
+
input (Tensor): The input for the loss module, i.e., the network
|
150 |
+
prediction.
|
151 |
+
target_is_real (bool): Whether the targe is real or fake.
|
152 |
+
is_disc (bool): Whether the loss for discriminators or not.
|
153 |
+
Default: False.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tensor: GAN loss value.
|
157 |
+
"""
|
158 |
+
target_label = self.get_target_label(input, target_is_real)
|
159 |
+
if self.gan_type == 'hinge':
|
160 |
+
if is_disc: # for discriminators in hinge-gan
|
161 |
+
input = -input if target_is_real else input
|
162 |
+
loss = self.loss(1 + input).mean()
|
163 |
+
else: # for generators in hinge-gan
|
164 |
+
loss = -input.mean()
|
165 |
+
else: # other gan types
|
166 |
+
loss = self.loss(input, target_label)
|
167 |
+
|
168 |
+
# loss_weight is always 1.0 for discriminators
|
169 |
+
return loss if is_disc else loss * self.loss_weight
|
models/lr_scheduler.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import Counter
|
3 |
+
from collections import defaultdict
|
4 |
+
import torch
|
5 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
6 |
+
|
7 |
+
|
8 |
+
class MultiStepLR_Restart(_LRScheduler):
|
9 |
+
def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
|
10 |
+
clear_state=False, last_epoch=-1):
|
11 |
+
self.milestones = Counter(milestones)
|
12 |
+
self.gamma = gamma
|
13 |
+
self.clear_state = clear_state
|
14 |
+
self.restarts = restarts if restarts else [0]
|
15 |
+
self.restart_weights = weights if weights else [1]
|
16 |
+
assert len(self.restarts) == len(
|
17 |
+
self.restart_weights), 'restarts and their weights do not match.'
|
18 |
+
super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
|
19 |
+
|
20 |
+
def get_lr(self):
|
21 |
+
if self.last_epoch in self.restarts:
|
22 |
+
if self.clear_state:
|
23 |
+
self.optimizer.state = defaultdict(dict)
|
24 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
25 |
+
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
26 |
+
if self.last_epoch not in self.milestones:
|
27 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
28 |
+
return [
|
29 |
+
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
30 |
+
for group in self.optimizer.param_groups
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
class CosineAnnealingLR_Restart(_LRScheduler):
|
35 |
+
def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
|
36 |
+
self.T_period = T_period
|
37 |
+
self.T_max = self.T_period[0] # current T period
|
38 |
+
self.eta_min = eta_min
|
39 |
+
self.restarts = restarts if restarts else [0]
|
40 |
+
self.restart_weights = weights if weights else [1]
|
41 |
+
self.last_restart = 0
|
42 |
+
assert len(self.restarts) == len(
|
43 |
+
self.restart_weights), 'restarts and their weights do not match.'
|
44 |
+
super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
|
45 |
+
|
46 |
+
def get_lr(self):
|
47 |
+
if self.last_epoch == 0:
|
48 |
+
return self.base_lrs
|
49 |
+
elif self.last_epoch in self.restarts:
|
50 |
+
self.last_restart = self.last_epoch
|
51 |
+
self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
|
52 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
53 |
+
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
54 |
+
elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
|
55 |
+
return [
|
56 |
+
group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
|
57 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
58 |
+
]
|
59 |
+
return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
|
60 |
+
(1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
|
61 |
+
(group['lr'] - self.eta_min) + self.eta_min
|
62 |
+
for group in self.optimizer.param_groups]
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
|
67 |
+
betas=(0.9, 0.99))
|
68 |
+
##############################
|
69 |
+
# MultiStepLR_Restart
|
70 |
+
##############################
|
71 |
+
## Original
|
72 |
+
lr_steps = [200000, 400000, 600000, 800000]
|
73 |
+
restarts = None
|
74 |
+
restart_weights = None
|
75 |
+
|
76 |
+
## two
|
77 |
+
lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
|
78 |
+
restarts = [500000]
|
79 |
+
restart_weights = [1]
|
80 |
+
|
81 |
+
## four
|
82 |
+
lr_steps = [
|
83 |
+
50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
|
84 |
+
600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
|
85 |
+
]
|
86 |
+
restarts = [250000, 500000, 750000]
|
87 |
+
restart_weights = [1, 1, 1]
|
88 |
+
|
89 |
+
scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
|
90 |
+
clear_state=False)
|
91 |
+
|
92 |
+
##############################
|
93 |
+
# Cosine Annealing Restart
|
94 |
+
##############################
|
95 |
+
## two
|
96 |
+
T_period = [500000, 500000]
|
97 |
+
restarts = [500000]
|
98 |
+
restart_weights = [1]
|
99 |
+
|
100 |
+
## four
|
101 |
+
T_period = [250000, 250000, 250000, 250000]
|
102 |
+
restarts = [250000, 500000, 750000]
|
103 |
+
restart_weights = [1, 1, 1]
|
104 |
+
|
105 |
+
scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
|
106 |
+
weights=restart_weights)
|
107 |
+
|
108 |
+
##############################
|
109 |
+
# Draw figure
|
110 |
+
##############################
|
111 |
+
N_iter = 1000000
|
112 |
+
lr_l = list(range(N_iter))
|
113 |
+
for i in range(N_iter):
|
114 |
+
scheduler.step()
|
115 |
+
current_lr = optimizer.param_groups[0]['lr']
|
116 |
+
lr_l[i] = current_lr
|
117 |
+
|
118 |
+
import matplotlib as mpl
|
119 |
+
from matplotlib import pyplot as plt
|
120 |
+
import matplotlib.ticker as mtick
|
121 |
+
mpl.style.use('default')
|
122 |
+
import seaborn
|
123 |
+
seaborn.set(style='whitegrid')
|
124 |
+
seaborn.set_context('paper')
|
125 |
+
|
126 |
+
plt.figure(1)
|
127 |
+
plt.subplot(111)
|
128 |
+
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
|
129 |
+
plt.title('Title', fontsize=16, color='k')
|
130 |
+
plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
|
131 |
+
legend = plt.legend(loc='upper right', shadow=False)
|
132 |
+
ax = plt.gca()
|
133 |
+
labels = ax.get_xticks().tolist()
|
134 |
+
for k, v in enumerate(labels):
|
135 |
+
labels[k] = str(int(v / 1000)) + 'K'
|
136 |
+
ax.set_xticklabels(labels)
|
137 |
+
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
|
138 |
+
|
139 |
+
ax.set_ylabel('Learning rate')
|
140 |
+
ax.set_xlabel('Iteration')
|
141 |
+
fig = plt.gcf()
|
142 |
+
plt.show()
|
models/modules/Inv_arch.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .module_util import initialize_weights_xavier
|
7 |
+
from torch.nn import init
|
8 |
+
from .common import DWT,IWT
|
9 |
+
import cv2
|
10 |
+
from basicsr.archs.arch_util import flow_warp
|
11 |
+
from models.modules.Subnet_constructor import subnet
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from pdb import set_trace as stx
|
15 |
+
import numbers
|
16 |
+
|
17 |
+
from einops import rearrange
|
18 |
+
from models.bitnetwork.Encoder_U import DW_Encoder
|
19 |
+
from models.bitnetwork.Decoder_U import DW_Decoder
|
20 |
+
|
21 |
+
|
22 |
+
## Layer Norm
|
23 |
+
def to_3d(x):
|
24 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
25 |
+
|
26 |
+
|
27 |
+
def to_4d(x, h, w):
|
28 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
29 |
+
|
30 |
+
|
31 |
+
class BiasFree_LayerNorm(nn.Module):
|
32 |
+
def __init__(self, normalized_shape):
|
33 |
+
super(BiasFree_LayerNorm, self).__init__()
|
34 |
+
if isinstance(normalized_shape, numbers.Integral):
|
35 |
+
normalized_shape = (normalized_shape,)
|
36 |
+
normalized_shape = torch.Size(normalized_shape)
|
37 |
+
|
38 |
+
assert len(normalized_shape) == 1
|
39 |
+
|
40 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
41 |
+
self.normalized_shape = normalized_shape
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
45 |
+
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
46 |
+
|
47 |
+
|
48 |
+
class WithBias_LayerNorm(nn.Module):
|
49 |
+
def __init__(self, normalized_shape):
|
50 |
+
super(WithBias_LayerNorm, self).__init__()
|
51 |
+
if isinstance(normalized_shape, numbers.Integral):
|
52 |
+
normalized_shape = (normalized_shape,)
|
53 |
+
normalized_shape = torch.Size(normalized_shape)
|
54 |
+
|
55 |
+
assert len(normalized_shape) == 1
|
56 |
+
|
57 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
58 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
59 |
+
self.normalized_shape = normalized_shape
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
mu = x.mean(-1, keepdim=True)
|
63 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
64 |
+
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
65 |
+
|
66 |
+
|
67 |
+
class LayerNorm(nn.Module):
|
68 |
+
def __init__(self, dim, LayerNorm_type):
|
69 |
+
super(LayerNorm, self).__init__()
|
70 |
+
if LayerNorm_type == 'BiasFree':
|
71 |
+
self.body = BiasFree_LayerNorm(dim)
|
72 |
+
else:
|
73 |
+
self.body = WithBias_LayerNorm(dim)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
h, w = x.shape[-2:]
|
77 |
+
return to_4d(self.body(to_3d(x)), h, w)
|
78 |
+
|
79 |
+
|
80 |
+
##########################################################################
|
81 |
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
82 |
+
class FeedForward(nn.Module):
|
83 |
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
84 |
+
super(FeedForward, self).__init__()
|
85 |
+
|
86 |
+
hidden_features = int(dim * ffn_expansion_factor)
|
87 |
+
|
88 |
+
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
89 |
+
|
90 |
+
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
|
91 |
+
groups=hidden_features * 2, bias=bias)
|
92 |
+
|
93 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.project_in(x)
|
97 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
98 |
+
x = F.gelu(x1) * x2
|
99 |
+
x = self.project_out(x)
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
##########################################################################
|
104 |
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
105 |
+
class Attention(nn.Module):
|
106 |
+
def __init__(self, dim, num_heads, bias):
|
107 |
+
super(Attention, self).__init__()
|
108 |
+
self.num_heads = num_heads
|
109 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
110 |
+
|
111 |
+
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
112 |
+
self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
|
113 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
b, c, h, w = x.shape
|
117 |
+
|
118 |
+
qkv = self.qkv_dwconv(self.qkv(x))
|
119 |
+
q, k, v = qkv.chunk(3, dim=1)
|
120 |
+
|
121 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
122 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
123 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
124 |
+
|
125 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
126 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
127 |
+
|
128 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
129 |
+
attn = attn.softmax(dim=-1)
|
130 |
+
|
131 |
+
out = (attn @ v)
|
132 |
+
|
133 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
134 |
+
|
135 |
+
out = self.project_out(out)
|
136 |
+
return out
|
137 |
+
|
138 |
+
|
139 |
+
##########################################################################
|
140 |
+
class TransformerBlock(nn.Module):
|
141 |
+
def __init__(self, dim, num_heads=4, ffn_expansion_factor=4, bias=False, LayerNorm_type="withbias"):
|
142 |
+
super(TransformerBlock, self).__init__()
|
143 |
+
|
144 |
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
145 |
+
self.attn = Attention(dim, num_heads, bias)
|
146 |
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
147 |
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = x + self.attn(self.norm1(x))
|
151 |
+
x = x + self.ffn(self.norm2(x))
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
dwt=DWT()
|
156 |
+
iwt=IWT()
|
157 |
+
|
158 |
+
class LayerNormFunction(torch.autograd.Function):
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def forward(ctx, x, weight, bias, eps):
|
162 |
+
ctx.eps = eps
|
163 |
+
N, C, H, W = x.size()
|
164 |
+
mu = x.mean(1, keepdim=True)
|
165 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
166 |
+
y = (x - mu) / (var + eps).sqrt()
|
167 |
+
ctx.save_for_backward(y, var, weight)
|
168 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
169 |
+
return y
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def backward(ctx, grad_output):
|
173 |
+
eps = ctx.eps
|
174 |
+
|
175 |
+
N, C, H, W = grad_output.size()
|
176 |
+
y, var, weight = ctx.saved_variables
|
177 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
178 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
179 |
+
|
180 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
181 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
182 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
183 |
+
dim=0), None
|
184 |
+
|
185 |
+
class LayerNorm2d(nn.Module):
|
186 |
+
|
187 |
+
def __init__(self, channels, eps=1e-6):
|
188 |
+
super(LayerNorm2d, self).__init__()
|
189 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
190 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
191 |
+
self.eps = eps
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
195 |
+
|
196 |
+
class SimpleGate(nn.Module):
|
197 |
+
def forward(self, x):
|
198 |
+
x1, x2 = x.chunk(2, dim=1)
|
199 |
+
return x1 * x2
|
200 |
+
|
201 |
+
class NAFBlock(nn.Module):
|
202 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
203 |
+
super().__init__()
|
204 |
+
dw_channel = c * DW_Expand
|
205 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
206 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
207 |
+
bias=True)
|
208 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
209 |
+
|
210 |
+
# Simplified Channel Attention
|
211 |
+
self.sca = nn.Sequential(
|
212 |
+
nn.AdaptiveAvgPool2d(1),
|
213 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
214 |
+
groups=1, bias=True),
|
215 |
+
)
|
216 |
+
|
217 |
+
# SimpleGate
|
218 |
+
self.sg = SimpleGate()
|
219 |
+
|
220 |
+
ffn_channel = FFN_Expand * c
|
221 |
+
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
222 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
223 |
+
|
224 |
+
self.norm1 = LayerNorm2d(c)
|
225 |
+
self.norm2 = LayerNorm2d(c)
|
226 |
+
|
227 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
228 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
229 |
+
|
230 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
231 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
232 |
+
|
233 |
+
def forward(self, inp):
|
234 |
+
x = inp
|
235 |
+
|
236 |
+
x = self.norm1(x)
|
237 |
+
|
238 |
+
x = self.conv1(x)
|
239 |
+
x = self.conv2(x)
|
240 |
+
x = self.sg(x)
|
241 |
+
x = x * self.sca(x)
|
242 |
+
x = self.conv3(x)
|
243 |
+
|
244 |
+
x = self.dropout1(x)
|
245 |
+
|
246 |
+
y = inp + x * self.beta
|
247 |
+
|
248 |
+
x = self.conv4(self.norm2(y))
|
249 |
+
x = self.sg(x)
|
250 |
+
x = self.conv5(x)
|
251 |
+
|
252 |
+
x = self.dropout2(x)
|
253 |
+
|
254 |
+
return y + x * self.gamma
|
255 |
+
|
256 |
+
def thops_mean(tensor, dim=None, keepdim=False):
|
257 |
+
if dim is None:
|
258 |
+
# mean all dim
|
259 |
+
return torch.mean(tensor)
|
260 |
+
else:
|
261 |
+
if isinstance(dim, int):
|
262 |
+
dim = [dim]
|
263 |
+
dim = sorted(dim)
|
264 |
+
for d in dim:
|
265 |
+
tensor = tensor.mean(dim=d, keepdim=True)
|
266 |
+
if not keepdim:
|
267 |
+
for i, d in enumerate(dim):
|
268 |
+
tensor.squeeze_(d-i)
|
269 |
+
return tensor
|
270 |
+
|
271 |
+
|
272 |
+
class ResidualBlockNoBN(nn.Module):
|
273 |
+
def __init__(self, nf=64, model='MIMO-VRN'):
|
274 |
+
super(ResidualBlockNoBN, self).__init__()
|
275 |
+
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
276 |
+
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
277 |
+
# honestly, there's no significant difference between ReLU and leaky ReLU in terms of performance here
|
278 |
+
# but this is how we trained the model in the first place and what we reported in the paper
|
279 |
+
if model == 'LSTM-VRN':
|
280 |
+
self.relu = nn.ReLU(inplace=True)
|
281 |
+
elif model == 'MIMO-VRN':
|
282 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
283 |
+
|
284 |
+
# initialization
|
285 |
+
initialize_weights_xavier([self.conv1, self.conv2], 0.1)
|
286 |
+
|
287 |
+
def forward(self, x):
|
288 |
+
identity = x
|
289 |
+
out = self.relu(self.conv1(x))
|
290 |
+
out = self.conv2(out)
|
291 |
+
return identity + out
|
292 |
+
|
293 |
+
|
294 |
+
class InvBlock(nn.Module):
|
295 |
+
def __init__(self, subnet_constructor, subnet_constructor_v2, channel_num_ho, channel_num_hi, groups, clamp=1.):
|
296 |
+
super(InvBlock, self).__init__()
|
297 |
+
self.split_len1 = channel_num_ho # channel_split_num
|
298 |
+
self.split_len2 = channel_num_hi # channel_num - channel_split_num
|
299 |
+
self.clamp = clamp
|
300 |
+
|
301 |
+
self.F = subnet_constructor_v2(self.split_len2, self.split_len1, groups=groups)
|
302 |
+
self.NF = NAFBlock(self.split_len2)
|
303 |
+
if groups == 1:
|
304 |
+
self.G = subnet_constructor(self.split_len1, self.split_len2, groups=groups)
|
305 |
+
self.NG = NAFBlock(self.split_len1)
|
306 |
+
self.H = subnet_constructor(self.split_len1, self.split_len2, groups=groups)
|
307 |
+
self.NH = NAFBlock(self.split_len1)
|
308 |
+
else:
|
309 |
+
self.G = subnet_constructor(self.split_len1, self.split_len2)
|
310 |
+
self.NG = NAFBlock(self.split_len1)
|
311 |
+
self.H = subnet_constructor(self.split_len1, self.split_len2)
|
312 |
+
self.NH = NAFBlock(self.split_len1)
|
313 |
+
|
314 |
+
def forward(self, x1, x2, rev=False):
|
315 |
+
if not rev:
|
316 |
+
y1 = x1 + self.NF(self.F(x2))
|
317 |
+
self.s = self.clamp * (torch.sigmoid(self.NH(self.H(y1))) * 2 - 1)
|
318 |
+
y2 = [x2i.mul(torch.exp(self.s)) + self.NG(self.G(y1)) for x2i in x2]
|
319 |
+
else:
|
320 |
+
self.s = self.clamp * (torch.sigmoid(self.NH(self.H(x1))) * 2 - 1)
|
321 |
+
y2 = [(x2i - self.NG(self.G(x1))).div(torch.exp(self.s)) for x2i in x2]
|
322 |
+
y1 = x1 - self.NF(self.F(y2))
|
323 |
+
|
324 |
+
return y1, y2 # torch.cat((y1, y2), 1)
|
325 |
+
|
326 |
+
def jacobian(self, x, rev=False):
|
327 |
+
if not rev:
|
328 |
+
jac = torch.sum(self.s)
|
329 |
+
else:
|
330 |
+
jac = -torch.sum(self.s)
|
331 |
+
|
332 |
+
return jac / x.shape[0]
|
333 |
+
|
334 |
+
class InvNN(nn.Module):
|
335 |
+
def __init__(self, channel_in_ho=3, channel_in_hi=3, subnet_constructor=None, subnet_constructor_v2=None, block_num=[], down_num=2, groups=None):
|
336 |
+
super(InvNN, self).__init__()
|
337 |
+
operations = []
|
338 |
+
|
339 |
+
current_channel_ho = channel_in_ho
|
340 |
+
current_channel_hi = channel_in_hi
|
341 |
+
for i in range(down_num):
|
342 |
+
for j in range(block_num[i]):
|
343 |
+
b = InvBlock(subnet_constructor, subnet_constructor_v2, current_channel_ho, current_channel_hi, groups=groups)
|
344 |
+
operations.append(b)
|
345 |
+
|
346 |
+
self.operations = nn.ModuleList(operations)
|
347 |
+
|
348 |
+
def forward(self, x, x_h, rev=False, cal_jacobian=False):
|
349 |
+
# out = x
|
350 |
+
jacobian = 0
|
351 |
+
|
352 |
+
if not rev:
|
353 |
+
for op in self.operations:
|
354 |
+
x, x_h = op.forward(x, x_h, rev)
|
355 |
+
if cal_jacobian:
|
356 |
+
jacobian += op.jacobian(x, rev)
|
357 |
+
else:
|
358 |
+
for op in reversed(self.operations):
|
359 |
+
x, x_h = op.forward(x, x_h, rev)
|
360 |
+
if cal_jacobian:
|
361 |
+
jacobian += op.jacobian(x, rev)
|
362 |
+
|
363 |
+
if cal_jacobian:
|
364 |
+
return x, x_h, jacobian
|
365 |
+
else:
|
366 |
+
return x, x_h
|
367 |
+
|
368 |
+
class PredictiveModuleMIMO(nn.Module):
|
369 |
+
def __init__(self, channel_in, nf, block_num_rbm=8, block_num_trans=4):
|
370 |
+
super(PredictiveModuleMIMO, self).__init__()
|
371 |
+
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
|
372 |
+
res_block = []
|
373 |
+
trans_block = []
|
374 |
+
for i in range(block_num_rbm):
|
375 |
+
res_block.append(ResidualBlockNoBN(nf))
|
376 |
+
for j in range(block_num_trans):
|
377 |
+
trans_block.append(TransformerBlock(nf))
|
378 |
+
|
379 |
+
self.res_block = nn.Sequential(*res_block)
|
380 |
+
self.transformer_block = nn.Sequential(*trans_block)
|
381 |
+
|
382 |
+
def forward(self, x):
|
383 |
+
x = self.conv_in(x)
|
384 |
+
x = self.res_block(x)
|
385 |
+
res = self.transformer_block(x) + x
|
386 |
+
|
387 |
+
return res
|
388 |
+
|
389 |
+
class ConvRelu(nn.Module):
|
390 |
+
def __init__(self, channels_in, channels_out, stride=1, init_zero=False):
|
391 |
+
super(ConvRelu, self).__init__()
|
392 |
+
self.init_zero = init_zero
|
393 |
+
if self.init_zero:
|
394 |
+
self.layers = nn.Conv2d(channels_in, channels_out, 3, stride, padding=1)
|
395 |
+
|
396 |
+
else:
|
397 |
+
self.layers = nn.Sequential(
|
398 |
+
nn.Conv2d(channels_in, channels_out, 3, stride, padding=1),
|
399 |
+
nn.LeakyReLU(inplace=True)
|
400 |
+
)
|
401 |
+
|
402 |
+
def forward(self, x):
|
403 |
+
return self.layers(x)
|
404 |
+
|
405 |
+
class PredictiveModuleBit(nn.Module):
|
406 |
+
def __init__(self, channel_in, nf, block_num_rbm=4, block_num_trans=2):
|
407 |
+
super(PredictiveModuleBit, self).__init__()
|
408 |
+
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
|
409 |
+
res_block = []
|
410 |
+
trans_block = []
|
411 |
+
for i in range(block_num_rbm):
|
412 |
+
res_block.append(ResidualBlockNoBN(nf))
|
413 |
+
for j in range(block_num_trans):
|
414 |
+
trans_block.append(TransformerBlock(nf))
|
415 |
+
|
416 |
+
blocks = 4
|
417 |
+
layers = [ConvRelu(nf, 1, 2)]
|
418 |
+
for _ in range(blocks - 1):
|
419 |
+
layer = ConvRelu(1, 1, 2)
|
420 |
+
layers.append(layer)
|
421 |
+
self.layers = nn.Sequential(*layers)
|
422 |
+
|
423 |
+
self.res_block = nn.Sequential(*res_block)
|
424 |
+
self.transformer_block = nn.Sequential(*trans_block)
|
425 |
+
|
426 |
+
def forward(self, x):
|
427 |
+
x = self.conv_in(x)
|
428 |
+
x = self.res_block(x)
|
429 |
+
res = self.transformer_block(x) + x
|
430 |
+
res = self.layers(res)
|
431 |
+
|
432 |
+
return res
|
433 |
+
|
434 |
+
|
435 |
+
##---------- Prompt Gen Module -----------------------
|
436 |
+
class PromptGenBlock(nn.Module):
|
437 |
+
def __init__(self,prompt_dim=12,prompt_len=3,prompt_size = 36,lin_dim = 12):
|
438 |
+
super(PromptGenBlock,self).__init__()
|
439 |
+
self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
|
440 |
+
self.linear_layer = nn.Linear(lin_dim,prompt_len)
|
441 |
+
self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
|
442 |
+
|
443 |
+
|
444 |
+
def forward(self,x):
|
445 |
+
B,C,H,W = x.shape
|
446 |
+
emb = x.mean(dim=(-2,-1))
|
447 |
+
prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
|
448 |
+
prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
|
449 |
+
prompt = torch.sum(prompt,dim=1)
|
450 |
+
prompt = F.interpolate(prompt,(H,W),mode="bilinear")
|
451 |
+
prompt = self.conv3x3(prompt)
|
452 |
+
|
453 |
+
return prompt
|
454 |
+
|
455 |
+
class PredictiveModuleMIMO_prompt(nn.Module):
|
456 |
+
def __init__(self, channel_in, nf, prompt_len=3, block_num_rbm=8, block_num_trans=4):
|
457 |
+
super(PredictiveModuleMIMO_prompt, self).__init__()
|
458 |
+
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
|
459 |
+
res_block = []
|
460 |
+
trans_block = []
|
461 |
+
for i in range(block_num_rbm):
|
462 |
+
res_block.append(ResidualBlockNoBN(nf))
|
463 |
+
for j in range(block_num_trans):
|
464 |
+
trans_block.append(TransformerBlock(nf))
|
465 |
+
|
466 |
+
self.res_block = nn.Sequential(*res_block)
|
467 |
+
self.transformer_block = nn.Sequential(*trans_block)
|
468 |
+
self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_len,prompt_size = 36,lin_dim = nf)
|
469 |
+
self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
|
470 |
+
|
471 |
+
def forward(self, x):
|
472 |
+
x = self.conv_in(x)
|
473 |
+
x = self.res_block(x)
|
474 |
+
res = self.transformer_block(x) + x
|
475 |
+
prompt = self.prompt(res)
|
476 |
+
|
477 |
+
result = self.fuse(torch.cat([res, prompt], dim=1))
|
478 |
+
|
479 |
+
return result
|
480 |
+
|
481 |
+
def gauss_noise(shape):
|
482 |
+
noise = torch.zeros(shape).cuda()
|
483 |
+
for i in range(noise.shape[0]):
|
484 |
+
noise[i] = torch.randn(noise[i].shape).cuda()
|
485 |
+
|
486 |
+
return noise
|
487 |
+
|
488 |
+
def gauss_noise_mul(shape):
|
489 |
+
noise = torch.randn(shape).cuda()
|
490 |
+
|
491 |
+
return noise
|
492 |
+
|
493 |
+
class PredictiveModuleBit_prompt(nn.Module):
|
494 |
+
def __init__(self, channel_in, nf, prompt_length, block_num_rbm=4, block_num_trans=2):
|
495 |
+
super(PredictiveModuleBit_prompt, self).__init__()
|
496 |
+
self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
|
497 |
+
res_block = []
|
498 |
+
trans_block = []
|
499 |
+
for i in range(block_num_rbm):
|
500 |
+
res_block.append(ResidualBlockNoBN(nf))
|
501 |
+
for j in range(block_num_trans):
|
502 |
+
trans_block.append(TransformerBlock(nf))
|
503 |
+
|
504 |
+
blocks = 4
|
505 |
+
layers = [ConvRelu(nf, 1, 2)]
|
506 |
+
for _ in range(blocks - 1):
|
507 |
+
layer = ConvRelu(1, 1, 2)
|
508 |
+
layers.append(layer)
|
509 |
+
self.layers = nn.Sequential(*layers)
|
510 |
+
|
511 |
+
self.res_block = nn.Sequential(*res_block)
|
512 |
+
self.transformer_block = nn.Sequential(*trans_block)
|
513 |
+
self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_length,prompt_size = 36,lin_dim = nf)
|
514 |
+
self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
|
515 |
+
|
516 |
+
def forward(self, x):
|
517 |
+
x = self.conv_in(x)
|
518 |
+
x = self.res_block(x)
|
519 |
+
res = self.transformer_block(x) + x
|
520 |
+
prompt = self.prompt(res)
|
521 |
+
res = self.fuse(torch.cat([res, prompt], dim=1))
|
522 |
+
res = self.layers(res)
|
523 |
+
|
524 |
+
return res
|
525 |
+
|
526 |
+
class VSN(nn.Module):
|
527 |
+
def __init__(self, opt, subnet_constructor=None, subnet_constructor_v2=None, down_num=2):
|
528 |
+
super(VSN, self).__init__()
|
529 |
+
self.model = opt['model']
|
530 |
+
self.mode = opt['mode']
|
531 |
+
opt_net = opt['network_G']
|
532 |
+
self.num_image = opt['num_image']
|
533 |
+
self.gop = opt['gop']
|
534 |
+
self.channel_in = opt_net['in_nc'] * self.gop
|
535 |
+
self.channel_out = opt_net['out_nc'] * self.gop
|
536 |
+
self.channel_in_hi = opt_net['in_nc'] * self.gop
|
537 |
+
self.channel_in_ho = opt_net['in_nc'] * self.gop
|
538 |
+
self.message_len = opt['message_length']
|
539 |
+
|
540 |
+
self.block_num = opt_net['block_num']
|
541 |
+
self.block_num_rbm = opt_net['block_num_rbm']
|
542 |
+
self.block_num_trans = opt_net['block_num_trans']
|
543 |
+
self.nf = self.channel_in_hi
|
544 |
+
|
545 |
+
self.bitencoder = DW_Encoder(self.message_len, attention = "se")
|
546 |
+
self.bitdecoder = DW_Decoder(self.message_len, attention = "se")
|
547 |
+
self.irn = InvNN(self.channel_in_ho, self.channel_in_hi, subnet_constructor, subnet_constructor_v2, self.block_num, down_num, groups=self.num_image)
|
548 |
+
|
549 |
+
if opt['prompt']:
|
550 |
+
self.pm = PredictiveModuleMIMO_prompt(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans)
|
551 |
+
else:
|
552 |
+
self.pm = PredictiveModuleMIMO(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans)
|
553 |
+
self.BitPM = PredictiveModuleBit(3, 4, block_num_rbm=4, block_num_trans=2)
|
554 |
+
|
555 |
+
|
556 |
+
def forward(self, x, x_h=None, message=None, rev=False, hs=[], direction='f'):
|
557 |
+
if not rev:
|
558 |
+
if self.mode == "image":
|
559 |
+
out_y, out_y_h = self.irn(x, x_h, rev)
|
560 |
+
out_y = iwt(out_y)
|
561 |
+
encoded_image = self.bitencoder(out_y, message)
|
562 |
+
return out_y, encoded_image
|
563 |
+
|
564 |
+
elif self.mode == "bit":
|
565 |
+
out_y = iwt(x)
|
566 |
+
encoded_image = self.bitencoder(out_y, message)
|
567 |
+
return out_y, encoded_image
|
568 |
+
|
569 |
+
else:
|
570 |
+
if self.mode == "image":
|
571 |
+
recmessage = self.bitdecoder(x)
|
572 |
+
|
573 |
+
x = dwt(x)
|
574 |
+
out_z = self.pm(x).unsqueeze(1)
|
575 |
+
out_z_new = out_z.view(-1, self.num_image, self.channel_in, x.shape[-2], x.shape[-1])
|
576 |
+
out_z_new = [out_z_new[:,i] for i in range(self.num_image)]
|
577 |
+
out_x, out_x_h = self.irn(x, out_z_new, rev)
|
578 |
+
|
579 |
+
return out_x, out_x_h, out_z, recmessage
|
580 |
+
|
581 |
+
elif self.mode == "bit":
|
582 |
+
recmessage = self.bitdecoder(x)
|
583 |
+
return recmessage
|
584 |
+
|
models/modules/Quantization.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Quant(torch.autograd.Function):
|
5 |
+
|
6 |
+
@staticmethod
|
7 |
+
def forward(ctx, input):
|
8 |
+
input = torch.clamp(input, 0, 1)
|
9 |
+
output = (input * 255.).round() / 255.
|
10 |
+
return output
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def backward(ctx, grad_output):
|
14 |
+
return grad_output
|
15 |
+
|
16 |
+
class Quantization(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super(Quantization, self).__init__()
|
19 |
+
|
20 |
+
def forward(self, input):
|
21 |
+
return Quant.apply(input)
|
models/modules/Subnet_constructor.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import models.modules.module_util as mutil
|
5 |
+
from basicsr.archs.arch_util import flow_warp, ResidualBlockNoBN
|
6 |
+
from models.modules.module_util import initialize_weights_xavier
|
7 |
+
|
8 |
+
class DenseBlock(nn.Module):
|
9 |
+
def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
|
10 |
+
super(DenseBlock, self).__init__()
|
11 |
+
self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
|
12 |
+
self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
|
13 |
+
self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
|
14 |
+
self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
|
15 |
+
self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
|
16 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
17 |
+
self.H = None
|
18 |
+
|
19 |
+
if init == 'xavier':
|
20 |
+
mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
|
21 |
+
else:
|
22 |
+
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
|
23 |
+
mutil.initialize_weights(self.conv5, 0)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
if isinstance(x, list):
|
27 |
+
x = x[0]
|
28 |
+
x1 = self.lrelu(self.conv1(x))
|
29 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
30 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
31 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
32 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
33 |
+
|
34 |
+
return x5
|
35 |
+
|
36 |
+
class DenseBlock_v2(nn.Module):
|
37 |
+
def __init__(self, channel_in, channel_out, groups, init='xavier', gc=32, bias=True):
|
38 |
+
super(DenseBlock_v2, self).__init__()
|
39 |
+
self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
|
40 |
+
self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
|
41 |
+
self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
|
42 |
+
self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
|
43 |
+
self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
|
44 |
+
self.conv_final = nn.Conv2d(channel_out*groups, channel_out, 3, 1, 1, bias=bias)
|
45 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
46 |
+
|
47 |
+
if init == 'xavier':
|
48 |
+
mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
49 |
+
else:
|
50 |
+
mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
51 |
+
mutil.initialize_weights(self.conv_final, 0)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
res = []
|
55 |
+
for xi in x:
|
56 |
+
x1 = self.lrelu(self.conv1(xi))
|
57 |
+
x2 = self.lrelu(self.conv2(torch.cat((xi, x1), 1)))
|
58 |
+
x3 = self.lrelu(self.conv3(torch.cat((xi, x1, x2), 1)))
|
59 |
+
x4 = self.lrelu(self.conv4(torch.cat((xi, x1, x2, x3), 1)))
|
60 |
+
x5 = self.lrelu(self.conv5(torch.cat((xi, x1, x2, x3, x4), 1)))
|
61 |
+
res.append(x5)
|
62 |
+
res = torch.cat(res, dim=1)
|
63 |
+
res = self.conv_final(res)
|
64 |
+
|
65 |
+
return res
|
66 |
+
|
67 |
+
def subnet(net_structure, init='xavier'):
|
68 |
+
def constructor(channel_in, channel_out, groups=None):
|
69 |
+
if net_structure == 'DBNet':
|
70 |
+
if init == 'xavier':
|
71 |
+
return DenseBlock(channel_in, channel_out, init)
|
72 |
+
elif init == 'xavier_v2':
|
73 |
+
return DenseBlock_v2(channel_in, channel_out, groups, 'xavier')
|
74 |
+
else:
|
75 |
+
return DenseBlock(channel_in, channel_out)
|
76 |
+
else:
|
77 |
+
return None
|
78 |
+
|
79 |
+
return constructor
|
models/modules/__init__.py
ADDED
File without changes
|
models/modules/__pycache__/Conv1x1.cpython-38.pyc
ADDED
Binary file (1.35 kB). View file
|
|