Ricoooo commited on
Commit
5d21dd2
1 Parent(s): 74124af
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __pycache__/test_gradio.cpython-310.pyc +0 -0
  2. __pycache__/test_gradio.cpython-38.pyc +0 -0
  3. data/__init__.py +43 -0
  4. data/__pycache__/__init__.cpython-310.pyc +0 -0
  5. data/__pycache__/__init__.cpython-38.pyc +0 -0
  6. data/__pycache__/coco_dataset.cpython-38.pyc +0 -0
  7. data/__pycache__/coco_test_dataset.cpython-38.pyc +0 -0
  8. data/__pycache__/data_sampler.cpython-310.pyc +0 -0
  9. data/__pycache__/data_sampler.cpython-38.pyc +0 -0
  10. data/__pycache__/test_dataset_td.cpython-310.pyc +0 -0
  11. data/__pycache__/test_dataset_td.cpython-38.pyc +0 -0
  12. data/__pycache__/util.cpython-310.pyc +0 -0
  13. data/__pycache__/util.cpython-38.pyc +0 -0
  14. data/__pycache__/video_test_dataset.cpython-38.pyc +0 -0
  15. data/coco_dataset.py +90 -0
  16. data/coco_test_dataset.py +61 -0
  17. data/data_sampler.py +65 -0
  18. data/test_dataset_td.py +63 -0
  19. data/util.py +551 -0
  20. models/IBSN.py +738 -0
  21. models/__init__.py +11 -0
  22. models/__pycache__/IBSN.cpython-310.pyc +0 -0
  23. models/__pycache__/IBSN.cpython-38.pyc +0 -0
  24. models/__pycache__/__init__.cpython-310.pyc +0 -0
  25. models/__pycache__/__init__.cpython-38.pyc +0 -0
  26. models/__pycache__/base_model.cpython-38.pyc +0 -0
  27. models/__pycache__/lr_scheduler.cpython-38.pyc +0 -0
  28. models/__pycache__/networks.cpython-310.pyc +0 -0
  29. models/__pycache__/networks.cpython-38.pyc +0 -0
  30. models/base_model.py +119 -0
  31. models/bitnetwork/ConvBlock.py +38 -0
  32. models/bitnetwork/DW_EncoderDecoder.py +28 -0
  33. models/bitnetwork/Decoder_U.py +87 -0
  34. models/bitnetwork/Dual_Mark.py +249 -0
  35. models/bitnetwork/Encoder_U.py +125 -0
  36. models/bitnetwork/Random_Noise.py +59 -0
  37. models/bitnetwork/ResBlock.py +222 -0
  38. models/bitnetwork/__init__.py +9 -0
  39. models/bitnetwork/__pycache__/ConvBlock.cpython-38.pyc +0 -0
  40. models/bitnetwork/__pycache__/Decoder_U.cpython-38.pyc +0 -0
  41. models/bitnetwork/__pycache__/Encoder_U.cpython-38.pyc +0 -0
  42. models/bitnetwork/__pycache__/ResBlock.cpython-38.pyc +0 -0
  43. models/bitnetwork/__pycache__/__init__.cpython-38.pyc +0 -0
  44. models/discrim.py +169 -0
  45. models/lr_scheduler.py +142 -0
  46. models/modules/Inv_arch.py +584 -0
  47. models/modules/Quantization.py +21 -0
  48. models/modules/Subnet_constructor.py +79 -0
  49. models/modules/__init__.py +0 -0
  50. models/modules/__pycache__/Conv1x1.cpython-38.pyc +0 -0
__pycache__/test_gradio.cpython-310.pyc ADDED
Binary file (2.85 kB). View file
 
__pycache__/test_gradio.cpython-38.pyc ADDED
Binary file (2.83 kB). View file
 
