diff --git a/__pycache__/test_gradio.cpython-310.pyc b/__pycache__/test_gradio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e7a426e2fa543fad5bc1ef0f41570ceb34fc013 Binary files /dev/null and b/__pycache__/test_gradio.cpython-310.pyc differ diff --git a/__pycache__/test_gradio.cpython-38.pyc b/__pycache__/test_gradio.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0263da67daa56a028df790b6c27a8da4ea0354af Binary files /dev/null and b/__pycache__/test_gradio.cpython-38.pyc differ diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..577cc917bd1a55204255a79e9b3db961fd8f0a64 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,43 @@ +'''create dataset and dataloader''' +import logging +import torch +import torch.utils.data + +def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): + phase = dataset_opt['phase'] + if phase == 'train': + if opt['dist']: + world_size = torch.distributed.get_world_size() + num_workers = dataset_opt['n_workers'] + assert dataset_opt['batch_size'] % world_size == 0 + batch_size = dataset_opt['batch_size'] // world_size + shuffle = False + else: + num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) + batch_size = dataset_opt['batch_size'] + shuffle = True + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, + num_workers=num_workers, sampler=sampler, drop_last=True, + pin_memory=False) + else: + return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, + pin_memory=True) + + +def create_dataset(dataset_opt): + mode = dataset_opt['mode'] + if mode == 'test': + from data.coco_test_dataset import imageTestDataset as D + elif mode == 'train': + from data.coco_dataset import CoCoDataset as D + elif mode == 'td': + from data.test_dataset_td import imageTestDataset as D + else: + raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) + print(mode) + dataset = D(dataset_opt) + + logger = logging.getLogger('base') + logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, + dataset_opt['name'])) + return dataset diff --git a/data/__pycache__/__init__.cpython-310.pyc b/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5169d63d6e9bdfa66a126d6438285ccf1fc9c5c7 Binary files /dev/null and b/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/data/__pycache__/__init__.cpython-38.pyc b/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2803cbcccbeaeb1e9a2b1f15471e145329274111 Binary files /dev/null and b/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/data/__pycache__/coco_dataset.cpython-38.pyc b/data/__pycache__/coco_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e072304f4febd0a6e513395965621dd2c782f38 Binary files /dev/null and b/data/__pycache__/coco_dataset.cpython-38.pyc differ diff --git a/data/__pycache__/coco_test_dataset.cpython-38.pyc b/data/__pycache__/coco_test_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44be6e94c3dd2db619ae3910c2e7a2da9e85de51 Binary files /dev/null and b/data/__pycache__/coco_test_dataset.cpython-38.pyc differ diff --git a/data/__pycache__/data_sampler.cpython-310.pyc b/data/__pycache__/data_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df683c2f071e6d8829f9f52e38b0209bf2c353cd Binary files /dev/null and b/data/__pycache__/data_sampler.cpython-310.pyc differ diff --git a/data/__pycache__/data_sampler.cpython-38.pyc b/data/__pycache__/data_sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42515c3e30a5be8ac112f0326c4e80a2f7ef1a99 Binary files /dev/null and b/data/__pycache__/data_sampler.cpython-38.pyc differ diff --git a/data/__pycache__/test_dataset_td.cpython-310.pyc b/data/__pycache__/test_dataset_td.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12489cb72d000e0b7f67dea8dfd008ed9b19eb0b Binary files /dev/null and b/data/__pycache__/test_dataset_td.cpython-310.pyc differ diff --git a/data/__pycache__/test_dataset_td.cpython-38.pyc b/data/__pycache__/test_dataset_td.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74a7f85dc5529f479c043d5d0583d2fe4365e327 Binary files /dev/null and b/data/__pycache__/test_dataset_td.cpython-38.pyc differ diff --git a/data/__pycache__/util.cpython-310.pyc b/data/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d97c049cd6b70ce2a284e6d8369dc3e02c43b74 Binary files /dev/null and b/data/__pycache__/util.cpython-310.pyc differ diff --git a/data/__pycache__/util.cpython-38.pyc b/data/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fae5d0037c1261874dc3cb5a593b27ef0b66c24 Binary files /dev/null and b/data/__pycache__/util.cpython-38.pyc differ diff --git a/data/__pycache__/video_test_dataset.cpython-38.pyc b/data/__pycache__/video_test_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7392600c42c903af003068e2fb2193be2d60c24e Binary files /dev/null and b/data/__pycache__/video_test_dataset.cpython-38.pyc differ diff --git a/data/coco_dataset.py b/data/coco_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc318930b62d52c118b1b82d0d9145c36ad08eb --- /dev/null +++ b/data/coco_dataset.py @@ -0,0 +1,90 @@ +''' +Vimeo90K dataset +support reading images from lmdb, image folder and memcached +''' +import logging +import os +import os.path as osp +import pickle +import random + +import cv2 +import lmdb +import numpy as np +import torch +import torch.utils.data as data + +import data.util as util + +try: + import mc +except ImportError: + pass +logger = logging.getLogger('base') + +class CoCoDataset(data.Dataset): + def __init__(self, opt): + super(CoCoDataset, self).__init__() + self.opt = opt + # get train indexes + self.data_path = self.opt['data_path'] + self.txt_path = self.opt['txt_path'] + with open(self.txt_path) as f: + self.list_image = f.readlines() + self.list_image = [line.strip('\n') for line in self.list_image] + # temporal augmentation + self.interval_list = opt['interval_list'] + self.random_reverse = opt['random_reverse'] + logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( + ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) + self.data_type = self.opt['data_type'] + random.shuffle(self.list_image) + self.LR_input = True + self.num_image = self.opt['num_image'] + + def _ensure_memcached(self): + if self.mclient is None: + # specify the config files + server_list_config_file = None + client_config_file = None + self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, + client_config_file) + + def __getitem__(self, index): + GT_size = self.opt['GT_size'] + image_name = self.list_image[index] + path_frame = os.path.join(self.data_path, image_name) + img_GT = util.read_img(None, osp.join(path_frame, path_frame)) + index_h = random.randint(0, len(self.list_image) - 1) + + # random crop + H, W, C = img_GT.shape + rnd_h = random.randint(0, max(0, H - GT_size)) + rnd_w = random.randint(0, max(0, W - GT_size)) + img_frames = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] + # BGR to RGB, HWC to CHW, numpy to tensor + img_frames = img_frames[:, :, [2, 1, 0]] + img_frames = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames, (2, 0, 1)))).float().unsqueeze(0) + + # process h_list + if index_h % 100 == 0: + path_frame_h = "../dataset/locwatermark/blue.png" + else: + image_name_h = self.list_image[index_h] + path_frame_h = os.path.join(self.data_path, image_name_h) + + frame_h = util.read_img(None, osp.join(path_frame_h, path_frame_h)) + H1, W1, C1 = frame_h.shape + rnd_h = random.randint(0, max(0, H1 - GT_size)) + rnd_w = random.randint(0, max(0, W1 - GT_size)) + img_frames_h = frame_h[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] + img_frames_h = img_frames_h[:, :, [2, 1, 0]] + img_frames_h = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames_h, (2, 0, 1)))).float().unsqueeze(0) + + img_frames_h = torch.nn.functional.interpolate(img_frames_h, size=(512, 512), mode='nearest', align_corners=None).unsqueeze(0) + img_frames = torch.nn.functional.interpolate(img_frames, size=(512, 512), mode='nearest', align_corners=None) + + return {'GT': img_frames, 'LQ': img_frames_h} + + def __len__(self): + return len(self.list_image) \ No newline at end of file diff --git a/data/coco_test_dataset.py b/data/coco_test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6846fbae14831647951c17f9d30bf6605db4d293 --- /dev/null +++ b/data/coco_test_dataset.py @@ -0,0 +1,61 @@ +import os +import os.path as osp +import torch +import torch.utils.data as data +import data.util as util + +import random +import numpy as np +from PIL import Image + +class imageTestDataset(data.Dataset): + + def __init__(self, opt): + super(imageTestDataset, self).__init__() + self.opt = opt + self.half_N_frames = opt['N_frames'] // 2 + self.data_path = opt['data_path'] + self.bit_path = opt['bit_path'] + self.txt_path = self.opt['txt_path'] + self.num_image = self.opt['num_image'] + with open(self.txt_path) as f: + self.list_image = f.readlines() + self.list_image = [line.strip('\n') for line in self.list_image] + self.list_image.sort() + self.list_image = self.list_image + l = len(self.list_image) // (self.num_image + 1) + self.image_list_gt = self.list_image + + def __getitem__(self, index): + path_GT = self.image_list_gt[index] + + img_GT = util.read_img(None, osp.join(self.data_path, path_GT)) + img_GT = img_GT[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0) + img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None) + + T, C, W, H = img_GT.shape + list_h = [] + R = 0 + G = 0 + B = 255 + image = Image.new('RGB', (W, H), (R, G, B)) + result = np.array(image) / 255. + expanded_matrix = np.expand_dims(result, axis=0) + expanded_matrix = np.repeat(expanded_matrix, T, axis=0) + imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float() + imgs_LQ = imgs_LQ.permute(0, 3, 1, 2) + + imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None) + + list_h.append(imgs_LQ) + + list_h = torch.stack(list_h, dim=0) + + return { + 'LQ': list_h, + 'GT': img_GT + } + + def __len__(self): + return len(self.image_list_gt) diff --git a/data/data_sampler.py b/data/data_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9f92f4301e6b08282b0b5273c11b581d0344b1 --- /dev/null +++ b/data/data_sampler.py @@ -0,0 +1,65 @@ +""" +Modified from torch.utils.data.distributed.DistributedSampler +Support enlarging the dataset for *iter-oriented* training, for saving time when restart the +dataloader after each epoch +""" +import math +import torch +from torch.utils.data.sampler import Sampler +import torch.distributed as dist + + +class DistIterSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dsize = len(self.dataset) + indices = [v % dsize for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/data/test_dataset_td.py b/data/test_dataset_td.py new file mode 100644 index 0000000000000000000000000000000000000000..cde7ae3fe5ee4a80078ded90fc259b3e2b906aa4 --- /dev/null +++ b/data/test_dataset_td.py @@ -0,0 +1,63 @@ +import os +import os.path as osp +import torch +import torch.utils.data as data +import data.util as util + +import random +import numpy as np +from PIL import Image + +class imageTestDataset(data.Dataset): + + def __init__(self, opt): + super(imageTestDataset, self).__init__() + self.opt = opt + self.half_N_frames = opt['N_frames'] // 2 + self.data_path = opt['data_path'] + self.bit_path = opt['bit_path'] + self.txt_path = self.opt['txt_path'] + self.num_image = self.opt['num_image'] + with open(self.txt_path) as f: + self.list_image = f.readlines() + self.list_image = [line.strip('\n') for line in self.list_image] + # self.list_image = sorted(self.list_image) + l = len(self.list_image) // (self.num_image + 1) + self.image_list_gt = self.list_image + self.image_list_bit = self.list_image + + + def __getitem__(self, index): + path_GT = self.image_list_gt[index] + + img_GT = util.read_img(None, osp.join(self.data_path, path_GT)) + img_GT = img_GT[:, :, [2, 1, 0]] + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0) + img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None) + + T, C, W, H = img_GT.shape + list_h = [] + R = 0 + G = 0 + B = 255 + image = Image.new('RGB', (W, H), (R, G, B)) + result = np.array(image) / 255. + expanded_matrix = np.expand_dims(result, axis=0) + expanded_matrix = np.repeat(expanded_matrix, T, axis=0) + imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float() + imgs_LQ = imgs_LQ.permute(0, 3, 1, 2) + + + imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None) + + list_h.append(imgs_LQ) + + list_h = torch.stack(list_h, dim=0) + + return { + 'LQ': list_h, + 'GT': img_GT + } + + def __len__(self): + return len(self.image_list_gt) diff --git a/data/util.py b/data/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1ede1227fdef184eab0144c01b496d2173a7539e --- /dev/null +++ b/data/util.py @@ -0,0 +1,551 @@ +import os +import math +import pickle +import random +import numpy as np +import glob +import torch +import cv2 + +#################### +# Files & IO +#################### + +###################### get image path list ###################### +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def _get_paths_from_images(path): + '''get image path list from image folder''' + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +def _get_paths_from_lmdb(dataroot): + '''get image path list from lmdb meta info''' + meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) + paths = meta_info['keys'] + sizes = meta_info['resolution'] + if len(sizes) == 1: + sizes = sizes * len(paths) + return paths, sizes + + +def get_image_paths(data_type, dataroot): + '''get image path list + support lmdb or image files''' + paths, sizes = None, None + if dataroot is not None: + if data_type == 'lmdb': + paths, sizes = _get_paths_from_lmdb(dataroot) + elif data_type == 'img': + paths = sorted(_get_paths_from_images(dataroot)) + else: + raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) + return paths, sizes + + +def glob_file_list(root): + return sorted(glob.glob(os.path.join(root, '*'))) + + +###################### read images ###################### +def _read_img_lmdb(env, key, size): + '''read image from lmdb with key (w/ and w/o fixed size) + size: (C, H, W) tuple''' + with env.begin(write=False) as txn: + buf = txn.get(key.encode('ascii')) + img_flat = np.frombuffer(buf, dtype=np.uint8) + C, H, W = size + img = img_flat.reshape(H, W, C) + return img + + +def read_img(env, path, size=None): + '''read image by cv2 or from lmdb + return: Numpy float32, HWC, BGR, [0,1]''' + if env is None: # img +# print(path) + #img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + img = cv2.imread(path, cv2.IMREAD_COLOR) + else: + img = _read_img_lmdb(env, path, size) +# print(img.shape) +# if img is None: + # print(path) +# print(img.shape) + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +def read_img_seq(path): + """Read a sequence of images from a given folder path + Args: + path (list/str): list of image paths/image folder path + + Returns: + imgs (Tensor): size (T, C, H, W), RGB, [0, 1] + """ + if type(path) is list: + img_path_l = path + else: + img_path_l = sorted(glob.glob(os.path.join(path, '*.png'))) +# print(path) +# print(path,img_path_l) + img_l = [read_img(None, v) for v in img_path_l] + # stack to Torch tensor + imgs = np.stack(img_l, axis=0) + imgs = imgs[:, :, :, [2, 1, 0]] + imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() + return imgs + + +def index_generation(crt_i, max_n, N, padding='reflection'): + """Generate an index list for reading N frames from a sequence of images + Args: + crt_i (int): current center index + max_n (int): max number of the sequence of images (calculated from 1) + N (int): reading N frames + padding (str): padding mode, one of replicate | reflection | new_info | circle + Example: crt_i = 0, N = 5 + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + new_info: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + return_l (list [int]): a list of indexes + """ + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == 'replicate': + add_idx = 0 + elif padding == 'reflection': + add_idx = -i + elif padding == 'new_info': + add_idx = (crt_i + n_pad) + (-i) + elif padding == 'circle': + add_idx = N + i + else: + raise ValueError('Wrong padding mode') + elif i > max_n: + if padding == 'replicate': + add_idx = max_n + elif padding == 'reflection': + add_idx = max_n * 2 - i + elif padding == 'new_info': + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == 'circle': + add_idx = i - N + else: + raise ValueError('Wrong padding mode') + else: + add_idx = i + return_l.append(add_idx) + return return_l + + +#################### +# image processing +# process on numpy image +#################### + + +def augment(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +def augment_flow(img_list, flow_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: + flow = flow[:, ::-1, :] + flow[:, :, 0] *= -1 + if vflip: + flow = flow[::-1, :, :] + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + rlt_img_list = [_augment(img) for img in img_list] + rlt_flow_list = [_augment_flow(flow) for flow in flow_list] + + return rlt_img_list, rlt_flow_list + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +#################### +# Functions +#################### + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (( + (absx > 1) * (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: CHW RGB [0,1] + # output: CHW RGB [0,1] w/o round + + in_C, in_H, in_W = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) + + return out_2 + + +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC BGR [0,1] + # output: HWC BGR [0,1] w/o round + img = torch.from_numpy(img) + + in_H, in_W, in_C = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) + out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) + out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) + + return out_2.numpy() + + +if __name__ == '__main__': + # test imresize function + # read images + img = cv2.imread('test.png') + img = img * 1.0 / 255 + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + # imresize + scale = 1 / 4 + import time + total_time = 0 + for i in range(10): + start_time = time.time() + rlt = imresize(img, scale, antialiasing=True) + use_time = time.time() - start_time + total_time += use_time + print('average time: {}'.format(total_time / 10)) + + import torchvision.utils + torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, + normalize=False) diff --git a/models/IBSN.py b/models/IBSN.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6554048daded07852164aee7144178552b504c --- /dev/null +++ b/models/IBSN.py @@ -0,0 +1,738 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel + +import models.networks as networks +import models.lr_scheduler as lr_scheduler +from .base_model import BaseModel +from models.modules.loss import ReconstructionLoss, ReconstructionMsgLoss +from models.modules.Quantization import Quantization +from .modules.common import DWT,IWT +from utils.jpegtest import JpegTest +from utils.JPEG import DiffJPEG +import utils.util as util + + +import numpy as np +import random +import cv2 +import time + +logger = logging.getLogger('base') +dwt=DWT() +iwt=IWT() + +from diffusers import StableDiffusionInpaintPipeline +from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler +from diffusers import StableDiffusionXLInpaintPipeline +from diffusers.utils import load_image +from diffusers import RePaintPipeline, RePaintScheduler + +class Model_VSN(BaseModel): + def __init__(self, opt): + super(Model_VSN, self).__init__(opt) + + if opt['dist']: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + + self.gop = opt['gop'] + train_opt = opt['train'] + test_opt = opt['test'] + self.opt = opt + self.train_opt = train_opt + self.test_opt = test_opt + self.opt_net = opt['network_G'] + self.center = self.gop // 2 + self.num_image = opt['num_image'] + self.mode = opt["mode"] + self.idxx = 0 + + self.netG = networks.define_G_v2(opt).to(self.device) + if opt['dist']: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + # print network + self.print_network() + self.load() + + self.Quantization = Quantization() + + if not self.opt['hide']: + file_path = "bit_sequence.txt" + + data_list = [] + + with open(file_path, "r") as file: + for line in file: + data = [int(bit) for bit in line.strip()] + data_list.append(data) + + self.msg_list = data_list + + if self.opt['sdinpaint']: + self.pipe = StableDiffusionInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-inpainting", + torch_dtype=torch.float16, + ).to("cuda") + + if self.opt['controlnetinpaint']: + controlnet = ControlNetModel.from_pretrained( + "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float32 + ).to("cuda") + self.pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32 + ).to("cuda") + + if self.opt['sdxl']: + self.pipe_sdxl = StableDiffusionXLInpaintPipeline.from_pretrained( + "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True, + ).to("cuda") + + if self.opt['repaint']: + self.scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") + self.pipe_repaint = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=self.scheduler) + self.pipe_repaint = self.pipe_repaint.to("cuda") + + if self.is_train: + self.netG.train() + + # loss + self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw']) + self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back']) + self.Reconstruction_center = ReconstructionLoss(losstype="center") + self.Reconstruction_msg = ReconstructionMsgLoss(losstype=self.opt['losstype']) + + # optimizers + wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 + optim_params = [] + + if self.mode == "image": + for k, v in self.netG.named_parameters(): + if (k.startswith('module.irn') or k.startswith('module.pm')) and v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + + elif self.mode == "bit": + for k, v in self.netG.named_parameters(): + if (k.startswith('module.bitencoder') or k.startswith('module.bitdecoder')) and v.requires_grad: + optim_params.append(v) + else: + if self.rank <= 0: + logger.warning('Params [{:s}] will not optimize.'.format(k)) + + + self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], + weight_decay=wd_G, + betas=(train_opt['beta1'], train_opt['beta2'])) + self.optimizers.append(self.optimizer_G) + + # schedulers + if train_opt['lr_scheme'] == 'MultiStepLR': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], + restarts=train_opt['restarts'], + weights=train_opt['restart_weights'], + gamma=train_opt['lr_gamma'], + clear_state=train_opt['clear_state'])) + elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], + restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) + else: + raise NotImplementedError('MultiStepLR learning rate scheme is enough.') + + self.log_dict = OrderedDict() + + def feed_data(self, data): + self.ref_L = data['LQ'].to(self.device) + self.real_H = data['GT'].to(self.device) + self.mes = data['MES'] + + def init_hidden_state(self, z): + b, c, h, w = z.shape + h_t = [] + c_t = [] + for _ in range(self.opt_net['block_num_rbm']): + h_t.append(torch.zeros([b, c, h, w]).cuda()) + c_t.append(torch.zeros([b, c, h, w]).cuda()) + memory = torch.zeros([b, c, h, w]).cuda() + + return h_t, c_t, memory + + def loss_forward(self, out, y): + l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y) + return l_forw_fit + + def loss_back_rec(self, out, x): + l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x) + return l_back_rec + + def loss_back_rec_mul(self, out, x): + l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x) + return l_back_rec + + def optimize_parameters(self, current_step): + self.optimizer_G.zero_grad() + + b, n, t, c, h, w = self.ref_L.shape + center = t // 2 + intval = self.gop // 2 + + message = torch.Tensor(np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))).to(self.device) + + add_noise = self.opt['addnoise'] + add_jpeg = self.opt['addjpeg'] + add_possion = self.opt['addpossion'] + add_sdinpaint = self.opt['sdinpaint'] + degrade_shuffle = self.opt['degrade_shuffle'] + + self.host = self.real_H[:, center - intval:center + intval + 1] + self.secret = self.ref_L[:, :, center - intval:center + intval + 1] + self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=dwt(self.secret[:,0].reshape(b, -1, h, w)), message=message) + + Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach() + + y_forw = container + + l_forw_fit = self.loss_forward(y_forw, self.host[:,0]) + + + if degrade_shuffle: + import random + choice = random.randint(0, 2) + + if choice == 0: + NL = float((np.random.randint(1, 16))/255) + noise = np.random.normal(0, NL, y_forw.shape) + torchnoise = torch.from_numpy(noise).cuda().float() + y_forw = y_forw + torchnoise + + elif choice == 1: + NL = int(np.random.randint(70,95)) + self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() + y_forw = self.DiffJPEG(y_forw) + + elif choice == 2: + vals = 10**4 + if random.random() < 0.5: + noisy_img_tensor = torch.poisson(y_forw * vals) / vals + else: + img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) + noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals + noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) + + y_forw = torch.clamp(noisy_img_tensor, 0, 1) + + else: + + if add_noise: + NL = float((np.random.randint(1,16))/255) + noise = np.random.normal(0, NL, y_forw.shape) + torchnoise = torch.from_numpy(noise).cuda().float() + y_forw = y_forw + torchnoise + + elif add_jpeg: + NL = int(np.random.randint(70,95)) + self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() + y_forw = self.DiffJPEG(y_forw) + + elif add_possion: + vals = 10**4 + if random.random() < 0.5: + noisy_img_tensor = torch.poisson(y_forw * vals) / vals + else: + img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) + noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals + noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) + + y_forw = torch.clamp(noisy_img_tensor, 0, 1) + + y = self.Quantization(y_forw) + all_zero = torch.zeros(message.shape).to(self.device) + + if self.mode == "image": + out_x, out_x_h, out_z, recmessage = self.netG(x=y, message=all_zero, rev=True) + out_x = iwt(out_x) + out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] + + l_back_rec = self.loss_back_rec(out_x, self.host[:,0]) + out_x_h = torch.stack(out_x_h, dim=1) + + l_center_x = self.loss_back_rec(out_x_h[:, 0], self.secret[:,0].reshape(b, -1, h, w)) + + recmessage = torch.clamp(recmessage, -0.5, 0.5) + + l_msg = self.Reconstruction_msg(message, recmessage) + + loss = l_forw_fit*2 + l_back_rec + l_center_x*4 + + loss.backward() + + if self.train_opt['lambda_center'] != 0: + self.log_dict['l_center_x'] = l_center_x.item() + + # set log + self.log_dict['l_back_rec'] = l_back_rec.item() + self.log_dict['l_forw_fit'] = l_forw_fit.item() + self.log_dict['l_msg'] = l_msg.item() + + self.log_dict['l_h'] = (l_center_x*10).item() + + # gradient clipping + if self.train_opt['gradient_clipping']: + nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) + + self.optimizer_G.step() + + elif self.mode == "bit": + recmessage = self.netG(x=y, message=all_zero, rev=True) + + recmessage = torch.clamp(recmessage, -0.5, 0.5) + + l_msg = self.Reconstruction_msg(message, recmessage) + + lambda_msg = self.train_opt['lambda_msg'] + + loss = l_msg * lambda_msg + l_forw_fit + + loss.backward() + + # set log + self.log_dict['l_forw_fit'] = l_forw_fit.item() + self.log_dict['l_msg'] = l_msg.item() + + # gradient clipping + if self.train_opt['gradient_clipping']: + nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) + + self.optimizer_G.step() + + def test(self, image_id): + self.netG.eval() + add_noise = self.opt['addnoise'] + add_jpeg = self.opt['addjpeg'] + add_possion = self.opt['addpossion'] + add_sdinpaint = self.opt['sdinpaint'] + add_controlnet = self.opt['controlnetinpaint'] + add_sdxl = self.opt['sdxl'] + add_repaint = self.opt['repaint'] + degrade_shuffle = self.opt['degrade_shuffle'] + + with torch.no_grad(): + forw_L = [] + forw_L_h = [] + fake_H = [] + fake_H_h = [] + pred_z = [] + recmsglist = [] + msglist = [] + b, t, c, h, w = self.real_H.shape + center = t // 2 + intval = self.gop // 2 + b, n, t, c, h, w = self.ref_L.shape + id=0 + # forward downscaling + self.host = self.real_H[:, center - intval+id:center + intval + 1+id] + self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] + self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] + + messagenp = np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length'])) + + message = torch.Tensor(messagenp).to(self.device) + + if self.opt['bitrecord']: + mymsg = message.clone() + + mymsg[mymsg>0] = 1 + mymsg[mymsg<0] = 0 + mymsg = mymsg.squeeze(0).to(torch.int) + + bit_list = mymsg.tolist() + + bit_string = ''.join(map(str, bit_list)) + + file_name = "bit_sequence.txt" + + with open(file_name, "a") as file: + file.write(bit_string + "\n") + + if self.opt['hide']: + self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message) + y_forw = container + else: + + message = torch.tensor(self.msg_list[image_id]).unsqueeze(0).cuda() + self.output = self.host + y_forw = self.output.squeeze(1) + + if add_sdinpaint: + import random + from PIL import Image + prompt = "" + + b, _, _, _ = y_forw.shape + + image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() + forw_list = [] + + for j in range(b): + i = image_id + 1 + masksrc = "../dataset/valAGE-Set-Mask/" + mask_image = Image.open(masksrc + str(i).zfill(4) + ".png").convert("L") + mask_image = mask_image.resize((512, 512)) + h, w = mask_image.size + + image = image_batch[j, :, :, :] + image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") + image_inpaint = self.pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0] + image_inpaint = np.array(image_inpaint) / 255. + mask_image = np.array(mask_image) + mask_image = np.stack([mask_image] * 3, axis=-1) / 255. + mask_image = mask_image.astype(np.uint8) + image_fuse = image * (1 - mask_image) + image_inpaint * mask_image + forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) + + y_forw = torch.stack(forw_list, dim=0).float().cuda() + + if add_controlnet: + from diffusers.utils import load_image + from PIL import Image + + b, _, _, _ = y_forw.shape + forw_list = [] + + image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() + generator = torch.Generator(device="cuda").manual_seed(1) + + for j in range(b): + i = image_id + 1 + mask_path = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png" + mask_image = load_image(mask_path) + mask_image = mask_image.resize((512, 512)) + image_init = image_batch[j, :, :, :] + image_init1 = Image.fromarray((image_init * 255).astype(np.uint8), mode = "RGB") + image_mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 + + assert image_init.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + image_init[image_mask > 0.5] = -1.0 # set as masked pixel + image = np.expand_dims(image_init, 0).transpose(0, 3, 1, 2) + control_image = torch.from_numpy(image) + + # generate image + image_inpaint = self.pipe_control( + "", + num_inference_steps=20, + generator=generator, + eta=1.0, + image=image_init1, + mask_image=image_mask, + control_image=control_image, + ).images[0] + + image_inpaint = np.array(image_inpaint) / 255. + image_mask = np.stack([image_mask] * 3, axis=-1) + image_mask = image_mask.astype(np.uint8) + image_fuse = image_init * (1 - image_mask) + image_inpaint * image_mask + forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) + + y_forw = torch.stack(forw_list, dim=0).float().cuda() + + if add_sdxl: + import random + from PIL import Image + from diffusers.utils import load_image + prompt = "" + + b, _, _, _ = y_forw.shape + + image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() + forw_list = [] + + for j in range(b): + i = image_id + 1 + masksrc = "../dataset/valAGE-Set-Mask/" + mask_image = load_image(masksrc + str(i).zfill(4) + ".png").convert("RGB") + mask_image = mask_image.resize((512, 512)) + h, w = mask_image.size + + image = image_batch[j, :, :, :] + image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") + image_inpaint = self.pipe_sdxl( + prompt=prompt, image=image_init, mask_image=mask_image, num_inference_steps=50, strength=0.80, target_size=(512, 512) + ).images[0] + image_inpaint = image_inpaint.resize((512, 512)) + image_inpaint = np.array(image_inpaint) / 255. + mask_image = np.array(mask_image) / 255. + mask_image = mask_image.astype(np.uint8) + image_fuse = image * (1 - mask_image) + image_inpaint * mask_image + forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) + + y_forw = torch.stack(forw_list, dim=0).float().cuda() + + + if add_repaint: + from PIL import Image + + b, _, _, _ = y_forw.shape + + image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy() + forw_list = [] + + generator = torch.Generator(device="cuda").manual_seed(0) + for j in range(b): + i = image_id + 1 + masksrc = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png" + mask_image = Image.open(masksrc).convert("RGB") + mask_image = mask_image.resize((256, 256)) + mask_image = Image.fromarray(255 - np.array(mask_image)) + image = image_batch[j, :, :, :] + original_image = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB") + original_image = original_image.resize((256, 256)) + output = self.pipe_repaint( + image=original_image, + mask_image=mask_image, + num_inference_steps=150, + eta=0.0, + jump_length=10, + jump_n_sample=10, + generator=generator, + ) + image_inpaint = output.images[0] + image_inpaint = image_inpaint.resize((512, 512)) + image_inpaint = np.array(image_inpaint) / 255. + mask_image = mask_image.resize((512, 512)) + mask_image = np.array(mask_image) / 255. + mask_image = mask_image.astype(np.uint8) + image_fuse = image * mask_image + image_inpaint * (1 - mask_image) + forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1)) + + y_forw = torch.stack(forw_list, dim=0).float().cuda() + + if degrade_shuffle: + import random + choice = random.randint(0, 2) + + if choice == 0: + NL = float((np.random.randint(1,5))/255) + noise = np.random.normal(0, NL, y_forw.shape) + torchnoise = torch.from_numpy(noise).cuda().float() + y_forw = y_forw + torchnoise + + elif choice == 1: + NL = 90 + self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda() + y_forw = self.DiffJPEG(y_forw) + + elif choice == 2: + vals = 10**4 + if random.random() < 0.5: + noisy_img_tensor = torch.poisson(y_forw * vals) / vals + else: + img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) + noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals + noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) + + y_forw = torch.clamp(noisy_img_tensor, 0, 1) + + else: + + if add_noise: + NL = self.opt['noisesigma'] / 255.0 + noise = np.random.normal(0, NL, y_forw.shape) + torchnoise = torch.from_numpy(noise).cuda().float() + y_forw = y_forw + torchnoise + + elif add_jpeg: + Q = self.opt['jpegfactor'] + self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(Q)).cuda() + y_forw = self.DiffJPEG(y_forw) + + elif add_possion: + vals = 10**4 + if random.random() < 0.5: + noisy_img_tensor = torch.poisson(y_forw * vals) / vals + else: + img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True) + noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals + noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor) + + y_forw = torch.clamp(noisy_img_tensor, 0, 1) + + # backward upscaling + if self.opt['hide']: + y = self.Quantization(y_forw) + else: + y = y_forw + + if self.mode == "image": + out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True) + out_x = iwt(out_x) + + out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] + out_x = out_x.reshape(-1, self.gop, 3, h, w) + out_x_h = torch.stack(out_x_h, dim=1) + out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w) + + forw_L.append(y_forw) + fake_H.append(out_x[:, self.gop//2]) + fake_H_h.append(out_x_h[:,:, self.gop//2]) + recmsglist.append(recmessage) + msglist.append(message) + + elif self.mode == "bit": + recmessage = self.netG(x=y, rev=True) + forw_L.append(y_forw) + recmsglist.append(recmessage) + msglist.append(message) + + if self.mode == "image": + self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1) + self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1) + + self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1) + remesg = torch.clamp(torch.stack(recmsglist, dim=0),-0.5,0.5) + + if self.opt['hide']: + mesg = torch.clamp(torch.stack(msglist, dim=0),-0.5,0.5) + else: + mesg = torch.stack(msglist, dim=0) + + self.recmessage = remesg.clone() + self.recmessage[remesg > 0] = 1 + self.recmessage[remesg <= 0] = 0 + + self.message = mesg.clone() + self.message[mesg > 0] = 1 + self.message[mesg <= 0] = 0 + + self.netG.train() + + + def image_hiding(self, ): + self.netG.eval() + with torch.no_grad(): + b, t, c, h, w = self.real_H.shape + center = t // 2 + intval = self.gop // 2 + b, n, t, c, h, w = self.ref_L.shape + id=0 + # forward downscaling + self.host = self.real_H[:, center - intval+id:center + intval + 1+id] + self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] + self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] + + message = torch.Tensor(self.mes).to(self.device) + + self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message) + y_forw = container + + result = torch.clamp(y_forw,0,1) + + lr_img = util.tensor2img(result) + + return lr_img + + def image_recovery(self, number): + self.netG.eval() + with torch.no_grad(): + b, t, c, h, w = self.real_H.shape + center = t // 2 + intval = self.gop // 2 + b, n, t, c, h, w = self.ref_L.shape + id=0 + # forward downscaling + self.host = self.real_H[:, center - intval+id:center + intval + 1+id] + self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] + template = self.secret.reshape(b, -1, h, w) + self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] + + self.output = self.host + y_forw = self.output.squeeze(1) + + y = self.Quantization(y_forw) + + out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True) + out_x = iwt(out_x) + + out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] + out_x = out_x.reshape(-1, self.gop, 3, h, w) + out_x_h = torch.stack(out_x_h, dim=1) + out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w) + + rec_loc = out_x_h[:,:, self.gop//2] + # from PIL import Image + # tmp = util.tensor2img(rec_loc) + # save + residual = torch.abs(template - rec_loc) + binary_residual = (residual > number).float() + residual = util.tensor2img(binary_residual) + mask = np.sum(residual, axis=2) + # print(mask) + + remesg = torch.clamp(recmessage,-0.5,0.5) + remesg[remesg > 0] = 1 + remesg[remesg <= 0] = 0 + + return mask, remesg + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self): + b, n, t, c, h, w = self.ref_L.shape + center = t // 2 + intval = self.gop // 2 + out_dict = OrderedDict() + LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu() + LR_ref = torch.chunk(LR_ref, self.num_image, dim=0) + out_dict['LR_ref'] = [image.squeeze(0) for image in LR_ref] + + if self.mode == "image": + out_dict['SR'] = self.fake_H.detach()[0].float().cpu() + SR_h = self.fake_H_h.detach()[0].float().cpu() + SR_h = torch.chunk(SR_h, self.num_image, dim=0) + out_dict['SR_h'] = [image.squeeze(0) for image in SR_h] + + out_dict['LR'] = self.forw_L.detach()[0].float().cpu() + out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu() + out_dict['message'] = self.message + out_dict['recmessage'] = self.recmessage + + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): + net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, + self.netG.module.__class__.__name__) + else: + net_struc_str = '{}'.format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) + logger.info(s) + + def load(self): + load_path_G = self.opt['path']['pretrain_model_G'] + if load_path_G is not None: + logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + def load_test(self,load_path_G): + self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) + + def save(self, iter_label): + self.save_network(self.netG, 'G', iter_label) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd6718e0397bc26f2ae99886bf920c1cd0cf496 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,11 @@ +import logging +logger = logging.getLogger('base') + +def create_model(opt): + model = opt['model'] + frame_num = opt['gop'] + from .IBSN import Model_VSN as M + + m = M(opt) + logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) + return m \ No newline at end of file diff --git a/models/__pycache__/IBSN.cpython-310.pyc b/models/__pycache__/IBSN.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94298df280ab282a356af5488b802de16e23423f Binary files /dev/null and b/models/__pycache__/IBSN.cpython-310.pyc differ diff --git a/models/__pycache__/IBSN.cpython-38.pyc b/models/__pycache__/IBSN.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90d245d25900a7ff6d4c19dd22635583a6dfb68e Binary files /dev/null and b/models/__pycache__/IBSN.cpython-38.pyc differ diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fe26e2d0b7051b7b73fd175c90d80c75d6f634b Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08098400bfee8f1aa3c9639fb6b6ac1022b22d5d Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/base_model.cpython-38.pyc b/models/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b1827e5a5958fd669e885efbfd2205e4c05a8c Binary files /dev/null and b/models/__pycache__/base_model.cpython-38.pyc differ diff --git a/models/__pycache__/lr_scheduler.cpython-38.pyc b/models/__pycache__/lr_scheduler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0de58527a7ad737cae8ef665b657dcf3341aa5ab Binary files /dev/null and b/models/__pycache__/lr_scheduler.cpython-38.pyc differ diff --git a/models/__pycache__/networks.cpython-310.pyc b/models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0d6317626285aabd5f1da35f68423337d1bd195 Binary files /dev/null and b/models/__pycache__/networks.cpython-310.pyc differ diff --git a/models/__pycache__/networks.cpython-38.pyc b/models/__pycache__/networks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93053e00c4a28b11b96f47280c494640c037eac9 Binary files /dev/null and b/models/__pycache__/networks.cpython-38.pyc differ diff --git a/models/base_model.py b/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fe45d62dae1611a6e141039eaa6f2a794920b9f9 --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,119 @@ +import os +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel + + +class BaseModel(): + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def get_current_losses(self): + pass + + def print_network(self): + pass + + def save(self, label): + pass + + def load(self): + pass + + def _set_lr(self, lr_groups_l): + ''' set learning rate for warmup, + lr_groups_l: list for lr_groups. each for a optimizer''' + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + # get the initial lr, which is set by the scheduler + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, cur_iter, warmup_iter=-1): + for scheduler in self.schedulers: + scheduler.step() + #### set up warm up learning rate + if cur_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + # return self.schedulers[0].get_lr()[0] + return self.optimizers[0].param_groups[0]['lr'] + + def get_network_description(self, network): + '''Get the string and total parameters of the network''' + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + s = str(network) + n = sum(map(lambda x: x.numel(), network.parameters())) + return s, n + + def save_network(self, network, network_label, iter_label): + save_filename = '{}_{}.pth'.format(iter_label, network_label) + save_path = os.path.join(self.opt['path']['models'], save_filename) + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + def load_network(self, load_path, network, strict=True): + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + load_net = torch.load(load_path) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + network.load_state_dict(load_net_clean, strict=strict) + + def save_training_state(self, epoch, iter_step): + '''Saves training state during training, which will be used for resuming''' + state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + save_filename = '{}.state'.format(iter_step) + save_path = os.path.join(self.opt['path']['training_state'], save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + '''Resume the optimizers and schedulers for training''' + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) diff --git a/models/bitnetwork/ConvBlock.py b/models/bitnetwork/ConvBlock.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5e08bb87d6bf60bfbd9de1f4cc8bcb1d4209ac --- /dev/null +++ b/models/bitnetwork/ConvBlock.py @@ -0,0 +1,38 @@ +import torch.nn as nn + + +class ConvINRelu(nn.Module): + """ + A sequence of Convolution, Instance Normalization, and ReLU activation + """ + + def __init__(self, channels_in, channels_out, stride): + super(ConvINRelu, self).__init__() + + self.layers = nn.Sequential( + nn.Conv2d(channels_in, channels_out, 3, stride, padding=1), + nn.InstanceNorm2d(channels_out), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.layers(x) + + +class ConvBlock(nn.Module): + ''' + Network that composed by layers of ConvINRelu + ''' + + def __init__(self, in_channels, out_channels, blocks=1, stride=1): + super(ConvBlock, self).__init__() + + layers = [ConvINRelu(in_channels, out_channels, stride)] if blocks != 0 else [] + for _ in range(blocks - 1): + layer = ConvINRelu(out_channels, out_channels, 1) + layers.append(layer) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) diff --git a/models/bitnetwork/DW_EncoderDecoder.py b/models/bitnetwork/DW_EncoderDecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c4943f83c7abb69570ad1a9286028ba6f4533bed --- /dev/null +++ b/models/bitnetwork/DW_EncoderDecoder.py @@ -0,0 +1,28 @@ +from . import * +from .Encoder_U import DW_Encoder +from .Decoder_U import DW_Decoder +from .Noise import Noise +from .Random_Noise import Random_Noise + + +class DW_EncoderDecoder(nn.Module): + ''' + A Sequential of Encoder_MP-Noise-Decoder + ''' + + def __init__(self, message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder): + super(DW_EncoderDecoder, self).__init__() + self.encoder = DW_Encoder(message_length, attention = attention_encoder) + self.noise = Random_Noise(noise_layers_R + noise_layers_F, len(noise_layers_R), len(noise_layers_F)) + self.decoder_C = DW_Decoder(message_length, attention = attention_decoder) + self.decoder_RF = DW_Decoder(message_length, attention = attention_decoder) + + + def forward(self, image, message, mask): + encoded_image = self.encoder(image, message) + noised_image_C, noised_image_R, noised_image_F = self.noise([encoded_image, image, mask]) + decoded_message_C = self.decoder_C(noised_image_C) + decoded_message_R = self.decoder_RF(noised_image_R) + decoded_message_F = self.decoder_RF(noised_image_F) + return encoded_image, noised_image_C, decoded_message_C, decoded_message_R, decoded_message_F + diff --git a/models/bitnetwork/Decoder_U.py b/models/bitnetwork/Decoder_U.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ca17d4e2a57d710b71c12ff21c47c4d21ecddc --- /dev/null +++ b/models/bitnetwork/Decoder_U.py @@ -0,0 +1,87 @@ +from . import * + + +class DW_Decoder(nn.Module): + + def __init__(self, message_length, blocks=2, channels=64, attention=None): + super(DW_Decoder, self).__init__() + + self.conv1 = ConvBlock(3, 16, blocks=blocks) + self.down1 = Down(16, 32, blocks=blocks) + self.down2 = Down(32, 64, blocks=blocks) + self.down3 = Down(64, 128, blocks=blocks) + + self.down4 = Down(128, 256, blocks=blocks) + + self.up3 = UP(256, 128) + self.att3 = ResBlock(128 * 2, 128, blocks=blocks, attention=attention) + + self.up2 = UP(128, 64) + self.att2 = ResBlock(64 * 2, 64, blocks=blocks, attention=attention) + + self.up1 = UP(64, 32) + self.att1 = ResBlock(32 * 2, 32, blocks=blocks, attention=attention) + + self.up0 = UP(32, 16) + self.att0 = ResBlock(16 * 2, 16, blocks=blocks, attention=attention) + + self.Conv_1x1 = nn.Conv2d(16, 1, kernel_size=1, stride=1, padding=0, bias=False) + + self.message_layer = nn.Linear(message_length * message_length, message_length) + self.message_length = message_length + + + def forward(self, x): + d0 = self.conv1(x) + d1 = self.down1(d0) + d2 = self.down2(d1) + d3 = self.down3(d2) + + d4 = self.down4(d3) + + u3 = self.up3(d4) + u3 = torch.cat((d3, u3), dim=1) + u3 = self.att3(u3) + + u2 = self.up2(u3) + u2 = torch.cat((d2, u2), dim=1) + u2 = self.att2(u2) + + u1 = self.up1(u2) + u1 = torch.cat((d1, u1), dim=1) + u1 = self.att1(u1) + + u0 = self.up0(u1) + u0 = torch.cat((d0, u0), dim=1) + u0 = self.att0(u0) + + residual = self.Conv_1x1(u0) + + message = F.interpolate(residual, size=(self.message_length, self.message_length), + mode='nearest') + message = message.view(message.shape[0], -1) + message = self.message_layer(message) + + return message + + +class Down(nn.Module): + def __init__(self, in_channels, out_channels, blocks): + super(Down, self).__init__() + self.layer = torch.nn.Sequential( + ConvBlock(in_channels, in_channels, stride=2), + ConvBlock(in_channels, out_channels, blocks=blocks) + ) + + def forward(self, x): + return self.layer(x) + + +class UP(nn.Module): + def __init__(self, in_channels, out_channels): + super(UP, self).__init__() + self.conv = ConvBlock(in_channels, out_channels) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + return self.conv(x) diff --git a/models/bitnetwork/Dual_Mark.py b/models/bitnetwork/Dual_Mark.py new file mode 100644 index 0000000000000000000000000000000000000000..01fa7c76658bb8f0e50d38bb0d47a1b7c20b1463 --- /dev/null +++ b/models/bitnetwork/Dual_Mark.py @@ -0,0 +1,249 @@ +from .DW_EncoderDecoder import * +from .Patch_Discriminator import Patch_Discriminator +import torch +import kornia.losses +import lpips + + +class Network: + + def __init__(self, message_length, noise_layers_R, noise_layers_F, device, batch_size, lr, beta1, attention_encoder, attention_decoder, weight): + # device + self.device = device + + # loss function + self.criterion_MSE = nn.MSELoss().to(device) + self.criterion_LPIPS = lpips.LPIPS().to(device) + + # weight of encoder-decoder loss + self.encoder_weight = weight[0] + self.decoder_weight_C = weight[1] + self.decoder_weight_R = weight[2] + self.decoder_weight_F = weight[3] + self.discriminator_weight = weight[4] + + # network + self.encoder_decoder = DW_EncoderDecoder(message_length, noise_layers_R, noise_layers_F, attention_encoder, attention_decoder).to(device) + self.discriminator = Patch_Discriminator().to(device) + + self.encoder_decoder = torch.nn.DataParallel(self.encoder_decoder) + self.discriminator = torch.nn.DataParallel(self.discriminator) + + # mark "cover" as 1, "encoded" as -1 + self.label_cover = 1.0 + self.label_encoded = - 1.0 + + for p in self.encoder_decoder.module.noise.parameters(): + p.requires_grad = False + + # optimizer + self.opt_encoder_decoder = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.encoder_decoder.parameters()), lr=lr, betas=(beta1, 0.999)) + self.opt_discriminator = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) + + + def train(self, images: torch.Tensor, messages: torch.Tensor, masks: torch.Tensor): + self.encoder_decoder.train() + self.discriminator.train() + + with torch.enable_grad(): + # use device to compute + images, messages, masks = images.to(self.device), messages.to(self.device), masks.to(self.device) + encoded_images, noised_images, decoded_messages_C, decoded_messages_R, decoded_messages_F = self.encoder_decoder(images, messages, masks) + + ''' + train discriminator + ''' + for p in self.discriminator.parameters(): + p.requires_grad = True + + self.opt_discriminator.zero_grad() + + # RAW : target label for image should be "cover"(1) + d_label_cover = self.discriminator(images) + #d_cover_loss = self.criterion_MSE(d_label_cover, torch.ones_like(d_label_cover)) + #d_cover_loss.backward() + + # GAN : target label for encoded image should be "encoded"(0) + d_label_encoded = self.discriminator(encoded_images.detach()) + #d_encoded_loss = self.criterion_MSE(d_label_encoded, torch.zeros_like(d_label_encoded)) + #d_encoded_loss.backward() + + d_loss = self.criterion_MSE(d_label_cover - torch.mean(d_label_encoded), self.label_cover * torch.ones_like(d_label_cover)) +\ + self.criterion_MSE(d_label_encoded - torch.mean(d_label_cover), self.label_encoded * torch.ones_like(d_label_encoded)) + d_loss.backward() + + self.opt_discriminator.step() + + ''' + train encoder and decoder + ''' + # Make it a tiny bit faster + for p in self.discriminator.parameters(): + p.requires_grad = False + + self.opt_encoder_decoder.zero_grad() + + # GAN : target label for encoded image should be "cover"(0) + g_label_cover = self.discriminator(images) + g_label_encoded = self.discriminator(encoded_images) + g_loss_on_discriminator = self.criterion_MSE(g_label_cover - torch.mean(g_label_encoded), self.label_encoded * torch.ones_like(g_label_cover)) +\ + self.criterion_MSE(g_label_encoded - torch.mean(g_label_cover), self.label_cover * torch.ones_like(g_label_encoded)) + + # RAW : the encoded image should be similar to cover image + g_loss_on_encoder_MSE = self.criterion_MSE(encoded_images, images) + g_loss_on_encoder_LPIPS = torch.mean(self.criterion_LPIPS(encoded_images, images)) + + # RESULT : the decoded message should be similar to the raw message /Dual + g_loss_on_decoder_C = self.criterion_MSE(decoded_messages_C, messages) + g_loss_on_decoder_R = self.criterion_MSE(decoded_messages_R, messages) + g_loss_on_decoder_F = self.criterion_MSE(decoded_messages_F, torch.zeros_like(messages)) + + # full loss + g_loss = self.discriminator_weight * g_loss_on_discriminator + self.encoder_weight * g_loss_on_encoder_MSE +\ + self.decoder_weight_C * g_loss_on_decoder_C + self.decoder_weight_R * g_loss_on_decoder_R + self.decoder_weight_F * g_loss_on_decoder_F + + g_loss.backward() + self.opt_encoder_decoder.step() + + # psnr + psnr = - kornia.losses.psnr_loss(encoded_images.detach(), images, 2) + + # ssim + ssim = 1 - 2 * kornia.losses.ssim_loss(encoded_images.detach(), images, window_size=11, reduction="mean") + + ''' + decoded message error rate /Dual + ''' + error_rate_C = self.decoded_message_error_rate_batch(messages, decoded_messages_C) + error_rate_R = self.decoded_message_error_rate_batch(messages, decoded_messages_R) + error_rate_F = self.decoded_message_error_rate_batch(messages, decoded_messages_F) + + result = { + "g_loss": g_loss, + "error_rate_C": error_rate_C, + "error_rate_R": error_rate_R, + "error_rate_F": error_rate_F, + "psnr": psnr, + "ssim": ssim, + "g_loss_on_discriminator": g_loss_on_discriminator, + "g_loss_on_encoder_MSE": g_loss_on_encoder_MSE, + "g_loss_on_encoder_LPIPS": g_loss_on_encoder_LPIPS, + "g_loss_on_decoder_C": g_loss_on_decoder_C, + "g_loss_on_decoder_R": g_loss_on_decoder_R, + "g_loss_on_decoder_F": g_loss_on_decoder_F, + "d_loss": d_loss + } + return result + + + def validation(self, images: torch.Tensor, messages: torch.Tensor, masks: torch.Tensor): + self.encoder_decoder.eval() + self.encoder_decoder.module.noise.train() + self.discriminator.eval() + + with torch.no_grad(): + # use device to compute + images, messages, masks = images.to(self.device), messages.to(self.device), masks.to(self.device) + encoded_images, noised_images, decoded_messages_C, decoded_messages_R, decoded_messages_F = self.encoder_decoder(images, messages, masks) + + ''' + validate discriminator + ''' + # RAW : target label for image should be "cover"(1) + d_label_cover = self.discriminator(images) + #d_cover_loss = self.criterion_MSE(d_label_cover, torch.ones_like(d_label_cover)) + + # GAN : target label for encoded image should be "encoded"(0) + d_label_encoded = self.discriminator(encoded_images.detach()) + #d_encoded_loss = self.criterion_MSE(d_label_encoded, torch.zeros_like(d_label_encoded)) + + d_loss = self.criterion_MSE(d_label_cover - torch.mean(d_label_encoded), self.label_cover * torch.ones_like(d_label_cover)) +\ + self.criterion_MSE(d_label_encoded - torch.mean(d_label_cover), self.label_encoded * torch.ones_like(d_label_encoded)) + + ''' + validate encoder and decoder + ''' + + # GAN : target label for encoded image should be "cover"(0) + g_label_cover = self.discriminator(images) + g_label_encoded = self.discriminator(encoded_images) + g_loss_on_discriminator = self.criterion_MSE(g_label_cover - torch.mean(g_label_encoded), self.label_encoded * torch.ones_like(g_label_cover)) +\ + self.criterion_MSE(g_label_encoded - torch.mean(g_label_cover), self.label_cover * torch.ones_like(g_label_encoded)) + + # RAW : the encoded image should be similar to cover image + g_loss_on_encoder_MSE = self.criterion_MSE(encoded_images, images) + g_loss_on_encoder_LPIPS = torch.mean(self.criterion_LPIPS(encoded_images, images)) + + # RESULT : the decoded message should be similar to the raw message /Dual + g_loss_on_decoder_C = self.criterion_MSE(decoded_messages_C, messages) + g_loss_on_decoder_R = self.criterion_MSE(decoded_messages_R, messages) + g_loss_on_decoder_F = self.criterion_MSE(decoded_messages_F, torch.zeros_like(messages)) + + # full loss + # unstable g_loss_on_discriminator is not used during validation + + g_loss = 0 * g_loss_on_discriminator + self.encoder_weight * g_loss_on_encoder_LPIPS +\ + self.decoder_weight_C * g_loss_on_decoder_C + self.decoder_weight_R * g_loss_on_decoder_R + self.decoder_weight_F * g_loss_on_decoder_F + + + # psnr + psnr = - kornia.losses.psnr_loss(encoded_images.detach(), images, 2) + + # ssim + ssim = 1 - 2 * kornia.losses.ssim_loss(encoded_images.detach(), images, window_size=11, reduction="mean") + + ''' + decoded message error rate /Dual + ''' + error_rate_C = self.decoded_message_error_rate_batch(messages, decoded_messages_C) + error_rate_R = self.decoded_message_error_rate_batch(messages, decoded_messages_R) + error_rate_F = self.decoded_message_error_rate_batch(messages, decoded_messages_F) + + result = { + "g_loss": g_loss, + "error_rate_C": error_rate_C, + "error_rate_R": error_rate_R, + "error_rate_F": error_rate_F, + "psnr": psnr, + "ssim": ssim, + "g_loss_on_discriminator": g_loss_on_discriminator, + "g_loss_on_encoder_MSE": g_loss_on_encoder_MSE, + "g_loss_on_encoder_LPIPS": g_loss_on_encoder_LPIPS, + "g_loss_on_decoder_C": g_loss_on_decoder_C, + "g_loss_on_decoder_R": g_loss_on_decoder_R, + "g_loss_on_decoder_F": g_loss_on_decoder_F, + "d_loss": d_loss + } + + return result, (images, encoded_images, noised_images) + + def decoded_message_error_rate(self, message, decoded_message): + length = message.shape[0] + + message = message.gt(0) + decoded_message = decoded_message.gt(0) + error_rate = float(sum(message != decoded_message)) / length + return error_rate + + def decoded_message_error_rate_batch(self, messages, decoded_messages): + error_rate = 0.0 + batch_size = len(messages) + for i in range(batch_size): + error_rate += self.decoded_message_error_rate(messages[i], decoded_messages[i]) + error_rate /= batch_size + return error_rate + + def save_model(self, path_encoder_decoder: str, path_discriminator: str): + torch.save(self.encoder_decoder.module.state_dict(), path_encoder_decoder) + torch.save(self.discriminator.module.state_dict(), path_discriminator) + + def load_model(self, path_encoder_decoder: str, path_discriminator: str): + self.load_model_ed(path_encoder_decoder) + self.load_model_dis(path_discriminator) + + def load_model_ed(self, path_encoder_decoder: str): + self.encoder_decoder.module.load_state_dict(torch.load(path_encoder_decoder), strict=False) + + def load_model_dis(self, path_discriminator: str): + self.discriminator.module.load_state_dict(torch.load(path_discriminator)) diff --git a/models/bitnetwork/Encoder_U.py b/models/bitnetwork/Encoder_U.py new file mode 100644 index 0000000000000000000000000000000000000000..af35295bfbc6bffe82f510bab19ca89a75556162 --- /dev/null +++ b/models/bitnetwork/Encoder_U.py @@ -0,0 +1,125 @@ +from . import * + +class DW_Encoder(nn.Module): + + def __init__(self, message_length, blocks=2, channels=64, attention=None): + super(DW_Encoder, self).__init__() + + self.conv1 = ConvBlock(3, 16, blocks=blocks) + self.down1 = Down(16, 32, blocks=blocks) + self.down2 = Down(32, 64, blocks=blocks) + self.down3 = Down(64, 128, blocks=blocks) + + self.down4 = Down(128, 256, blocks=blocks) + + self.up3 = UP(256, 128) + self.linear3 = nn.Linear(message_length, message_length * message_length) + self.Conv_message3 = ConvBlock(1, channels, blocks=blocks) + self.att3 = ResBlock(128 * 2 + channels, 128, blocks=blocks, attention=attention) + + self.up2 = UP(128, 64) + self.linear2 = nn.Linear(message_length, message_length * message_length) + self.Conv_message2 = ConvBlock(1, channels, blocks=blocks) + self.att2 = ResBlock(64 * 2 + channels, 64, blocks=blocks, attention=attention) + + self.up1 = UP(64, 32) + self.linear1 = nn.Linear(message_length, message_length * message_length) + self.Conv_message1 = ConvBlock(1, channels, blocks=blocks) + self.att1 = ResBlock(32 * 2 + channels, 32, blocks=blocks, attention=attention) + + self.up0 = UP(32, 16) + self.linear0 = nn.Linear(message_length, message_length * message_length) + self.Conv_message0 = ConvBlock(1, channels, blocks=blocks) + self.att0 = ResBlock(16 * 2 + channels, 16, blocks=blocks, attention=attention) + + self.Conv_1x1 = nn.Conv2d(16 + 3, 3, kernel_size=1, stride=1, padding=0) + + self.message_length = message_length + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) + ]) + + + def forward(self, x, watermark): + d0 = self.conv1(x) + d1 = self.down1(d0) + d2 = self.down2(d1) + d3 = self.down3(d2) + + d4 = self.down4(d3) + + u3 = self.up3(d4) + expanded_message = self.linear3(watermark) + expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length) + expanded_message = F.interpolate(expanded_message, size=(d3.shape[2], d3.shape[3]), + mode='nearest') + expanded_message = self.Conv_message3(expanded_message) + u3 = torch.cat((d3, u3, expanded_message), dim=1) + u3 = self.att3(u3) + + u2 = self.up2(u3) + expanded_message = self.linear2(watermark) + expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length) + expanded_message = F.interpolate(expanded_message, size=(d2.shape[2], d2.shape[3]), + mode='nearest') + expanded_message = self.Conv_message2(expanded_message) + u2 = torch.cat((d2, u2, expanded_message), dim=1) + u2 = self.att2(u2) + + u1 = self.up1(u2) + expanded_message = self.linear1(watermark) + expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length) + expanded_message = F.interpolate(expanded_message, size=(d1.shape[2], d1.shape[3]), + mode='nearest') + expanded_message = self.Conv_message1(expanded_message) + u1 = torch.cat((d1, u1, expanded_message), dim=1) + u1 = self.att1(u1) + + u0 = self.up0(u1) + expanded_message = self.linear0(watermark) + expanded_message = expanded_message.view(-1, 1, self.message_length, self.message_length) + expanded_message = F.interpolate(expanded_message, size=(d0.shape[2], d0.shape[3]), + mode='nearest') + expanded_message = self.Conv_message0(expanded_message) + u0 = torch.cat((d0, u0, expanded_message), dim=1) + u0 = self.att0(u0) + + image = self.Conv_1x1(torch.cat((x, u0), dim=1)) + + forward_image = image.clone().detach() + '''read_image = torch.zeros_like(forward_image) + + for index in range(forward_image.shape[0]): + single_image = ((forward_image[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy() + im = Image.fromarray(single_image) + read = np.array(im, dtype=np.uint8) + read_image[index] = self.transform(read).unsqueeze(0).to(image.device) + + gap = read_image - forward_image''' + gap = forward_image.clamp(-1, 1) - forward_image + + return image + gap + + +class Down(nn.Module): + def __init__(self, in_channels, out_channels, blocks): + super(Down, self).__init__() + self.layer = torch.nn.Sequential( + ConvBlock(in_channels, in_channels, stride=2), + ConvBlock(in_channels, out_channels, blocks=blocks) + ) + + def forward(self, x): + return self.layer(x) + + +class UP(nn.Module): + def __init__(self, in_channels, out_channels): + super(UP, self).__init__() + self.conv = ConvBlock(in_channels, out_channels) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + return self.conv(x) diff --git a/models/bitnetwork/Random_Noise.py b/models/bitnetwork/Random_Noise.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a04c6e7ffb85ff0b3b1586c25c7b7d0c64aa31 --- /dev/null +++ b/models/bitnetwork/Random_Noise.py @@ -0,0 +1,59 @@ +from . import * +from .noise_layers import * + + +class Random_Noise(nn.Module): + + def __init__(self, layers, len_layers_R, len_layers_F): + super(Random_Noise, self).__init__() + for i in range(len(layers)): + layers[i] = eval(layers[i]) + self.noise = nn.Sequential(*layers) + self.len_layers_R = len_layers_R + self.len_layers_F = len_layers_F + print(self.noise) + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) + ]) + + def forward(self, image_cover_mask): + image, cover_image, mask = image_cover_mask[0], image_cover_mask[1], image_cover_mask[2] + forward_image = image.clone().detach() + forward_cover_image = cover_image.clone().detach() + forward_mask = mask.clone().detach() + noised_image_C = torch.zeros_like(forward_image) + noised_image_R = torch.zeros_like(forward_image) + noised_image_F = torch.zeros_like(forward_image) + + for index in range(forward_image.shape[0]): + random_noise_layer_C = np.random.choice(self.noise, 1)[0] + random_noise_layer_R = np.random.choice(self.noise[0:self.len_layers_R], 1)[0] + random_noise_layer_F = np.random.choice(self.noise[self.len_layers_R:self.len_layers_R + self.len_layers_F], 1)[0] + noised_image_C[index] = random_noise_layer_C([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)]) + noised_image_R[index] = random_noise_layer_R([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)]) + noised_image_F[index] = random_noise_layer_F([forward_image[index].clone().unsqueeze(0), forward_cover_image[index].clone().unsqueeze(0), forward_mask[index].clone().unsqueeze(0)]) + + '''single_image = ((noised_image_C[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy() + im = Image.fromarray(single_image) + read = np.array(im, dtype=np.uint8) + noised_image_C[index] = self.transform(read).unsqueeze(0).to(image.device) + + single_image = ((noised_image_R[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy() + im = Image.fromarray(single_image) + read = np.array(im, dtype=np.uint8) + noised_image_R[index] = self.transform(read).unsqueeze(0).to(image.device) + + single_image = ((noised_image_F[index].clamp(-1, 1).permute(1, 2, 0) + 1) / 2 * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy() + im = Image.fromarray(single_image) + read = np.array(im, dtype=np.uint8) + noised_image_F[index] = self.transform(read).unsqueeze(0).to(image.device) + + noised_image_gap_C = noised_image_C - forward_image + noised_image_gap_R = noised_image_R - forward_image + noised_image_gap_F = noised_image_F - forward_image''' + noised_image_gap_C = noised_image_C.clamp(-1, 1) - forward_image + noised_image_gap_R = noised_image_R.clamp(-1, 1) - forward_image + noised_image_gap_F = noised_image_F.clamp(-1, 1) - forward_image + + return image + noised_image_gap_C, image + noised_image_gap_R, image + noised_image_gap_F diff --git a/models/bitnetwork/ResBlock.py b/models/bitnetwork/ResBlock.py new file mode 100644 index 0000000000000000000000000000000000000000..87a95f99f8a74fc168eaa34423b1b8b4ab49745e --- /dev/null +++ b/models/bitnetwork/ResBlock.py @@ -0,0 +1,222 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SEAttention(nn.Module): + def __init__(self, in_channels, out_channels, reduction=8): + super(SEAttention, self).__init__() + self.se = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + x = self.se(x) * x + return x + + +class ChannelAttention(nn.Module): + def __init__(self, in_channels, out_channels, reduction=8): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.max_pool = nn.AdaptiveMaxPool2d((1, 1)) + + self.fc = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False)) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc(self.avg_pool(x)) + max_out = self.fc(self.max_pool(x)) + out = avg_out + max_out + return self.sigmoid(out) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + + +class CBAMAttention(nn.Module): + def __init__(self, in_channels, out_channels, reduction=8): + super(CBAMAttention, self).__init__() + self.ca = ChannelAttention(in_channels=in_channels, out_channels=out_channels, reduction=reduction) + self.sa = SpatialAttention() + + def forward(self, x): + x = self.ca(x) * x + x = self.sa(x) * x + return x + + +class h_sigmoid(nn.Module): + def __init__(self, inplace=True): + super(h_sigmoid, self).__init__() + self.relu = nn.ReLU6(inplace=inplace) + + def forward(self, x): + return self.relu(x + 3) / 6 + + +class h_swish(nn.Module): + def __init__(self, inplace=True): + super(h_swish, self).__init__() + self.sigmoid = h_sigmoid(inplace=inplace) + + def forward(self, x): + return x * self.sigmoid(x) + + +class CoordAttention(nn.Module): + def __init__(self, in_channels, out_channels, reduction=8): + super(CoordAttention, self).__init__() + self.pool_w, self.pool_h = nn.AdaptiveAvgPool2d((1, None)), nn.AdaptiveAvgPool2d((None, 1)) + temp_c = max(8, in_channels // reduction) + self.conv1 = nn.Conv2d(in_channels, temp_c, kernel_size=1, stride=1, padding=0) + + self.bn1 = nn.InstanceNorm2d(temp_c) + self.act1 = h_swish() # nn.SiLU() # nn.Hardswish() # nn.SiLU() + + self.conv2 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0) + self.conv3 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + short = x + n, c, H, W = x.shape + x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2) + x_cat = torch.cat([x_h, x_w], dim=2) + out = self.act1(self.bn1(self.conv1(x_cat))) + x_h, x_w = torch.split(out, [H, W], dim=2) + x_w = x_w.permute(0, 1, 3, 2) + out_h = torch.sigmoid(self.conv2(x_h)) + out_w = torch.sigmoid(self.conv3(x_w)) + return short * out_w * out_h + + +class BasicBlock(nn.Module): + def __init__(self, in_channels, out_channels, reduction, stride, attention=None): + super(BasicBlock, self).__init__() + + self.change = None + if (in_channels != out_channels or stride != 1): + self.change = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0, + stride=stride, bias=False), + nn.InstanceNorm2d(out_channels) + ) + + self.left = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, + stride=stride, bias=False), + nn.InstanceNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_channels) + ) + + if attention == 'se': + print('SEAttention') + self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction) + elif attention == 'cbam': + print('CBAMAttention') + self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction) + elif attention == 'coord': + print('CoordAttention') + self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction) + else: + print('None Attention') + self.attention = nn.Identity() + + def forward(self, x): + identity = x + x = self.left(x) + x = self.attention(x) + + if self.change is not None: + identity = self.change(identity) + + x += identity + x = F.relu(x) + return x + + +class BottleneckBlock(nn.Module): + def __init__(self, in_channels, out_channels, reduction, stride, attention=None): + super(BottleneckBlock, self).__init__() + + self.change = None + if (in_channels != out_channels or stride != 1): + self.change = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0, + stride=stride, bias=False), + nn.InstanceNorm2d(out_channels) + ) + + self.left = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, + stride=stride, padding=0, bias=False), + nn.InstanceNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding=0, bias=False), + nn.InstanceNorm2d(out_channels) + ) + + if attention == 'se': + print('SEAttention') + self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction) + elif attention == 'cbam': + print('CBAMAttention') + self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction) + elif attention == 'coord': + print('CoordAttention') + self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction) + else: + print('None Attention') + self.attention = nn.Identity() + + def forward(self, x): + identity = x + x = self.left(x) + x = self.attention(x) + + if self.change is not None: + identity = self.change(identity) + + x += identity + x = F.relu(x) + return x + + +class ResBlock(nn.Module): + + def __init__(self, in_channels, out_channels, blocks=1, block_type="BottleneckBlock", reduction=8, stride=1, attention=None): + super(ResBlock, self).__init__() + + layers = [eval(block_type)(in_channels, out_channels, reduction, stride, attention=attention)] if blocks != 0 else [] + for _ in range(blocks - 1): + layer = eval(block_type)(out_channels, out_channels, reduction, 1, attention=attention) + layers.append(layer) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + diff --git a/models/bitnetwork/__init__.py b/models/bitnetwork/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..534f5a7d29dc37bbfbfe7466fe44232909e45baa --- /dev/null +++ b/models/bitnetwork/__init__.py @@ -0,0 +1,9 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +# import kornia.losses +from PIL import Image +from torchvision import transforms +from .ResBlock import * +from .ConvBlock import * \ No newline at end of file diff --git a/models/bitnetwork/__pycache__/ConvBlock.cpython-38.pyc b/models/bitnetwork/__pycache__/ConvBlock.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd6998c1453987f2eadccb1f32af8ec8abaca701 Binary files /dev/null and b/models/bitnetwork/__pycache__/ConvBlock.cpython-38.pyc differ diff --git a/models/bitnetwork/__pycache__/Decoder_U.cpython-38.pyc b/models/bitnetwork/__pycache__/Decoder_U.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0da66a20b00a38dbff7daf11dfe59a9a67206f3 Binary files /dev/null and b/models/bitnetwork/__pycache__/Decoder_U.cpython-38.pyc differ diff --git a/models/bitnetwork/__pycache__/Encoder_U.cpython-38.pyc b/models/bitnetwork/__pycache__/Encoder_U.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ece58c2f16cfc2ab6b372d53031d0f84997bf63b Binary files /dev/null and b/models/bitnetwork/__pycache__/Encoder_U.cpython-38.pyc differ diff --git a/models/bitnetwork/__pycache__/ResBlock.cpython-38.pyc b/models/bitnetwork/__pycache__/ResBlock.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..525dab2e5670626a9b0c86d3e4b3bbe511cf9437 Binary files /dev/null and b/models/bitnetwork/__pycache__/ResBlock.cpython-38.pyc differ diff --git a/models/bitnetwork/__pycache__/__init__.cpython-38.pyc b/models/bitnetwork/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f93a244a251eb5e56e038b5e31eeb7ba189f7a9 Binary files /dev/null and b/models/bitnetwork/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/discrim.py b/models/discrim.py new file mode 100644 index 0000000000000000000000000000000000000000..535c6ffdac64a7a92345d28536cbb72c07089c3e --- /dev/null +++ b/models/discrim.py @@ -0,0 +1,169 @@ +from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm + + +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + # the first convolution + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + # downsample + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + # extra convolutions + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + # downsample + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra convolutions + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out + + +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight \ No newline at end of file diff --git a/models/lr_scheduler.py b/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..19113a788ccc1e0818ad2c388d6c4601da102cdf --- /dev/null +++ b/models/lr_scheduler.py @@ -0,0 +1,142 @@ +import math +from collections import Counter +from collections import defaultdict +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepLR_Restart(_LRScheduler): + def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, + clear_state=False, last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.clear_state = clear_state + self.restarts = restarts if restarts else [0] + self.restart_weights = weights if weights else [1] + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + if self.clear_state: + self.optimizer.state = defaultdict(dict) + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + +class CosineAnnealingLR_Restart(_LRScheduler): + def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): + self.T_period = T_period + self.T_max = self.T_period[0] # current T period + self.eta_min = eta_min + self.restarts = restarts if restarts else [0] + self.restart_weights = weights if weights else [1] + self.last_restart = 0 + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lrs + elif self.last_epoch in self.restarts: + self.last_restart = self.last_epoch + self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / + (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + +if __name__ == "__main__": + optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, + betas=(0.9, 0.99)) + ############################## + # MultiStepLR_Restart + ############################## + ## Original + lr_steps = [200000, 400000, 600000, 800000] + restarts = None + restart_weights = None + + ## two + lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] + restarts = [500000] + restart_weights = [1] + + ## four + lr_steps = [ + 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, + 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 + ] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, + clear_state=False) + + ############################## + # Cosine Annealing Restart + ############################## + ## two + T_period = [500000, 500000] + restarts = [500000] + restart_weights = [1] + + ## four + T_period = [250000, 250000, 250000, 250000] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, + weights=restart_weights) + + ############################## + # Draw figure + ############################## + N_iter = 1000000 + lr_l = list(range(N_iter)) + for i in range(N_iter): + scheduler.step() + current_lr = optimizer.param_groups[0]['lr'] + lr_l[i] = current_lr + + import matplotlib as mpl + from matplotlib import pyplot as plt + import matplotlib.ticker as mtick + mpl.style.use('default') + import seaborn + seaborn.set(style='whitegrid') + seaborn.set_context('paper') + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + plt.title('Title', fontsize=16, color='k') + plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') + legend = plt.legend(loc='upper right', shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + 'K' + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) + + ax.set_ylabel('Learning rate') + ax.set_xlabel('Iteration') + fig = plt.gcf() + plt.show() diff --git a/models/modules/Inv_arch.py b/models/modules/Inv_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..298c498374cb4e3ee3067b66d0b6055ec1fa9f6c --- /dev/null +++ b/models/modules/Inv_arch.py @@ -0,0 +1,584 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .module_util import initialize_weights_xavier +from torch.nn import init +from .common import DWT,IWT +import cv2 +from basicsr.archs.arch_util import flow_warp +from models.modules.Subnet_constructor import subnet +import numpy as np + +from pdb import set_trace as stx +import numbers + +from einops import rearrange +from models.bitnetwork.Encoder_U import DW_Encoder +from models.bitnetwork.Decoder_U import DW_Decoder + + +## Layer Norm +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + + +def to_4d(x, h, w): + return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type == 'BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, + groups=hidden_features * 2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads=4, ffn_expansion_factor=4, bias=False, LayerNorm_type="withbias"): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + +dwt=DWT() +iwt=IWT() + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( + dim=0), None + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +def thops_mean(tensor, dim=None, keepdim=False): + if dim is None: + # mean all dim + return torch.mean(tensor) + else: + if isinstance(dim, int): + dim = [dim] + dim = sorted(dim) + for d in dim: + tensor = tensor.mean(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dim): + tensor.squeeze_(d-i) + return tensor + + +class ResidualBlockNoBN(nn.Module): + def __init__(self, nf=64, model='MIMO-VRN'): + super(ResidualBlockNoBN, self).__init__() + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + # honestly, there's no significant difference between ReLU and leaky ReLU in terms of performance here + # but this is how we trained the model in the first place and what we reported in the paper + if model == 'LSTM-VRN': + self.relu = nn.ReLU(inplace=True) + elif model == 'MIMO-VRN': + self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + initialize_weights_xavier([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.relu(self.conv1(x)) + out = self.conv2(out) + return identity + out + + +class InvBlock(nn.Module): + def __init__(self, subnet_constructor, subnet_constructor_v2, channel_num_ho, channel_num_hi, groups, clamp=1.): + super(InvBlock, self).__init__() + self.split_len1 = channel_num_ho # channel_split_num + self.split_len2 = channel_num_hi # channel_num - channel_split_num + self.clamp = clamp + + self.F = subnet_constructor_v2(self.split_len2, self.split_len1, groups=groups) + self.NF = NAFBlock(self.split_len2) + if groups == 1: + self.G = subnet_constructor(self.split_len1, self.split_len2, groups=groups) + self.NG = NAFBlock(self.split_len1) + self.H = subnet_constructor(self.split_len1, self.split_len2, groups=groups) + self.NH = NAFBlock(self.split_len1) + else: + self.G = subnet_constructor(self.split_len1, self.split_len2) + self.NG = NAFBlock(self.split_len1) + self.H = subnet_constructor(self.split_len1, self.split_len2) + self.NH = NAFBlock(self.split_len1) + + def forward(self, x1, x2, rev=False): + if not rev: + y1 = x1 + self.NF(self.F(x2)) + self.s = self.clamp * (torch.sigmoid(self.NH(self.H(y1))) * 2 - 1) + y2 = [x2i.mul(torch.exp(self.s)) + self.NG(self.G(y1)) for x2i in x2] + else: + self.s = self.clamp * (torch.sigmoid(self.NH(self.H(x1))) * 2 - 1) + y2 = [(x2i - self.NG(self.G(x1))).div(torch.exp(self.s)) for x2i in x2] + y1 = x1 - self.NF(self.F(y2)) + + return y1, y2 # torch.cat((y1, y2), 1) + + def jacobian(self, x, rev=False): + if not rev: + jac = torch.sum(self.s) + else: + jac = -torch.sum(self.s) + + return jac / x.shape[0] + +class InvNN(nn.Module): + def __init__(self, channel_in_ho=3, channel_in_hi=3, subnet_constructor=None, subnet_constructor_v2=None, block_num=[], down_num=2, groups=None): + super(InvNN, self).__init__() + operations = [] + + current_channel_ho = channel_in_ho + current_channel_hi = channel_in_hi + for i in range(down_num): + for j in range(block_num[i]): + b = InvBlock(subnet_constructor, subnet_constructor_v2, current_channel_ho, current_channel_hi, groups=groups) + operations.append(b) + + self.operations = nn.ModuleList(operations) + + def forward(self, x, x_h, rev=False, cal_jacobian=False): + # out = x + jacobian = 0 + + if not rev: + for op in self.operations: + x, x_h = op.forward(x, x_h, rev) + if cal_jacobian: + jacobian += op.jacobian(x, rev) + else: + for op in reversed(self.operations): + x, x_h = op.forward(x, x_h, rev) + if cal_jacobian: + jacobian += op.jacobian(x, rev) + + if cal_jacobian: + return x, x_h, jacobian + else: + return x, x_h + +class PredictiveModuleMIMO(nn.Module): + def __init__(self, channel_in, nf, block_num_rbm=8, block_num_trans=4): + super(PredictiveModuleMIMO, self).__init__() + self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) + res_block = [] + trans_block = [] + for i in range(block_num_rbm): + res_block.append(ResidualBlockNoBN(nf)) + for j in range(block_num_trans): + trans_block.append(TransformerBlock(nf)) + + self.res_block = nn.Sequential(*res_block) + self.transformer_block = nn.Sequential(*trans_block) + + def forward(self, x): + x = self.conv_in(x) + x = self.res_block(x) + res = self.transformer_block(x) + x + + return res + +class ConvRelu(nn.Module): + def __init__(self, channels_in, channels_out, stride=1, init_zero=False): + super(ConvRelu, self).__init__() + self.init_zero = init_zero + if self.init_zero: + self.layers = nn.Conv2d(channels_in, channels_out, 3, stride, padding=1) + + else: + self.layers = nn.Sequential( + nn.Conv2d(channels_in, channels_out, 3, stride, padding=1), + nn.LeakyReLU(inplace=True) + ) + + def forward(self, x): + return self.layers(x) + +class PredictiveModuleBit(nn.Module): + def __init__(self, channel_in, nf, block_num_rbm=4, block_num_trans=2): + super(PredictiveModuleBit, self).__init__() + self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) + res_block = [] + trans_block = [] + for i in range(block_num_rbm): + res_block.append(ResidualBlockNoBN(nf)) + for j in range(block_num_trans): + trans_block.append(TransformerBlock(nf)) + + blocks = 4 + layers = [ConvRelu(nf, 1, 2)] + for _ in range(blocks - 1): + layer = ConvRelu(1, 1, 2) + layers.append(layer) + self.layers = nn.Sequential(*layers) + + self.res_block = nn.Sequential(*res_block) + self.transformer_block = nn.Sequential(*trans_block) + + def forward(self, x): + x = self.conv_in(x) + x = self.res_block(x) + res = self.transformer_block(x) + x + res = self.layers(res) + + return res + + +##---------- Prompt Gen Module ----------------------- +class PromptGenBlock(nn.Module): + def __init__(self,prompt_dim=12,prompt_len=3,prompt_size = 36,lin_dim = 12): + super(PromptGenBlock,self).__init__() + self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size)) + self.linear_layer = nn.Linear(lin_dim,prompt_len) + self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False) + + + def forward(self,x): + B,C,H,W = x.shape + emb = x.mean(dim=(-2,-1)) + prompt_weights = F.softmax(self.linear_layer(emb),dim=1) + prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1) + prompt = torch.sum(prompt,dim=1) + prompt = F.interpolate(prompt,(H,W),mode="bilinear") + prompt = self.conv3x3(prompt) + + return prompt + +class PredictiveModuleMIMO_prompt(nn.Module): + def __init__(self, channel_in, nf, prompt_len=3, block_num_rbm=8, block_num_trans=4): + super(PredictiveModuleMIMO_prompt, self).__init__() + self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) + res_block = [] + trans_block = [] + for i in range(block_num_rbm): + res_block.append(ResidualBlockNoBN(nf)) + for j in range(block_num_trans): + trans_block.append(TransformerBlock(nf)) + + self.res_block = nn.Sequential(*res_block) + self.transformer_block = nn.Sequential(*trans_block) + self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_len,prompt_size = 36,lin_dim = nf) + self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) + + def forward(self, x): + x = self.conv_in(x) + x = self.res_block(x) + res = self.transformer_block(x) + x + prompt = self.prompt(res) + + result = self.fuse(torch.cat([res, prompt], dim=1)) + + return result + +def gauss_noise(shape): + noise = torch.zeros(shape).cuda() + for i in range(noise.shape[0]): + noise[i] = torch.randn(noise[i].shape).cuda() + + return noise + +def gauss_noise_mul(shape): + noise = torch.randn(shape).cuda() + + return noise + +class PredictiveModuleBit_prompt(nn.Module): + def __init__(self, channel_in, nf, prompt_length, block_num_rbm=4, block_num_trans=2): + super(PredictiveModuleBit_prompt, self).__init__() + self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) + res_block = [] + trans_block = [] + for i in range(block_num_rbm): + res_block.append(ResidualBlockNoBN(nf)) + for j in range(block_num_trans): + trans_block.append(TransformerBlock(nf)) + + blocks = 4 + layers = [ConvRelu(nf, 1, 2)] + for _ in range(blocks - 1): + layer = ConvRelu(1, 1, 2) + layers.append(layer) + self.layers = nn.Sequential(*layers) + + self.res_block = nn.Sequential(*res_block) + self.transformer_block = nn.Sequential(*trans_block) + self.prompt = PromptGenBlock(prompt_dim=nf,prompt_len=prompt_length,prompt_size = 36,lin_dim = nf) + self.fuse = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) + + def forward(self, x): + x = self.conv_in(x) + x = self.res_block(x) + res = self.transformer_block(x) + x + prompt = self.prompt(res) + res = self.fuse(torch.cat([res, prompt], dim=1)) + res = self.layers(res) + + return res + +class VSN(nn.Module): + def __init__(self, opt, subnet_constructor=None, subnet_constructor_v2=None, down_num=2): + super(VSN, self).__init__() + self.model = opt['model'] + self.mode = opt['mode'] + opt_net = opt['network_G'] + self.num_image = opt['num_image'] + self.gop = opt['gop'] + self.channel_in = opt_net['in_nc'] * self.gop + self.channel_out = opt_net['out_nc'] * self.gop + self.channel_in_hi = opt_net['in_nc'] * self.gop + self.channel_in_ho = opt_net['in_nc'] * self.gop + self.message_len = opt['message_length'] + + self.block_num = opt_net['block_num'] + self.block_num_rbm = opt_net['block_num_rbm'] + self.block_num_trans = opt_net['block_num_trans'] + self.nf = self.channel_in_hi + + self.bitencoder = DW_Encoder(self.message_len, attention = "se") + self.bitdecoder = DW_Decoder(self.message_len, attention = "se") + self.irn = InvNN(self.channel_in_ho, self.channel_in_hi, subnet_constructor, subnet_constructor_v2, self.block_num, down_num, groups=self.num_image) + + if opt['prompt']: + self.pm = PredictiveModuleMIMO_prompt(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans) + else: + self.pm = PredictiveModuleMIMO(self.channel_in_ho, self.nf* self.num_image, opt['prompt_len'], block_num_rbm=self.block_num_rbm, block_num_trans=self.block_num_trans) + self.BitPM = PredictiveModuleBit(3, 4, block_num_rbm=4, block_num_trans=2) + + + def forward(self, x, x_h=None, message=None, rev=False, hs=[], direction='f'): + if not rev: + if self.mode == "image": + out_y, out_y_h = self.irn(x, x_h, rev) + out_y = iwt(out_y) + encoded_image = self.bitencoder(out_y, message) + return out_y, encoded_image + + elif self.mode == "bit": + out_y = iwt(x) + encoded_image = self.bitencoder(out_y, message) + return out_y, encoded_image + + else: + if self.mode == "image": + recmessage = self.bitdecoder(x) + + x = dwt(x) + out_z = self.pm(x).unsqueeze(1) + out_z_new = out_z.view(-1, self.num_image, self.channel_in, x.shape[-2], x.shape[-1]) + out_z_new = [out_z_new[:,i] for i in range(self.num_image)] + out_x, out_x_h = self.irn(x, out_z_new, rev) + + return out_x, out_x_h, out_z, recmessage + + elif self.mode == "bit": + recmessage = self.bitdecoder(x) + return recmessage + diff --git a/models/modules/Quantization.py b/models/modules/Quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..7716b66bf96d75ff891b547c59be60c7df0e2853 --- /dev/null +++ b/models/modules/Quantization.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn + +class Quant(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + input = torch.clamp(input, 0, 1) + output = (input * 255.).round() / 255. + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class Quantization(nn.Module): + def __init__(self): + super(Quantization, self).__init__() + + def forward(self, input): + return Quant.apply(input) diff --git a/models/modules/Subnet_constructor.py b/models/modules/Subnet_constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3c42f2691117e5bf9abc09bfc0021a49ad1145 --- /dev/null +++ b/models/modules/Subnet_constructor.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import models.modules.module_util as mutil +from basicsr.archs.arch_util import flow_warp, ResidualBlockNoBN +from models.modules.module_util import initialize_weights_xavier + +class DenseBlock(nn.Module): + def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True): + super(DenseBlock, self).__init__() + self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.H = None + + if init == 'xavier': + mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) + else: + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) + mutil.initialize_weights(self.conv5, 0) + + def forward(self, x): + if isinstance(x, list): + x = x[0] + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + + return x5 + +class DenseBlock_v2(nn.Module): + def __init__(self, channel_in, channel_out, groups, init='xavier', gc=32, bias=True): + super(DenseBlock_v2, self).__init__() + self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) + self.conv_final = nn.Conv2d(channel_out*groups, channel_out, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + if init == 'xavier': + mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + else: + mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + mutil.initialize_weights(self.conv_final, 0) + + def forward(self, x): + res = [] + for xi in x: + x1 = self.lrelu(self.conv1(xi)) + x2 = self.lrelu(self.conv2(torch.cat((xi, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((xi, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((xi, x1, x2, x3), 1))) + x5 = self.lrelu(self.conv5(torch.cat((xi, x1, x2, x3, x4), 1))) + res.append(x5) + res = torch.cat(res, dim=1) + res = self.conv_final(res) + + return res + +def subnet(net_structure, init='xavier'): + def constructor(channel_in, channel_out, groups=None): + if net_structure == 'DBNet': + if init == 'xavier': + return DenseBlock(channel_in, channel_out, init) + elif init == 'xavier_v2': + return DenseBlock_v2(channel_in, channel_out, groups, 'xavier') + else: + return DenseBlock(channel_in, channel_out) + else: + return None + + return constructor diff --git a/models/modules/__init__.py b/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/modules/__pycache__/Conv1x1.cpython-38.pyc b/models/modules/__pycache__/Conv1x1.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..301adf5e98568b684287989c173fdac1bbcd11b5 Binary files /dev/null and b/models/modules/__pycache__/Conv1x1.cpython-38.pyc differ diff --git a/models/modules/__pycache__/DEM.cpython-38.pyc b/models/modules/__pycache__/DEM.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed582880f9cbd13c1bf7a30828524977a66a8f1c Binary files /dev/null and b/models/modules/__pycache__/DEM.cpython-38.pyc differ diff --git a/models/modules/__pycache__/DenseBlock.cpython-38.pyc b/models/modules/__pycache__/DenseBlock.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ed00d443cb51506f28b54b3a31bb3be994935cc Binary files /dev/null and b/models/modules/__pycache__/DenseBlock.cpython-38.pyc differ diff --git a/models/modules/__pycache__/FSM.cpython-38.pyc b/models/modules/__pycache__/FSM.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5339a32100732cf346c711e4780ffa636cec7421 Binary files /dev/null and b/models/modules/__pycache__/FSM.cpython-38.pyc differ diff --git a/models/modules/__pycache__/IM.cpython-38.pyc b/models/modules/__pycache__/IM.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96b2ee1660fb0e86e52c13c85b08556ec52c6ebe Binary files /dev/null and b/models/modules/__pycache__/IM.cpython-38.pyc differ diff --git a/models/modules/__pycache__/Inn.cpython-38.pyc b/models/modules/__pycache__/Inn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1144c3ab2dbf8bf4df42423bd3ca078288767669 Binary files /dev/null and b/models/modules/__pycache__/Inn.cpython-38.pyc differ diff --git a/models/modules/__pycache__/InvArch.cpython-38.pyc b/models/modules/__pycache__/InvArch.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3706f61c8a74a218fbcd912e40d03b5a7993f5c Binary files /dev/null and b/models/modules/__pycache__/InvArch.cpython-38.pyc differ diff --git a/models/modules/__pycache__/InvDownscaling.cpython-38.pyc b/models/modules/__pycache__/InvDownscaling.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da12c89edc1dfd5eecfd5da19f15d9542698424e Binary files /dev/null and b/models/modules/__pycache__/InvDownscaling.cpython-38.pyc differ diff --git a/models/modules/__pycache__/Inv_arch.cpython-310.pyc b/models/modules/__pycache__/Inv_arch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0a352013bb8305d240ca60b1d73588db10d3efb Binary files /dev/null and b/models/modules/__pycache__/Inv_arch.cpython-310.pyc differ diff --git a/models/modules/__pycache__/Inv_arch.cpython-38.pyc b/models/modules/__pycache__/Inv_arch.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b124b06329a6697abb4c782092a9ac20ffd0c6a Binary files /dev/null and b/models/modules/__pycache__/Inv_arch.cpython-38.pyc differ diff --git a/models/modules/__pycache__/Quantization.cpython-38.pyc b/models/modules/__pycache__/Quantization.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1781dea64b44b0276be748df487ea314d39da20 Binary files /dev/null and b/models/modules/__pycache__/Quantization.cpython-38.pyc differ diff --git a/models/modules/__pycache__/Subnet_constructor.cpython-38.pyc b/models/modules/__pycache__/Subnet_constructor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41855950fd1755484bfefd3c2786be6f81acdd12 Binary files /dev/null and b/models/modules/__pycache__/Subnet_constructor.cpython-38.pyc differ diff --git a/models/modules/__pycache__/__init__.cpython-310.pyc b/models/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7c4c3089bafb8721ced482c0eeb9a3930c1d593 Binary files /dev/null and b/models/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/modules/__pycache__/__init__.cpython-38.pyc b/models/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..550804dd96f53175a431ba5ef71893ad11092a91 Binary files /dev/null and b/models/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/modules/__pycache__/common.cpython-310.pyc b/models/modules/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55244bef245f6afe91c49dec5b6a3d674021576c Binary files /dev/null and b/models/modules/__pycache__/common.cpython-310.pyc differ diff --git a/models/modules/__pycache__/common.cpython-38.pyc b/models/modules/__pycache__/common.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcca2406c21887b042e778c267e0e9463e3cfa22 Binary files /dev/null and b/models/modules/__pycache__/common.cpython-38.pyc differ diff --git a/models/modules/__pycache__/gaussianblur.cpython-38.pyc b/models/modules/__pycache__/gaussianblur.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aade6dbd26c4d1302ebab410c9252f3f93efbe9 Binary files /dev/null and b/models/modules/__pycache__/gaussianblur.cpython-38.pyc differ diff --git a/models/modules/__pycache__/loss.cpython-38.pyc b/models/modules/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a51ffc1422b10a57e44775d81996d2b6ca0bdd1 Binary files /dev/null and b/models/modules/__pycache__/loss.cpython-38.pyc differ diff --git a/models/modules/__pycache__/module_util.cpython-310.pyc b/models/modules/__pycache__/module_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..051d0fff3b04fb633a88c38592cd526d61e9d9fc Binary files /dev/null and b/models/modules/__pycache__/module_util.cpython-310.pyc differ diff --git a/models/modules/__pycache__/module_util.cpython-38.pyc b/models/modules/__pycache__/module_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bec86bdf3e46a3b9f4802cd82d3f8a214845cd0 Binary files /dev/null and b/models/modules/__pycache__/module_util.cpython-38.pyc differ diff --git a/models/modules/common.py b/models/modules/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec2258874c378f85951f7674e6f730107fc47b4 --- /dev/null +++ b/models/modules/common.py @@ -0,0 +1,79 @@ +import math + +import torch +import torch.nn as nn + +def dwt_init3d(x): + + x01 = x[:, :, :, 0::2, :] / 2 + x02 = x[:, :, :, 1::2, :] / 2 + x1 = x01[:, :, :, :, 0::2] + x2 = x02[:, :, :, :, 0::2] + x3 = x01[:, :, :, :, 1::2] + x4 = x02[:, :, :, :, 1::2] + x_LL = x1 + x2 + x3 + x4 + x_HL = -x1 - x2 + x3 + x4 + x_LH = -x1 + x2 - x3 + x4 + x_HH = x1 - x2 - x3 + x4 + + return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) + +def dwt_init(x): + + x01 = x[:, :, 0::2, :] / 2 + x02 = x[:, :, 1::2, :] / 2 + x1 = x01[:, :, :, 0::2] + x2 = x02[:, :, :, 0::2] + x3 = x01[:, :, :, 1::2] + x4 = x02[:, :, :, 1::2] + x_LL = x1 + x2 + x3 + x4 + x_HL = -x1 - x2 + x3 + x4 + x_LH = -x1 + x2 - x3 + x4 + x_HH = x1 - x2 - x3 + x4 + + return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) + +def iwt_init(x): + r = 2 + in_batch, in_channel, in_height, in_width = x.size() + #print([in_batch, in_channel, in_height, in_width]) + out_batch, out_channel, out_height, out_width = in_batch, int( + in_channel / (r ** 2)), r * in_height, r * in_width + x1 = x[:, 0:out_channel, :, :] / 2 + x2 = x[:, out_channel:out_channel * 2, :, :] / 2 + x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 + x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 + + + h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda() + + h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 + h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 + h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 + h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 + + return h + +class DWT(nn.Module): + def __init__(self): + super(DWT, self).__init__() + self.requires_grad = False + + def forward(self, x): + return dwt_init(x) + +class DWT3d(nn.Module): + def __init__(self): + super(DWT3d, self).__init__() + self.requires_grad = False + + def forward(self, x): + return dwt_init3d(x) + +class IWT(nn.Module): + def __init__(self): + super(IWT, self).__init__() + self.requires_grad = False + + def forward(self, x): + return iwt_init(x) \ No newline at end of file diff --git a/models/modules/loss.py b/models/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7485438b56e342daca739e354001658aba4b3773 --- /dev/null +++ b/models/modules/loss.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +import numpy as np + + +class ReconstructionLoss(nn.Module): + def __init__(self, losstype='l2', eps=1e-6): + super(ReconstructionLoss, self).__init__() + self.losstype = losstype + self.eps = eps + + def forward(self, x, target): + if self.losstype == 'l2': + return torch.mean(torch.sum((x - target) ** 2, (1, 2, 3))) + elif self.losstype == 'l1': + diff = x - target + return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) + elif self.losstype == 'center': + return torch.sum((x - target) ** 2, (1, 2, 3)) + + else: + print("reconstruction loss type error!") + return 0 + + +# Define GAN loss: [vanilla | lsgan | wgan-gp] +class GANLoss(nn.Module): + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type.lower() + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'gan' or self.gan_type == 'ragan': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan-gp': + + def wgan_loss(input, target): + # target is boolean + return -1 * input.mean() if target else input.mean() + + self.loss = wgan_loss + else: + raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) + + def get_target_label(self, input, target_is_real): + if self.gan_type == 'wgan-gp': + return target_is_real + if target_is_real: + return torch.empty_like(input).fill_(self.real_label_val) + else: + return torch.empty_like(input).fill_(self.fake_label_val) + + def forward(self, input, target_is_real): + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + return loss + + +class GradientPenaltyLoss(nn.Module): + def __init__(self, device=torch.device('cpu')): + super(GradientPenaltyLoss, self).__init__() + self.register_buffer('grad_outputs', torch.Tensor()) + self.grad_outputs = self.grad_outputs.to(device) + + def get_grad_outputs(self, input): + if self.grad_outputs.size() != input.size(): + self.grad_outputs.resize_(input.size()).fill_(1.0) + return self.grad_outputs + + def forward(self, interp, interp_crit): + grad_outputs = self.get_grad_outputs(interp_crit) + grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, + grad_outputs=grad_outputs, create_graph=True, + retain_graph=True, only_inputs=True)[0] + grad_interp = grad_interp.view(grad_interp.size(0), -1) + grad_interp_norm = grad_interp.norm(2, dim=1) + + loss = ((grad_interp_norm - 1) ** 2).mean() + return loss + + +class ReconstructionMsgLoss(nn.Module): + def __init__(self, losstype='mse'): + super(ReconstructionMsgLoss, self).__init__() + self.losstype = losstype + self.mse_loss = nn.MSELoss() + self.bce_loss = nn.BCELoss() + self.bce_logits_loss = nn.BCEWithLogitsLoss() + + def forward(self, messages, decoded_messages): + if self.losstype == 'mse': + return self.mse_loss(messages, decoded_messages) + elif self.losstype == 'bce': + return self.bce_loss(messages, decoded_messages) + elif self.losstype == 'bce_logits': + return self.bce_logits_loss(messages, decoded_messages) + else: + print("ReconstructionMsgLoss loss type error!") + return 0 diff --git a/models/modules/module_util.py b/models/modules/module_util.py new file mode 100644 index 0000000000000000000000000000000000000000..52b69b6250da1c413b26fc63a865052e6f8a9452 --- /dev/null +++ b/models/modules/module_util.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def initialize_weights_xavier(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.xavier_normal_(m.weight) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualBlock_noBN(nn.Module): + '''Residual block w/o BN + ---Conv-ReLU-Conv-+- + |________________| + ''' + + def __init__(self, nf=64): + super(ResidualBlock_noBN, self).__init__() + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + # initialization + initialize_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = F.relu(self.conv1(x), inplace=True) + out = self.conv2(out) + return identity + out + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): + """Warp an image or feature map with optical flow + Args: + x (Tensor): size (N, C, H, W) + flow (Tensor): size (N, H, W, 2), normal value + interp_mode (str): 'nearest' or 'bilinear' + padding_mode (str): 'zeros' or 'border' or 'reflection' + + Returns: + Tensor: warped image or feature map + """ + assert x.size()[-2:] == flow.size()[1:3] + B, C, H, W = x.size() + # mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + grid = grid.type_as(x) + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) + return output diff --git a/models/networks.py b/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..4f62ed98c34ec23ec6b530fcfc9d172f3118e07e --- /dev/null +++ b/models/networks.py @@ -0,0 +1,23 @@ +import logging +import math + +from models.modules.Inv_arch import * +from models.modules.Subnet_constructor import subnet + +logger = logging.getLogger('base') + +#################### +# define network +#################### +def define_G_v2(opt): + opt_net = opt['network_G'] + which_model = opt_net['which_model_G'] + subnet_type = which_model['subnet_type'] + opt_datasets = opt['datasets'] + down_num = int(math.log(opt_net['scale'], 2)) + if opt['num_image'] == 1: + netG = VSN(opt, subnet(subnet_type, 'xavier'), subnet(subnet_type, 'xavier'), down_num) + else: + netG = VSN(opt, subnet(subnet_type, 'xavier'), subnet(subnet_type, 'xavier_v2'), down_num) + + return netG diff --git a/options/__init__.py b/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/options/__pycache__/__init__.cpython-310.pyc b/options/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfbef1262de181d18767710d80d5fd82cdcba506 Binary files /dev/null and b/options/__pycache__/__init__.cpython-310.pyc differ diff --git a/options/__pycache__/__init__.cpython-38.pyc b/options/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9208aa7f729a956e74f6cae5acc8646700f0230e Binary files /dev/null and b/options/__pycache__/__init__.cpython-38.pyc differ diff --git a/options/__pycache__/options.cpython-310.pyc b/options/__pycache__/options.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..495245cd0e4b9f52d281e04527a94c717affd1cc Binary files /dev/null and b/options/__pycache__/options.cpython-310.pyc differ diff --git a/options/__pycache__/options.cpython-38.pyc b/options/__pycache__/options.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e754600cca106bcb37fb7c8a67e53125f4cd794 Binary files /dev/null and b/options/__pycache__/options.cpython-38.pyc differ diff --git a/options/options.py b/options/options.py new file mode 100644 index 0000000000000000000000000000000000000000..dc359c63728cb2db38563901fe8cd52fb068adaa --- /dev/null +++ b/options/options.py @@ -0,0 +1,114 @@ +import os +import os.path as osp +import logging +import yaml +from utils.util import OrderedYaml +Loader, Dumper = OrderedYaml() + + +def parse(opt_path, is_train=True): + with open(opt_path, mode='r') as f: + opt = yaml.load(f, Loader=Loader) + # export CUDA_VISIBLE_DEVICES + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) + + opt['is_train'] = is_train + if opt['distortion'] == 'sr': + scale = opt['scale'] + + # datasets + for phase, dataset in opt['datasets'].items(): + phase = phase.split('_')[0] + dataset['phase'] = phase + if opt['distortion'] == 'sr': + dataset['scale'] = scale + is_lmdb = False + if dataset.get('dataroot_GT', None) is not None: + dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) + if dataset['dataroot_GT'].endswith('lmdb'): + is_lmdb = True + # if dataset.get('dataroot_GT_bg', None) is not None: + # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) + if dataset.get('dataroot_LQ', None) is not None: + dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) + if dataset['dataroot_LQ'].endswith('lmdb'): + is_lmdb = True + dataset['data_type'] = 'lmdb' if is_lmdb else 'img' + if dataset['mode'].endswith('mc'): # for memcached + dataset['data_type'] = 'mc' + dataset['mode'] = dataset['mode'].replace('_mc', '') + + # path + for key, path in opt['path'].items(): + if path and key in opt['path'] and key != 'strict_load': + opt['path'][key] = osp.expanduser(path) + opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) + if is_train: + experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_state'] = osp.join(experiments_root, 'training_state') + opt['path']['log'] = experiments_root + opt['path']['val_images'] = osp.join(experiments_root, 'val_images') + + # change some options for debug mode + if 'debug' in opt['name']: + opt['train']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = osp.join(opt['path']['root'], 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + + # network + if opt['distortion'] == 'sr': + opt['network_G']['scale'] = scale + + return opt + + +def dict2str(opt, indent_l=1): + '''dict to string for logger''' + msg = '' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_l * 2) + k + ':[\n' + msg += dict2str(v, indent_l + 1) + msg += ' ' * (indent_l * 2) + ']\n' + else: + msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' + return msg + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +# convert to NoneDict, which return None for missing key. +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +def check_resume(opt, resume_iter): + '''Check resume states and pretrain_model paths''' + logger = logging.getLogger('base') + if opt['path']['resume_state']: + if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( + 'pretrain_model_D', None) is not None: + logger.warning('pretrain_model path will be ignored when resuming training.') + + opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], + '{}_G.pth'.format(resume_iter)) + logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) diff --git a/options/test_editguard.yml b/options/test_editguard.yml new file mode 100644 index 0000000000000000000000000000000000000000..90d7513a6d2c413322428fe6d66c4c93cf6248bb --- /dev/null +++ b/options/test_editguard.yml @@ -0,0 +1,109 @@ +#### general settings + +name: test_age-set +use_tb_logger: true +model: MIMO-VRN-h +distortion: sr +scale: 4 +gpu_ids: [0] +gop: 1 +num_image: 1 + +addnoise: False +noisesigma: 1 + +addjpeg: False +jpegfactor: 70 +addpossion: False +sdinpaint: True +controlnetinpaint: False +sdxl: False +repaint: False + +hide: True +hidebit: True +degrade_shuffle: True +prompt: True +prompt_len: 3 +message_length: 64 +bitrecord: False + +mode: image + +#### datasets + +datasets: + TD: + num_image: 1 + name: AGE-Set + mode: td + + data_path: ../dataset/valAGE-Set + txt_path: ../dataset/sep_testlist.txt + + N_frames: 1 + padding: 'new_info' + pred_interval: -1 + + N_frames: 1 + padding: 'new_info' + pred_interval: -1 + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 12 + out_nc: 12 + block_num: [6, 6] + scale: 2 + init: xavier_group + block_num_rbm: 8 + block_num_trans: 4 + + +#### path + +path: + pretrain_model_G: + models: ckp/base + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + + lr_G: !!float 1e-4 + beta1: 0.9 + beta2: 0.5 + niter: 250000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] + lr_gamma: 0.5 + + pixel_criterion_forw: l2 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 1000 #!!float 5e3 + + lambda_fit_forw: 64. + lambda_rec_back: 1 + lambda_center: 0 + + weight_decay_G: !!float 1e-12 + gradient_clipping: 10 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 diff --git a/options/train_editguard_bit.yml b/options/train_editguard_bit.yml new file mode 100644 index 0000000000000000000000000000000000000000..7cb0d1c31b3c012a386f50cd8afde55b9713de1a --- /dev/null +++ b/options/train_editguard_bit.yml @@ -0,0 +1,132 @@ +#### general settings + +name: train_ibsn_bit_64 +use_tb_logger: true +model: MIMO-VRN-h +distortion: sr +scale: 4 +gpu_ids: [0, 1] +gop: 1 +num_image: 1 + +addnoise: False +noisesigma: 0.05 + +addjpeg: False +jpegfactor: 90 +addpossion: False +sdinpaint: False +controlnetinpaint: False +sdxl: False +repaint: False +sdprompt: False +sdxlprompt: False +faceswap: False + +hide: True +bithide: True +degrade_shuffle: True +prompt: True +prompt_len: 3 +message_length: 64 + +losstype: mse + +mode: bit + +#### datasets + +datasets: + train: + name: Vimeo90K + mode: train + interval_list: [1] + random_reverse: false + border_mode: false + data_path: /userhome/train2017 + txt_path: /userhome/train2017.txt + dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb + cache_keys: Vimeo90K_train_keys.pkl + num_image: 1 + + N_frames: 7 + use_shuffle: true + n_workers: 24 + batch_size: 4 + GT_size: 400 + LQ_size: 36 + use_flip: true + use_rot: true + color: RGB + + val: + num_image: 1 + name: Vid4 + mode: test + data_path: ../dataset/valAGE-Set + txt_path: ../dataset/sep_vallist.txt + + N_frames: 1 + padding: 'new_info' + pred_interval: -1 + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 12 + out_nc: 12 + block_num: [6, 6] + scale: 2 + init: xavier_group + block_num_rbm: 8 + block_num_trans: 4 + +#### path + +path: + pretrain_model_G: + models: ckp/base + strict_load: true + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + + lr_G: !!float 1e-4 + beta1: 0.9 + beta2: 0.5 + niter: 250000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [30000, 100000, 250000] + lr_gamma: 0.5 + + pixel_criterion_forw: l2 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 500 #!!float 5e3 + + lambda_fit_forw: 1. + lambda_rec_back: 1 + lambda_center: 0 + lambda_msg: !!float 100 # 500000 + + progressive: False + + weight_decay_G: !!float 1e-12 + gradient_clipping: 10 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 500 diff --git a/options/train_editguard_image.yml b/options/train_editguard_image.yml new file mode 100644 index 0000000000000000000000000000000000000000..3d42346b4023c5404abd91b6373f5ea3683551eb --- /dev/null +++ b/options/train_editguard_image.yml @@ -0,0 +1,126 @@ +#### general settings + +name: train_editguard +use_tb_logger: true +model: MIMO-VRN-h +distortion: sr +scale: 4 +gpu_ids: [0, 1, 2, 3] +gop: 1 +num_image: 1 + +addnoise: False +noisesigma: 10 + +addjpeg: False +jpegfactor: 90 +addpossion: False +sdinpaint: False +controlnetinpaint: False +sdxl: False +repaint: False + +hide: True +bithide: False +degrade_shuffle: False +prompt: True +prompt_len: 3 +message_length: 64 + +losstype: mse + +mode: image + +#### datasets + +datasets: + train: + name: CoCo + mode: train + interval_list: [1] + random_reverse: false + border_mode: false + data_path: /userhome/train2017 + txt_path: /userhome/train2017.txt + dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb + cache_keys: Vimeo90K_train_keys.pkl + num_image: 1 + + N_frames: 7 + use_shuffle: true + n_workers: 24 + batch_size: 4 + GT_size: 400 + LQ_size: 36 + use_flip: true + use_rot: true + color: RGB + + val: + num_image: 1 + name: CoCo + mode: test + data_path: ../dataset/valAGE-Set + txt_path: ../dataset/sep_vallist.txt + + N_frames: 1 + padding: 'new_info' + pred_interval: -1 + + +#### network structures + +network_G: + which_model_G: + subnet_type: DBNet + in_nc: 12 + out_nc: 12 + block_num: [6, 6] + scale: 2 + init: xavier_group + block_num_rbm: 8 + block_num_trans: 4 + +#### path + +path: + pretrain_model_G: + models: ckp/base + strict_load: False + resume_state: ~ + + +#### training settings: learning rate scheme, loss + +train: + + lr_G: !!float 1e-4 + beta1: 0.9 + beta2: 0.5 + niter: 250000 + warmup_iter: -1 # no warm up + + lr_scheme: MultiStepLR + lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] + lr_gamma: 0.5 + + pixel_criterion_forw: l2 + pixel_criterion_back: l1 + + manual_seed: 10 + + val_freq: !!float 100 #!!float 5e3 + + lambda_fit_forw: 100 + lambda_rec_back: 1 + lambda_center: 0 + + weight_decay_G: !!float 1e-12 + gradient_clipping: 10 + + +#### logger + +logger: + print_freq: 100 + save_checkpoint_freq: !!float 500 diff --git a/utils/JPEG.py b/utils/JPEG.py new file mode 100644 index 0000000000000000000000000000000000000000..8997ee98a41668b4737a9b2acc2341032f173bd3 --- /dev/null +++ b/utils/JPEG.py @@ -0,0 +1,43 @@ + + +import torch +import torch.nn as nn + +from .JPEG_utils import diff_round, quality_to_factor, Quantization +from .compression import compress_jpeg +from .decompression import decompress_jpeg + + +class DiffJPEG(nn.Module): + def __init__(self, differentiable=True, quality=75): + ''' Initialize the DiffJPEG layer + Inputs: + height(int): Original image height + width(int): Original image width + differentiable(bool): If true uses custom differentiable + rounding function, if false uses standrard torch.round + quality(float): Quality factor for jpeg compression scheme. + ''' + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + # rounding = Quantization() + else: + rounding = torch.round + factor = quality_to_factor(quality) + self.compress = compress_jpeg(rounding=rounding, factor=factor) + # self.decompress = decompress_jpeg(height, width, rounding=rounding, + # factor=factor) + self.decompress = decompress_jpeg(rounding=rounding, factor=factor) + + def forward(self, x): + ''' + ''' + org_height = x.shape[2] + org_width = x.shape[3] + y, cb, cr = self.compress(x) + + recovered = self.decompress(y, cb, cr, org_height, org_width) + return recovered + + diff --git a/utils/JPEG_utils.py b/utils/JPEG_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ebd9bdc184e869ade58eea1c6763baa1d9fc91 --- /dev/null +++ b/utils/JPEG_utils.py @@ -0,0 +1,75 @@ +# Standard libraries +import numpy as np +# PyTorch +import torch +import torch.nn as nn +import math + +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, + 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, + 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T + +y_table = nn.Parameter(torch.from_numpy(y_table)) +# +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], + [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round_back(x): + """ Differentiable rounding function + Input: + x(tensor) + Output: + x(tensor) + """ + return torch.round(x) + (x - torch.round(x))**3 + + + +def diff_round(input_tensor): + test = 0 + for n in range(1, 10): + test += math.pow(-1, n+1) / n * torch.sin(2 * math.pi * n * input_tensor) + final_tensor = input_tensor - 1 / math.pi * test + return final_tensor + + +class Quant(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + input = torch.clamp(input, 0, 1) + output = (input * 255.).round() / 255. + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class Quantization(nn.Module): + def __init__(self): + super(Quantization, self).__init__() + + def forward(self, input): + return Quant.apply(input) + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + Input: + quality(float): Quality for jpeg compression + Output: + factor(float): Compression factor + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality*2 + return quality / 100. \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/JPEG.cpython-310.pyc b/utils/__pycache__/JPEG.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af37ab029e241f0cf4953f036121b111fd171a0 Binary files /dev/null and b/utils/__pycache__/JPEG.cpython-310.pyc differ diff --git a/utils/__pycache__/JPEG.cpython-38.pyc b/utils/__pycache__/JPEG.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e5f033898e6f619295a920143f17a4fb1320df6 Binary files /dev/null and b/utils/__pycache__/JPEG.cpython-38.pyc differ diff --git a/utils/__pycache__/JPEG_utils.cpython-310.pyc b/utils/__pycache__/JPEG_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5abe231409ad363d5c647a2fad63fa47dc1eaa7 Binary files /dev/null and b/utils/__pycache__/JPEG_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/JPEG_utils.cpython-38.pyc b/utils/__pycache__/JPEG_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5c6af94d8a0f285ba012036d93eb266abd0d918 Binary files /dev/null and b/utils/__pycache__/JPEG_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69a547e724d122fcc636b7ad85efc1c1f426765b Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34d46c84b979a5e58ec3cba451666badc319bb61 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/compression.cpython-310.pyc b/utils/__pycache__/compression.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55081325eb2bdefcd1dc3b7a917da86406dd7872 Binary files /dev/null and b/utils/__pycache__/compression.cpython-310.pyc differ diff --git a/utils/__pycache__/compression.cpython-38.pyc b/utils/__pycache__/compression.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfcb3286d0a6a97e1e822ff2643ae889a40afe97 Binary files /dev/null and b/utils/__pycache__/compression.cpython-38.pyc differ diff --git a/utils/__pycache__/decompression.cpython-310.pyc b/utils/__pycache__/decompression.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43df45a2b71665f5ecbe498c136d0968a3815cd0 Binary files /dev/null and b/utils/__pycache__/decompression.cpython-310.pyc differ diff --git a/utils/__pycache__/decompression.cpython-38.pyc b/utils/__pycache__/decompression.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..387ee11d65d5205244da6177dfa4d51ac2ac6ff0 Binary files /dev/null and b/utils/__pycache__/decompression.cpython-38.pyc differ diff --git a/utils/__pycache__/face_detection.cpython-38.pyc b/utils/__pycache__/face_detection.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2f0533da2364d11e0db05a6935eeaab054acbd1 Binary files /dev/null and b/utils/__pycache__/face_detection.cpython-38.pyc differ diff --git a/utils/__pycache__/jpegtest.cpython-38.pyc b/utils/__pycache__/jpegtest.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14a1fcf4b451de86575572a4bc1099df6a318dfd Binary files /dev/null and b/utils/__pycache__/jpegtest.cpython-38.pyc differ diff --git a/utils/__pycache__/util.cpython-310.pyc b/utils/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..973625a95c56e054826a9b8a439cae5fa667304e Binary files /dev/null and b/utils/__pycache__/util.cpython-310.pyc differ diff --git a/utils/__pycache__/util.cpython-38.pyc b/utils/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb0cdcf747f540f4d6ba9020ced03fd44ec70376 Binary files /dev/null and b/utils/__pycache__/util.cpython-38.pyc differ diff --git a/utils/compression.py b/utils/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..363d3416ac949a9c1f38693ef193698d2a84664c --- /dev/null +++ b/utils/compression.py @@ -0,0 +1,182 @@ +# Standard libraries +import itertools +import numpy as np +# PyTorch +import torch +import torch.nn as nn +# Local +from . import JPEG_utils + + +class rgb_to_ycbcr_jpeg(nn.Module): + """ Converts RGB image to YCbCr + + """ + def __init__(self): + super(rgb_to_ycbcr_jpeg, self).__init__() + matrix = np.array( + [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], + [0.5, -0.418688, -0.081312]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + # + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + # result = torch.from_numpy(result) + result.view(image.shape) + return result + + + +class chroma_subsampling(nn.Module): + """ Chroma subsampling on CbCv channels + Input: + image(tensor): batch x height x width x 3 + Output: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + def __init__(self): + super(chroma_subsampling, self).__init__() + + def forward(self, image): + image_2 = image.permute(0, 3, 1, 2).clone() + avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), + count_include_pad=False) + cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1)) + cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1)) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class block_splitting(nn.Module): + """ Splitting image into patches + Input: + image(tensor): batch x height x width + Output: + patch(tensor): batch x h*w/64 x h x w + """ + def __init__(self): + super(block_splitting, self).__init__() + self.k = 8 + + def forward(self, image): + height, width = image.shape[1:3] + # print(height, width) + batch_size = image.shape[0] + # print(image.shape) + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class dct_8x8(nn.Module): + """ Discrete Cosine Transformation + Input: + image(tensor): batch x height x width + Output: + dcp(tensor): batch x height x width + """ + def __init__(self): + super(dct_8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos( + (2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + # + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() ) + + def forward(self, image): + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class y_quantize(nn.Module): + """ JPEG Quantization for Y channel + Input: + image(tensor): batch x height x width + rounding(function): rounding function to use + factor(float): Degree of compression + Output: + image(tensor): batch x height x width + """ + def __init__(self, rounding, factor=1): + super(y_quantize, self).__init__() + self.rounding = rounding + self.factor = factor + self.y_table = JPEG_utils.y_table + + def forward(self, image): + image = image.float() / (self.y_table * self.factor) + image = self.rounding(image) + return image + + +class c_quantize(nn.Module): + """ JPEG Quantization for CrCb channels + Input: + image(tensor): batch x height x width + rounding(function): rounding function to use + factor(float): Degree of compression + Output: + image(tensor): batch x height x width + """ + def __init__(self, rounding, factor=1): + super(c_quantize, self).__init__() + self.rounding = rounding + self.factor = factor + self.c_table = JPEG_utils.c_table + + def forward(self, image): + image = image.float() / (self.c_table * self.factor) + image = self.rounding(image) + return image + + +class compress_jpeg(nn.Module): + """ Full JPEG compression algortihm + Input: + imgs(tensor): batch x 3 x height x width + rounding(function): rounding function to use + factor(float): Compression factor + Ouput: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + """ + def __init__(self, rounding=torch.round, factor=1): + super(compress_jpeg, self).__init__() + self.l1 = nn.Sequential( + rgb_to_ycbcr_jpeg(), + # comment this line if no subsampling + chroma_subsampling() + ) + self.l2 = nn.Sequential( + block_splitting(), + dct_8x8() + ) + self.c_quantize = c_quantize(rounding=rounding, factor=factor) + self.y_quantize = y_quantize(rounding=rounding, factor=factor) + + def forward(self, image): + y, cb, cr = self.l1(image*255) # modify + + # y, cb, cr = result[:,:,:,0], result[:,:,:,1], result[:,:,:,2] + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + # print(comp.shape) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp) + else: + comp = self.y_quantize(comp) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] \ No newline at end of file diff --git a/utils/decompression.py b/utils/decompression.py new file mode 100644 index 0000000000000000000000000000000000000000..b73ff96d5f6818e1d0464b9c4133f559a3b23fba --- /dev/null +++ b/utils/decompression.py @@ -0,0 +1,190 @@ +# Standard libraries +import itertools +import numpy as np +# PyTorch +import torch +import torch.nn as nn +# Local +from . import JPEG_utils as utils + + +class y_dequantize(nn.Module): + """ Dequantize Y channel + Inputs: + image(tensor): batch x height x width + factor(float): compression factor + Outputs: + image(tensor): batch x height x width + """ + def __init__(self, factor=1): + super(y_dequantize, self).__init__() + self.y_table = utils.y_table + self.factor = factor + + def forward(self, image): + return image * (self.y_table * self.factor) + + +class c_dequantize(nn.Module): + """ Dequantize CbCr channel + Inputs: + image(tensor): batch x height x width + factor(float): compression factor + Outputs: + image(tensor): batch x height x width + """ + def __init__(self, factor=1): + super(c_dequantize, self).__init__() + self.factor = factor + self.c_table = utils.c_table + + def forward(self, image): + return image * (self.c_table * self.factor) + + +class idct_8x8(nn.Module): + """ Inverse discrete Cosine Transformation + Input: + dcp(tensor): batch x height x width + Output: + image(tensor): batch x height x width + """ + def __init__(self): + super(idct_8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( + (2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class block_merging(nn.Module): + """ Merge pathces into image + Inputs: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + Output: + image(tensor): batch x height x width + """ + def __init__(self): + super(block_merging, self).__init__() + + def forward(self, patches, height, width): + k = 8 + batch_size = patches.shape[0] + # print(patches.shape) # (1,1024,8,8) + image_reshaped = patches.view(batch_size, height//k, width//k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class chroma_upsampling(nn.Module): + """ Upsample chroma layers + Input: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + Ouput: + image(tensor): batch x height x width x 3 + """ + def __init__(self): + super(chroma_upsampling, self).__init__() + + def forward(self, y, cb, cr): + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class ycbcr_to_rgb_jpeg(nn.Module): + """ Converts YCbCr image to RGB JPEG + Input: + image(tensor): batch x height x width x 3 + Outpput: + result(tensor): batch x 3 x height x width + """ + def __init__(self): + super(ycbcr_to_rgb_jpeg, self).__init__() + + matrix = np.array( + [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + #result = torch.from_numpy(result) + result.view(image.shape) + return result.permute(0, 3, 1, 2) + + +class decompress_jpeg(nn.Module): + """ Full JPEG decompression algortihm + Input: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + rounding(function): rounding function to use + factor(float): Compression factor + Ouput: + image(tensor): batch x 3 x height x width + """ + # def __init__(self, height, width, rounding=torch.round, factor=1): + def __init__(self, rounding=torch.round, factor=1): + super(decompress_jpeg, self).__init__() + self.c_dequantize = c_dequantize(factor=factor) + self.y_dequantize = y_dequantize(factor=factor) + self.idct = idct_8x8() + self.merging = block_merging() + # comment this line if no subsampling + self.chroma = chroma_upsampling() + self.colors = ycbcr_to_rgb_jpeg() + + # self.height, self.width = height, width + + def forward(self, y, cb, cr, height, width): + components = {'y': y, 'cb': cb, 'cr': cr} + # height = y.shape[0] + # width = y.shape[1] + self.height = height + self.width = width + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k]) + # comment this line if no subsampling + height, width = int(self.height/2), int(self.width/2) + # height, width = int(self.height), int(self.width) + + else: + comp = self.y_dequantize(components[k]) + # comment this line if no subsampling + height, width = self.height, self.width + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + # comment this line if no subsampling + image = self.chroma(components['y'], components['cb'], components['cr']) + # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3) + image = self.colors(image) + + image = torch.min(255*torch.ones_like(image), + torch.max(torch.zeros_like(image), image)) + return image/255 + diff --git a/utils/face_detection.py b/utils/face_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3abbf10553233bab9dbee92f4149c33614043c --- /dev/null +++ b/utils/face_detection.py @@ -0,0 +1,95 @@ +import cv2 +import dlib +import numpy as np + +## Face detection +def face_detection(img,upsample_times=1): + # Ask the detector to find the bounding boxes of each face. The 1 in the + # second argument indicates that we should upsample the image 1 time. This + # will make everything bigger and allow us to detect more faces. + detector = dlib.get_frontal_face_detector() + faces = detector(img, upsample_times) + + return faces + +PREDICTOR_PATH = '/userhome/NewIBSN/IBSN_SepMark/shape_predictor_68_face_landmarks.dat' +predictor = dlib.shape_predictor(PREDICTOR_PATH) +## Face and points detection +def face_points_detection(img, bbox:dlib.rectangle): + # Get the landmarks/parts for the face in box d. + shape = predictor(img, bbox) + + # loop over the 68 facial landmarks and convert them + # to a 2-tuple of (x, y)-coordinates + coords = np.asarray(list([p.x, p.y] for p in shape.parts()), dtype=int) + + # return the array of (x, y)-coordinates + return coords + +def select_face(im, r=10, choose=True): + faces = face_detection(im) + + if len(faces) == 0: + return None, None, None + + if len(faces) == 1 or not choose: + idx = np.argmax([(face.right() - face.left()) * (face.bottom() - face.top()) for face in faces]) + bbox = faces[idx] + else: + bbox = [] + + def click_on_face(event, x, y, flags, params): + if event != cv2.EVENT_LBUTTONDOWN: + return + + for face in faces: + if face.left() < x < face.right() and face.top() < y < face.bottom(): + bbox.append(face) + break + + im_copy = im.copy() + for face in faces: + # draw the face bounding box + cv2.rectangle(im_copy, (face.left(), face.top()), (face.right(), face.bottom()), (0, 0, 255), 1) + cv2.imshow('Click the Face:', im_copy) + cv2.setMouseCallback('Click the Face:', click_on_face) + while len(bbox) == 0: + cv2.waitKey(1) + cv2.destroyAllWindows() + bbox = bbox[0] + + points = np.asarray(face_points_detection(im, bbox)) + + im_w, im_h = im.shape[:2] + left, top = np.min(points, 0) + right, bottom = np.max(points, 0) + + x, y = max(0, left - r), max(0, top - r) + w, h = min(right + r, im_h) - x, min(bottom + r, im_w) - y + + return points - np.asarray([[x, y]]), (x, y, w, h), im[y:y + h, x:x + w] + + +def select_all_faces(im, r=10): + faces = face_detection(im) + + if len(faces) == 0: + return None + + faceBoxes = {k : {"points" : None, + "shape" : None, + "face" : None} for k in range(len(faces))} + for i, bbox in enumerate(faces): + points = np.asarray(face_points_detection(im, bbox)) + + im_w, im_h = im.shape[:2] + left, top = np.min(points, 0) + right, bottom = np.max(points, 0) + + x, y = max(0, left - r), max(0, top - r) + w, h = min(right + r, im_h) - x, min(bottom + r, im_w) - y + faceBoxes[i]["points"] = points - np.asarray([[x, y]]) + faceBoxes[i]["shape"] = (x, y, w, h) + faceBoxes[i]["face"] = im[y:y + h, x:x + w] + + return faceBoxes diff --git a/utils/face_swap.py b/utils/face_swap.py new file mode 100644 index 0000000000000000000000000000000000000000..04358b4abccad5d462df49fe299cc0f8340370d9 --- /dev/null +++ b/utils/face_swap.py @@ -0,0 +1,242 @@ +#! /usr/bin/env python +import cv2 +import numpy as np +import scipy.spatial as spatial +import logging + + +## 3D Transform +def bilinear_interpolate(img, coords): + """ Interpolates over every image channel + http://en.wikipedia.org/wiki/Bilinear_interpolation + :param img: max 3 channel image + :param coords: 2 x _m_ array. 1st row = xcoords, 2nd row = ycoords + :returns: array of interpolated pixels with same shape as coords + """ + int_coords = np.int32(coords) + x0, y0 = int_coords + dx, dy = coords - int_coords + + # 4 Neighour pixels + q11 = img[y0, x0] + q21 = img[y0, x0 + 1] + q12 = img[y0 + 1, x0] + q22 = img[y0 + 1, x0 + 1] + + btm = q21.T * dx + q11.T * (1 - dx) + top = q22.T * dx + q12.T * (1 - dx) + inter_pixel = top * dy + btm * (1 - dy) + + return inter_pixel.T + +def grid_coordinates(points): + """ x,y grid coordinates within the ROI of supplied points + :param points: points to generate grid coordinates + :returns: array of (x, y) coordinates + """ + xmin = np.min(points[:, 0]) + xmax = np.max(points[:, 0]) + 1 + ymin = np.min(points[:, 1]) + ymax = np.max(points[:, 1]) + 1 + + return np.asarray([(x, y) for y in range(ymin, ymax) + for x in range(xmin, xmax)], np.uint32) + + +def process_warp(src_img, result_img, tri_affines, dst_points, delaunay): + """ + Warp each triangle from the src_image only within the + ROI of the destination image (points in dst_points). + """ + roi_coords = grid_coordinates(dst_points) + # indices to vertices. -1 if pixel is not in any triangle + roi_tri_indices = delaunay.find_simplex(roi_coords) + + for simplex_index in range(len(delaunay.simplices)): + coords = roi_coords[roi_tri_indices == simplex_index] + num_coords = len(coords) + out_coords = np.dot(tri_affines[simplex_index], + np.vstack((coords.T, np.ones(num_coords)))) + x, y = coords.T + result_img[y, x] = bilinear_interpolate(src_img, out_coords) + + return None + + +def triangular_affine_matrices(vertices, src_points, dst_points): + """ + Calculate the affine transformation matrix for each + triangle (x,y) vertex from dst_points to src_points + :param vertices: array of triplet indices to corners of triangle + :param src_points: array of [x, y] points to landmarks for source image + :param dst_points: array of [x, y] points to landmarks for destination image + :returns: 2 x 3 affine matrix transformation for a triangle + """ + ones = [1, 1, 1] + for tri_indices in vertices: + src_tri = np.vstack((src_points[tri_indices, :].T, ones)) + dst_tri = np.vstack((dst_points[tri_indices, :].T, ones)) + mat = np.dot(src_tri, np.linalg.inv(dst_tri))[:2, :] + yield mat + + +def warp_image_3d(src_img, src_points, dst_points, dst_shape, dtype=np.uint8): + rows, cols = dst_shape[:2] + result_img = np.zeros((rows, cols, 3), dtype=dtype) + + delaunay = spatial.Delaunay(dst_points) + tri_affines = np.asarray(list(triangular_affine_matrices( + delaunay.simplices, src_points, dst_points))) + + process_warp(src_img, result_img, tri_affines, dst_points, delaunay) + + return result_img + + +## 2D Transform +def transformation_from_points(points1, points2): + points1 = points1.astype(np.float64) + points2 = points2.astype(np.float64) + + c1 = np.mean(points1, axis=0) + c2 = np.mean(points2, axis=0) + points1 -= c1 + points2 -= c2 + + s1 = np.std(points1) + s2 = np.std(points2) + points1 /= s1 + points2 /= s2 + + U, S, Vt = np.linalg.svd(np.dot(points1.T, points2)) + R = (np.dot(U, Vt)).T + + return np.vstack([np.hstack([s2 / s1 * R, + (c2.T - np.dot(s2 / s1 * R, c1.T))[:, np.newaxis]]), + np.array([[0., 0., 1.]])]) + + +def warp_image_2d(im, M, dshape): + output_im = np.zeros(dshape, dtype=im.dtype) + cv2.warpAffine(im, + M[:2], + (dshape[1], dshape[0]), + dst=output_im, + borderMode=cv2.BORDER_TRANSPARENT, + flags=cv2.WARP_INVERSE_MAP) + + return output_im + + +## Generate Mask +def mask_from_points(size, points,erode_flag=1): + radius = 10 # kernel size + kernel = np.ones((radius, radius), np.uint8) + + mask = np.zeros(size, np.uint8) + cv2.fillConvexPoly(mask, cv2.convexHull(points), 255) + if erode_flag: + mask = cv2.erode(mask, kernel,iterations=1) + + return mask + + +## Color Correction +def correct_colours(im1, im2, landmarks1): + COLOUR_CORRECT_BLUR_FRAC = 0.75 + LEFT_EYE_POINTS = list(range(42, 48)) + RIGHT_EYE_POINTS = list(range(36, 42)) + + blur_amount = COLOUR_CORRECT_BLUR_FRAC * np.linalg.norm( + np.mean(landmarks1[LEFT_EYE_POINTS], axis=0) - + np.mean(landmarks1[RIGHT_EYE_POINTS], axis=0)) + blur_amount = int(blur_amount) + if blur_amount % 2 == 0: + blur_amount += 1 + im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0) + im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0) + + # Avoid divide-by-zero errors. + im2_blur = im2_blur.astype(int) + im2_blur += 128*(im2_blur <= 1) + + result = im2.astype(np.float64) * im1_blur.astype(np.float64) / im2_blur.astype(np.float64) + result = np.clip(result, 0, 255).astype(np.uint8) + + return result + + +## Copy-and-paste +def apply_mask(img, mask): + """ Apply mask to supplied image + :param img: max 3 channel image + :param mask: [0-255] values in mask + :returns: new image with mask applied + """ + masked_img=cv2.bitwise_and(img,img,mask=mask) + + return masked_img + + +## Alpha blending +def alpha_feathering(src_img, dest_img, img_mask, blur_radius=15): + mask = cv2.blur(img_mask, (blur_radius, blur_radius)) + mask = mask / 255.0 + + result_img = np.empty(src_img.shape, np.uint8) + for i in range(3): + result_img[..., i] = src_img[..., i] * mask + dest_img[..., i] * (1-mask) + + return result_img + + +def check_points(img,points): + # Todo: I just consider one situation. + if points[8,1]>img.shape[0]: + logging.error("Jaw part out of image") + else: + return True + return False + + +def face_swap(src_face, dst_face, src_points, dst_points, dst_shape, dst_img, end=48): + h, w = dst_face.shape[:2] + + correct_color = False + warp_2d = True + + ## 3d warp + warped_src_face = warp_image_3d(src_face, src_points[:end], dst_points[:end], (h, w)) + ## Mask for blending + mask = mask_from_points((h, w), dst_points) + mask_src = np.mean(warped_src_face, axis=2) > 0 + mask = np.asarray(mask * mask_src, dtype=np.uint8) + ## Correct color + if correct_color: + warped_src_face = apply_mask(warped_src_face, mask) + dst_face_masked = apply_mask(dst_face, mask) + warped_src_face = correct_colours(dst_face_masked, warped_src_face, dst_points) + ## 2d warp + if warp_2d: + unwarped_src_face = warp_image_3d(warped_src_face, dst_points[:end], src_points[:end], src_face.shape[:2]) + warped_src_face = warp_image_2d(unwarped_src_face, transformation_from_points(dst_points, src_points), + (h, w, 3)) + + mask = mask_from_points((h, w), dst_points) + mask_src = np.mean(warped_src_face, axis=2) > 0 + mask = np.asarray(mask * mask_src, dtype=np.uint8) + + ## Shrink the mask + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + cv2.imwrite("/userhome/NewIBSN/IBSN_SepMark/mask.png", mask) + ##Poisson Blending + r = cv2.boundingRect(mask) + center = ((r[0] + int(r[2] / 2), r[1] + int(r[3] / 2))) + output = cv2.seamlessClone(warped_src_face, dst_face, mask, center, cv2.NORMAL_CLONE) + + x, y, w, h = dst_shape + dst_img_cp = dst_img.copy() + dst_img_cp[y:y + h, x:x + w] = output + + return dst_img_cp diff --git a/utils/jpegtest.py b/utils/jpegtest.py new file mode 100644 index 0000000000000000000000000000000000000000..e51b7a586bdabaf8695f49d95ecc0b0a91ec3a53 --- /dev/null +++ b/utils/jpegtest.py @@ -0,0 +1,43 @@ +import os +import numpy as np +import torch +import torch.nn as nn +from torchvision import transforms +from PIL import Image +import random, string + + +class JpegTest(nn.Module): + def __init__(self, Q=50, subsample=0, path="temp/"): + super(JpegTest, self).__init__() + self.Q = Q + self.subsample = subsample + self.path = path + if not os.path.exists(path): os.mkdir(path) + self.transform = transforms.Compose([ + transforms.ToTensor(), + # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) + ]) + + def get_path(self): + return self.path + ''.join(random.sample(string.ascii_letters + string.digits, 16)) + ".jpg" + + def forward(self, image_cover_mask): + image = image_cover_mask + + noised_image = torch.zeros_like(image) + + for i in range(image.shape[0]): + single_image = ((image[i].clamp(0, 1).permute(1, 2, 0)) * 255).add(0.5).clamp(0, 255).to('cpu', torch.uint8).numpy() + im = Image.fromarray(single_image) + + file = self.get_path() + while os.path.exists(file): + file = self.get_path() + im.save(file, format="JPEG", quality=self.Q, subsampling=self.subsample) + jpeg = np.array(Image.open(file), dtype=np.uint8) + os.remove(file) + + noised_image[i] = self.transform(jpeg).unsqueeze(0).to(image.device) + + return noised_image diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..71b455bbb9f04d177a1b99be6458a1692c713701 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,271 @@ +import os +import sys +import time +import math +from datetime import datetime +import random +import logging +from collections import OrderedDict +import numpy as np +import cv2 +import torch +from torchvision.utils import make_grid +from shutil import get_terminal_size + +import yaml +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + + +def OrderedYaml(): + '''yaml orderedDict support''' + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +#################### +# miscellaneous +#################### + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): +# print(path) +# exit(0) + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + logger = logging.getLogger('base') + logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) +# path = new_name + os.rename(path, new_name) + os.makedirs(path) +# return path + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + '''set up logger''' + lg = logging.getLogger(logger_name) + formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', + datefmt='%y-%m-%d %H:%M:%S') + lg.setLevel(level) + if tofile: + log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) + fh = logging.FileHandler(log_file, mode='w') + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) + + +#################### +# image convert +#################### + + +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +def save_img(img, img_path, mode='RGB'): + cv2.imwrite(img_path, img) + + +#################### +# metric +#################### + + +def calculate_psnr(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, img2): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1, img2)) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +class ProgressBar(object): + '''A progress bar which can print the progress + modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py + ''' + + def __init__(self, task_num=0, bar_width=50, start=True): + self.task_num = task_num + max_bar_width = self._get_max_bar_width() + self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) + self.completed = 0 + if start: + self.start() + + def _get_max_bar_width(self): + terminal_width, _ = get_terminal_size() + max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) + if max_bar_width < 10: + print('terminal width is too small ({}), please consider widen the terminal for better ' + 'progressbar visualization'.format(terminal_width)) + max_bar_width = 10 + return max_bar_width + + def start(self): + if self.task_num > 0: + sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( + ' ' * self.bar_width, self.task_num, 'Start...')) + else: + sys.stdout.write('completed: 0, elapsed: 0s') + sys.stdout.flush() + self.start_time = time.time() + + def update(self, msg='In progress...'): + self.completed += 1 + elapsed = time.time() - self.start_time + fps = self.completed / elapsed + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + mark_width = int(self.bar_width * percentage) + bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) + sys.stdout.write('\033[2F') # cursor up 2 lines + sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) + sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( + bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) + else: + sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( + self.completed, int(elapsed + 0.5), fps)) + sys.stdout.flush() + +def bitWise_accurary(msg_fake, message): + # + if msg_fake == None: + return None, None + else: + DecodedMsg_rounded = msg_fake.detach().cpu().numpy().round().clip(0, 1) + + diff = DecodedMsg_rounded - message.detach().cpu().numpy().round().clip(0, 1) + count = np.sum(np.abs(diff)) + b, l = msg_fake.shape + + accuracy = (1 - count / (b * l)) + BitWise_AvgErr = count / (b * l) + + return accuracy * 100, BitWise_AvgErr + +def decoded_message_error_rate(message, decoded_message): + message = message.view(message.shape[0], -1).squeeze() + length = message.shape[0] + message = message.gt(0) + decoded_message = decoded_message.gt(0) + error_rate = float(sum(message != decoded_message)) / length + return error_rate + +def decoded_message_error_rate_batch(messages, decoded_messages): + error_rate = 0.0 + batch_size = len(messages) + for i in range(batch_size): + error_rate += decoded_message_error_rate(messages[i], decoded_messages[i]) + error_rate /= batch_size + return error_rate