data/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''create dataset and dataloader'''
2
+ import logging
3
+ import torch
4
+ import torch.utils.data
5
+
6
+ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
7
+ phase = dataset_opt['phase']
8
+ if phase == 'train':
9
+ if opt['dist']:
10
+ world_size = torch.distributed.get_world_size()
11
+ num_workers = dataset_opt['n_workers']
12
+ assert dataset_opt['batch_size'] % world_size == 0
13
+ batch_size = dataset_opt['batch_size'] // world_size
14
+ shuffle = False
15
+ else:
16
+ num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
17
+ batch_size = dataset_opt['batch_size']
18
+ shuffle = True
19
+ return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
20
+ num_workers=num_workers, sampler=sampler, drop_last=True,
21
+ pin_memory=False)
22
+ else:
23
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
24
+ pin_memory=True)
25
+
26
+
27
+ def create_dataset(dataset_opt):
28
+ mode = dataset_opt['mode']
29
+ if mode == 'test':
30
+ from data.coco_test_dataset import imageTestDataset as D
31
+ elif mode == 'train':
32
+ from data.coco_dataset import CoCoDataset as D
33
+ elif mode == 'td':
34
+ from data.test_dataset_td import imageTestDataset as D
35
+ else:
36
+ raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
37
+ print(mode)
38
+ dataset = D(dataset_opt)
39
+
40
+ logger = logging.getLogger('base')
41
+ logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
42
+ dataset_opt['name']))
43
+ return dataset
data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.42 kB). View file
 
data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.45 kB). View file
 
data/__pycache__/coco_dataset.cpython-38.pyc ADDED
Binary file (3.3 kB). View file
 
data/__pycache__/coco_test_dataset.cpython-38.pyc ADDED
Binary file (2.3 kB). View file
 
data/__pycache__/data_sampler.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
data/__pycache__/data_sampler.cpython-38.pyc ADDED
Binary file (2.63 kB). View file
 
data/__pycache__/test_dataset_td.cpython-310.pyc ADDED
Binary file (2.32 kB). View file
 
data/__pycache__/test_dataset_td.cpython-38.pyc ADDED
Binary file (2.29 kB). View file
 
data/__pycache__/util.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
data/__pycache__/util.cpython-38.pyc ADDED
Binary file (14.2 kB). View file
 
data/__pycache__/video_test_dataset.cpython-38.pyc ADDED
Binary file (2.48 kB). View file
 
data/coco_dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Vimeo90K dataset
3
+ support reading images from lmdb, image folder and memcached
4
+ '''
5
+ import logging
6
+ import os
7
+ import os.path as osp
8
+ import pickle
9
+ import random
10
+
11
+ import cv2
12
+ import lmdb
13
+ import numpy as np
14
+ import torch
15
+ import torch.utils.data as data
16
+
17
+ import data.util as util
18
+
19
+ try:
20
+ import mc
21
+ except ImportError:
22
+ pass
23
+ logger = logging.getLogger('base')
24
+
25
+ class CoCoDataset(data.Dataset):
26
+ def __init__(self, opt):
27
+ super(CoCoDataset, self).__init__()
28
+ self.opt = opt
29
+ # get train indexes
30
+ self.data_path = self.opt['data_path']
31
+ self.txt_path = self.opt['txt_path']
32
+ with open(self.txt_path) as f:
33
+ self.list_image = f.readlines()
34
+ self.list_image = [line.strip('\n') for line in self.list_image]
35
+ # temporal augmentation
36
+ self.interval_list = opt['interval_list']
37
+ self.random_reverse = opt['random_reverse']
38
+ logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
39
+ ','.join(str(x) for x in opt['interval_list']), self.random_reverse))
40
+ self.data_type = self.opt['data_type']
41
+ random.shuffle(self.list_image)
42
+ self.LR_input = True
43
+ self.num_image = self.opt['num_image']
44
+
45
+ def _ensure_memcached(self):
46
+ if self.mclient is None:
47
+ # specify the config files
48
+ server_list_config_file = None
49
+ client_config_file = None
50
+ self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
51
+ client_config_file)
52
+
53
+ def __getitem__(self, index):
54
+ GT_size = self.opt['GT_size']
55
+ image_name = self.list_image[index]
56
+ path_frame = os.path.join(self.data_path, image_name)
57
+ img_GT = util.read_img(None, osp.join(path_frame, path_frame))
58
+ index_h = random.randint(0, len(self.list_image) - 1)
59
+
60
+ # random crop
61
+ H, W, C = img_GT.shape
62
+ rnd_h = random.randint(0, max(0, H - GT_size))
63
+ rnd_w = random.randint(0, max(0, W - GT_size))
64
+ img_frames = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
65
+ # BGR to RGB, HWC to CHW, numpy to tensor
66
+ img_frames = img_frames[:, :, [2, 1, 0]]
67
+ img_frames = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames, (2, 0, 1)))).float().unsqueeze(0)
68
+
69
+ # process h_list
70
+ if index_h % 100 == 0:
71
+ path_frame_h = "../dataset/locwatermark/blue.png"
72
+ else:
73
+ image_name_h = self.list_image[index_h]
74
+ path_frame_h = os.path.join(self.data_path, image_name_h)
75
+
76
+ frame_h = util.read_img(None, osp.join(path_frame_h, path_frame_h))
77
+ H1, W1, C1 = frame_h.shape
78
+ rnd_h = random.randint(0, max(0, H1 - GT_size))
79
+ rnd_w = random.randint(0, max(0, W1 - GT_size))
80
+ img_frames_h = frame_h[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
81
+ img_frames_h = img_frames_h[:, :, [2, 1, 0]]
82
+ img_frames_h = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames_h, (2, 0, 1)))).float().unsqueeze(0)
83
+
84
+ img_frames_h = torch.nn.functional.interpolate(img_frames_h, size=(512, 512), mode='nearest', align_corners=None).unsqueeze(0)
85
+ img_frames = torch.nn.functional.interpolate(img_frames, size=(512, 512), mode='nearest', align_corners=None)
86
+
87
+ return {'GT': img_frames, 'LQ': img_frames_h}
88
+
89
+ def __len__(self):
90
+ return len(self.list_image)
data/coco_test_dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import torch
4
+ import torch.utils.data as data
5
+ import data.util as util
6
+
7
+ import random
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ class imageTestDataset(data.Dataset):
12
+
13
+ def __init__(self, opt):
14
+ super(imageTestDataset, self).__init__()
15
+ self.opt = opt
16
+ self.half_N_frames = opt['N_frames'] // 2
17
+ self.data_path = opt['data_path']
18
+ self.bit_path = opt['bit_path']
19
+ self.txt_path = self.opt['txt_path']
20
+ self.num_image = self.opt['num_image']
21
+ with open(self.txt_path) as f:
22
+ self.list_image = f.readlines()
23
+ self.list_image = [line.strip('\n') for line in self.list_image]
24
+ self.list_image.sort()
25
+ self.list_image = self.list_image
26
+ l = len(self.list_image) // (self.num_image + 1)
27
+ self.image_list_gt = self.list_image
28
+
29
+ def __getitem__(self, index):
30
+ path_GT = self.image_list_gt[index]
31
+
32
+ img_GT = util.read_img(None, osp.join(self.data_path, path_GT))
33
+ img_GT = img_GT[:, :, [2, 1, 0]]
34
+ img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
35
+ img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
36
+
37
+ T, C, W, H = img_GT.shape
38
+ list_h = []
39
+ R = 0
40
+ G = 0
41
+ B = 255
42
+ image = Image.new('RGB', (W, H), (R, G, B))
43
+ result = np.array(image) / 255.
44
+ expanded_matrix = np.expand_dims(result, axis=0)
45
+ expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
46
+ imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
47
+ imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
48
+
49
+ imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
50
+
51
+ list_h.append(imgs_LQ)
52
+
53
+ list_h = torch.stack(list_h, dim=0)
54
+
55
+ return {
56
+ 'LQ': list_h,
57
+ 'GT': img_GT
58
+ }
59
+
60
+ def __len__(self):
61
+ return len(self.image_list_gt)
data/data_sampler.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from torch.utils.data.distributed.DistributedSampler
3
+ Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
4
+ dataloader after each epoch
5
+ """
6
+ import math
7
+ import torch
8
+ from torch.utils.data.sampler import Sampler
9
+ import torch.distributed as dist
10
+
11
+
12
+ class DistIterSampler(Sampler):
13
+ """Sampler that restricts data loading to a subset of the dataset.
14
+
15
+ It is especially useful in conjunction with
16
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17
+ process can pass a DistributedSampler instance as a DataLoader sampler,
18
+ and load a subset of the original dataset that is exclusive to it.
19
+
20
+ .. note::
21
+ Dataset is assumed to be of constant size.
22
+
23
+ Arguments:
24
+ dataset: Dataset used for sampling.
25
+ num_replicas (optional): Number of processes participating in
26
+ distributed training.
27
+ rank (optional): Rank of the current process within num_replicas.
28
+ """
29
+
30
+ def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31
+ if num_replicas is None:
32
+ if not dist.is_available():
33
+ raise RuntimeError("Requires distributed package to be available")
34
+ num_replicas = dist.get_world_size()
35
+ if rank is None:
36
+ if not dist.is_available():
37
+ raise RuntimeError("Requires distributed package to be available")
38
+ rank = dist.get_rank()
39
+ self.dataset = dataset
40
+ self.num_replicas = num_replicas
41
+ self.rank = rank
42
+ self.epoch = 0
43
+ self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
44
+ self.total_size = self.num_samples * self.num_replicas
45
+
46
+ def __iter__(self):
47
+ # deterministically shuffle based on epoch
48
+ g = torch.Generator()
49
+ g.manual_seed(self.epoch)
50
+ indices = torch.randperm(self.total_size, generator=g).tolist()
51
+
52
+ dsize = len(self.dataset)
53
+ indices = [v % dsize for v in indices]
54
+
55
+ # subsample
56
+ indices = indices[self.rank:self.total_size:self.num_replicas]
57
+ assert len(indices) == self.num_samples
58
+
59
+ return iter(indices)
60
+
61
+ def __len__(self):
62
+ return self.num_samples
63
+
64
+ def set_epoch(self, epoch):
65
+ self.epoch = epoch
data/test_dataset_td.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import torch
4
+ import torch.utils.data as data
5
+ import data.util as util
6
+
7
+ import random
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ class imageTestDataset(data.Dataset):
12
+
13
+ def __init__(self, opt):
14
+ super(imageTestDataset, self).__init__()
15
+ self.opt = opt
16
+ self.half_N_frames = opt['N_frames'] // 2
17
+ self.data_path = opt['data_path']
18
+ self.bit_path = opt['bit_path']
19
+ self.txt_path = self.opt['txt_path']
20
+ self.num_image = self.opt['num_image']
21
+ with open(self.txt_path) as f:
22
+ self.list_image = f.readlines()
23
+ self.list_image = [line.strip('\n') for line in self.list_image]
24
+ # self.list_image = sorted(self.list_image)
25
+ l = len(self.list_image) // (self.num_image + 1)
26
+ self.image_list_gt = self.list_image
27
+ self.image_list_bit = self.list_image
28
+
29
+
30
+ def __getitem__(self, index):
31
+ path_GT = self.image_list_gt[index]
32
+
33
+ img_GT = util.read_img(None, osp.join(self.data_path, path_GT))
34
+ img_GT = img_GT[:, :, [2, 1, 0]]
35
+ img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
36
+ img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
37
+
38
+ T, C, W, H = img_GT.shape
39
+ list_h = []
40
+ R = 0
41
+ G = 0
42
+ B = 255
43
+ image = Image.new('RGB', (W, H), (R, G, B))
44
+ result = np.array(image) / 255.
45
+ expanded_matrix = np.expand_dims(result, axis=0)
46
+ expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
47
+ imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
48
+ imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
49
+
50
+
51
+ imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
52
+
53
+ list_h.append(imgs_LQ)
54
+
55
+ list_h = torch.stack(list_h, dim=0)
56
+
57
+ return {
58
+ 'LQ': list_h,
59
+ 'GT': img_GT
60
+ }
61
+
62
+ def __len__(self):
63
+ return len(self.image_list_gt)
data/util.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import pickle
4
+ import random
5
+ import numpy as np
6
+ import glob
7
+ import torch
8
+ import cv2
9
+
10
+ ####################
11
+ # Files & IO
12
+ ####################
13
+
14
+ ###################### get image path list ######################
15
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
16
+
17
+
18
+ def is_image_file(filename):
19
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20
+
21
+
22
+ def _get_paths_from_images(path):
23
+ '''get image path list from image folder'''
24
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
25
+ images = []
26
+ for dirpath, _, fnames in sorted(os.walk(path)):
27
+ for fname in sorted(fnames):
28
+ if is_image_file(fname):
29
+ img_path = os.path.join(dirpath, fname)
30
+ images.append(img_path)
31
+ assert images, '{:s} has no valid image file'.format(path)
32
+ return images
33
+
34
+
35
+ def _get_paths_from_lmdb(dataroot):
36
+ '''get image path list from lmdb meta info'''
37
+ meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
38
+ paths = meta_info['keys']
39
+ sizes = meta_info['resolution']
40
+ if len(sizes) == 1:
41
+ sizes = sizes * len(paths)
42
+ return paths, sizes
43
+
44
+
45
+ def get_image_paths(data_type, dataroot):
46
+ '''get image path list
47
+ support lmdb or image files'''
48
+ paths, sizes = None, None
49
+ if dataroot is not None:
50
+ if data_type == 'lmdb':
51
+ paths, sizes = _get_paths_from_lmdb(dataroot)
52
+ elif data_type == 'img':
53
+ paths = sorted(_get_paths_from_images(dataroot))
54
+ else:
55
+ raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
56
+ return paths, sizes
57
+
58
+
59
+ def glob_file_list(root):
60
+ return sorted(glob.glob(os.path.join(root, '*')))
61
+
62
+
63
+ ###################### read images ######################
64
+ def _read_img_lmdb(env, key, size):
65
+ '''read image from lmdb with key (w/ and w/o fixed size)
66
+ size: (C, H, W) tuple'''
67
+ with env.begin(write=False) as txn:
68
+ buf = txn.get(key.encode('ascii'))
69
+ img_flat = np.frombuffer(buf, dtype=np.uint8)
70
+ C, H, W = size
71
+ img = img_flat.reshape(H, W, C)
72
+ return img
73
+
74
+
75
+ def read_img(env, path, size=None):
76
+ '''read image by cv2 or from lmdb
77
+ return: Numpy float32, HWC, BGR, [0,1]'''
78
+ if env is None: # img
79
+ # print(path)
80
+ #img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
81
+ img = cv2.imread(path, cv2.IMREAD_COLOR)
82
+ else:
83
+ img = _read_img_lmdb(env, path, size)
84
+ # print(img.shape)
85
+ # if img is None:
86
+ # print(path)
87
+ # print(img.shape)
88
+ img = img.astype(np.float32) / 255.
89
+ if img.ndim == 2:
90
+ img = np.expand_dims(img, axis=2)
91
+ # some images have 4 channels
92
+ if img.shape[2] > 3:
93
+ img = img[:, :, :3]
94
+ return img
95
+
96
+
97
+ def read_img_seq(path):
98
+ """Read a sequence of images from a given folder path
99
+ Args:
100
+ path (list/str): list of image paths/image folder path
101
+
102
+ Returns:
103
+ imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
104
+ """
105
+ if type(path) is list:
106
+ img_path_l = path
107
+ else:
108
+ img_path_l = sorted(glob.glob(os.path.join(path, '*.png')))
109
+ # print(path)
110
+ # print(path,img_path_l)
111
+ img_l = [read_img(None, v) for v in img_path_l]
112
+ # stack to Torch tensor
113
+ imgs = np.stack(img_l, axis=0)
114
+ imgs = imgs[:, :, :, [2, 1, 0]]
115
+ imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
116
+ return imgs
117
+
118
+
119
+ def index_generation(crt_i, max_n, N, padding='reflection'):
120
+ """Generate an index list for reading N frames from a sequence of images
121
+ Args:
122
+ crt_i (int): current center index
123
+ max_n (int): max number of the sequence of images (calculated from 1)
124
+ N (int): reading N frames
125
+ padding (str): padding mode, one of replicate | reflection | new_info | circle
126
+ Example: crt_i = 0, N = 5
127
+ replicate: [0, 0, 0, 1, 2]
128
+ reflection: [2, 1, 0, 1, 2]
129
+ new_info: [4, 3, 0, 1, 2]
130
+ circle: [3, 4, 0, 1, 2]
131
+
132
+ Returns:
133
+ return_l (list [int]): a list of indexes
134
+ """
135
+ max_n = max_n - 1
136
+ n_pad = N // 2
137
+ return_l = []
138
+
139
+ for i in range(crt_i - n_pad, crt_i + n_pad + 1):
140
+ if i < 0:
141
+ if padding == 'replicate':
142
+ add_idx = 0
143
+ elif padding == 'reflection':
144
+ add_idx = -i
145
+ elif padding == 'new_info':
146
+ add_idx = (crt_i + n_pad) + (-i)
147
+ elif padding == 'circle':
148
+ add_idx = N + i
149
+ else:
150
+ raise ValueError('Wrong padding mode')
151
+ elif i > max_n:
152
+ if padding == 'replicate':
153
+ add_idx = max_n
154
+ elif padding == 'reflection':
155
+ add_idx = max_n * 2 - i
156
+ elif padding == 'new_info':
157
+ add_idx = (crt_i - n_pad) - (i - max_n)
158
+ elif padding == 'circle':
159
+ add_idx = i - N
160
+ else:
161
+ raise ValueError('Wrong padding mode')
162
+ else:
163
+ add_idx = i
164
+ return_l.append(add_idx)
165
+ return return_l
166
+
167
+
168
+ ####################
169
+ # image processing
170
+ # process on numpy image
171
+ ####################
172
+
173
+
174
+ def augment(img_list, hflip=True, rot=True):
175
+ # horizontal flip OR rotate
176
+ hflip = hflip and random.random() < 0.5
177
+ vflip = rot and random.random() < 0.5
178
+ rot90 = rot and random.random() < 0.5
179
+
180
+ def _augment(img):
181
+ if hflip:
182
+ img = img[:, ::-1, :]
183
+ if vflip:
184
+ img = img[::-1, :, :]
185
+ if rot90:
186
+ img = img.transpose(1, 0, 2)
187
+ return img
188
+
189
+ return [_augment(img) for img in img_list]
190
+
191
+
192
+ def augment_flow(img_list, flow_list, hflip=True, rot=True):
193
+ # horizontal flip OR rotate
194
+ hflip = hflip and random.random() < 0.5
195
+ vflip = rot and random.random() < 0.5
196
+ rot90 = rot and random.random() < 0.5
197
+
198
+ def _augment(img):
199
+ if hflip:
200
+ img = img[:, ::-1, :]
201
+ if vflip:
202
+ img = img[::-1, :, :]
203
+ if rot90:
204
+ img = img.transpose(1, 0, 2)
205
+ return img
206
+
207
+ def _augment_flow(flow):
208
+ if hflip:
209
+ flow = flow[:, ::-1, :]
210
+ flow[:, :, 0] *= -1
211
+ if vflip:
212
+ flow = flow[::-1, :, :]
213
+ flow[:, :, 1] *= -1
214
+ if rot90:
215
+ flow = flow.transpose(1, 0, 2)
216
+ flow = flow[:, :, [1, 0]]
217
+ return flow
218
+
219
+ rlt_img_list = [_augment(img) for img in img_list]
220
+ rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
221
+
222
+ return rlt_img_list, rlt_flow_list
223
+
224
+
225
+ def channel_convert(in_c, tar_type, img_list):
226
+ # conversion among BGR, gray and y
227
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
228
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
229
+ return [np.expand_dims(img, axis=2) for img in gray_list]
230
+ elif in_c == 3 and tar_type == 'y': # BGR to y
231
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
232
+ return [np.expand_dims(img, axis=2) for img in y_list]
233
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
234
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
235
+ else:
236
+ return img_list
237
+
238
+
239
+ def rgb2ycbcr(img, only_y=True):
240
+ '''same as matlab rgb2ycbcr
241
+ only_y: only return Y channel
242
+ Input:
243
+ uint8, [0, 255]
244
+ float, [0, 1]
245
+ '''
246
+ in_img_type = img.dtype
247
+ img.astype(np.float32)
248
+ if in_img_type != np.uint8:
249
+ img *= 255.
250
+ # convert
251
+ if only_y:
252
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
253
+ else:
254
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
255
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
256
+ if in_img_type == np.uint8:
257
+ rlt = rlt.round()
258
+ else:
259
+ rlt /= 255.
260
+ return rlt.astype(in_img_type)
261
+
262
+
263
+ def bgr2ycbcr(img, only_y=True):
264
+ '''bgr version of rgb2ycbcr
265
+ only_y: only return Y channel
266
+ Input:
267
+ uint8, [0, 255]
268
+ float, [0, 1]
269
+ '''
270
+ in_img_type = img.dtype
271
+ img.astype(np.float32)
272
+ if in_img_type != np.uint8:
273
+ img *= 255.
274
+ # convert
275
+ if only_y:
276
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
277
+ else:
278
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
279
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
280
+ if in_img_type == np.uint8:
281
+ rlt = rlt.round()
282
+ else:
283
+ rlt /= 255.
284
+ return rlt.astype(in_img_type)
285
+
286
+
287
+ def ycbcr2rgb(img):
288
+ '''same as matlab ycbcr2rgb
289
+ Input:
290
+ uint8, [0, 255]
291
+ float, [0, 1]
292
+ '''
293
+ in_img_type = img.dtype
294
+ img.astype(np.float32)
295
+ if in_img_type != np.uint8:
296
+ img *= 255.
297
+ # convert
298
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
299
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
300
+ if in_img_type == np.uint8:
301
+ rlt = rlt.round()
302
+ else:
303
+ rlt /= 255.
304
+ return rlt.astype(in_img_type)
305
+
306
+
307
+ def modcrop(img_in, scale):
308
+ # img_in: Numpy, HWC or HW
309
+ img = np.copy(img_in)
310
+ if img.ndim == 2:
311
+ H, W = img.shape
312
+ H_r, W_r = H % scale, W % scale
313
+ img = img[:H - H_r, :W - W_r]
314
+ elif img.ndim == 3:
315
+ H, W, C = img.shape
316
+ H_r, W_r = H % scale, W % scale
317
+ img = img[:H - H_r, :W - W_r, :]
318
+ else:
319
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
320
+ return img
321
+
322
+
323
+ ####################
324
+ # Functions
325
+ ####################
326
+
327
+
328
+ # matlab 'imresize' function, now only support 'bicubic'
329
+ def cubic(x):
330
+ absx = torch.abs(x)
331
+ absx2 = absx**2
332
+ absx3 = absx**3
333
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
334
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
335
+ (absx > 1) * (absx <= 2)).type_as(absx))
336
+
337
+
338
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
339
+ if (scale < 1) and (antialiasing):
340
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
341
+ kernel_width = kernel_width / scale
342
+
343
+ # Output-space coordinates
344
+ x = torch.linspace(1, out_length, out_length)
345
+
346
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
347
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
348
+ # space maps to 1.5 in input space.
349
+ u = x / scale + 0.5 * (1 - 1 / scale)
350
+
351
+ # What is the left-most pixel that can be involved in the computation?
352
+ left = torch.floor(u - kernel_width / 2)
353
+
354
+ # What is the maximum number of pixels that can be involved in the
355
+ # computation? Note: it's OK to use an extra pixel here; if the
356
+ # corresponding weights are all zero, it will be eliminated at the end
357
+ # of this function.
358
+ P = math.ceil(kernel_width) + 2
359
+
360
+ # The indices of the input pixels involved in computing the k-th output
361
+ # pixel are in row k of the indices matrix.
362
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
363
+ 1, P).expand(out_length, P)
364
+
365
+ # The weights used to compute the k-th output pixel are in row k of the
366
+ # weights matrix.
367
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
368
+ # apply cubic kernel
369
+ if (scale < 1) and (antialiasing):
370
+ weights = scale * cubic(distance_to_center * scale)
371
+ else:
372
+ weights = cubic(distance_to_center)
373
+ # Normalize the weights matrix so that each row sums to 1.
374
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
375
+ weights = weights / weights_sum.expand(out_length, P)
376
+
377
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
378
+ weights_zero_tmp = torch.sum((weights == 0), 0)
379
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
380
+ indices = indices.narrow(1, 1, P - 2)
381
+ weights = weights.narrow(1, 1, P - 2)
382
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
383
+ indices = indices.narrow(1, 0, P - 2)
384
+ weights = weights.narrow(1, 0, P - 2)
385
+ weights = weights.contiguous()
386
+ indices = indices.contiguous()
387
+ sym_len_s = -indices.min() + 1
388
+ sym_len_e = indices.max() - in_length
389
+ indices = indices + sym_len_s - 1
390
+ return weights, indices, int(sym_len_s), int(sym_len_e)
391
+
392
+
393
+ def imresize(img, scale, antialiasing=True):
394
+ # Now the scale should be the same for H and W
395
+ # input: img: CHW RGB [0,1]
396
+ # output: CHW RGB [0,1] w/o round
397
+
398
+ in_C, in_H, in_W = img.size()
399
+ _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
400
+ kernel_width = 4
401
+ kernel = 'cubic'
402
+
403
+ # Return the desired dimension order for performing the resize. The
404
+ # strategy is to perform the resize first along the dimension with the
405
+ # smallest scale factor.
406
+ # Now we do not support this.
407
+
408
+ # get weights and indices
409
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
410
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
411
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
412
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
413
+ # process H dimension
414
+ # symmetric copying
415
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
416
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
417
+
418
+ sym_patch = img[:, :sym_len_Hs, :]
419
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
420
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
421
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
422
+
423
+ sym_patch = img[:, -sym_len_He:, :]
424
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
425
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
426
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
427
+
428
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
429
+ kernel_width = weights_H.size(1)
430
+ for i in range(out_H):
431
+ idx = int(indices_H[i][0])
432
+ out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
433
+ out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
434
+ out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
435
+
436
+ # process W dimension
437
+ # symmetric copying
438
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
439
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
440
+
441
+ sym_patch = out_1[:, :, :sym_len_Ws]
442
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
443
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
444
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
445
+
446
+ sym_patch = out_1[:, :, -sym_len_We:]
447
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
448
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
449
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
450
+
451
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
452
+ kernel_width = weights_W.size(1)
453
+ for i in range(out_W):
454
+ idx = int(indices_W[i][0])
455
+ out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
456
+ out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
457
+ out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
458
+
459
+ return out_2
460
+
461
+
462
+ def imresize_np(img, scale, antialiasing=True):
463
+ # Now the scale should be the same for H and W
464
+ # input: img: Numpy, HWC BGR [0,1]
465
+ # output: HWC BGR [0,1] w/o round
466
+ img = torch.from_numpy(img)
467
+
468
+ in_H, in_W, in_C = img.size()
469
+ _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
470
+ kernel_width = 4
471
+ kernel = 'cubic'
472
+
473
+ # Return the desired dimension order for performing the resize. The
474
+ # strategy is to perform the resize first along the dimension with the
475
+ # smallest scale factor.
476
+ # Now we do not support this.
477
+
478
+ # get weights and indices
479
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
480
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
481
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
482
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
483
+ # process H dimension
484
+ # symmetric copying
485
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
486
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
487
+
488
+ sym_patch = img[:sym_len_Hs, :, :]
489
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
490
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
491
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
492
+
493
+ sym_patch = img[-sym_len_He:, :, :]
494
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
495
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
496
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
497
+
498
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
499
+ kernel_width = weights_H.size(1)
500
+ for i in range(out_H):
501
+ idx = int(indices_H[i][0])
502
+ out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
503
+ out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
504
+ out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
505
+
506
+ # process W dimension
507
+ # symmetric copying
508
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
509
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
510
+
511
+ sym_patch = out_1[:, :sym_len_Ws, :]
512
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
513
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
514
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
515
+
516
+ sym_patch = out_1[:, -sym_len_We:, :]
517
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
518
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
519
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
520
+
521
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
522
+ kernel_width = weights_W.size(1)
523
+ for i in range(out_W):
524
+ idx = int(indices_W[i][0])
525
+ out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
526
+ out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
527
+ out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
528
+
529
+ return out_2.numpy()
530
+
531
+
532
+ if __name__ == '__main__':
533
+ # test imresize function
534
+ # read images
535
+ img = cv2.imread('test.png')
536
+ img = img * 1.0 / 255
537
+ img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
538
+ # imresize
539
+ scale = 1 / 4
540
+ import time
541
+ total_time = 0
542
+ for i in range(10):
543
+ start_time = time.time()
544
+ rlt = imresize(img, scale, antialiasing=True)
545
+ use_time = time.time() - start_time
546
+ total_time += use_time
547
+ print('average time: {}'.format(total_time / 10))
548
+
549
+ import torchvision.utils
550
+ torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
551
+ normalize=False)
models/IBSN.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
7
+
8
+ import models.networks as networks
9
+ import models.lr_scheduler as lr_scheduler
10
+ from .base_model import BaseModel
11
+ from models.modules.loss import ReconstructionLoss, ReconstructionMsgLoss
12
+ from models.modules.Quantization import Quantization
13
+ from .modules.common import DWT,IWT
14
+ from utils.jpegtest import JpegTest
15
+ from utils.JPEG import DiffJPEG
16
+ import utils.util as util
17
+
18
+
19
+ import numpy as np
20
+ import random
21
+ import cv2
22
+ import time
23
+
24
+ logger = logging.getLogger('base')
25
+ dwt=DWT()
26
+ iwt=IWT()
27
+
28
+ from diffusers import StableDiffusionInpaintPipeline
29
+ from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
30
+ from diffusers import StableDiffusionXLInpaintPipeline
31
+ from diffusers.utils import load_image
32
+ from diffusers import RePaintPipeline, RePaintScheduler
33
+
34
+ class Model_VSN(BaseModel):
35
+ def __init__(self, opt):
36
+ super(Model_VSN, self).__init__(opt)
37
+
38
+ if opt['dist']:
39
+ self.rank = torch.distributed.get_rank()
40
+ else:
41
+ self.rank = -1 # non dist training
42
+
43
+ self.gop = opt['gop']
44
+ train_opt = opt['train']
45
+ test_opt = opt['test']
46
+ self.opt = opt
47
+ self.train_opt = train_opt
48
+ self.test_opt = test_opt
49
+ self.opt_net = opt['network_G']
50
+ self.center = self.gop // 2
51
+ self.num_image = opt['num_image']
52
+ self.mode = opt["mode"]
53
+ self.idxx = 0
54
+
55
+ self.netG = networks.define_G_v2(opt).to(self.device)
56
+ if opt['dist']:
57
+ self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
58
+ else:
59
+ self.netG = DataParallel(self.netG)
60
+ # print network
61
+ self.print_network()
62
+ self.load()
63
+
64
+ self.Quantization = Quantization()
65
+
66
+ if not self.opt['hide']:
67
+ file_path = "bit_sequence.txt"
68
+
69
+ data_list = []
70
+
71
+ with open(file_path, "r") as file:
72
+ for line in file:
73
+ data = [int(bit) for bit in line.strip()]
74
+ data_list.append(data)
75
+
76
+ self.msg_list = data_list
77
+
78
+ if self.opt['sdinpaint']:
79
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
80
+ "stabilityai/stable-diffusion-2-inpainting",
81
+ torch_dtype=torch.float16,
82
+ ).to("cuda")
83
+
84
+ if self.opt['controlnetinpaint']:
85
+ controlnet = ControlNetModel.from_pretrained(
86
+ "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float32
87
+ ).to("cuda")
88
+ self.pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained(
89
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32
90
+ ).to("cuda")
91
+
92
+ if self.opt['sdxl']:
93
+ self.pipe_sdxl = StableDiffusionXLInpaintPipeline.from_pretrained(
94
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
95
+ torch_dtype=torch.float16,
96
+ variant="fp16",
97
+ use_safetensors=True,
98
+ ).to("cuda")
99
+
100
+ if self.opt['repaint']:
101
+ self.scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
102
+ self.pipe_repaint = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=self.scheduler)
103
+ self.pipe_repaint = self.pipe_repaint.to("cuda")
104
+
105
+ if self.is_train:
106
+ self.netG.train()
107
+
108
+ # loss
109
+ self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw'])
110
+ self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back'])
111
+ self.Reconstruction_center = ReconstructionLoss(losstype="center")
112
+ self.Reconstruction_msg = ReconstructionMsgLoss(losstype=self.opt['losstype'])
113
+
114
+ # optimizers
115
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
116
+ optim_params = []
117
+
118
+ if self.mode == "image":
119
+ for k, v in self.netG.named_parameters():
120
+ if (k.startswith('module.irn') or k.startswith('module.pm')) and v.requires_grad:
121
+ optim_params.append(v)
122
+ else:
123
+ if self.rank <= 0:
124
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
125
+
126
+ elif self.mode == "bit":
127
+ for k, v in self.netG.named_parameters():
128
+ if (k.startswith('module.bitencoder') or k.startswith('module.bitdecoder')) and v.requires_grad:
129
+ optim_params.append(v)
130
+ else:
131
+ if self.rank <= 0:
132
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
133
+
134
+
135
+ self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
136
+ weight_decay=wd_G,
137
+ betas=(train_opt['beta1'], train_opt['beta2']))
138
+ self.optimizers.append(self.optimizer_G)
139
+
140
+ # schedulers
141
+ if train_opt['lr_scheme'] == 'MultiStepLR':
142
+ for optimizer in self.optimizers:
143
+ self.schedulers.append(
144
+ lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
145
+ restarts=train_opt['restarts'],
146
+ weights=train_opt['restart_weights'],
147
+ gamma=train_opt['lr_gamma'],
148
+ clear_state=train_opt['clear_state']))
149
+ elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
150
+ for optimizer in self.optimizers:
151
+ self.schedulers.append(
152
+ lr_scheduler.CosineAnnealingLR_Restart(
153
+ optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
154
+ restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
155
+ else:
156
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
157
+
158
+ self.log_dict = OrderedDict()
159
+
160
+ def feed_data(self, data):
161
+ self.ref_L = data['LQ'].to(self.device)
162
+ self.real_H = data['GT'].to(self.device)
163
+ self.mes = data['MES']
164
+
165
+ def init_hidden_state(self, z):
166
+ b, c, h, w = z.shape
167
+ h_t = []
168
+ c_t = []
169
+ for _ in range(self.opt_net['block_num_rbm']):
170
+ h_t.append(torch.zeros([b, c, h, w]).cuda())
171
+ c_t.append(torch.zeros([b, c, h, w]).cuda())
172
+ memory = torch.zeros([b, c, h, w]).cuda()
173
+
174
+ return h_t, c_t, memory
175
+
176
+ def loss_forward(self, out, y):
177
+ l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y)
178
+ return l_forw_fit
179
+
180
+ def loss_back_rec(self, out, x):
181
+ l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x)
182
+ return l_back_rec
183
+
184
+ def loss_back_rec_mul(self, out, x):
185
+ l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x)
186
+ return l_back_rec
187
+
188
+ def optimize_parameters(self, current_step):
189
+ self.optimizer_G.zero_grad()
190
+
191
+ b, n, t, c, h, w = self.ref_L.shape
192
+ center = t // 2
193
+ intval = self.gop // 2
194
+
195
+ message = torch.Tensor(np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))).to(self.device)
196
+
197
+ add_noise = self.opt['addnoise']
198
+ add_jpeg = self.opt['addjpeg']
199
+ add_possion = self.opt['addpossion']
200
+ add_sdinpaint = self.opt['sdinpaint']
201
+ degrade_shuffle = self.opt['degrade_shuffle']
202
+
203
+ self.host = self.real_H[:, center - intval:center + intval + 1]
204
+ self.secret = self.ref_L[:, :, center - intval:center + intval + 1]
205
+ self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=dwt(self.secret[:,0].reshape(b, -1, h, w)), message=message)
206
+
207
+ Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach()
208
+
209
+ y_forw = container
210
+
211
+ l_forw_fit = self.loss_forward(y_forw, self.host[:,0])
212
+
213
+
214
+ if degrade_shuffle:
215
+ import random
216
+ choice = random.randint(0, 2)
217
+
218
+ if choice == 0:
219
+ NL = float((np.random.randint(1, 16))/255)
220
+ noise = np.random.normal(0, NL, y_forw.shape)
221
+ torchnoise = torch.from_numpy(noise).cuda().float()
222
+ y_forw = y_forw + torchnoise
223
+
224
+ elif choice == 1:
225
+ NL = int(np.random.randint(70,95))
226
+ self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
227
+ y_forw = self.DiffJPEG(y_forw)
228
+
229
+ elif choice == 2:
230
+ vals = 10**4
231
+ if random.random() < 0.5:
232
+ noisy_img_tensor = torch.poisson(y_forw * vals) / vals
233
+ else:
234
+ img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
235
+ noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
236
+ noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
237
+
238
+ y_forw = torch.clamp(noisy_img_tensor, 0, 1)
239
+
240
+ else:
241
+
242
+ if add_noise:
243
+ NL = float((np.random.randint(1,16))/255)
244
+ noise = np.random.normal(0, NL, y_forw.shape)
245
+ torchnoise = torch.from_numpy(noise).cuda().float()
246
+ y_forw = y_forw + torchnoise
247
+
248
+ elif add_jpeg:
249
+ NL = int(np.random.randint(70,95))
250
+ self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
251
+ y_forw = self.DiffJPEG(y_forw)
252
+
253
+ elif add_possion:
254
+ vals = 10**4
255
+ if random.random() < 0.5:
256
+ noisy_img_tensor = torch.poisson(y_forw * vals) / vals
257
+ else:
258
+ img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
259
+ noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
260
+ noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
261
+
262
+ y_forw = torch.clamp(noisy_img_tensor, 0, 1)
263
+
264
+ y = self.Quantization(y_forw)
265
+ all_zero = torch.zeros(message.shape).to(self.device)
266
+
267
+ if self.mode == "image":
268
+ out_x, out_x_h, out_z, recmessage = self.netG(x=y, message=all_zero, rev=True)
269
+ out_x = iwt(out_x)
270
+ out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
271
+
272
+ l_back_rec = self.loss_back_rec(out_x, self.host[:,0])
273
+ out_x_h = torch.stack(out_x_h, dim=1)
274
+
275
+ l_center_x = self.loss_back_rec(out_x_h[:, 0], self.secret[:,0].reshape(b, -1, h, w))
276
+
277
+ recmessage = torch.clamp(recmessage, -0.5, 0.5)
278
+
279
+ l_msg = self.Reconstruction_msg(message, recmessage)
280
+
281
+ loss = l_forw_fit*2 + l_back_rec + l_center_x*4
282
+
283
+ loss.backward()
284
+
285
+ if self.train_opt['lambda_center'] != 0:
286
+ self.log_dict['l_center_x'] = l_center_x.item()
287
+
288
+ # set log
289
+ self.log_dict['l_back_rec'] = l_back_rec.item()
290
+ self.log_dict['l_forw_fit'] = l_forw_fit.item()
291
+ self.log_dict['l_msg'] = l_msg.item()
292
+
293
+ self.log_dict['l_h'] = (l_center_x*10).item()
294
+
295
+ # gradient clipping
296
+ if self.train_opt['gradient_clipping']:
297
+ nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
298
+
299
+ self.optimizer_G.step()
300
+
301
+ elif self.mode == "bit":
302
+ recmessage = self.netG(x=y, message=all_zero, rev=True)
303
+
304
+ recmessage = torch.clamp(recmessage, -0.5, 0.5)
305
+
306
+ l_msg = self.Reconstruction_msg(message, recmessage)
307
+
308
+ lambda_msg = self.train_opt['lambda_msg']
309
+
310
+ loss = l_msg * lambda_msg + l_forw_fit
311
+
312
+ loss.backward()
313
+
314
+ # set log
315
+ self.log_dict['l_forw_fit'] = l_forw_fit.item()
316
+ self.log_dict['l_msg'] = l_msg.item()
317
+
318
+ # gradient clipping
319
+ if self.train_opt['gradient_clipping']:
320
+ nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
321
+
322
+ self.optimizer_G.step()
323
+
324
+ def test(self, image_id):
325
+ self.netG.eval()
326
+ add_noise = self.opt['addnoise']
327
+ add_jpeg = self.opt['addjpeg']
328
+ add_possion = self.opt['addpossion']
329
+ add_sdinpaint = self.opt['sdinpaint']
330
+ add_controlnet = self.opt['controlnetinpaint']
331
+ add_sdxl = self.opt['sdxl']
332
+ add_repaint = self.opt['repaint']
333
+ degrade_shuffle = self.opt['degrade_shuffle']
334
+
335
+ with torch.no_grad():
336
+ forw_L = []
337
+ forw_L_h = []
338
+ fake_H = []
339
+ fake_H_h = []
340
+ pred_z = []
341
+ recmsglist = []
342
+ msglist = []
343
+ b, t, c, h, w = self.real_H.shape
344
+ center = t // 2
345
+ intval = self.gop // 2
346
+ b, n, t, c, h, w = self.ref_L.shape
347
+ id=0
348
+ # forward downscaling
349
+ self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
350
+ self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
351
+ self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
352
+
353
+ messagenp = np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))
354
+
355
+ message = torch.Tensor(messagenp).to(self.device)
356
+
357
+ if self.opt['bitrecord']:
358
+ mymsg = message.clone()
359
+
360
+ mymsg[mymsg>0] = 1
361
+ mymsg[mymsg<0] = 0
362
+ mymsg = mymsg.squeeze(0).to(torch.int)
363
+
364
+ bit_list = mymsg.tolist()
365
+
366
+ bit_string = ''.join(map(str, bit_list))
367
+
368
+ file_name = "bit_sequence.txt"
369
+
370
+ with open(file_name, "a") as file:
371
+ file.write(bit_string + "\n")
372
+
373
+ if self.opt['hide']:
374
+ self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message)
375
+ y_forw = container
376
+ else:
377
+
378
+ message = torch.tensor(self.msg_list[image_id]).unsqueeze(0).cuda()
379
+ self.output = self.host
380
+ y_forw = self.output.squeeze(1)
381
+
382
+ if add_sdinpaint:
383
+ import random
384
+ from PIL import Image
385
+ prompt = ""
386
+
387
+ b, _, _, _ = y_forw.shape
388
+
389
+ image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
390
+ forw_list = []
391
+
392
+ for j in range(b):
393
+ i = image_id + 1
394
+ masksrc = "../dataset/valAGE-Set-Mask/"
395
+ mask_image = Image.open(masksrc + str(i).zfill(4) + ".png").convert("L")
396
+ mask_image = mask_image.resize((512, 512))
397
+ h, w = mask_image.size
398
+
399
+ image = image_batch[j, :, :, :]
400
+ image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
401
+ image_inpaint = self.pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0]
402
+ image_inpaint = np.array(image_inpaint) / 255.
403
+ mask_image = np.array(mask_image)
404
+ mask_image = np.stack([mask_image] * 3, axis=-1) / 255.
405
+ mask_image = mask_image.astype(np.uint8)
406
+ image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
407
+ forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
408
+
409
+ y_forw = torch.stack(forw_list, dim=0).float().cuda()
410
+
411
+ if add_controlnet:
412
+ from diffusers.utils import load_image
413
+ from PIL import Image
414
+
415
+ b, _, _, _ = y_forw.shape
416
+ forw_list = []
417
+
418
+ image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
419
+ generator = torch.Generator(device="cuda").manual_seed(1)
420
+
421
+ for j in range(b):
422
+ i = image_id + 1
423
+ mask_path = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png"
424
+ mask_image = load_image(mask_path)
425
+ mask_image = mask_image.resize((512, 512))
426
+ image_init = image_batch[j, :, :, :]
427
+ image_init1 = Image.fromarray((image_init * 255).astype(np.uint8), mode = "RGB")
428
+ image_mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
429
+
430
+ assert image_init.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
431
+ image_init[image_mask > 0.5] = -1.0 # set as masked pixel
432
+ image = np.expand_dims(image_init, 0).transpose(0, 3, 1, 2)
433
+ control_image = torch.from_numpy(image)
434
+
435
+ # generate image
436
+ image_inpaint = self.pipe_control(
437
+ "",
438
+ num_inference_steps=20,
439
+ generator=generator,
440
+ eta=1.0,
441
+ image=image_init1,
442
+ mask_image=image_mask,
443
+ control_image=control_image,
444
+ ).images[0]
445
+
446
+ image_inpaint = np.array(image_inpaint) / 255.
447
+ image_mask = np.stack([image_mask] * 3, axis=-1)
448
+ image_mask = image_mask.astype(np.uint8)
449
+ image_fuse = image_init * (1 - image_mask) + image_inpaint * image_mask
450
+ forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
451
+
452
+ y_forw = torch.stack(forw_list, dim=0).float().cuda()
453
+
454
+ if add_sdxl:
455
+ import random
456
+ from PIL import Image
457
+ from diffusers.utils import load_image
458
+ prompt = ""
459
+
460
+ b, _, _, _ = y_forw.shape
461
+
462
+ image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
463
+ forw_list = []
464
+
465
+ for j in range(b):
466
+ i = image_id + 1
467
+ masksrc = "../dataset/valAGE-Set-Mask/"
468
+ mask_image = load_image(masksrc + str(i).zfill(4) + ".png").convert("RGB")
469
+ mask_image = mask_image.resize((512, 512))
470
+ h, w = mask_image.size
471
+
472
+ image = image_batch[j, :, :, :]
473
+ image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
474
+ image_inpaint = self.pipe_sdxl(
475
+ prompt=prompt, image=image_init, mask_image=mask_image, num_inference_steps=50, strength=0.80, target_size=(512, 512)
476
+ ).images[0]
477
+ image_inpaint = image_inpaint.resize((512, 512))
478
+ image_inpaint = np.array(image_inpaint) / 255.
479
+ mask_image = np.array(mask_image) / 255.
480
+ mask_image = mask_image.astype(np.uint8)
481
+ image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
482
+ forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
483
+
484
+ y_forw = torch.stack(forw_list, dim=0).float().cuda()
485
+
486
+
487
+ if add_repaint:
488
+ from PIL import Image
489
+
490
+ b, _, _, _ = y_forw.shape
491
+
492
+ image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
493
+ forw_list = []
494
+
495
+ generator = torch.Generator(device="cuda").manual_seed(0)
496
+ for j in range(b):
497
+ i = image_id + 1
498
+ masksrc = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png"
499
+ mask_image = Image.open(masksrc).convert("RGB")
500
+ mask_image = mask_image.resize((256, 256))
501
+ mask_image = Image.fromarray(255 - np.array(mask_image))
502
+ image = image_batch[j, :, :, :]
503
+ original_image = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
504
+ original_image = original_image.resize((256, 256))
505
+ output = self.pipe_repaint(
506
+ image=original_image,
507
+ mask_image=mask_image,
508
+ num_inference_steps=150,
509
+ eta=0.0,
510
+ jump_length=10,
511
+ jump_n_sample=10,
512
+ generator=generator,
513
+ )
514
+ image_inpaint = output.images[0]
515
+ image_inpaint = image_inpaint.resize((512, 512))
516
+ image_inpaint = np.array(image_inpaint) / 255.
517
+ mask_image = mask_image.resize((512, 512))
518
+ mask_image = np.array(mask_image) / 255.
519
+ mask_image = mask_image.astype(np.uint8)
520
+ image_fuse = image * mask_image + image_inpaint * (1 - mask_image)
521
+ forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
522
+
523
+ y_forw = torch.stack(forw_list, dim=0).float().cuda()
524
+
525
+ if degrade_shuffle:
526
+ import random
527
+ choice = random.randint(0, 2)
528
+
529
+ if choice == 0:
530
+ NL = float((np.random.randint(1,5))/255)
531
+ noise = np.random.normal(0, NL, y_forw.shape)
532
+ torchnoise = torch.from_numpy(noise).cuda().float()
533
+ y_forw = y_forw + torchnoise
534
+
535
+ elif choice == 1:
536
+ NL = 90
537
+ self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
538
+ y_forw = self.DiffJPEG(y_forw)
539
+
540
+ elif choice == 2:
541
+ vals = 10**4
542
+ if random.random() < 0.5:
543
+ noisy_img_tensor = torch.poisson(y_forw * vals) / vals
544
+ else:
545
+ img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
546
+ noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
547
+ noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
548
+
549
+ y_forw = torch.clamp(noisy_img_tensor, 0, 1)
550
+
551
+ else:
552
+
553
+ if add_noise:
554
+ NL = self.opt['noisesigma'] / 255.0
555
+ noise = np.random.normal(0, NL, y_forw.shape)
556
+ torchnoise = torch.from_numpy(noise).cuda().float()
557
+ y_forw = y_forw + torchnoise
558
+
559
+ elif add_jpeg:
560
+ Q = self.opt['jpegfactor']
561
+ self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(Q)).cuda()
562
+ y_forw = self.DiffJPEG(y_forw)
563
+
564
+ elif add_possion:
565
+ vals = 10**4
566
+ if random.random() < 0.5:
567
+ noisy_img_tensor = torch.poisson(y_forw * vals) / vals
568
+ else:
569
+ img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
570
+ noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
571
+ noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
572
+
573
+ y_forw = torch.clamp(noisy_img_tensor, 0, 1)
574
+
575
+ # backward upscaling
576
+ if self.opt['hide']:
577
+ y = self.Quantization(y_forw)
578
+ else:
579
+ y = y_forw
580
+
581
+ if self.mode == "image":
582
+ out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True)
583
+ out_x = iwt(out_x)
584
+
585
+ out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
586
+ out_x = out_x.reshape(-1, self.gop, 3, h, w)
587
+ out_x_h = torch.stack(out_x_h, dim=1)
588
+ out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w)
589
+
590
+ forw_L.append(y_forw)
591
+ fake_H.append(out_x[:, self.gop//2])
592
+ fake_H_h.append(out_x_h[:,:, self.gop//2])
593
+ recmsglist.append(recmessage)
594
+ msglist.append(message)
595
+
596
+ elif self.mode == "bit":
597
+ recmessage = self.netG(x=y, rev=True)
598
+ forw_L.append(y_forw)
599
+ recmsglist.append(recmessage)
600
+ msglist.append(message)
601
+
602
+ if self.mode == "image":
603
+ self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1)
604
+ self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1)
605
+
606
+ self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1)
607
+ remesg = torch.clamp(torch.stack(recmsglist, dim=0),-0.5,0.5)
608
+
609
+ if self.opt['hide']:
610
+ mesg = torch.clamp(torch.stack(msglist, dim=0),-0.5,0.5)
611
+ else:
612
+ mesg = torch.stack(msglist, dim=0)
613
+
614
+ self.recmessage = remesg.clone()
615
+ self.recmessage[remesg > 0] = 1
616
+ self.recmessage[remesg <= 0] = 0
617
+
618
+ self.message = mesg.clone()
619
+ self.message[mesg > 0] = 1
620
+ self.message[mesg <= 0] = 0
621
+
622
+ self.netG.train()
623
+
624
+
625
+ def image_hiding(self, ):
626
+ self.netG.eval()
627
+ with torch.no_grad():
628
+ b, t, c, h, w = self.real_H.shape
629
+ center = t // 2
630
+ intval = self.gop // 2
631
+ b, n, t, c, h, w = self.ref_L.shape
632
+ id=0
633
+ # forward downscaling
634
+ self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
635
+ self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
636
+ self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
637
+
638
+ message = torch.Tensor(self.mes).to(self.device)
639
+
640
+ self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message)
641
+ y_forw = container
642
+
643
+ result = torch.clamp(y_forw,0,1)
644
+
645
+ lr_img = util.tensor2img(result)
646
+
647
+ return lr_img
648
+
649
+ def image_recovery(self, number):
650
+ self.netG.eval()
651
+ with torch.no_grad():
652
+ b, t, c, h, w = self.real_H.shape
653
+ center = t // 2
654
+ intval = self.gop // 2
655
+ b, n, t, c, h, w = self.ref_L.shape
656
+ id=0
657
+ # forward downscaling
658
+ self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
659
+ self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
660
+ template = self.secret.reshape(b, -1, h, w)
661
+ self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
662
+
663
+ self.output = self.host
664
+ y_forw = self.output.squeeze(1)
665
+
666
+ y = self.Quantization(y_forw)
667
+
668
+ out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True)
669
+ out_x = iwt(out_x)
670
+
671
+ out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
672
+ out_x = out_x.reshape(-1, self.gop, 3, h, w)
673
+ out_x_h = torch.stack(out_x_h, dim=1)
674
+ out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w)
675
+
676
+ rec_loc = out_x_h[:,:, self.gop//2]
677
+ # from PIL import Image
678
+ # tmp = util.tensor2img(rec_loc)
679
+ # save
680
+ residual = torch.abs(template - rec_loc)
681
+ binary_residual = (residual > number).float()
682
+ residual = util.tensor2img(binary_residual)
683
+ mask = np.sum(residual, axis=2)
684
+ # print(mask)
685
+
686
+ remesg = torch.clamp(recmessage,-0.5,0.5)
687
+ remesg[remesg > 0] = 1
688
+ remesg[remesg <= 0] = 0
689
+
690
+ return mask, remesg
691
+
692
+ def get_current_log(self):
693
+ return self.log_dict
694
+
695
+ def get_current_visuals(self):
696
+ b, n, t, c, h, w = self.ref_L.shape
697
+ center = t // 2
698
+ intval = self.gop // 2
699
+ out_dict = OrderedDict()
700
+ LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu()
701
+ LR_ref = torch.chunk(LR_ref, self.num_image, dim=0)
702
+ out_dict['LR_ref'] = [image.squeeze(0) for image in LR_ref]
703
+
704
+ if self.mode == "image":
705
+ out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
706
+ SR_h = self.fake_H_h.detach()[0].float().cpu()
707
+ SR_h = torch.chunk(SR_h, self.num_image, dim=0)
708
+ out_dict['SR_h'] = [image.squeeze(0) for image in SR_h]
709
+
710
+ out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
711
+ out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu()
712
+ out_dict['message'] = self.message
713
+ out_dict['recmessage'] = self.recmessage
714
+
715
+ return out_dict
716
+
717
+ def print_network(self):
718
+ s, n = self.get_network_description(self.netG)
719
+ if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
720
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
721
+ self.netG.module.__class__.__name__)
722
+ else:
723
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
724
+ if self.rank <= 0:
725
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
726
+ logger.info(s)
727
+
728
+ def load(self):
729
+ load_path_G = self.opt['path']['pretrain_model_G']
730
+ if load_path_G is not None:
731
+ logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
732
+ self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
733
+
734
+ def load_test(self,load_path_G):
735
+ self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
736
+
737
+ def save(self, iter_label):
738
+ self.save_network(self.netG, 'G', iter_label)
models/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logger = logging.getLogger('base')
3
+
4
+ def create_model(opt):
5
+ model = opt['model']
6
+ frame_num = opt['gop']
7
+ from .IBSN import Model_VSN as M
8
+
9
+ m = M(opt)
10
+ logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
11
+ return m
models/__pycache__/IBSN.cpython-310.pyc ADDED
Binary file (18.5 kB). View file
 
models/__pycache__/IBSN.cpython-38.pyc ADDED
Binary file (18.7 kB). View file
 
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (491 Bytes). View file
 
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (504 Bytes). View file
 
models/__pycache__/base_model.cpython-38.pyc ADDED
Binary file (5.44 kB). View file
 
models/__pycache__/lr_scheduler.cpython-38.pyc ADDED
Binary file (4.98 kB). View file
 
models/__pycache__/networks.cpython-310.pyc ADDED
Binary file (731 Bytes). View file
 
models/__pycache__/networks.cpython-38.pyc ADDED
Binary file (740 Bytes). View file
 
models/base_model.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.parallel import DistributedDataParallel
6
+
7
+
8
+ class BaseModel():
9
+ def __init__(self, opt):
10
+ self.opt = opt
11
+ self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
12
+ self.is_train = opt['is_train']
13
+ self.schedulers = []
14
+ self.optimizers = []
15
+
16
+ def feed_data(self, data):
17
+ pass
18
+
19
+ def optimize_parameters(self):
20
+ pass
21
+
22
+ def get_current_visuals(self):
23
+ pass
24
+
25
+ def get_current_losses(self):
26
+ pass
27
+
28
+ def print_network(self):
29
+ pass
30
+
31
+ def save(self, label):
32
+ pass
33
+
34
+ def load(self):
35
+ pass
36
+
37
+ def _set_lr(self, lr_groups_l):
38
+ ''' set learning rate for warmup,
39
+ lr_groups_l: list for lr_groups. each for a optimizer'''
40
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
41
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
42
+ param_group['lr'] = lr
43
+
44
+ def _get_init_lr(self):
45
+ # get the initial lr, which is set by the scheduler
46
+ init_lr_groups_l = []
47
+ for optimizer in self.optimizers:
48
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
49
+ return init_lr_groups_l
50
+
51
+ def update_learning_rate(self, cur_iter, warmup_iter=-1):
52
+ for scheduler in self.schedulers:
53
+ scheduler.step()
54
+ #### set up warm up learning rate
55
+ if cur_iter < warmup_iter:
56
+ # get initial lr for each group
57
+ init_lr_g_l = self._get_init_lr()
58
+ # modify warming-up learning rates
59
+ warm_up_lr_l = []
60
+ for init_lr_g in init_lr_g_l:
61
+ warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
62
+ # set learning rate
63
+ self._set_lr(warm_up_lr_l)
64
+
65
+ def get_current_learning_rate(self):
66
+ # return self.schedulers[0].get_lr()[0]
67
+ return self.optimizers[0].param_groups[0]['lr']
68
+
69
+ def get_network_description(self, network):
70
+ '''Get the string and total parameters of the network'''
71
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
72
+ network = network.module
73
+ s = str(network)
74
+ n = sum(map(lambda x: x.numel(), network.parameters()))
75
+ return s, n
76
+
77
+ def save_network(self, network, network_label, iter_label):
78
+ save_filename = '{}_{}.pth'.format(iter_label, network_label)
79
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
80
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
81
+ network = network.module
82
+ state_dict = network.state_dict()
83
+ for key, param in state_dict.items():
84
+ state_dict[key] = param.cpu()
85
+ torch.save(state_dict, save_path)
86
+
87
+ def load_network(self, load_path, network, strict=True):
88
+ if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
89
+ network = network.module
90
+ load_net = torch.load(load_path)
91
+ load_net_clean = OrderedDict() # remove unnecessary 'module.'
92
+ for k, v in load_net.items():
93
+ if k.startswith('module.'):
94
+ load_net_clean[k[7:]] = v
95
+ else:
96
+ load_net_clean[k] = v
97
+ network.load_state_dict(load_net_clean, strict=strict)
98
+
99
+ def save_training_state(self, epoch, iter_step):
100
+ '''Saves training state during training, which will be used for resuming'''
101
+ state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
102
+ for s in self.schedulers:
103
+ state['schedulers'].append(s.state_dict())
104
+ for o in self.optimizers:
105
+ state['optimizers'].append(o.state_dict())
106
+ save_filename = '{}.state'.format(iter_step)
107
+ save_path = os.path.join(self.opt['path']['training_state'], save_filename)
108
+ torch.save(state, save_path)
109
+
110
+ def resume_training(self, resume_state):
111
+ '''Resume the optimizers and schedulers for training'''
112
+ resume_optimizers = resume_state['optimizers']
113
+ resume_schedulers = resume_state['schedulers']
114
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
115
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
116
+ for i, o in enumerate(resume_optimizers):
117
+ self.optimizers[i].load_state_dict(o)
118
+ for i, s in enumerate(resume_schedulers):
119
+ self.schedulers[i].load_state_dict(s)
models/bitnetwork/ConvBlock.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class ConvINRelu(nn.Module):
5
+ """
6
+ A sequence of Convolution, Instance Normalization, and ReLU activation
7
+ """
8
+
9
+ def __init__(self, channels_in, channels_out, stride):
10
+ super(ConvINRelu, self).__init__()
11
+
12
+ self.layers = nn.Sequential(
13
+ nn.Conv2d(channels_in, channels_out, 3, stride, padding=1),
14
+ nn.InstanceNorm2d(channels_out),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.layers(x)
20
+
21
+
22
+ class ConvBlock(nn.Module):
23
+ '''
24
+ Network that composed by layers of ConvINRelu
25
+ '''
26
+
27
+ def __init__(self, in_channels, out_channels, blocks=1, stride=1):
28
+ super(ConvBlock, self).__init__()
29
+
30
+ layers = [ConvINRelu(in_channels, out_channels, stride)] if blocks != 0 else []
31
+ for _ in range(blocks - 1):
32
+ layer = ConvINRelu(out_channels, out_channels, 1)
33
+ layers.append(layer)
34
+
35
+ self.layers = nn.Sequential(*layers)
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
models/bitnetwork/DW_EncoderDecoder.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import *
2
+ from .Encoder_U import DW_Encoder
3
+ from .Decoder_U import DW_Decoder
4
+ from .Noise import Noise
5
+ from .Random_Noise import Random_Noise
6
+
7
+
8
+ class DW_EncoderDecoder(nn.Module):
9
+ '''
10
+ A Sequential of Encoder_MP-Noise-Decoder
11
+ '''
12
+
13
+ def __init__(self, message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder):
14
+ super(DW_EncoderDecoder, self).__init__()
15
+ self.encoder = DW_Encoder(message_length, attention = attention_encoder)
16
+ self.noise = Random_Noise(noise_layers_R + noise_layers_F, len(noise_layers_R), len(noise_layers_F))
17
+ self.decoder_C = DW_Decoder(message_length, attention = attention_decoder)
18
+ self.decoder_RF = DW_Decoder(message_length, attention = attention_decoder)
19
+
20
+
21
+ def forward(self, image, message, mask):
22
+ encoded_image = self.encoder(image, message)
23
+ noised_image_C, noised_image_R, noised_image_F = self.noise([encoded_image, image, mask])
24
+ decoded_message_C = self.decoder_C(noised_image_C)
25
+ decoded_message_R = self.decoder_RF(noised_image_R)
26
+ decoded_message_F = self.decoder_RF(noised_image_F)
27
+ return encoded_image, noised_image_C, decoded_message_C, decoded_message_R, decoded_message_F
28
+
models/bitnetwork/Decoder_U.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import *
2
+
3
+
4
+ class DW_Decoder(nn.Module):
5
+
6
+ def __init__(self, message_length, blocks=2, channels=64, attention=None):
7
+ super(DW_Decoder, self).__init__()
8
+
9
+ self.conv1 = ConvBlock(3, 16, blocks=blocks)
10
+ self.down1 = Down(16, 32, blocks=blocks)
11
+ self.down2 = Down(32, 64, blocks=blocks)
12
+ self.down3 = Down(64, 128, blocks=blocks)
13
+
14
+ self.down4 = Down(128, 256, blocks=blocks)
15
+
16
+ self.up3 = UP(256, 128)
17
+ self.att3 = ResBlock(128 * 2, 128, blocks=blocks, attention=attention)
18
+
19
+ self.up2 = UP(128, 64)
20
+ self.att2 = ResBlock(64 * 2, 64, blocks=blocks, attention=attention)
21
+
22
+ self.up1 = UP(64, 32)
23
+ self.att1 = ResBlock(32 * 2, 32, blocks=blocks, attention=attention)
24
+
25
+ self.up0 = UP(32, 16)
26
+ self.att0 = ResBlock(16 * 2, 16, blocks=blocks, attention=attention)
27
+
28
+ self.Conv_1x1 = nn.Conv2d(16, 1, kernel_size=1, stride=1, padding=0, bias=False)
29
+
30
+ self.message_layer = nn.Linear(message_length * message_length, message_length)
31
+ self.message_length = message_length
32
+
33
+
34
+ def forward(self, x):
35
+ d0 = self.conv1(x)
36
+ d1 = self.down1(d0)
37
+ d2 = self.down2(d1)
38
+ d3 = self.down3(d2)
39
+
40
+ d4 = self.down4(d3)
41
+
42
+ u3 = self.up3(d4)
43
+ u3 = torch.cat((d3, u3), dim=1)
44
+ u3 = self.att3(u3)
45
+
46
+ u2 = self.up2(u3)
47
+ u2 = torch.cat((d2, u2), dim=1)
48
+ u2 = self.att2(u2)
49
+
50
+ u1 = self.up1(u2)
51
+ u1 = torch.cat((d1, u1), dim=1)
52
+ u1 = self.att1(u1)
53
+
54
+ u0 = self.up0(u1)
55
+ u0 = torch.cat((d0, u0), dim=1)
56
+ u0 = self.att0(u0)
57
+
58
+ residual = self.Conv_1x1(u0)
59
+
60
+ message = F.interpolate(residual, size=(self.message_length, self.message_length),
61
+ mode='nearest')
62
+ message = message.view(message.shape[0], -1)
63
+ message = self.message_layer(message)
64
+
65
+ return message
66
+
67
+
68
+ class Down(nn.Module):
69
+ def __init__(self, in_channels, out_channels, blocks):
70
+ super(Down, self).__init__()
71
+ self.layer = torch.nn.Sequential(
72
+ ConvBlock(in_channels, in_channels, stride=2),
73
+ ConvBlock(in_channels, out_channels, blocks=blocks)
74
+ )
75
+
76
+ def forward(self, x):
77
+ return self.layer(x)
78
+
79
+
80
+ class UP(nn.Module):
81
+ def __init__(self, in_channels, out_channels):
82
+ super(UP, self).__init__()
83
+ self.conv = ConvBlock(in_channels, out_channels)
84
+
85
+ def forward(self, x):
86
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
87
+ return self.conv(x)
models/bitnetwork/Dual_Mark.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .DW_EncoderDecoder import *
2
+ from .Patch_Discriminator import Patch_Discriminator
3
+ import torch
4
+ import kornia.losses
5
+ import lpips
6
+
7
+
8
+ class Network:
9
+
10
+ def __init__(self, message_length, noise_layers_R, noise_layers_F, device, batch_size, lr, beta1, attention_encoder, attention_decoder, weight):
11
+ # device
12
+ self.device = device
13
+
14
+ # loss function
15
+ self.criterion_MSE = nn.MSELoss().to(device)
16
+ self.criterion_LPIPS = lpips.LPIPS().to(device)
17
+
18
+ # weight of encoder-decoder loss
19
+ self.encoder_weight = weight[0]
20
+ self.decoder_weight_C = weight[1]
21
+ self.decoder_weight_R = weight[2]
22
+ self.decoder_weight_F = weight[3]
23
+ self.discriminator_weight = weight[4]
24
+
25
+ # network
26
+ self.encoder_decoder = DW_EncoderDecoder(message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder).to(device)
27
+ self.discriminator = Patch_Discriminator().to(device)
28
+
29
+ self.encoder_decoder = torch.nn.DataParallel(self.encoder_decoder)
30
+ self.discriminator = torch.nn.DataParallel(self.discriminator)
31
+
32
+ # mark "cover" as 1, "encoded" as -1
33
+ self.label_cover = 1.0
34
+ self.label_encoded = - 1.0
35
+
36
+ for p in self.encoder_decoder.module.noise.parameters():
37
+ p.requires_grad = False
38
+
39
+ # optimizer
40
+ self.opt_encoder_decoder = torch.optim.Adam(
41
+ filter(lambda p: p.requires_grad, self.encoder_decoder.parameters()), lr=lr, betas=(beta1, 0.999))
42
+ self.opt_discriminator = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
43
+
44
+
45
+ def train(self, images: torch.Tensor, messages: torch.Tensor, masks: torch.Tensor):
46
+ self.encoder_decoder.train()
47
+ self.discriminator.train()
48
+
49
+ with torch.enable_grad():
50
+ # use device to compute
51
+ images, messages, masks = images.to(self.device), messages.to(self.device), masks.to(self.device)
52
+ encoded_images, noised_images, decoded_messages_C, decoded_messages_R, decoded_messages_F = self.encoder_decoder(images, messages, masks)
53
+
54
+ '''
55
+ train discriminator
56
+ '''
57
+ for p in self.discriminator.parameters():
58
+ p.requires_grad = True
59
+
60
+ self.opt_discriminator.zero_grad()
61
+
62
+ # RAW : target label for image should be "cover"(1)
63
+ d_label_cover = self.discriminator(images)
64
+ #d_cover_loss = self.criterion_MSE(d_label_cover, torch.ones_like(d_label_cover))
65
+ #d_cover_loss.backward()
66
+
67
+ # GAN : target label for encoded image should be "encoded"(0)
68
+ d_label_encoded = self.discriminator(encoded_images.detach())
69
+ #d_encoded_loss = self.criterion_MSE(d_label_encoded, torch.zeros_like(d_label_encoded))
70
+ #d_encoded_loss.backward()
71
+
72
+ d_loss = self.criterion_MSE(d_label_cover - torch.mean(d_label_encoded), self.label_cover * torch.ones_like(d_label_cover)) +\
73
+ self.criterion_MSE(d_label_encoded - torch.mean(d_label_cover), self.label_encoded * torch.ones_like(d_label_encoded))
74
+ d_loss.backward()
75
+
76
+ self.opt_discriminator.step()
77
+
78
+ '''
79
+ train encoder and decoder
80
+ '''
81
+ # Make it a tiny bit faster
82
+ for p in self.discriminator.parameters():
83
+ p.requires_grad = False
84
+
85
+ self.opt_encoder_decoder.zero_grad()
86
+
87
+ # GAN : target label for encoded image should be "cover"(0)
88
+ g_label_cover = self.discriminator(images)
89
+ g_label_encoded = self.discriminator(encoded_images)
90
+ g_loss_on_discriminator = self.criterion_MSE(g_label_cover - torch.mean(g_label_encoded), self.label_encoded * torch.ones_like(g_label_cover)) +\
91
+ self.criterion_MSE(g_label_encoded - torch.mean(g_label_cover), self.label_cover * torch.ones_like(g_label_encoded))
92
+
93
+ # RAW : the encoded image should be similar to cover image
94
+ g_loss_on_encoder_MSE = self.criterion_MSE(encoded_images, images)
95
+ g_loss_on_encoder_LPIPS = torch.mean(self.criterion_LPIPS(encoded_images, images))
96
+
97
+ # RESULT : the decoded message should be similar to the raw message /Dual
98
+ g_loss_on_decoder_C = self.criterion_MSE(decoded_messages_C, messages)
99
+ g_loss_on_decoder_R = self.criterion_MSE(decoded_messages_R, messages)
100
+ g_loss_on_decoder_F = self.criterion_MSE(decoded_messages_F, torch.zeros_like(messages))
101
+
102
+ # full loss
103
+ g_loss = self.discriminator_weight * g_loss_on_discriminator + self.encoder_weight * g_loss_on_encoder_MSE +\
104
+ self.decoder_weight_C * g_loss_on_decoder_C + self.decoder_weight_R * g_loss_on_decoder_R + self.decoder_weight_F * g_loss_on_decoder_F
105
+
106
+ g_loss.backward()
107
+ self.opt_encoder_decoder.step()
108
+
109
+ # psnr
110
+ psnr = - kornia.losses.psnr_loss(encoded_images.detach(), images, 2)
111
+
112
+ # ssim
113
+ ssim = 1 - 2 * kornia.losses.ssim_loss(encoded_images.detach(), images, window_size=11, reduction="mean")
114
+
115
+ '''
116
+ decoded message error rate /Dual
117
+ '''
118
+ error_rate_C = self.decoded_message_error_rate_batch(messages, decoded_messages_C)
119
+ error_rate_R = self.decoded_message_error_rate_batch(messages, decoded_messages_R)
120
+ error_rate_F = self.decoded_message_error_rate_batch(messages, decoded_messages_F)
121
+
122
+ result = {
123
+ "g_loss": g_loss,
124
+ "error_rate_C": error_rate_C,
125
+ "error_rate_R": error_rate_R,
126
+ "error_rate_F": error_rate_F,
127
+ "psnr": psnr,
128
+ "ssim": ssim,
129
+ "g_loss_on_discriminator": g_loss_on_discriminator,
130
+ "g_loss_on_encoder_MSE": g_loss_on_encoder_MSE,
131
+ "g_loss_on_encoder_LPIPS": g_loss_on_encoder_LPIPS,
132
+ "g_loss_on_decoder_C": g_loss_on_decoder_C,
133
+ "g_loss_on_decoder_R": g_loss_on_decoder_R,
134
+ "g_loss_on_decoder_F": g_loss_on_decoder_F,
135
+ "d_loss": d_loss
136
+ }
137
+ return result
138
+
139
+
140
+ def validation(self, images: torch.Tensor, messages: torch.Tensor, masks: torch.Tensor):
141
+ self.encoder_decoder.eval()
142
+ self.encoder_decoder.module.noise.train()
143
+ self.discriminator.eval()
144
+
145
+ with torch.no_grad():
146
+ # use device to compute
147
+ images, messages, masks = images.to(self.device), messages.to(self.device), masks.to(self.device)
148
+ encoded_images, noised_images, decoded_messages_C, decoded_messages_R, decoded_messages_F = self.encoder_decoder(images, messages, masks)
149
+
150
+ '''
151
+ validate discriminator
152
+ '''
153
+ # RAW : target label for image should be "cover"(1)
154
+ d_label_cover = self.discriminator(images)
155
+ #d_cover_loss = self.criterion_MSE(d_label_cover, torch.ones_like(d_label_cover))
156
+
157
+ # GAN : target label for encoded image should be "encoded"(0)
158
+ d_label_encoded = self.discriminator(encoded_images.detach())
159
+ #d_encoded_loss = self.criterion_MSE(d_label_encoded, torch.zeros_like(d_label_encoded))
160
+
161
+ d_loss = self.criterion_MSE(d_label_cover - torch.mean(d_label_encoded), self.label_cover * torch.ones_like(d_label_cover)) +\
162
+ self.criterion_MSE(d_label_encoded - torch.mean(d_label_cover), self.label_encoded * torch.ones_like(d_label_encoded))
163
+
164
+ '''
165
+ validate encoder and decoder
166
+ '''
167
+
168
+ # GAN : target label for encoded image should be "cover"(0)
169
+ g_label_cover = self.discriminator(images)
170
+ g_label_encoded = self.discriminator(encoded_images)
171
+ g_loss_on_discriminator = self.criterion_MSE(g_label_cover - torch.mean(g_label_encoded), self.label_encoded * torch.ones_like(g_label_cover)) +\
172
+ self.criterion_MSE(g_label_encoded - torch.mean(g_label_cover), self.label_cover * torch.ones_like(g_label_encoded))
173
+
174
+ # RAW : the encoded image should be similar to cover image
175
+ g_loss_on_encoder_MSE = self.criterion_MSE(encoded_images, images)
176
+ g_loss_on_encoder_LPIPS = torch.mean(self.criterion_LPIPS(encoded_images, images))
177
+
178
+ # RESULT : the decoded message should be similar to the raw message /Dual
179
+ g_loss_on_decoder_C = self.criterion_MSE(decoded_messages_C, messages)
180
+ g_loss_on_decoder_R = self.criterion_MSE(decoded_messages_R, messages)
181
+ g_loss_on_decoder_F = self.criterion_MSE(decoded_messages_F, torch.zeros_like(messages))
182
+
183
+ # full loss
184
+ # unstable g_loss_on_discriminator is not used during validation
185
+
186
+ g_loss = 0 * g_loss_on_discriminator + self.encoder_weight * g_loss_on_encoder_LPIPS +\
187
+ self.decoder_weight_C * g_loss_on_decoder_C + self.decoder_weight_R * g_loss_on_decoder_R + self.decoder_weight_F * g_loss_on_decoder_F
188
+
189
+
190
+ # psnr
191
+ psnr = - kornia.losses.psnr_loss(encoded_images.detach(), images, 2)
192
+
193
+ # ssim
194
+ ssim = 1 - 2 * kornia.losses.ssim_loss(encoded_images.detach(), images, window_size=11, reduction="mean")
195
+
196
+ '''
197
+ decoded message error rate /Dual
198
+ '''
199
+ error_rate_C = self.decoded_message_error_rate_batch(messages, decoded_messages_C)
200
+ error_rate_R = self.decoded_message_error_rate_batch(messages, decoded_messages_R)
201
+ error_rate_F = self.decoded_message_error_rate_batch(messages, decoded_messages_F)
202
+
203
+ result = {
204
+ "g_loss": g_loss,
205
+ "error_rate_C": error_rate_C,
206
+ "error_rate_R": error_rate_R,
207
+ "error_rate_F": error_rate_F,
208
+ "psnr": psnr,
209
+ "ssim": ssim,
210
+ "g_loss_on_discriminator": g_loss_on_discriminator,
211
+ "g_loss_on_encoder_MSE": g_loss_on_encoder_MSE,
212
+ "g_loss_on_encoder_LPIPS": g_loss_on_encoder_LPIPS,
213
+ "g_loss_on_decoder_C": g_loss_on_decoder_C,
214
+ "g_loss_on_decoder_R": g_loss_on_decoder_R,
215
+ "g_loss_on_decoder_F": g_loss_on_decoder_F,
216
+ "d_loss": d_loss
217
+ }
218
+
219
+ return result, (images, encoded_images, noised_images)
220
+
221
+ def decoded_message_error_rate(self, message, decoded_message):
222
+ length = message.shape[0]
223
+
224
+ message = message.gt(0)
225
+ decoded_message = decoded_message.gt(0)
226
+ error_rate = float(sum(message != decoded_message)) / length
227
+ return error_rate
228
+
229
+ def decoded_message_error_rate_batch(self, messages, decoded_messages):
230
+ error_rate = 0.0
231
+ batch_size = len(messages)
232
+ for i in range(batch_size):
233
+ error_rate += self.decoded_message_error_rate(messages[i], decoded_messages[i])
234
+ error_rate /= batch_size
235
+ return error_rate
236
+
237
+ def save_model(self, path_encoder_decoder: str, path_discriminator: str):
238
+ torch.save(self.encoder_decoder.module.state_dict(), path_encoder_decoder)
239
+ torch.save(self.discriminator.module.state_dict(), path_discriminator)
240
+
241
+ def load_model(self, path_encoder_decoder: str, path_discriminator: str):
242
+ self.load_model_ed(path_encoder_decoder)
243
+ self.load_model_dis(path_discriminator)
244
+
245
+ def load_model_ed(self, path_encoder_decoder: str):
246
+ self.encoder_decoder.module.load_state_dict(torch.load(path_encoder_decoder), strict=False)
247
+
248
+ def load_model_dis(self, path_discriminator: str):
249
+ self.discriminator.module.load_state_dict(torch.load(path_discriminator))
models/bitnetwork/Encoder_U.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import *
2
+
3
+ class DW_Encoder(nn.Module):
4
+
5
+ def __init__(self, message_length, blocks=2, channels=64, attention=None):
6
+ super(DW_Encoder, self).__init__()
7
+
8
+ self.conv1 = ConvBlock(3, 16, blocks=blocks)
9
+ self.down1 = Down(16, 32, blocks=blocks)
10
+ self.down2 = Down(32, 64, blocks=blocks)
11
+ self.down3 = Down(64, 128, blocks=blocks)
12
+
13
+ self.down4 = Down(128, 256, blocks=blocks)
14
+
15
+ self.up3 = UP(256, 128)
16
+ self.linear3 = nn.Linear(message_length, message_length * message_length)
17
+ self.Conv_message3 = ConvBlock(1, channels, blocks=blocks)
18
+ self.att3 = ResBlock(128 * 2 + channels, 128, blocks=blocks, attention=attention)
19
+
20
+ self.up2 = UP(128, 64)
21
+ self.linear2 = nn.Linear(message_length, message_length * message_length)
22
+ self.Conv_message2 = ConvBlock(1, channels, blocks=blocks)
23
+ self.att2 = ResBlock(64 * 2 + channels, 64, blocks=blocks, attention=attention)
24
+
25
+ self.up1 = UP(64, 32)
26
+ self.linear1 = nn.Linear(message_length, message_length * message_length)
27
+ self.Conv_message1 = ConvBlock(1, channels, blocks=blocks)
28
+ self.att1 = ResBlock(32 * 2 + channels, 32, blocks=blocks, attention=attention)
29
+
30
+ self.up0 = UP(32, 16)
31
+ self.linear0 = nn.Linear(message_length, message_length * message_length)
32
+ self.Conv_message0 = ConvBlock(1, channels, blocks=blocks)
33
+ self.att0 = ResBlock(16 * 2 + channels, 16, blocks=blocks, attention=attention)
34
+
35
+ self.Conv_1x1 = nn.Conv2d(16 + 3, 3, kernel_size=1, stride=1, padding=0)
36
+
37
+ self.message_length = message_length
38
+
39
+ self.transform = transforms.Compose([
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
42
+ ])
43
+
44
+
45
+ def forward(self, x, watermark):
46
+ d0 = self.conv1(x)
47
+ d1 = self.down1(d0)
48
+ d2 = self.down2(d1)
49
+ d3 = self.down3(d2)
50
+
51
+ d4 = self.down4(d3)
52
+
53
+ u3 = self.up3(d4)
54
+ expanded_message = self.linear3(watermark)
55
+ expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
56
+ expanded_message = F.interpolate(expanded_message, size=(d3.shape[2], d3.shape[3]),
57
+ mode='nearest')
58
+ expanded_message = self.Conv_message3(expanded_message)
59
+ u3 = torch.cat((d3, u3, expanded_message), dim=1)
60
+ u3 = self.att3(u3)
61
+
62
+ u2 = self.up2(u3)
63
+ expanded_message = self.linear2(watermark)
64
+ expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
65
+ expanded_message = F.interpolate(expanded_message, size=(d2.shape[2], d2.shape[3]),
66
+ mode='nearest')
67
+ expanded_message = self.Conv_message2(expanded_message)
68
+ u2 = torch.cat((d2, u2, expanded_message), dim=1)
69
+ u2 = self.att2(u2)
70
+
71
+ u1 = self.up1(u2)
72
+ expanded_message = self.linear1(watermark)
73
+ expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
74
+ expanded_message = F.interpolate(expanded_message, size=(d1.shape[2], d1.shape[3]),
75
+ mode='nearest')
76
+ expanded_message = self.Conv_message1(expanded_message)
77
+ u1 = torch.cat((d1, u1, expanded_message), dim=1)
78
+ u1 = self.att1(u1)
79
+
80
+ u0 = self.up0(u1)
81
+ expanded_message = self.linear0(watermark)
82
+ expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length)
83
+ expanded_message = F.interpolate(expanded_message, size=(d0.shape[2], d0.shape[3]),
84
+ mode='nearest')
85
+ expanded_message = self.Conv_message0(expanded_message)
86
+ u0 = torch.cat((d0, u0, expanded_message), dim=1)
87
+ u0 = self.att0(u0)
88
+
89
+ image = self.Conv_1x1(torch.cat((x, u0), dim=1))
90
+
91
+ forward_image = image.clone().detach()
92
+ '''read_image = torch.zeros_like(forward_image)
93
+
94
+ for index in range(forward_image.shape[0]):
95
+ single_image = ((forward_image[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
96
+ im = Image.fromarray(single_image)
97
+ read = np.array(im, dtype=np.uint8)
98
+ read_image[index] = self.transform(read).unsqueeze(0).to(image.device)
99
+
100
+ gap = read_image - forward_image'''
101
+ gap = forward_image.clamp(-1, 1) - forward_image
102
+
103
+ return image + gap
104
+
105
+
106
+ class Down(nn.Module):
107
+ def __init__(self, in_channels, out_channels, blocks):
108
+ super(Down, self).__init__()
109
+ self.layer = torch.nn.Sequential(
110
+ ConvBlock(in_channels, in_channels, stride=2),
111
+ ConvBlock(in_channels, out_channels, blocks=blocks)
112
+ )
113
+
114
+ def forward(self, x):
115
+ return self.layer(x)
116
+
117
+
118
+ class UP(nn.Module):
119
+ def __init__(self, in_channels, out_channels):
120
+ super(UP, self).__init__()
121
+ self.conv = ConvBlock(in_channels, out_channels)
122
+
123
+ def forward(self, x):
124
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
125
+ return self.conv(x)
models/bitnetwork/Random_Noise.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import *
2
+ from .noise_layers import *
3
+
4
+
5
+ class Random_Noise(nn.Module):
6
+
7
+ def __init__(self, layers, len_layers_R, len_layers_F):
8
+ super(Random_Noise, self).__init__()
9
+ for i in range(len(layers)):
10
+ layers[i] = eval(layers[i])
11
+ self.noise = nn.Sequential(*layers)
12
+ self.len_layers_R = len_layers_R
13
+ self.len_layers_F = len_layers_F
14
+ print(self.noise)
15
+ self.transform = transforms.Compose([
16
+ transforms.ToTensor(),
17
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
18
+ ])
19
+
20
+ def forward(self, image_cover_mask):
21
+ image, cover_image, mask = image_cover_mask[0], image_cover_mask[1], image_cover_mask[2]
22
+ forward_image = image.clone().detach()
23
+ forward_cover_image = cover_image.clone().detach()
24
+ forward_mask = mask.clone().detach()
25
+ noised_image_C = torch.zeros_like(forward_image)
26
+ noised_image_R = torch.zeros_like(forward_image)
27
+ noised_image_F = torch.zeros_like(forward_image)
28
+
29
+ for index in range(forward_image.shape[0]):
30
+ random_noise_layer_C = np.random.choice(self.noise, 1)[0]
31
+ random_noise_layer_R = np.random.choice(self.noise[0:self.len_layers_R], 1)[0]
32
+ random_noise_layer_F = np.random.choice(self.noise[self.len_layers_R:self.len_layers_R + self.len_layers_F], 1)[0]
33
+ noised_image_C[index] = random_noise_layer_C([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)])
34
+ noised_image_R[index] = random_noise_layer_R([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)])
35
+ noised_image_F[index] = random_noise_layer_F([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)])
36
+
37
+ '''single_image = ((noised_image_C[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
38
+ im = Image.fromarray(single_image)
39
+ read = np.array(im, dtype=np.uint8)
40
+ noised_image_C[index] = self.transform(read).unsqueeze(0).to(image.device)
41
+
42
+ single_image = ((noised_image_R[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
43
+ im = Image.fromarray(single_image)
44
+ read = np.array(im, dtype=np.uint8)
45
+ noised_image_R[index] = self.transform(read).unsqueeze(0).to(image.device)
46
+
47
+ single_image = ((noised_image_F[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy()
48
+ im = Image.fromarray(single_image)
49
+ read = np.array(im, dtype=np.uint8)
50
+ noised_image_F[index] = self.transform(read).unsqueeze(0).to(image.device)
51
+
52
+ noised_image_gap_C = noised_image_C - forward_image
53
+ noised_image_gap_R = noised_image_R - forward_image
54
+ noised_image_gap_F = noised_image_F - forward_image'''
55
+ noised_image_gap_C = noised_image_C.clamp(-1, 1) - forward_image
56
+ noised_image_gap_R = noised_image_R.clamp(-1, 1) - forward_image
57
+ noised_image_gap_F = noised_image_F.clamp(-1, 1) - forward_image
58
+
59
+ return image + noised_image_gap_C, image + noised_image_gap_R, image + noised_image_gap_F
models/bitnetwork/ResBlock.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SEAttention(nn.Module):
7
+ def __init__(self, in_channels, out_channels, reduction=8):
8
+ super(SEAttention, self).__init__()
9
+ self.se = nn.Sequential(
10
+ nn.AdaptiveAvgPool2d((1, 1)),
11
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False),
14
+ nn.Sigmoid()
15
+ )
16
+
17
+ def forward(self, x):
18
+ x = self.se(x) * x
19
+ return x
20
+
21
+
22
+ class ChannelAttention(nn.Module):
23
+ def __init__(self, in_channels, out_channels, reduction=8):
24
+ super(ChannelAttention, self).__init__()
25
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
26
+ self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
27
+
28
+ self.fc = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False),
29
+ nn.ReLU(inplace=True),
30
+ nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False))
31
+ self.sigmoid = nn.Sigmoid()
32
+
33
+ def forward(self, x):
34
+ avg_out = self.fc(self.avg_pool(x))
35
+ max_out = self.fc(self.max_pool(x))
36
+ out = avg_out + max_out
37
+ return self.sigmoid(out)
38
+
39
+
40
+ class SpatialAttention(nn.Module):
41
+ def __init__(self, kernel_size=7):
42
+ super(SpatialAttention, self).__init__()
43
+
44
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
45
+ self.sigmoid = nn.Sigmoid()
46
+
47
+ def forward(self, x):
48
+ avg_out = torch.mean(x, dim=1, keepdim=True)
49
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
50
+ x = torch.cat([avg_out, max_out], dim=1)
51
+ x = self.conv1(x)
52
+ return self.sigmoid(x)
53
+
54
+
55
+ class CBAMAttention(nn.Module):
56
+ def __init__(self, in_channels, out_channels, reduction=8):
57
+ super(CBAMAttention, self).__init__()
58
+ self.ca = ChannelAttention(in_channels=in_channels, out_channels=out_channels, reduction=reduction)
59
+ self.sa = SpatialAttention()
60
+
61
+ def forward(self, x):
62
+ x = self.ca(x) * x
63
+ x = self.sa(x) * x
64
+ return x
65
+
66
+
67
+ class h_sigmoid(nn.Module):
68
+ def __init__(self, inplace=True):
69
+ super(h_sigmoid, self).__init__()
70
+ self.relu = nn.ReLU6(inplace=inplace)
71
+
72
+ def forward(self, x):
73
+ return self.relu(x + 3) / 6
74
+
75
+
76
+ class h_swish(nn.Module):
77
+ def __init__(self, inplace=True):
78
+ super(h_swish, self).__init__()
79
+ self.sigmoid = h_sigmoid(inplace=inplace)
80
+
81
+ def forward(self, x):
82
+ return x * self.sigmoid(x)
83
+
84
+
85
+ class CoordAttention(nn.Module):
86
+ def __init__(self, in_channels, out_channels, reduction=8):
87
+ super(CoordAttention, self).__init__()
88
+ self.pool_w, self.pool_h = nn.AdaptiveAvgPool2d((1, None)), nn.AdaptiveAvgPool2d((None, 1))
89
+ temp_c = max(8, in_channels // reduction)
90
+ self.conv1 = nn.Conv2d(in_channels, temp_c, kernel_size=1, stride=1, padding=0)
91
+
92
+ self.bn1 = nn.InstanceNorm2d(temp_c)
93
+ self.act1 = h_swish() # nn.SiLU() # nn.Hardswish() # nn.SiLU()
94
+
95
+ self.conv2 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
96
+ self.conv3 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
97
+
98
+ def forward(self, x):
99
+ short = x
100
+ n, c, H, W = x.shape
101
+ x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2)
102
+ x_cat = torch.cat([x_h, x_w], dim=2)
103
+ out = self.act1(self.bn1(self.conv1(x_cat)))
104
+ x_h, x_w = torch.split(out, [H, W], dim=2)
105
+ x_w = x_w.permute(0, 1, 3, 2)
106
+ out_h = torch.sigmoid(self.conv2(x_h))
107
+ out_w = torch.sigmoid(self.conv3(x_w))
108
+ return short * out_w * out_h
109
+
110
+
111
+ class BasicBlock(nn.Module):
112
+ def __init__(self, in_channels, out_channels, reduction, stride, attention=None):
113
+ super(BasicBlock, self).__init__()
114
+
115
+ self.change = None
116
+ if (in_channels != out_channels or stride != 1):
117
+ self.change = nn.Sequential(
118
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0,
119
+ stride=stride, bias=False),
120
+ nn.InstanceNorm2d(out_channels)
121
+ )
122
+
123
+ self.left = nn.Sequential(
124
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1,
125
+ stride=stride, bias=False),
126
+ nn.InstanceNorm2d(out_channels),
127
+ nn.ReLU(inplace=True),
128
+ nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
129
+ nn.InstanceNorm2d(out_channels)
130
+ )
131
+
132
+ if attention == 'se':
133
+ print('SEAttention')
134
+ self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
135
+ elif attention == 'cbam':
136
+ print('CBAMAttention')
137
+ self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
138
+ elif attention == 'coord':
139
+ print('CoordAttention')
140
+ self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
141
+ else:
142
+ print('None Attention')
143
+ self.attention = nn.Identity()
144
+
145
+ def forward(self, x):
146
+ identity = x
147
+ x = self.left(x)
148
+ x = self.attention(x)
149
+
150
+ if self.change is not None:
151
+ identity = self.change(identity)
152
+
153
+ x += identity
154
+ x = F.relu(x)
155
+ return x
156
+
157
+
158
+ class BottleneckBlock(nn.Module):
159
+ def __init__(self, in_channels, out_channels, reduction, stride, attention=None):
160
+ super(BottleneckBlock, self).__init__()
161
+
162
+ self.change = None
163
+ if (in_channels != out_channels or stride != 1):
164
+ self.change = nn.Sequential(
165
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0,
166
+ stride=stride, bias=False),
167
+ nn.InstanceNorm2d(out_channels)
168
+ )
169
+
170
+ self.left = nn.Sequential(
171
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
172
+ stride=stride, padding=0, bias=False),
173
+ nn.InstanceNorm2d(out_channels),
174
+ nn.ReLU(inplace=True),
175
+ nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
176
+ nn.InstanceNorm2d(out_channels),
177
+ nn.ReLU(inplace=True),
178
+ nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding=0, bias=False),
179
+ nn.InstanceNorm2d(out_channels)
180
+ )
181
+
182
+ if attention == 'se':
183
+ print('SEAttention')
184
+ self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
185
+ elif attention == 'cbam':
186
+ print('CBAMAttention')
187
+ self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
188
+ elif attention == 'coord':
189
+ print('CoordAttention')
190
+ self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
191
+ else:
192
+ print('None Attention')
193
+ self.attention = nn.Identity()
194
+
195
+ def forward(self, x):
196
+ identity = x
197
+ x = self.left(x)
198
+ x = self.attention(x)
199
+
200
+ if self.change is not None:
201
+ identity = self.change(identity)
202
+
203
+ x += identity
204
+ x = F.relu(x)
205
+ return x
206
+
207
+
208
+ class ResBlock(nn.Module):
209
+
210
+ def __init__(self, in_channels, out_channels, blocks=1, block_type="BottleneckBlock", reduction=8, stride=1, attention=None):
211
+ super(ResBlock, self).__init__()
212
+
213
+ layers = [eval(block_type)(in_channels, out_channels, reduction, stride, attention=attention)] if blocks != 0 else []
214
+ for _ in range(blocks - 1):
215
+ layer = eval(block_type)(out_channels, out_channels, reduction, 1, attention=attention)
216
+ layers.append(layer)
217
+
218
+ self.layers = nn.Sequential(*layers)
219
+
220
+ def forward(self, x):
221
+ return self.layers(x)
222
+
models/bitnetwork/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ # import kornia.losses
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from .ResBlock import *
9
+ from .ConvBlock import *
models/bitnetwork/__pycache__/ConvBlock.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
models/bitnetwork/__pycache__/Decoder_U.cpython-38.pyc ADDED
Binary file (2.75 kB). View file
 
models/bitnetwork/__pycache__/Encoder_U.cpython-38.pyc ADDED
Binary file (3.51 kB). View file
 
models/bitnetwork/__pycache__/ResBlock.cpython-38.pyc ADDED
Binary file (7.51 kB). View file
 
models/bitnetwork/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (410 Bytes). View file
 
models/discrim.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+ from torch.nn.utils import spectral_norm
4
+
5
+
6
+ class UNetDiscriminatorSN(nn.Module):
7
+ """Defines a U-Net discriminator with spectral normalization (SN)
8
+
9
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
10
+
11
+ Arg:
12
+ num_in_ch (int): Channel number of inputs. Default: 3.
13
+ num_feat (int): Channel number of base intermediate features. Default: 64.
14
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
15
+ """
16
+
17
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
18
+ super(UNetDiscriminatorSN, self).__init__()
19
+ self.skip_connection = skip_connection
20
+ norm = spectral_norm
21
+ # the first convolution
22
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
23
+ # downsample
24
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
25
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
26
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
27
+ # upsample
28
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
29
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
30
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
31
+ # extra convolutions
32
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
33
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
34
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
35
+
36
+ def forward(self, x):
37
+ # downsample
38
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
39
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
40
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
41
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
42
+
43
+ # upsample
44
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
45
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
46
+
47
+ if self.skip_connection:
48
+ x4 = x4 + x2
49
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
50
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
51
+
52
+ if self.skip_connection:
53
+ x5 = x5 + x1
54
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
55
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
56
+
57
+ if self.skip_connection:
58
+ x6 = x6 + x0
59
+
60
+ # extra convolutions
61
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
62
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
63
+ out = self.conv9(out)
64
+
65
+ return out
66
+
67
+
68
+ class GANLoss(nn.Module):
69
+ """Define GAN loss.
70
+
71
+ Args:
72
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
73
+ real_label_val (float): The value for real label. Default: 1.0.
74
+ fake_label_val (float): The value for fake label. Default: 0.0.
75
+ loss_weight (float): Loss weight. Default: 1.0.
76
+ Note that loss_weight is only for generators; and it is always 1.0
77
+ for discriminators.
78
+ """
79
+
80
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
81
+ super(GANLoss, self).__init__()
82
+ self.gan_type = gan_type
83
+ self.loss_weight = loss_weight
84
+ self.real_label_val = real_label_val
85
+ self.fake_label_val = fake_label_val
86
+
87
+ if self.gan_type == 'vanilla':
88
+ self.loss = nn.BCEWithLogitsLoss()
89
+ elif self.gan_type == 'lsgan':
90
+ self.loss = nn.MSELoss()
91
+ elif self.gan_type == 'wgan':
92
+ self.loss = self._wgan_loss
93
+ elif self.gan_type == 'wgan_softplus':
94
+ self.loss = self._wgan_softplus_loss
95
+ elif self.gan_type == 'hinge':
96
+ self.loss = nn.ReLU()
97
+ else:
98
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
99
+
100
+ def _wgan_loss(self, input, target):
101
+ """wgan loss.
102
+
103
+ Args:
104
+ input (Tensor): Input tensor.
105
+ target (bool): Target label.
106
+
107
+ Returns:
108
+ Tensor: wgan loss.
109
+ """
110
+ return -input.mean() if target else input.mean()
111
+
112
+ def _wgan_softplus_loss(self, input, target):
113
+ """wgan loss with soft plus. softplus is a smooth approximation to the
114
+ ReLU function.
115
+
116
+ In StyleGAN2, it is called:
117
+ Logistic loss for discriminator;
118
+ Non-saturating loss for generator.
119
+
120
+ Args:
121
+ input (Tensor): Input tensor.
122
+ target (bool): Target label.
123
+
124
+ Returns:
125
+ Tensor: wgan loss.
126
+ """
127
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
128
+
129
+ def get_target_label(self, input, target_is_real):
130
+ """Get target label.
131
+
132
+ Args:
133
+ input (Tensor): Input tensor.
134
+ target_is_real (bool): Whether the target is real or fake.
135
+
136
+ Returns:
137
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
138
+ return Tensor.
139
+ """
140
+
141
+ if self.gan_type in ['wgan', 'wgan_softplus']:
142
+ return target_is_real
143
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
144
+ return input.new_ones(input.size()) * target_val
145
+
146
+ def forward(self, input, target_is_real, is_disc=False):
147
+ """
148
+ Args:
149
+ input (Tensor): The input for the loss module, i.e., the network
150
+ prediction.
151
+ target_is_real (bool): Whether the targe is real or fake.
152
+ is_disc (bool): Whether the loss for discriminators or not.
153
+ Default: False.
154
+
155
+ Returns:
156
+ Tensor: GAN loss value.
157
+ """
158
+ target_label = self.get_target_label(input, target_is_real)
159
+ if self.gan_type == 'hinge':
160
+ if is_disc: # for discriminators in hinge-gan
161
+ input = -input if target_is_real else input
162
+ loss = self.loss(1 + input).mean()
163
+ else: # for generators in hinge-gan
164
+ loss = -input.mean()
165
+ else: # other gan types
166
+ loss = self.loss(input, target_label)
167
+
168
+ # loss_weight is always 1.0 for discriminators
169
+ return loss if is_disc else loss * self.loss_weight
models/lr_scheduler.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import Counter
3
+ from collections import defaultdict
4
+ import torch
5
+ from torch.optim.lr_scheduler import _LRScheduler
6
+
7
+
8
+ class MultiStepLR_Restart(_LRScheduler):
9
+ def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10
+ clear_state=False, last_epoch=-1):
11
+ self.milestones = Counter(milestones)
12
+ self.gamma = gamma
13
+ self.clear_state = clear_state
14
+ self.restarts = restarts if restarts else [0]
15
+ self.restart_weights = weights if weights else [1]
16
+ assert len(self.restarts) == len(
17
+ self.restart_weights), 'restarts and their weights do not match.'
18
+ super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
19
+
20
+ def get_lr(self):
21
+ if self.last_epoch in self.restarts:
22
+ if self.clear_state:
23
+ self.optimizer.state = defaultdict(dict)
24
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
25
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
26
+ if self.last_epoch not in self.milestones:
27
+ return [group['lr'] for group in self.optimizer.param_groups]
28
+ return [
29
+ group['lr'] * self.gamma**self.milestones[self.last_epoch]
30
+ for group in self.optimizer.param_groups
31
+ ]
32
+
33
+
34
+ class CosineAnnealingLR_Restart(_LRScheduler):
35
+ def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
36
+ self.T_period = T_period
37
+ self.T_max = self.T_period[0] # current T period
38
+ self.eta_min = eta_min
39
+ self.restarts = restarts if restarts else [0]
40
+ self.restart_weights = weights if weights else [1]
41
+ self.last_restart = 0
42
+ assert len(self.restarts) == len(
43
+ self.restart_weights), 'restarts and their weights do not match.'
44
+ super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
45
+
46
+ def get_lr(self):
47
+ if self.last_epoch == 0:
48
+ return self.base_lrs
49
+ elif self.last_epoch in self.restarts:
50
+ self.last_restart = self.last_epoch
51
+ self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
52
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
53
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
54
+ elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
55
+ return [
56
+ group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
57
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
58
+ ]
59
+ return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
60
+ (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
61
+ (group['lr'] - self.eta_min) + self.eta_min
62
+ for group in self.optimizer.param_groups]
63
+
64
+
65
+ if __name__ == "__main__":
66
+ optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
67
+ betas=(0.9, 0.99))
68
+ ##############################
69
+ # MultiStepLR_Restart
70
+ ##############################
71
+ ## Original
72
+ lr_steps = [200000, 400000, 600000, 800000]
73
+ restarts = None
74
+ restart_weights = None
75
+
76
+ ## two
77
+ lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
78
+ restarts = [500000]
79
+ restart_weights = [1]
80
+
81
+ ## four
82
+ lr_steps = [
83
+ 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
84
+ 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
85
+ ]
86
+ restarts = [250000, 500000, 750000]
87
+ restart_weights = [1, 1, 1]
88
+
89
+ scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
90
+ clear_state=False)
91
+
92
+ ##############################
93
+ # Cosine Annealing Restart
94
+ ##############################
95
+ ## two
96
+ T_period = [500000, 500000]
97
+ restarts = [500000]
98
+ restart_weights = [1]
99
+
100
+ ## four
101
+ T_period = [250000, 250000, 250000, 250000]
102
+ restarts = [250000, 500000, 750000]
103
+ restart_weights = [1, 1, 1]
104
+
105
+ scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
106
+ weights=restart_weights)
107
+
108
+ ##############################
109
+ # Draw figure
110
+ ##############################
111
+ N_iter = 1000000
112
+ lr_l = list(range(N_iter))
113
+ for i in range(N_iter):
114
+ scheduler.step()
115
+ current_lr = optimizer.param_groups[0]['lr']
116
+ lr_l[i] = current_lr
117
+
118
+ import matplotlib as mpl
119
+ from matplotlib import pyplot as plt
120
+ import matplotlib.ticker as mtick
121
+ mpl.style.use('default')
122
+ import seaborn
123
+ seaborn.set(style='whitegrid')
124
+ seaborn.set_context('paper')
125
+
126
+ plt.figure(1)
127
+ plt.subplot(111)
128
+ plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
129
+ plt.title('Title', fontsize=16, color='k')
130
+ plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
131
+ legend = plt.legend(loc='upper right', shadow=False)
132
+ ax = plt.gca()
133
+ labels = ax.get_xticks().tolist()
134
+ for k, v in enumerate(labels):
135
+ labels[k] = str(int(v / 1000)) + 'K'
136
+ ax.set_xticklabels(labels)
137
+ ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
138
+
139
+ ax.set_ylabel('Learning rate')
140
+ ax.set_xlabel('Iteration')
141
+ fig = plt.gcf()
142
+ plt.show()
models/modules/Inv_arch.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .module_util import initialize_weights_xavier
7
+ from torch.nn import init
8
+ from .common import DWT,IWT
9
+ import cv2
10
+ from basicsr.archs.arch_util import flow_warp
11
+ from models.modules.Subnet_constructor import subnet
12
+ import numpy as np
13
+
14
+ from pdb import set_trace as stx
15
+ import numbers
16
+
17
+ from einops import rearrange
18
+ from models.bitnetwork.Encoder_U import DW_Encoder
19
+ from models.bitnetwork.Decoder_U import DW_Decoder
20
+
21
+
22
+ ## Layer Norm
23
+ def to_3d(x):
24
+ return rearrange(x, 'b c h w -> b (h w) c')
25
+
26
+
27
+ def to_4d(x, h, w):
28
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
29
+
30
+
31
+ class BiasFree_LayerNorm(nn.Module):
32
+ def __init__(self, normalized_shape):
33
+ super(BiasFree_LayerNorm, self).__init__()
34
+ if isinstance(normalized_shape, numbers.Integral):
35
+ normalized_shape = (normalized_shape,)
36
+ normalized_shape = torch.Size(normalized_shape)
37
+
38
+ assert len(normalized_shape) == 1
39
+
40
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
41
+ self.normalized_shape = normalized_shape
42
+
43
+ def forward(self, x):
44
+ sigma = x.var(-1, keepdim=True, unbiased=False)
45
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
46
+
47
+
48
+ class WithBias_LayerNorm(nn.Module):
49
+ def __init__(self, normalized_shape):
50
+ super(WithBias_LayerNorm, self).__init__()
51
+ if isinstance(normalized_shape, numbers.Integral):
52
+ normalized_shape = (normalized_shape,)
53
+ normalized_shape = torch.Size(normalized_shape)
54
+
55
+ assert len(normalized_shape) == 1
56
+
57
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
58
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
59
+ self.normalized_shape = normalized_shape
60
+
61
+ def forward(self, x):
62
+ mu = x.mean(-1, keepdim=True)
63
+ sigma = x.var(-1, keepdim=True, unbiased=False)
64
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
65
+
66
+
67
+ class LayerNorm(nn.Module):
68
+ def __init__(self, dim, LayerNorm_type):
69
+ super(LayerNorm, self).__init__()
70
+ if LayerNorm_type == 'BiasFree':
71
+ self.body = BiasFree_LayerNorm(dim)
72
+ else:
73
+ self.body = WithBias_LayerNorm(dim)
74
+
75
+ def forward(self, x):
76
+ h, w = x.shape[-2:]
77
+ return to_4d(self.body(to_3d(x)), h, w)
78
+
79
+
80
+ ##########################################################################
81
+ ## Gated-Dconv Feed-Forward Network (GDFN)
82
+ class FeedForward(nn.Module):
83
+ def __init__(self, dim, ffn_expansion_factor, bias):
84
+ super(FeedForward, self).__init__()
85
+
86
+ hidden_features = int(dim * ffn_expansion_factor)
87
+
88
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
89
+
90
+ self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
91
+ groups=hidden_features * 2, bias=bias)
92
+
93
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
94
+
95
+ def forward(self, x):
96
+ x = self.project_in(x)
97
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
98
+ x = F.gelu(x1) * x2
99
+ x = self.project_out(x)
100
+ return x
101
+
102
+
103
+ ##########################################################################
104
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
105
+ class Attention(nn.Module):
106
+ def __init__(self, dim, num_heads, bias):
107
+ super(Attention, self).__init__()
108
+ self.num_heads = num_heads
109
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
110
+
111
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
112
+ self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
113
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
114
+
115
+ def forward(self, x):
116
+ b, c, h, w = x.shape
117
+
118
+ qkv = self.qkv_dwconv(self.qkv(x))
119
+ q, k, v = qkv.chunk(3, dim=1)
120
+
121
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
122
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
123
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
124
+
125
+ q = torch.nn.functional.normalize(q, dim=-1)
126
+ k = torch.nn.functional.normalize(k, dim=-1)
127
+
128
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
129
+ attn = attn.softmax(dim=-1)
130
+
131
+ out = (attn @ v)
132
+
133
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
134
+
135
+ out = self.project_out(out)
136
+ return out
137
+
138
+
139
+ ##########################################################################
140
+ class TransformerBlock(nn.Module):
141
+ def __init__(self, dim, num_heads=4, ffn_expansion_factor=4, bias=False, LayerNorm_type="withbias"):
142
+ super(TransformerBlock, self).__init__()
143
+
144
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
145
+ self.attn = Attention(dim, num_heads, bias)
146
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
147
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
148
+
149
+ def forward(self, x):
150
+ x = x + self.attn(self.norm1(x))
151
+ x = x + self.ffn(self.norm2(x))
152
+
153
+ return x
154
+
155
+ dwt=DWT()
156
+ iwt=IWT()
157
+
158
+ class LayerNormFunction(torch.autograd.Function):
159
+
160
+ @staticmethod
161
+ def forward(ctx, x, weight, bias, eps):
162
+ ctx.eps = eps
163
+ N, C, H, W = x.size()
164
+ mu = x.mean(1, keepdim=True)
165
+ var = (x - mu).pow(2).mean(1, keepdim=True)
166
+ y = (x - mu) / (var + eps).sqrt()
167
+ ctx.save_for_backward(y, var, weight)
168
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
169
+ return y
170
+
171
+ @staticmethod
172
+ def backward(ctx, grad_output):
173
+ eps = ctx.eps
174
+
175
+ N, C, H, W = grad_output.size()
176
+ y, var, weight = ctx.saved_variables
177
+ g = grad_output * weight.view(1, C, 1, 1)
178
+ mean_g = g.mean(dim=1, keepdim=True)
179
+
180
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
181
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
182
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
183
+ dim=0), None
184
+
185
+ class LayerNorm2d(nn.Module):
186
+
187
+ def __init__(self, channels, eps=1e-6):
188
+ super(LayerNorm2d, self).__init__()
189
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
190
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
191
+ self.eps = eps
192
+
193
+ def forward(self, x):
194
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
195
+
196
+ class SimpleGate(nn.Module):
197
+ def forward(self, x):
198
+ x1, x2 = x.chunk(2, dim=1)
199
+ return x1 * x2
200
+
201
+ class NAFBlock(nn.Module):
202
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
203
+ super().__init__()
204
+ dw_channel = c * DW_Expand
205
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
206
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
207
+ bias=True)
208
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
209
+
210
+ # Simplified Channel Attention
211
+ self.sca = nn.Sequential(
212
+ nn.AdaptiveAvgPool2d(1),
213
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
214
+ groups=1, bias=True),
215
+ )
216
+
217
+ # SimpleGate
218
+ self.sg = SimpleGate()
219
+
220
+ ffn_channel = FFN_Expand * c
221
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
222
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
223
+
224
+ self.norm1 = LayerNorm2d(c)
225
+ self.norm2 = LayerNorm2d(c)
226
+
227
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
228
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
229
+
230
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
231
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
232
+
233
+ def forward(self, inp):
234
+ x = inp
235
+
236
+ x = self.norm1(x)
237
+
238
+ x = self.conv1(x)
239
+ x = self.conv2(x)
240
+ x = self.sg(x)
241
+ x = x * self.sca(x)
242
+ x = self.conv3(x)
243
+
244
+ x = self.dropout1(x)
245
+
246
+ y = inp + x * self.beta
247
+
248
+ x = self.conv4(self.norm2(y))
249
+ x = self.sg(x)
250
+ x = self.conv5(x)
251
+
252
+ x = self.dropout2(x)
253
+
254
+ return y + x * self.gamma
255
+
256
+ def thops_mean(tensor, dim=None, keepdim=False):
257
+ if dim is None:
258
+ # mean all dim
259
+ return torch.mean(tensor)
260
+ else:
261
+ if isinstance(dim, int):
262
+ dim = [dim]
263
+ dim = sorted(dim)
264
+ for d in dim:
265
+ tensor = tensor.mean(dim=d, keepdim=True)
266
+ if not keepdim:
267
+ for i, d in enumerate(dim):
268
+ tensor.squeeze_(d-i)
269
+ return tensor
270
+
271
+
272
+ class ResidualBlockNoBN(nn.Module):
273
+ def __init__(self, nf=64, model='MIMO-VRN'):
274
+ super(ResidualBlockNoBN, self).__init__()
275
+ self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
276
+ self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
277
+ # honestly, there's no significant difference between ReLU and leaky ReLU in terms of performance here
278
+ # but this is how we trained the model in the first place and what we reported in the paper
279
+ if model == 'LSTM-VRN':
280
+ self.relu = nn.ReLU(inplace=True)
281
+ elif model == 'MIMO-VRN':
282
+ self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
283
+
284
+ # initialization
285
+ initialize_weights_xavier([self.conv1, self.conv2], 0.1)
286
+
287
+ def forward(self, x):
288
+ identity = x
289
+ out = self.relu(self.conv1(x))
290
+ out = self.conv2(out)
291
+ return identity + out
292
+
293
+
294
+ class InvBlock(nn.Module):
295
+ def __init__(self, subnet_constructor, subnet_constructor_v2, channel_num_ho, channel_num_hi, groups, clamp=1.):
296
+ super(InvBlock, self).__init__()
297
+ self.split_len1 = channel_num_ho # channel_split_num
298
+ self.split_len2 = channel_num_hi # channel_num - channel_split_num
299
+ self.clamp = clamp
300
+
301
+ self.F = subnet_constructor_v2(self.split_len2, self.split_len1, groups=groups)
302
+ self.NF = NAFBlock(self.split_len2)
303
+ if groups == 1:
304
+ self.G = subnet_constructor(self.split_len1, self.split_len2, groups=groups)
305
+ self.NG = NAFBlock(self.split_len1)
306
+ self.H = subnet_constructor(self.split_len1, self.split_len2, groups=groups)
307
+ self.NH = NAFBlock(self.split_len1)
308
+ else:
309
+ self.G = subnet_constructor(self.split_len1, self.split_len2)
310
+ self.NG = NAFBlock(self.split_len1)
311
+ self.H = subnet_constructor(self.split_len1, self.split_len2)
312
+ self.NH = NAFBlock(self.split_len1)
313
+
314
+ def forward(self, x1, x2, rev=False):
315
+ if not rev:
316
+ y1 = x1 + self.NF(self.F(x2))
317
+ self.s = self.clamp * (torch.sigmoid(self.NH(self.H(y1))) * 2 - 1)
318
+ y2 = [x2i.mul(torch.exp(self.s)) + self.NG(self.G(y1)) for x2i in x2]
319
+ else:
320
+ self.s = self.clamp * (torch.sigmoid(self.NH(self.H(x1))) * 2 - 1)
321
+ y2 = [(x2i - self.NG(self.G(x1))).div(torch.exp(self.s)) for x2i in x2]
322
+ y1 = x1 - self.NF(self.F(y2))
323
+
324
+ return y1, y2 # torch.cat((y1, y2), 1)
325
+
326
+ def jacobian(self, x, rev=False):
327
+ if not rev:
328
+ jac = torch.sum(self.s)
329
+ else:
330
+ jac = -torch.sum(self.s)
331
+
332
+ return jac / x.shape[0]
333
+
334
+ class InvNN(nn.Module):
335
+ def __init__(self, channel_in_ho=3, channel_in_hi=3, subnet_constructor=None, subnet_constructor_v2=None, block_num=[], down_num=2, groups=None):
336
+ super(InvNN, self).__init__()
337
+ operations = []
338
+
339
+ current_channel_ho = channel_in_ho
340
+ current_channel_hi = channel_in_hi
341
+ for i in range(down_num):
342
+ for j in range(block_num[i]):
343
+ b = InvBlock(subnet_constructor, subnet_constructor_v2, current_channel_ho, current_channel_hi, groups=groups)
344
+ operations.append(b)
345
+
346
+ self.operations = nn.ModuleList(operations)
347
+
348
+ def forward(self, x, x_h, rev=False, cal_jacobian=False):
349
+ # out = x
350
+ jacobian = 0
351
+
352
+ if not rev:
353
+ for op in self.operations:
354
+ x, x_h = op.forward(x, x_h, rev)
355
+ if cal_jacobian:
356
+ jacobian += op.jacobian(x, rev)
357
+ else:
358
+ for op in reversed(self.operations):
359
+ x, x_h = op.forward(x, x_h, rev)
360
+ if cal_jacobian:
361
+ jacobian += op.jacobian(x, rev)
362
+
363
+ if cal_jacobian:
364
+ return x, x_h, jacobian
365
+ else:
366
+ return x, x_h
367
+
368
+ class PredictiveModuleMIMO(nn.Module):
369
+ def __init__(self, channel_in, nf, block_num_rbm=8, block_num_trans=4):
370
+ super(PredictiveModuleMIMO, self).__init__()
371
+ self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
372
+ res_block = []
373
+ trans_block = []
374
+ for i in range(block_num_rbm):
375
+ res_block.append(ResidualBlockNoBN(nf))
376
+ for j in range(block_num_trans):
377
+ trans_block.append(TransformerBlock(nf))
378
+
379
+ self.res_block = nn.Sequential(*res_block)
380
+ self.transformer_block = nn.Sequential(*trans_block)
381
+
382
+ def forward(self, x):
383
+ x = self.conv_in(x)
384
+ x = self.res_block(x)
385
+ res = self.transformer_block(x) + x
386
+
387
+ return res
388
+
389
+ class ConvRelu(nn.Module):
390
+ def __init__(self, channels_in, channels_out, stride=1, init_zero=False):
391
+ super(ConvRelu, self).__init__()
392
+ self.init_zero = init_zero
393
+ if self.init_zero:
394
+ self.layers = nn.Conv2d(channels_in, channels_out, 3, stride, padding=1)
395
+
396
+ else:
397
+ self.layers = nn.Sequential(
398
+ nn.Conv2d(channels_in, channels_out, 3, stride, padding=1),
399
+ nn.LeakyReLU(inplace=True)
400
+ )
401
+
402
+ def forward(self, x):
403
+ return self.layers(x)
404
+
405
+ class PredictiveModuleBit(nn.Module):
406
+ def __init__(self, channel_in, nf, block_num_rbm=4, block_num_trans=2):
407
+ super(PredictiveModuleBit, self).__init__()
408
+ self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
409
+ res_block = []
410
+ trans_block = []
411
+ for i in range(block_num_rbm):
412
+ res_block.append(ResidualBlockNoBN(nf))
413
+ for j in range(block_num_trans):
414
+ trans_block.append(TransformerBlock(nf))
415
+
416
+ blocks = 4
417
+ layers = [ConvRelu(nf, 1, 2)]
418
+ for _ in range(blocks - 1):
419
+ layer = ConvRelu(1, 1, 2)
420
+ layers.append(layer)
421
+ self.layers = nn.Sequential(*layers)
422
+
423
+ self.res_block = nn.Sequential(*res_block)
424
+ self.transformer_block = nn.Sequential(*trans_block)
425
+
426
+ def forward(self, x):
427
+ x = self.conv_in(x)
428
+ x = self.res_block(x)
429
+ res = self.transformer_block(x) + x
430
+ res = self.layers(res)
431
+
432
+ return res
433
+
434
+
435
+ ##---------- Prompt Gen Module -----------------------
436
+ class PromptGenBlock(nn.Module):
437
+ def __init__(self,prompt_dim=12,prompt_len=3,prompt_size = 36,lin_dim = 12):
438
+ super(PromptGenBlock,self).__init__()
439
+ self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
440
+ self.linear_layer = nn.Linear(lin_dim,prompt_len)
441
+ self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
442
+
443
+
444
+ def forward(self,x):
445
+ B,C,H,W = x.shape
446
+ emb = x.mean(dim=(-2,-1))
447
+ prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
448
+ prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
449
+ prompt = torch.sum(prompt,dim=1)
450
+ prompt = F.interpolate(prompt,(H,W),mode="bilinear")
451
+ prompt = self.conv3x3(prompt)
452
+
453
+ return prompt
454
+
455
+ class PredictiveModuleMIMO_prompt(nn.Module):
456
+ def __init__(self, channel_in, nf, prompt_len=3, block_num_rbm=8, block_num_trans=4):
457
+ super(PredictiveModuleMIMO_prompt, self).__init__()
458
+ self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
459
+ res_block = []
460
+ trans_block = []
461
+ for i in range(block_num_rbm):
462
+ res_block.append(ResidualBlockNoBN(nf))
463
+ for j in range(block_num_trans):
464
+ trans_block.append(TransformerBlock(nf))
465
+
466
+ self.res_block = nn.Sequential(*res_block)
467
+ self.transformer_block = nn.Sequential(*trans_block)
468
+ self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_len,prompt_size = 36,lin_dim = nf)
469
+ self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
470
+
471
+ def forward(self, x):
472
+ x = self.conv_in(x)
473
+ x = self.res_block(x)
474
+ res = self.transformer_block(x) + x
475
+ prompt = self.prompt(res)
476
+
477
+ result = self.fuse(torch.cat([res, prompt], dim=1))
478
+
479
+ return result
480
+
481
+ def gauss_noise(shape):
482
+ noise = torch.zeros(shape).cuda()
483
+ for i in range(noise.shape[0]):
484
+ noise[i] = torch.randn(noise[i].shape).cuda()
485
+
486
+ return noise
487
+
488
+ def gauss_noise_mul(shape):
489
+ noise = torch.randn(shape).cuda()
490
+
491
+ return noise
492
+
493
+ class PredictiveModuleBit_prompt(nn.Module):
494
+ def __init__(self, channel_in, nf, prompt_length, block_num_rbm=4, block_num_trans=2):
495
+ super(PredictiveModuleBit_prompt, self).__init__()
496
+ self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
497
+ res_block = []
498
+ trans_block = []
499
+ for i in range(block_num_rbm):
500
+ res_block.append(ResidualBlockNoBN(nf))
501
+ for j in range(block_num_trans):
502
+ trans_block.append(TransformerBlock(nf))
503
+
504
+ blocks = 4
505
+ layers = [ConvRelu(nf, 1, 2)]
506
+ for _ in range(blocks - 1):
507
+ layer = ConvRelu(1, 1, 2)
508
+ layers.append(layer)
509
+ self.layers = nn.Sequential(*layers)
510
+
511
+ self.res_block = nn.Sequential(*res_block)
512
+ self.transformer_block = nn.Sequential(*trans_block)
513
+ self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_length,prompt_size = 36,lin_dim = nf)
514
+ self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
515
+
516
+ def forward(self, x):
517
+ x = self.conv_in(x)
518
+ x = self.res_block(x)
519
+ res = self.transformer_block(x) + x
520
+ prompt = self.prompt(res)
521
+ res = self.fuse(torch.cat([res, prompt], dim=1))
522
+ res = self.layers(res)
523
+
524
+ return res
525
+
526
+ class VSN(nn.Module):
527
+ def __init__(self, opt, subnet_constructor=None, subnet_constructor_v2=None, down_num=2):
528
+ super(VSN, self).__init__()
529
+ self.model = opt['model']
530
+ self.mode = opt['mode']
531
+ opt_net = opt['network_G']
532
+ self.num_image = opt['num_image']
533
+ self.gop = opt['gop']
534
+ self.channel_in = opt_net['in_nc'] * self.gop
535
+ self.channel_out = opt_net['out_nc'] * self.gop
536
+ self.channel_in_hi = opt_net['in_nc'] * self.gop
537
+ self.channel_in_ho = opt_net['in_nc'] * self.gop
538
+ self.message_len = opt['message_length']
539
+
540
+ self.block_num = opt_net['block_num']
541
+ self.block_num_rbm = opt_net['block_num_rbm']
542
+ self.block_num_trans = opt_net['block_num_trans']
543
+ self.nf = self.channel_in_hi
544
+
545
+ self.bitencoder = DW_Encoder(self.message_len, attention = "se")
546
+ self.bitdecoder = DW_Decoder(self.message_len, attention = "se")
547
+ self.irn = InvNN(self.channel_in_ho, self.channel_in_hi, subnet_constructor, subnet_constructor_v2, self.block_num, down_num, groups=self.num_image)
548
+
549
+ if opt['prompt']:
550
+ self.pm = PredictiveModuleMIMO_prompt(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans)
551
+ else:
552
+ self.pm = PredictiveModuleMIMO(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans)
553
+ self.BitPM = PredictiveModuleBit(3, 4, block_num_rbm=4, block_num_trans=2)
554
+
555
+
556
+ def forward(self, x, x_h=None, message=None, rev=False, hs=[], direction='f'):
557
+ if not rev:
558
+ if self.mode == "image":
559
+ out_y, out_y_h = self.irn(x, x_h, rev)
560
+ out_y = iwt(out_y)
561
+ encoded_image = self.bitencoder(out_y, message)
562
+ return out_y, encoded_image
563
+
564
+ elif self.mode == "bit":
565
+ out_y = iwt(x)
566
+ encoded_image = self.bitencoder(out_y, message)
567
+ return out_y, encoded_image
568
+
569
+ else:
570
+ if self.mode == "image":
571
+ recmessage = self.bitdecoder(x)
572
+
573
+ x = dwt(x)
574
+ out_z = self.pm(x).unsqueeze(1)
575
+ out_z_new = out_z.view(-1, self.num_image, self.channel_in, x.shape[-2], x.shape[-1])
576
+ out_z_new = [out_z_new[:,i] for i in range(self.num_image)]
577
+ out_x, out_x_h = self.irn(x, out_z_new, rev)
578
+
579
+ return out_x, out_x_h, out_z, recmessage
580
+
581
+ elif self.mode == "bit":
582
+ recmessage = self.bitdecoder(x)
583
+ return recmessage
584
+
models/modules/Quantization.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Quant(torch.autograd.Function):
5
+
6
+ @staticmethod
7
+ def forward(ctx, input):
8
+ input = torch.clamp(input, 0, 1)
9
+ output = (input * 255.).round() / 255.
10
+ return output
11
+
12
+ @staticmethod
13
+ def backward(ctx, grad_output):
14
+ return grad_output
15
+
16
+ class Quantization(nn.Module):
17
+ def __init__(self):
18
+ super(Quantization, self).__init__()
19
+
20
+ def forward(self, input):
21
+ return Quant.apply(input)
models/modules/Subnet_constructor.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import models.modules.module_util as mutil
5
+ from basicsr.archs.arch_util import flow_warp, ResidualBlockNoBN
6
+ from models.modules.module_util import initialize_weights_xavier
7
+
8
+ class DenseBlock(nn.Module):
9
+ def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
10
+ super(DenseBlock, self).__init__()
11
+ self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
12
+ self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
13
+ self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
14
+ self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
15
+ self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
16
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+ self.H = None
18
+
19
+ if init == 'xavier':
20
+ mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
21
+ else:
22
+ mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
23
+ mutil.initialize_weights(self.conv5, 0)
24
+
25
+ def forward(self, x):
26
+ if isinstance(x, list):
27
+ x = x[0]
28
+ x1 = self.lrelu(self.conv1(x))
29
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
30
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
31
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
32
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
33
+
34
+ return x5
35
+
36
+ class DenseBlock_v2(nn.Module):
37
+ def __init__(self, channel_in, channel_out, groups, init='xavier', gc=32, bias=True):
38
+ super(DenseBlock_v2, self).__init__()
39
+ self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
40
+ self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
41
+ self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
42
+ self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
43
+ self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
44
+ self.conv_final = nn.Conv2d(channel_out*groups, channel_out, 3, 1, 1, bias=bias)
45
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
46
+
47
+ if init == 'xavier':
48
+ mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
49
+ else:
50
+ mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
51
+ mutil.initialize_weights(self.conv_final, 0)
52
+
53
+ def forward(self, x):
54
+ res = []
55
+ for xi in x:
56
+ x1 = self.lrelu(self.conv1(xi))
57
+ x2 = self.lrelu(self.conv2(torch.cat((xi, x1), 1)))
58
+ x3 = self.lrelu(self.conv3(torch.cat((xi, x1, x2), 1)))
59
+ x4 = self.lrelu(self.conv4(torch.cat((xi, x1, x2, x3), 1)))
60
+ x5 = self.lrelu(self.conv5(torch.cat((xi, x1, x2, x3, x4), 1)))
61
+ res.append(x5)
62
+ res = torch.cat(res, dim=1)
63
+ res = self.conv_final(res)
64
+
65
+ return res
66
+
67
+ def subnet(net_structure, init='xavier'):
68
+ def constructor(channel_in, channel_out, groups=None):
69
+ if net_structure == 'DBNet':
70
+ if init == 'xavier':
71
+ return DenseBlock(channel_in, channel_out, init)
72
+ elif init == 'xavier_v2':
73
+ return DenseBlock_v2(channel_in, channel_out, groups, 'xavier')
74
+ else:
75
+ return DenseBlock(channel_in, channel_out)
76
+ else:
77
+ return None
78
+
79
+ return constructor
models/modules/__init__.py ADDED
File without changes
models/modules/__pycache__/Conv1x1.cpython-38.pyc ADDED
Binary file (1.35 kB). View file