gokaygokay commited on
Commit
2f4febc
1 Parent(s): f17a2ad

full_files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +146 -201
  2. configs/inference/controlnet_c_3b_canny.yaml +14 -0
  3. configs/inference/controlnet_c_3b_identity.yaml +17 -0
  4. configs/inference/controlnet_c_3b_inpainting.yaml +15 -0
  5. configs/inference/controlnet_c_3b_sr.yaml +15 -0
  6. configs/inference/lora_c_3b.yaml +15 -0
  7. configs/inference/stage_b_1b.yaml +13 -0
  8. configs/inference/stage_b_3b.yaml +13 -0
  9. configs/inference/stage_c_1b.yaml +7 -0
  10. configs/inference/stage_c_3b.yaml +7 -0
  11. configs/training/cfg_control_lr.yaml +47 -0
  12. configs/training/lora_personalization.yaml +37 -0
  13. configs/training/t2i.yaml +29 -0
  14. core/__init__.py +372 -0
  15. core/data/__init__.py +69 -0
  16. core/data/bucketeer.py +88 -0
  17. core/data/bucketeer_deg.py +91 -0
  18. core/data/deg_kair_utils/utils_alignfaces.py +263 -0
  19. core/data/deg_kair_utils/utils_blindsr.py +631 -0
  20. core/data/deg_kair_utils/utils_bnorm.py +91 -0
  21. core/data/deg_kair_utils/utils_deblur.py +655 -0
  22. core/data/deg_kair_utils/utils_dist.py +201 -0
  23. core/data/deg_kair_utils/utils_googledownload.py +93 -0
  24. core/data/deg_kair_utils/utils_image.py +1016 -0
  25. core/data/deg_kair_utils/utils_lmdb.py +205 -0
  26. core/data/deg_kair_utils/utils_logger.py +66 -0
  27. core/data/deg_kair_utils/utils_mat.py +88 -0
  28. core/data/deg_kair_utils/utils_matconvnet.py +197 -0
  29. core/data/deg_kair_utils/utils_model.py +330 -0
  30. core/data/deg_kair_utils/utils_modelsummary.py +485 -0
  31. core/data/deg_kair_utils/utils_option.py +255 -0
  32. core/data/deg_kair_utils/utils_params.py +135 -0
  33. core/data/deg_kair_utils/utils_receptivefield.py +62 -0
  34. core/data/deg_kair_utils/utils_regularizers.py +104 -0
  35. core/data/deg_kair_utils/utils_sisr.py +848 -0
  36. core/data/deg_kair_utils/utils_video.py +493 -0
  37. core/data/deg_kair_utils/utils_videoio.py +555 -0
  38. core/scripts/__init__.py +0 -0
  39. core/scripts/cli.py +41 -0
  40. core/templates/__init__.py +1 -0
  41. core/templates/diffusion.py +236 -0
  42. core/utils/__init__.py +9 -0
  43. core/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  44. core/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  45. core/utils/__pycache__/base_dto.cpython-310.pyc +0 -0
  46. core/utils/__pycache__/base_dto.cpython-39.pyc +0 -0
  47. core/utils/__pycache__/save_and_load.cpython-310.pyc +0 -0
  48. core/utils/__pycache__/save_and_load.cpython-39.pyc +0 -0
  49. core/utils/base_dto.py +56 -0
  50. core/utils/save_and_load.py +59 -0
app.py CHANGED
@@ -1,213 +1,158 @@
1
  import spaces
2
- import json
3
- import subprocess
4
  import os
5
- import sys
6
-
7
- def run_command(command):
8
- process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
9
- output, error = process.communicate()
10
- if process.returncode != 0:
11
- print(f"Error executing command: {command}")
12
- print(error.decode('utf-8'))
13
- exit(1)
14
- return output.decode('utf-8')
15
-
16
- # Download CUDA installer
17
- download_command = "wget https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
18
- result = run_command(download_command)
19
- if result is None:
20
- print("Failed to download CUDA installer.")
21
- exit(1)
22
-
23
- # Run CUDA installer in silent mode
24
- install_command = "sh cuda_12.2.0_535.54.03_linux.run --silent --toolkit --samples --override"
25
- result = run_command(install_command)
26
- if result is None:
27
- print("Failed to run CUDA installer.")
28
- exit(1)
29
-
30
- print("CUDA installation process completed.")
31
-
32
- def install_packages():
33
-
34
- # Clone the repository with submodules
35
- run_command("git clone --recurse-submodules https://github.com/abetlen/llama-cpp-python.git")
36
-
37
- # Change to the cloned directory
38
- os.chdir("llama-cpp-python")
39
-
40
- # Checkout the specific commit in the llama.cpp submodule
41
- os.chdir("vendor/llama.cpp")
42
- run_command("git checkout 50e0535")
43
- os.chdir("../..")
44
-
45
- # Upgrade pip
46
- run_command("pip install --upgrade pip")
47
-
48
-
49
-
50
- # Install all optional dependencies with CUDA support
51
- run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
52
-
53
- run_command("make clean && GGML_OPENBLAS=1 make -j")
54
-
55
- # Reinstall the package with CUDA support
56
- run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
57
-
58
- # Install llama-cpp-agent
59
- run_command("pip install llama-cpp-agent")
60
-
61
- run_command("export PYTHONPATH=$PYTHONPATH:$(pwd)")
62
-
63
- print("Installation complete!")
64
-
65
- try:
66
- install_packages()
67
-
68
- # Add a delay to allow for package registration
69
- import time
70
- time.sleep(5)
71
-
72
- # Force Python to reload the site packages
73
- import site
74
- import importlib
75
- importlib.reload(site)
76
-
77
- # Now try to import the libraries
78
- from llama_cpp import Llama
79
- from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
80
- from llama_cpp_agent.providers import LlamaCppPythonProvider
81
- from llama_cpp_agent.chat_history import BasicChatHistory
82
- from llama_cpp_agent.chat_history.messages import Roles
83
-
84
- print("Libraries imported successfully!")
85
- except Exception as e:
86
- print(f"Installation failed or libraries couldn't be imported: {str(e)}")
87
- sys.exit(1)
88
-
89
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  from huggingface_hub import hf_hub_download
91
 
92
- hf_hub_download(
93
- repo_id="MaziyarPanahi/Mistral-Nemo-Instruct-2407-GGUF",
94
- filename="Mistral-Nemo-Instruct-2407.Q5_K_M.gguf",
95
- local_dir="./models"
96
- )
97
 
98
- # Initialize LLM outside the respond function
99
- llm = Llama(
100
- model_path="models/Mistral-Nemo-Instruct-2407.Q5_K_M.gguf",
101
- flash_attn=True,
102
- n_gpu_layers=81,
103
- n_batch=1024,
104
- n_ctx=32768,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
106
-
107
- provider = LlamaCppPythonProvider(llm)
108
-
109
- @spaces.GPU(duration=120)
110
- def respond(
111
- message,
112
- history: list[tuple[str, str]],
113
- system_message,
114
- max_tokens,
115
- temperature,
116
- top_p,
117
- top_k,
118
- repeat_penalty,
119
- ):
120
- chat_template = MessagesFormatterType.MISTRAL
121
-
122
- agent = LlamaCppAgent(
123
- provider,
124
- system_prompt=f"{system_message}",
125
- predefined_messages_formatter_type=chat_template,
126
- debug_output=True
127
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- settings = provider.get_provider_default_settings()
130
- settings.temperature = temperature
131
- settings.top_k = top_k
132
- settings.top_p = top_p
133
- settings.max_tokens = max_tokens
134
- settings.repeat_penalty = repeat_penalty
135
- settings.stream = True
136
-
137
- messages = BasicChatHistory()
138
-
139
- for msn in history:
140
- user = {
141
- 'role': Roles.user,
142
- 'content': msn[0]
143
- }
144
- assistant = {
145
- 'role': Roles.assistant,
146
- 'content': msn[1]
147
- }
148
- messages.add_message(user)
149
- messages.add_message(assistant)
150
 
151
- stream = agent.get_chat_response(
152
- message,
153
- llm_sampling_settings=settings,
154
- chat_history=messages,
155
- returns_streaming_generator=True,
156
- print_output=False
157
- )
158
-
159
- outputs = ""
160
- for output in stream:
161
- outputs += output
162
- yield outputs
163
-
164
- description = """<p><center>
165
- <a href="https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407" target="_blank">[Instruct Model]</a>
166
- <a href="https://huggingface.co/mistralai/Mistral-Nemo-Base-2407" target="_blank">[Base Model]</a>
167
- <a href="https://huggingface.co/second-state/Mistral-Nemo-Instruct-2407-GGUF" target="_blank">[GGUF Version]</a>
168
- </center></p>
169
- """
170
-
171
- demo = gr.ChatInterface(
172
- respond,
173
- additional_inputs=[
174
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
175
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
176
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
177
- gr.Slider(
178
- minimum=0.1,
179
- maximum=1.0,
180
- value=0.95,
181
- step=0.05,
182
- label="Top-p",
183
- ),
184
- gr.Slider(
185
- minimum=0,
186
- maximum=100,
187
- value=40,
188
- step=1,
189
- label="Top-k",
190
- ),
191
- gr.Slider(
192
- minimum=0.0,
193
- maximum=2.0,
194
- value=1.1,
195
- step=0.1,
196
- label="Repetition penalty",
197
- ),
198
  ],
199
- retry_btn="Retry",
200
- undo_btn="Undo",
201
- clear_btn="Clear",
202
- submit_btn="Send",
203
- title="Chat with Mistral-NeMo using llama.cpp",
204
- description=description,
205
- chatbot=gr.Chatbot(
206
- scale=1,
207
- likeable=False,
208
- show_copy_button=True
209
- )
210
  )
211
 
212
- if __name__ == "__main__":
213
- demo.launch(debug=True)
 
1
  import spaces
 
 
2
  import os
3
+ import requests
4
+ import yaml
5
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gradio as gr
7
+ from PIL import Image
8
+ import sys
9
+ sys.path.append(os.path.abspath('./'))
10
+ from inference.utils import *
11
+ from core.utils import load_or_fail
12
+ from train import WurstCoreB
13
+ from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
14
+ from train import WurstCore_t2i as WurstCoreC
15
+ import torch.nn.functional as F
16
+ from core.utils import load_or_fail
17
+ import numpy as np
18
+ import random
19
+ import math
20
+ from einops import rearrange
21
  from huggingface_hub import hf_hub_download
22
 
 
 
 
 
 
23
 
24
+ def download_file(url, folder_path, filename):
25
+ if not os.path.exists(folder_path):
26
+ os.makedirs(folder_path)
27
+ file_path = os.path.join(folder_path, filename)
28
+
29
+ if os.path.isfile(file_path):
30
+ print(f"File already exists: {file_path}")
31
+ else:
32
+ response = requests.get(url, stream=True)
33
+ if response.status_code == 200:
34
+ with open(file_path, 'wb') as file:
35
+ for chunk in response.iter_content(chunk_size=1024):
36
+ file.write(chunk)
37
+ print(f"File successfully downloaded and saved: {file_path}")
38
+ else:
39
+ print(f"Error downloading the file. Status code: {response.status_code}")
40
+
41
+ def download_models():
42
+ models = {
43
+ "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/StableWurst", "stage_a.safetensors"),
44
+ "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/StableWurst", "previewer.safetensors"),
45
+ "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/StableWurst", "effnet_encoder.safetensors"),
46
+ "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/StableWurst", "stage_b_lite_bf16.safetensors"),
47
+ "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/StableWurst", "stage_c_bf16.safetensors"),
48
+ "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/UltraPixel", "ultrapixel_t2i.safetensors"),
49
+ "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/UltraPixel", "lora_cat.safetensors"),
50
+ }
51
+
52
+ for model, (url, folder, filename) in models.items():
53
+ download_file(url, folder, filename)
54
+
55
+ download_models()
56
+
57
+ # Global variables
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ dtype = torch.bfloat16
60
+
61
+ # Load configs and setup models
62
+ with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
63
+ config_c = yaml.safe_load(file)
64
+
65
+ with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
66
+ config_b = yaml.safe_load(file)
67
+
68
+ core = WurstCoreC(config_dict=config_c, device=device, training=False)
69
+ core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
70
+
71
+ extras = core.setup_extras_pre()
72
+ models = core.setup_models(extras)
73
+ models.generator.eval().requires_grad_(False)
74
+
75
+ extras_b = core_b.setup_extras_pre()
76
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
77
+ models_b = WurstCoreB.Models(
78
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
79
  )
80
+ models_b.generator.bfloat16().eval().requires_grad_(False)
81
+
82
+ # Load pretrained model
83
+ pretrained_path = "models/ultrapixel_t2i.safetensors"
84
+ sdd = torch.load(pretrained_path, map_location='cpu')
85
+ collect_sd = {k[7:]: v for k, v in sdd.items()}
86
+ models.train_norm.load_state_dict(collect_sd)
87
+ models.generator.eval()
88
+ models.train_norm.eval()
89
+
90
+ # Set up sampling configurations
91
+ extras.sampling_configs.update({
92
+ 'cfg': 4,
93
+ 'shift': 1,
94
+ 'timesteps': 20,
95
+ 't_start': 1.0,
96
+ 'sampler': DDPMSampler(extras.gdf)
97
+ })
98
+
99
+ extras_b.sampling_configs.update({
100
+ 'cfg': 1.1,
101
+ 'shift': 1,
102
+ 'timesteps': 10,
103
+ 't_start': 1.0
104
+ })
105
+
106
+ @spaces.GPU
107
+ def generate_image(prompt, height, width, seed):
108
+ torch.manual_seed(seed)
109
+ random.seed(seed)
110
+ np.random.seed(seed)
111
+
112
+ batch_size = 1
113
+ height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
114
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
115
+ stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
116
+
117
+ batch = {'captions': [prompt] * batch_size}
118
+ conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
119
+ unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
120
 
121
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
122
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ with torch.no_grad():
125
+ models.generator.cuda()
126
+ with torch.cuda.amp.autocast(dtype=dtype):
127
+ sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
128
+
129
+ models.generator.cpu()
130
+ torch.cuda.empty_cache()
131
+
132
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
133
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
134
+ conditions_b['effnet'] = sampled_c
135
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
136
+
137
+ with torch.cuda.amp.autocast(dtype=dtype):
138
+ sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
139
+
140
+ torch.cuda.empty_cache()
141
+ imgs = show_images(sampled)
142
+ return imgs[0]
143
+
144
+ iface = gr.Interface(
145
+ fn=generate_image,
146
+ inputs=[
147
+ gr.Textbox(label="Prompt"),
148
+ gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
149
+ gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
150
+ gr.Number(label="Seed", value=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  ],
152
+ outputs=gr.Image(type="pil"),
153
+ title="UltraPixel Image Generation",
154
+ description="Generate high-resolution images using UltraPixel model.",
155
+ theme='bethecloud/storj_theme'
 
 
 
 
 
 
 
156
  )
157
 
158
+ iface.launch()
 
configs/inference/controlnet_c_3b_canny.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
7
+ controlnet_filter: CannyFilter
8
+ controlnet_filter_params:
9
+ resize: 224
10
+
11
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
12
+ previewer_checkpoint_path: models/previewer.safetensors
13
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
14
+ controlnet_checkpoint_path: models/canny.safetensors
configs/inference/controlnet_c_3b_identity.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_bottleneck_mode: 'simple'
7
+ controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
8
+ controlnet_filter: IdentityFilter
9
+ controlnet_filter_params:
10
+ max_faces: 4
11
+ p_drop: 0.00
12
+ p_full: 0.0
13
+
14
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
15
+ previewer_checkpoint_path: models/previewer.safetensors
16
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
17
+ controlnet_checkpoint_path:
configs/inference/controlnet_c_3b_inpainting.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
7
+ controlnet_filter: InpaintFilter
8
+ controlnet_filter_params:
9
+ thresold: [0.04, 0.4]
10
+ p_outpaint: 0.4
11
+
12
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
13
+ previewer_checkpoint_path: models/previewer.safetensors
14
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
15
+ controlnet_checkpoint_path: models/inpainting.safetensors
configs/inference/controlnet_c_3b_sr.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # ControlNet specific
6
+ controlnet_bottleneck_mode: 'large'
7
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
8
+ controlnet_filter: SREffnetFilter
9
+ controlnet_filter_params:
10
+ scale_factor: 0.5
11
+
12
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
13
+ previewer_checkpoint_path: models/previewer.safetensors
14
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
15
+ controlnet_checkpoint_path: models/super_resolution.safetensors
configs/inference/lora_c_3b.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ # LoRA specific
6
+ module_filters: ['.attn']
7
+ rank: 4
8
+ train_tokens:
9
+ # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
10
+ - ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails
11
+
12
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
13
+ previewer_checkpoint_path: models/previewer.safetensors
14
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
15
+ lora_checkpoint_path: models/lora_fernando_10k.safetensors
configs/inference/stage_b_1b.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 700M
3
+ dtype: bfloat16
4
+
5
+ # For demonstration purposes in reconstruct_images.ipynb
6
+ webdataset_path: path to your dataset
7
+ batch_size: 1
8
+ image_size: 2048
9
+ grad_accum_steps: 1
10
+
11
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
12
+ stage_a_checkpoint_path: models/stage_a.safetensors
13
+ generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
configs/inference/stage_b_3b.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3B
3
+ dtype: bfloat16
4
+
5
+ # For demonstration purposes in reconstruct_images.ipynb
6
+ webdataset_path: path to your dataset
7
+ batch_size: 4
8
+ image_size: 1024
9
+ grad_accum_steps: 1
10
+
11
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
12
+ stage_a_checkpoint_path: models/stage_a.safetensors
13
+ generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
configs/inference/stage_c_1b.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 1B
3
+ dtype: bfloat16
4
+
5
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
6
+ previewer_checkpoint_path: models/previewer.safetensors
7
+ generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
configs/inference/stage_c_3b.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ model_version: 3.6B
3
+ dtype: bfloat16
4
+
5
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
6
+ previewer_checkpoint_path: models/previewer.safetensors
7
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
configs/training/cfg_control_lr.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: Ultrapixel_controlnet
3
+
4
+ checkpoint_path: checkpoint output path
5
+ output_path: visual results output path
6
+ model_version: 3.6B
7
+ dtype: float32
8
+ # # WandB
9
+ # wandb_project: StableCascade
10
+ # wandb_entity: wandb_username
11
+ #module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ]
12
+ #rank: 32
13
+ # TRAINING PARAMS
14
+ lr: 1.0e-4
15
+ batch_size: 12
16
+ #image_size: [1536, 2048, 2560, 3072, 4096]
17
+ image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
18
+ #image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
19
+ #image_size: [ 1024, 1280]
20
+ multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
21
+ grad_accum_steps: 2
22
+ updates: 40000
23
+ backup_every: 5000
24
+ save_every: 256
25
+ warmup_updates: 1
26
+ use_fsdp: True
27
+
28
+ # ControlNet specific
29
+ controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
30
+ controlnet_filter: CannyFilter
31
+ controlnet_filter_params:
32
+ resize: 224
33
+ # offset_noise: 0.1
34
+
35
+ # GDF
36
+ adaptive_loss_weight: True
37
+
38
+ ema_start_iters: 10
39
+ ema_iters: 50
40
+ ema_beta: 0.9
41
+
42
+ webdataset_path: path to your training dataset
43
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
44
+ previewer_checkpoint_path: models/previewer.safetensors
45
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
46
+ controlnet_checkpoint_path: pretrained controlnet path
47
+
configs/training/lora_personalization.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: roubao_cat_personalized
3
+
4
+ checkpoint_path: checkpoint output path
5
+ output_path: visual results output path
6
+ model_version: 3.6B
7
+ dtype: float32
8
+
9
+ module_filters: [ '.attn']
10
+ rank: 4
11
+ train_tokens:
12
+ # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
13
+ - ['[roubaobao]', '^cat</w>'] # custom token [snail], initialize as avg of snail & snails
14
+ # TRAINING PARAMS
15
+ lr: 1.0e-4
16
+ batch_size: 4
17
+
18
+ image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
19
+ multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
20
+ grad_accum_steps: 2
21
+ updates: 40000
22
+ backup_every: 5000
23
+ save_every: 512
24
+ warmup_updates: 1
25
+ use_ddp: True
26
+
27
+ # GDF
28
+ adaptive_loss_weight: True
29
+
30
+
31
+ tmp_prompt: a photo of a cat [roubaobao]
32
+ webdataset_path: path to your personalized training dataset
33
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
34
+ previewer_checkpoint_path: models/previewer.safetensors
35
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
36
+ ultrapixel_path: models/ultrapixel_t2i.safetensors
37
+
configs/training/t2i.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLOBAL STUFF
2
+ experiment_id: ultrapixel_t2i
3
+ #strc_fixlrt_norm3_lite_1024_hrft_newdata
4
+ checkpoint_path: checkpoint output path #output model directory
5
+ output_path: visual results output path #experiment output directory
6
+ model_version: 3.6B # finetune large stage c model of stablecascade
7
+ dtype: float32
8
+
9
+
10
+ # TRAINING PARAMS
11
+ lr: 1.0e-4
12
+ batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps
13
+ image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution
14
+ multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
15
+ grad_accum_steps: 2
16
+ updates: 40000
17
+ backup_every: 5000
18
+ save_every: 256
19
+ warmup_updates: 1
20
+ use_ddp: True
21
+
22
+ # GDF
23
+ adaptive_loss_weight: True
24
+
25
+
26
+ webdataset_path: path to your personalized training dataset
27
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
28
+ previewer_checkpoint_path: models/previewer.safetensors
29
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
core/__init__.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from torch import nn
5
+ import wandb
6
+ import json
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from torch.utils.data import Dataset, DataLoader
10
+
11
+ from torch.distributed import init_process_group, destroy_process_group, barrier
12
+ from torch.distributed.fsdp import (
13
+ FullyShardedDataParallel as FSDP,
14
+ FullStateDictConfig,
15
+ MixedPrecision,
16
+ ShardingStrategy,
17
+ StateDictType
18
+ )
19
+
20
+ from .utils import Base, EXPECTED, EXPECTED_TRAIN
21
+ from .utils import create_folder_if_necessary, safe_save, load_or_fail
22
+
23
+ # pylint: disable=unused-argument
24
+ class WarpCore(ABC):
25
+ @dataclass(frozen=True)
26
+ class Config(Base):
27
+ experiment_id: str = EXPECTED_TRAIN
28
+ checkpoint_path: str = EXPECTED_TRAIN
29
+ output_path: str = EXPECTED_TRAIN
30
+ checkpoint_extension: str = "safetensors"
31
+ dist_file_subfolder: str = ""
32
+ allow_tf32: bool = True
33
+
34
+ wandb_project: str = None
35
+ wandb_entity: str = None
36
+
37
+ @dataclass() # not frozen, means that fields are mutable
38
+ class Info(): # not inheriting from Base, because we don't want to enforce the default fields
39
+ wandb_run_id: str = None
40
+ total_steps: int = 0
41
+ iter: int = 0
42
+
43
+ @dataclass(frozen=True)
44
+ class Data(Base):
45
+ dataset: Dataset = EXPECTED
46
+ dataloader: DataLoader = EXPECTED
47
+ iterator: any = EXPECTED
48
+
49
+ @dataclass(frozen=True)
50
+ class Models(Base):
51
+ pass
52
+
53
+ @dataclass(frozen=True)
54
+ class Optimizers(Base):
55
+ pass
56
+
57
+ @dataclass(frozen=True)
58
+ class Schedulers(Base):
59
+ pass
60
+
61
+ @dataclass(frozen=True)
62
+ class Extras(Base):
63
+ pass
64
+ # ---------------------------------------
65
+ info: Info
66
+ config: Config
67
+
68
+ # FSDP stuff
69
+ fsdp_defaults = {
70
+ "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
71
+ "cpu_offload": None,
72
+ "mixed_precision": MixedPrecision(
73
+ param_dtype=torch.bfloat16,
74
+ reduce_dtype=torch.bfloat16,
75
+ buffer_dtype=torch.bfloat16,
76
+ ),
77
+ "limit_all_gathers": True,
78
+ }
79
+ fsdp_fullstate_save_policy = FullStateDictConfig(
80
+ offload_to_cpu=True, rank0_only=True
81
+ )
82
+ # ------------
83
+
84
+ # OVERRIDEABLE METHODS
85
+
86
+ # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
87
+ def setup_extras_pre(self) -> Extras:
88
+ return self.Extras()
89
+
90
+ # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
91
+ @abstractmethod
92
+ def setup_data(self, extras: Extras) -> Data:
93
+ raise NotImplementedError("This method needs to be overriden")
94
+
95
+ # return a dict with all models that are going to be used in the training
96
+ @abstractmethod
97
+ def setup_models(self, extras: Extras) -> Models:
98
+ raise NotImplementedError("This method needs to be overriden")
99
+
100
+ # return a dict with all optimizers that are going to be used in the training
101
+ @abstractmethod
102
+ def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
103
+ raise NotImplementedError("This method needs to be overriden")
104
+
105
+ # [optionally] return a dict with all schedulers that are going to be used in the training
106
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
107
+ return self.Schedulers()
108
+
109
+ # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
110
+ def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
111
+ return self.Extras.from_dict(extras.to_dict())
112
+
113
+ # perform the training here
114
+ @abstractmethod
115
+ def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
116
+ raise NotImplementedError("This method needs to be overriden")
117
+ # ------------
118
+
119
+ def setup_info(self, full_path=None) -> Info:
120
+ if full_path is None:
121
+ full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
122
+ info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
123
+ info_dto = self.Info(**info_dict)
124
+ if info_dto.total_steps > 0 and self.is_main_node:
125
+ print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
126
+ return info_dto
127
+
128
+ def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
129
+ if config_file_path is not None:
130
+ if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
131
+ with open(config_file_path, "r", encoding="utf-8") as file:
132
+ loaded_config = yaml.safe_load(file)
133
+ elif config_file_path.endswith(".json"):
134
+ with open(config_file_path, "r", encoding="utf-8") as file:
135
+ loaded_config = json.load(file)
136
+ else:
137
+ raise ValueError("Config file must be either a .yml|.yaml or .json file")
138
+ return self.Config.from_dict({**loaded_config, 'training': training})
139
+ if config_dict is not None:
140
+ return self.Config.from_dict({**config_dict, 'training': training})
141
+ return self.Config(training=training)
142
+
143
+ def setup_ddp(self, experiment_id, single_gpu=False):
144
+ if not single_gpu:
145
+ local_rank = int(os.environ.get("SLURM_LOCALID"))
146
+ process_id = int(os.environ.get("SLURM_PROCID"))
147
+ world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
148
+
149
+ self.process_id = process_id
150
+ self.is_main_node = process_id == 0
151
+ self.device = torch.device(local_rank)
152
+ self.world_size = world_size
153
+
154
+ dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
155
+ # if os.path.exists(dist_file_path) and self.is_main_node:
156
+ # os.remove(dist_file_path)
157
+
158
+ torch.cuda.set_device(local_rank)
159
+ init_process_group(
160
+ backend="nccl",
161
+ rank=process_id,
162
+ world_size=world_size,
163
+ init_method=f"file://{dist_file_path}",
164
+ )
165
+ print(f"[GPU {process_id}] READY")
166
+ else:
167
+ print("Running in single thread, DDP not enabled.")
168
+
169
+ def setup_wandb(self):
170
+ if self.is_main_node and self.config.wandb_project is not None:
171
+ self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
172
+ wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
173
+
174
+ if self.info.total_steps > 0:
175
+ wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
176
+ else:
177
+ wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
178
+
179
+ # LOAD UTILITIES ----------
180
+ def load_model(self, model, model_id=None, full_path=None, strict=True):
181
+ print('in line 181 load model', type(model), model_id, full_path, strict)
182
+ if model_id is not None and full_path is None:
183
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
184
+ elif full_path is None and model_id is None:
185
+ raise ValueError(
186
+ "This method expects either 'model_id' or 'full_path' to be defined"
187
+ )
188
+
189
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
190
+ if checkpoint is not None:
191
+ model.load_state_dict(checkpoint, strict=strict)
192
+ del checkpoint
193
+
194
+ return model
195
+
196
+ def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
197
+ if optim_id is not None and full_path is None:
198
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
199
+ elif full_path is None and optim_id is None:
200
+ raise ValueError(
201
+ "This method expects either 'optim_id' or 'full_path' to be defined"
202
+ )
203
+
204
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
205
+ if checkpoint is not None:
206
+ try:
207
+ if fsdp_model is not None:
208
+ sharded_optimizer_state_dict = (
209
+ FSDP.scatter_full_optim_state_dict( # <---- FSDP
210
+ checkpoint
211
+ if (
212
+ self.is_main_node
213
+ or self.fsdp_defaults["sharding_strategy"]
214
+ == ShardingStrategy.NO_SHARD
215
+ )
216
+ else None,
217
+ fsdp_model,
218
+ )
219
+ )
220
+ optim.load_state_dict(sharded_optimizer_state_dict)
221
+ del checkpoint, sharded_optimizer_state_dict
222
+ else:
223
+ optim.load_state_dict(checkpoint)
224
+ # pylint: disable=broad-except
225
+ except Exception as e:
226
+ print("!!! Failed loading optimizer, skipping... Exception:", e)
227
+
228
+ return optim
229
+
230
+ # SAVE UTILITIES ----------
231
+ def save_info(self, info, suffix=""):
232
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
233
+ create_folder_if_necessary(full_path)
234
+ if self.is_main_node:
235
+ safe_save(vars(self.info), full_path)
236
+
237
+ def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
238
+ if model_id is not None and full_path is None:
239
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
240
+ elif full_path is None and model_id is None:
241
+ raise ValueError(
242
+ "This method expects either 'model_id' or 'full_path' to be defined"
243
+ )
244
+ create_folder_if_necessary(full_path)
245
+ if is_fsdp:
246
+ with FSDP.summon_full_params(model):
247
+ pass
248
+ with FSDP.state_dict_type(
249
+ model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
250
+ ):
251
+ checkpoint = model.state_dict()
252
+ if self.is_main_node:
253
+ safe_save(checkpoint, full_path)
254
+ del checkpoint
255
+ else:
256
+ if self.is_main_node:
257
+ checkpoint = model.state_dict()
258
+ safe_save(checkpoint, full_path)
259
+ del checkpoint
260
+
261
+ def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
262
+ if optim_id is not None and full_path is None:
263
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
264
+ elif full_path is None and optim_id is None:
265
+ raise ValueError(
266
+ "This method expects either 'optim_id' or 'full_path' to be defined"
267
+ )
268
+ create_folder_if_necessary(full_path)
269
+ if fsdp_model is not None:
270
+ optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
271
+ if self.is_main_node:
272
+ safe_save(optim_statedict, full_path)
273
+ del optim_statedict
274
+ else:
275
+ if self.is_main_node:
276
+ checkpoint = optim.state_dict()
277
+ safe_save(checkpoint, full_path)
278
+ del checkpoint
279
+ # -----
280
+
281
+ def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
282
+ # Temporary setup, will be overriden by setup_ddp if required
283
+ self.device = device
284
+ self.process_id = 0
285
+ self.is_main_node = True
286
+ self.world_size = 1
287
+ # ----
288
+
289
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
290
+ self.info: self.Info = self.setup_info()
291
+
292
+ def __call__(self, single_gpu=False):
293
+ self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
294
+ self.setup_wandb()
295
+ if self.config.allow_tf32:
296
+ torch.backends.cuda.matmul.allow_tf32 = True
297
+ torch.backends.cudnn.allow_tf32 = True
298
+
299
+ if self.is_main_node:
300
+ print()
301
+ print("**STARTIG JOB WITH CONFIG:**")
302
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
303
+ print("------------------------------------")
304
+ print()
305
+ print("**INFO:**")
306
+ print(yaml.dump(vars(self.info), default_flow_style=False))
307
+ print("------------------------------------")
308
+ print()
309
+
310
+ # SETUP STUFF
311
+ extras = self.setup_extras_pre()
312
+ assert extras is not None, "setup_extras_pre() must return a DTO"
313
+
314
+ data = self.setup_data(extras)
315
+ assert data is not None, "setup_data() must return a DTO"
316
+ if self.is_main_node:
317
+ print("**DATA:**")
318
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
319
+ print("------------------------------------")
320
+ print()
321
+
322
+ models = self.setup_models(extras)
323
+ assert models is not None, "setup_models() must return a DTO"
324
+ if self.is_main_node:
325
+ print("**MODELS:**")
326
+ print(yaml.dump({
327
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
328
+ }, default_flow_style=False))
329
+ print("------------------------------------")
330
+ print()
331
+
332
+ optimizers = self.setup_optimizers(extras, models)
333
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
334
+ if self.is_main_node:
335
+ print("**OPTIMIZERS:**")
336
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
337
+ print("------------------------------------")
338
+ print()
339
+
340
+ schedulers = self.setup_schedulers(extras, models, optimizers)
341
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
342
+ if self.is_main_node:
343
+ print("**SCHEDULERS:**")
344
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
345
+ print("------------------------------------")
346
+ print()
347
+
348
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
349
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
350
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
351
+ if self.is_main_node:
352
+ print("**EXTRAS:**")
353
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
354
+ print("------------------------------------")
355
+ print()
356
+ # -------
357
+
358
+ # TRAIN
359
+ if self.is_main_node:
360
+ print("**TRAINING STARTING...**")
361
+ self.train(data, extras, models, optimizers, schedulers)
362
+
363
+ if single_gpu is False:
364
+ barrier()
365
+ destroy_process_group()
366
+ if self.is_main_node:
367
+ print()
368
+ print("------------------------------------")
369
+ print()
370
+ print("**TRAINING COMPLETE**")
371
+ if self.config.wandb_project is not None:
372
+ wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
core/data/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import subprocess
3
+ import yaml
4
+ import os
5
+ from .bucketeer import Bucketeer
6
+
7
+ class MultiFilter():
8
+ def __init__(self, rules, default=False):
9
+ self.rules = rules
10
+ self.default = default
11
+
12
+ def __call__(self, x):
13
+ try:
14
+ x_json = x['json']
15
+ if isinstance(x_json, bytes):
16
+ x_json = json.loads(x_json)
17
+ validations = []
18
+ for k, r in self.rules.items():
19
+ if isinstance(k, tuple):
20
+ v = r(*[x_json[kv] for kv in k])
21
+ else:
22
+ v = r(x_json[k])
23
+ validations.append(v)
24
+ return all(validations)
25
+ except Exception:
26
+ return False
27
+
28
+ class MultiGetter():
29
+ def __init__(self, rules):
30
+ self.rules = rules
31
+
32
+ def __call__(self, x_json):
33
+ if isinstance(x_json, bytes):
34
+ x_json = json.loads(x_json)
35
+ outputs = []
36
+ for k, r in self.rules.items():
37
+ if isinstance(k, tuple):
38
+ v = r(*[x_json[kv] for kv in k])
39
+ else:
40
+ v = r(x_json[k])
41
+ outputs.append(v)
42
+ if len(outputs) == 1:
43
+ outputs = outputs[0]
44
+ return outputs
45
+
46
+ def setup_webdataset_path(paths, cache_path=None):
47
+ if cache_path is None or not os.path.exists(cache_path):
48
+ tar_paths = []
49
+ if isinstance(paths, str):
50
+ paths = [paths]
51
+ for path in paths:
52
+ if path.strip().endswith(".tar"):
53
+ # Avoid looking up s3 if we already have a tar file
54
+ tar_paths.append(path)
55
+ continue
56
+ bucket = "/".join(path.split("/")[:3])
57
+ result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
58
+ files = result.stdout.decode('utf-8').split()
59
+ files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
60
+ tar_paths += files
61
+
62
+ with open(cache_path, 'w', encoding='utf-8') as outfile:
63
+ yaml.dump(tar_paths, outfile, default_flow_style=False)
64
+ else:
65
+ with open(cache_path, 'r', encoding='utf-8') as file:
66
+ tar_paths = yaml.safe_load(file)
67
+
68
+ tar_paths_str = ",".join([f"{p}" for p in tar_paths])
69
+ return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
core/data/bucketeer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ from torchtools.transforms import SmartCrop
5
+ import math
6
+
7
+ class Bucketeer():
8
+ def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
9
+ assert crop_mode in ['center', 'random', 'smart']
10
+ self.crop_mode = crop_mode
11
+ self.ratios = ratios
12
+ if reverse_list:
13
+ for r in list(ratios):
14
+ if 1/r not in self.ratios:
15
+ self.ratios.append(1/r)
16
+ self.sizes = {}
17
+ for dd in density:
18
+ self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
19
+
20
+ self.batch_size = dataloader.batch_size
21
+ self.iterator = iter(dataloader)
22
+ all_sizes = []
23
+ for k, vs in self.sizes.items():
24
+ all_sizes += vs
25
+ self.buckets = {s: [] for s in all_sizes}
26
+ self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
27
+ self.p_random_ratio = p_random_ratio
28
+ self.interpolate_nearest = interpolate_nearest
29
+
30
+ def get_available_batch(self):
31
+ for b in self.buckets:
32
+ if len(self.buckets[b]) >= self.batch_size:
33
+ batch = self.buckets[b][:self.batch_size]
34
+ self.buckets[b] = self.buckets[b][self.batch_size:]
35
+ return batch
36
+ return None
37
+
38
+ def get_closest_size(self, x):
39
+ w, h = x.size(-1), x.size(-2)
40
+
41
+
42
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
43
+ find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
44
+ min_ = find_dict[list(find_dict.keys())[0]]
45
+ find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
46
+ for dd, val in find_dict.items():
47
+ if val < min_:
48
+ min_ = val
49
+ find_size = self.sizes[dd][best_size_idx]
50
+
51
+ return find_size
52
+
53
+ def get_resize_size(self, orig_size, tgt_size):
54
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
55
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
56
+ resize_size = max(alt_min, min(tgt_size))
57
+ else:
58
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
59
+ resize_size = max(alt_max, max(tgt_size))
60
+
61
+ return resize_size
62
+
63
+ def __next__(self):
64
+ batch = self.get_available_batch()
65
+ while batch is None:
66
+ elements = next(self.iterator)
67
+ for dct in elements:
68
+ img = dct['images']
69
+ size = self.get_closest_size(img)
70
+ resize_size = self.get_resize_size(img.shape[-2:], size)
71
+
72
+ if self.interpolate_nearest:
73
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
74
+ else:
75
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
76
+ if self.crop_mode == 'center':
77
+ img = torchvision.transforms.functional.center_crop(img, size)
78
+ elif self.crop_mode == 'random':
79
+ img = torchvision.transforms.RandomCrop(size)(img)
80
+ elif self.crop_mode == 'smart':
81
+ self.smartcrop.output_size = size
82
+ img = self.smartcrop(img)
83
+
84
+ self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
85
+ batch = self.get_available_batch()
86
+
87
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
88
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
core/data/bucketeer_deg.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ from torchtools.transforms import SmartCrop
5
+ import math
6
+
7
+ class Bucketeer():
8
+ def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
9
+ assert crop_mode in ['center', 'random', 'smart']
10
+ self.crop_mode = crop_mode
11
+ self.ratios = ratios
12
+ if reverse_list:
13
+ for r in list(ratios):
14
+ if 1/r not in self.ratios:
15
+ self.ratios.append(1/r)
16
+ self.sizes = {}
17
+ for dd in density:
18
+ self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
19
+ print('in line 17 buckteer', self.sizes)
20
+ self.batch_size = dataloader.batch_size
21
+ self.iterator = iter(dataloader)
22
+ all_sizes = []
23
+ for k, vs in self.sizes.items():
24
+ all_sizes += vs
25
+ self.buckets = {s: [] for s in all_sizes}
26
+ self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
27
+ self.p_random_ratio = p_random_ratio
28
+ self.interpolate_nearest = interpolate_nearest
29
+
30
+ def get_available_batch(self):
31
+ for b in self.buckets:
32
+ if len(self.buckets[b]) >= self.batch_size:
33
+ batch = self.buckets[b][:self.batch_size]
34
+ self.buckets[b] = self.buckets[b][self.batch_size:]
35
+ return batch
36
+ return None
37
+
38
+ def get_closest_size(self, x):
39
+ w, h = x.size(-1), x.size(-2)
40
+ #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
41
+ # best_size_idx = np.random.randint(len(self.ratios))
42
+ #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
43
+ #else:
44
+
45
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
46
+ find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
47
+ min_ = find_dict[list(find_dict.keys())[0]]
48
+ find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
49
+ for dd, val in find_dict.items():
50
+ if val < min_:
51
+ min_ = val
52
+ find_size = self.sizes[dd][best_size_idx]
53
+
54
+ return find_size
55
+
56
+ def get_resize_size(self, orig_size, tgt_size):
57
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
58
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
59
+ resize_size = max(alt_min, min(tgt_size))
60
+ else:
61
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
62
+ resize_size = max(alt_max, max(tgt_size))
63
+ #print('in line 50', orig_size, tgt_size, resize_size)
64
+ return resize_size
65
+
66
+ def __next__(self):
67
+ batch = self.get_available_batch()
68
+ while batch is None:
69
+ elements = next(self.iterator)
70
+ for dct in elements:
71
+ img = dct['images']
72
+ size = self.get_closest_size(img)
73
+ resize_size = self.get_resize_size(img.shape[-2:], size)
74
+ #print('in line 74', img.size(), resize_size)
75
+ if self.interpolate_nearest:
76
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
77
+ else:
78
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
79
+ if self.crop_mode == 'center':
80
+ img = torchvision.transforms.functional.center_crop(img, size)
81
+ elif self.crop_mode == 'random':
82
+ img = torchvision.transforms.RandomCrop(size)(img)
83
+ elif self.crop_mode == 'smart':
84
+ self.smartcrop.output_size = size
85
+ img = self.smartcrop(img)
86
+ print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
87
+ self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
88
+ batch = self.get_available_batch()
89
+
90
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
91
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
core/data/deg_kair_utils/utils_alignfaces.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Apr 24 15:43:29 2017
4
+ @author: zhaoy
5
+ """
6
+ import cv2
7
+ import numpy as np
8
+ from skimage import transform as trans
9
+
10
+ # reference facial points, a list of coordinates (x,y)
11
+ REFERENCE_FACIAL_POINTS = [
12
+ [30.29459953, 51.69630051],
13
+ [65.53179932, 51.50139999],
14
+ [48.02519989, 71.73660278],
15
+ [33.54930115, 92.3655014],
16
+ [62.72990036, 92.20410156]
17
+ ]
18
+
19
+ DEFAULT_CROP_SIZE = (96, 112)
20
+
21
+
22
+ def _umeyama(src, dst, estimate_scale=True, scale=1.0):
23
+ """Estimate N-D similarity transformation with or without scaling.
24
+ Parameters
25
+ ----------
26
+ src : (M, N) array
27
+ Source coordinates.
28
+ dst : (M, N) array
29
+ Destination coordinates.
30
+ estimate_scale : bool
31
+ Whether to estimate scaling factor.
32
+ Returns
33
+ -------
34
+ T : (N + 1, N + 1)
35
+ The homogeneous similarity transformation matrix. The matrix contains
36
+ NaN values only if the problem is not well-conditioned.
37
+ References
38
+ ----------
39
+ .. [1] "Least-squares estimation of transformation parameters between two
40
+ point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
41
+ """
42
+
43
+ num = src.shape[0]
44
+ dim = src.shape[1]
45
+
46
+ # Compute mean of src and dst.
47
+ src_mean = src.mean(axis=0)
48
+ dst_mean = dst.mean(axis=0)
49
+
50
+ # Subtract mean from src and dst.
51
+ src_demean = src - src_mean
52
+ dst_demean = dst - dst_mean
53
+
54
+ # Eq. (38).
55
+ A = dst_demean.T @ src_demean / num
56
+
57
+ # Eq. (39).
58
+ d = np.ones((dim,), dtype=np.double)
59
+ if np.linalg.det(A) < 0:
60
+ d[dim - 1] = -1
61
+
62
+ T = np.eye(dim + 1, dtype=np.double)
63
+
64
+ U, S, V = np.linalg.svd(A)
65
+
66
+ # Eq. (40) and (43).
67
+ rank = np.linalg.matrix_rank(A)
68
+ if rank == 0:
69
+ return np.nan * T
70
+ elif rank == dim - 1:
71
+ if np.linalg.det(U) * np.linalg.det(V) > 0:
72
+ T[:dim, :dim] = U @ V
73
+ else:
74
+ s = d[dim - 1]
75
+ d[dim - 1] = -1
76
+ T[:dim, :dim] = U @ np.diag(d) @ V
77
+ d[dim - 1] = s
78
+ else:
79
+ T[:dim, :dim] = U @ np.diag(d) @ V
80
+
81
+ if estimate_scale:
82
+ # Eq. (41) and (42).
83
+ scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
84
+ else:
85
+ scale = scale
86
+
87
+ T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
88
+ T[:dim, :dim] *= scale
89
+
90
+ return T, scale
91
+
92
+
93
+ class FaceWarpException(Exception):
94
+ def __str__(self):
95
+ return 'In File {}:{}'.format(
96
+ __file__, super.__str__(self))
97
+
98
+
99
+ def get_reference_facial_points(output_size=None,
100
+ inner_padding_factor=0.0,
101
+ outer_padding=(0, 0),
102
+ default_square=False):
103
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
104
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
105
+
106
+ # 0) make the inner region a square
107
+ if default_square:
108
+ size_diff = max(tmp_crop_size) - tmp_crop_size
109
+ tmp_5pts += size_diff / 2
110
+ tmp_crop_size += size_diff
111
+
112
+ if (output_size and
113
+ output_size[0] == tmp_crop_size[0] and
114
+ output_size[1] == tmp_crop_size[1]):
115
+ print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
116
+ return tmp_5pts
117
+
118
+ if (inner_padding_factor == 0 and
119
+ outer_padding == (0, 0)):
120
+ if output_size is None:
121
+ print('No paddings to do: return default reference points')
122
+ return tmp_5pts
123
+ else:
124
+ raise FaceWarpException(
125
+ 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
126
+
127
+ # check output size
128
+ if not (0 <= inner_padding_factor <= 1.0):
129
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
130
+
131
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
132
+ and output_size is None):
133
+ output_size = tmp_crop_size * \
134
+ (1 + inner_padding_factor * 2).astype(np.int32)
135
+ output_size += np.array(outer_padding)
136
+ print(' deduced from paddings, output_size = ', output_size)
137
+
138
+ if not (outer_padding[0] < output_size[0]
139
+ and outer_padding[1] < output_size[1]):
140
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
141
+ 'and outer_padding[1] < output_size[1])')
142
+
143
+ # 1) pad the inner region according inner_padding_factor
144
+ # print('---> STEP1: pad the inner region according inner_padding_factor')
145
+ if inner_padding_factor > 0:
146
+ size_diff = tmp_crop_size * inner_padding_factor * 2
147
+ tmp_5pts += size_diff / 2
148
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
149
+
150
+ # print(' crop_size = ', tmp_crop_size)
151
+ # print(' reference_5pts = ', tmp_5pts)
152
+
153
+ # 2) resize the padded inner region
154
+ # print('---> STEP2: resize the padded inner region')
155
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
156
+ # print(' crop_size = ', tmp_crop_size)
157
+ # print(' size_bf_outer_pad = ', size_bf_outer_pad)
158
+
159
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
160
+ raise FaceWarpException('Must have (output_size - outer_padding)'
161
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
162
+
163
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
164
+ # print(' resize scale_factor = ', scale_factor)
165
+ tmp_5pts = tmp_5pts * scale_factor
166
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
167
+ # tmp_5pts = tmp_5pts + size_diff / 2
168
+ tmp_crop_size = size_bf_outer_pad
169
+ # print(' crop_size = ', tmp_crop_size)
170
+ # print(' reference_5pts = ', tmp_5pts)
171
+
172
+ # 3) add outer_padding to make output_size
173
+ reference_5point = tmp_5pts + np.array(outer_padding)
174
+ tmp_crop_size = output_size
175
+ # print('---> STEP3: add outer_padding to make output_size')
176
+ # print(' crop_size = ', tmp_crop_size)
177
+ # print(' reference_5pts = ', tmp_5pts)
178
+ #
179
+ # print('===> end get_reference_facial_points\n')
180
+
181
+ return reference_5point
182
+
183
+
184
+ def get_affine_transform_matrix(src_pts, dst_pts):
185
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
186
+ n_pts = src_pts.shape[0]
187
+ ones = np.ones((n_pts, 1), src_pts.dtype)
188
+ src_pts_ = np.hstack([src_pts, ones])
189
+ dst_pts_ = np.hstack([dst_pts, ones])
190
+
191
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
192
+
193
+ if rank == 3:
194
+ tfm = np.float32([
195
+ [A[0, 0], A[1, 0], A[2, 0]],
196
+ [A[0, 1], A[1, 1], A[2, 1]]
197
+ ])
198
+ elif rank == 2:
199
+ tfm = np.float32([
200
+ [A[0, 0], A[1, 0], 0],
201
+ [A[0, 1], A[1, 1], 0]
202
+ ])
203
+
204
+ return tfm
205
+
206
+
207
+ def warp_and_crop_face(src_img,
208
+ facial_pts,
209
+ reference_pts=None,
210
+ crop_size=(96, 112),
211
+ align_type='smilarity'): #smilarity cv2_affine affine
212
+ if reference_pts is None:
213
+ if crop_size[0] == 96 and crop_size[1] == 112:
214
+ reference_pts = REFERENCE_FACIAL_POINTS
215
+ else:
216
+ default_square = False
217
+ inner_padding_factor = 0
218
+ outer_padding = (0, 0)
219
+ output_size = crop_size
220
+
221
+ reference_pts = get_reference_facial_points(output_size,
222
+ inner_padding_factor,
223
+ outer_padding,
224
+ default_square)
225
+
226
+ ref_pts = np.float32(reference_pts)
227
+ ref_pts_shp = ref_pts.shape
228
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
229
+ raise FaceWarpException(
230
+ 'reference_pts.shape must be (K,2) or (2,K) and K>2')
231
+
232
+ if ref_pts_shp[0] == 2:
233
+ ref_pts = ref_pts.T
234
+
235
+ src_pts = np.float32(facial_pts)
236
+ src_pts_shp = src_pts.shape
237
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
238
+ raise FaceWarpException(
239
+ 'facial_pts.shape must be (K,2) or (2,K) and K>2')
240
+
241
+ if src_pts_shp[0] == 2:
242
+ src_pts = src_pts.T
243
+
244
+ if src_pts.shape != ref_pts.shape:
245
+ raise FaceWarpException(
246
+ 'facial_pts and reference_pts must have the same shape')
247
+
248
+ if align_type is 'cv2_affine':
249
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
250
+ tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
251
+ elif align_type is 'affine':
252
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
253
+ tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
254
+ else:
255
+ params, scale = _umeyama(src_pts, ref_pts)
256
+ tfm = params[:2, :]
257
+
258
+ params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
259
+ tfm_inv = params[:2, :]
260
+
261
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
262
+
263
+ return face_img, tfm_inv
core/data/deg_kair_utils/utils_blindsr.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+
6
+ from core.data.deg_kair_utils import utils_image as util
7
+
8
+ import random
9
+ from scipy import ndimage
10
+ import scipy
11
+ import scipy.stats as ss
12
+ from scipy.interpolate import interp2d
13
+ from scipy.linalg import orth
14
+
15
+
16
+
17
+
18
+ """
19
+ # --------------------------------------------
20
+ # Super-Resolution
21
+ # --------------------------------------------
22
+ #
23
+ # Kai Zhang ([email protected])
24
+ # https://github.com/cszn
25
+ # From 2019/03--2021/08
26
+ # --------------------------------------------
27
+ """
28
+
29
+ def modcrop_np(img, sf):
30
+ '''
31
+ Args:
32
+ img: numpy image, WxH or WxHxC
33
+ sf: scale factor
34
+
35
+ Return:
36
+ cropped image
37
+ '''
38
+ w, h = img.shape[:2]
39
+ im = np.copy(img)
40
+ return im[:w - w % sf, :h - h % sf, ...]
41
+
42
+
43
+ """
44
+ # --------------------------------------------
45
+ # anisotropic Gaussian kernels
46
+ # --------------------------------------------
47
+ """
48
+ def analytic_kernel(k):
49
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
50
+ k_size = k.shape[0]
51
+ # Calculate the big kernels size
52
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
53
+ # Loop over the small kernel to fill the big one
54
+ for r in range(k_size):
55
+ for c in range(k_size):
56
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
57
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
58
+ crop = k_size // 2
59
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
60
+ # Normalize to 1
61
+ return cropped_big_k / cropped_big_k.sum()
62
+
63
+
64
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
65
+ """ generate an anisotropic Gaussian kernel
66
+ Args:
67
+ ksize : e.g., 15, kernel size
68
+ theta : [0, pi], rotation angle range
69
+ l1 : [0.1,50], scaling of eigenvalues
70
+ l2 : [0.1,l1], scaling of eigenvalues
71
+ If l1 = l2, will get an isotropic Gaussian kernel.
72
+
73
+ Returns:
74
+ k : kernel
75
+ """
76
+
77
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
78
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
79
+ D = np.array([[l1, 0], [0, l2]])
80
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
81
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
82
+
83
+ return k
84
+
85
+
86
+ def gm_blur_kernel(mean, cov, size=15):
87
+ center = size / 2.0 + 0.5
88
+ k = np.zeros([size, size])
89
+ for y in range(size):
90
+ for x in range(size):
91
+ cy = y - center + 1
92
+ cx = x - center + 1
93
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
94
+
95
+ k = k / np.sum(k)
96
+ return k
97
+
98
+
99
+ def shift_pixel(x, sf, upper_left=True):
100
+ """shift pixel for super-resolution with different scale factors
101
+ Args:
102
+ x: WxHxC or WxH
103
+ sf: scale factor
104
+ upper_left: shift direction
105
+ """
106
+ h, w = x.shape[:2]
107
+ shift = (sf-1)*0.5
108
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
109
+ if upper_left:
110
+ x1 = xv + shift
111
+ y1 = yv + shift
112
+ else:
113
+ x1 = xv - shift
114
+ y1 = yv - shift
115
+
116
+ x1 = np.clip(x1, 0, w-1)
117
+ y1 = np.clip(y1, 0, h-1)
118
+
119
+ if x.ndim == 2:
120
+ x = interp2d(xv, yv, x)(x1, y1)
121
+ if x.ndim == 3:
122
+ for i in range(x.shape[-1]):
123
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
124
+
125
+ return x
126
+
127
+
128
+ def blur(x, k):
129
+ '''
130
+ x: image, NxcxHxW
131
+ k: kernel, Nx1xhxw
132
+ '''
133
+ n, c = x.shape[:2]
134
+ p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
135
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
136
+ k = k.repeat(1,c,1,1)
137
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
138
+ x = x.view(1, -1, x.shape[2], x.shape[3])
139
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
140
+ x = x.view(n, c, x.shape[2], x.shape[3])
141
+
142
+ return x
143
+
144
+
145
+
146
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
147
+ """"
148
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
149
+ # Kai Zhang
150
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
151
+ # max_var = 2.5 * sf
152
+ """
153
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
154
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
155
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
156
+ theta = np.random.rand() * np.pi # random theta
157
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
158
+
159
+ # Set COV matrix using Lambdas and Theta
160
+ LAMBDA = np.diag([lambda_1, lambda_2])
161
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
162
+ [np.sin(theta), np.cos(theta)]])
163
+ SIGMA = Q @ LAMBDA @ Q.T
164
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
165
+
166
+ # Set expectation position (shifting kernel for aligned image)
167
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
168
+ MU = MU[None, None, :, None]
169
+
170
+ # Create meshgrid for Gaussian
171
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
172
+ Z = np.stack([X, Y], 2)[:, :, :, None]
173
+
174
+ # Calcualte Gaussian for every pixel of the kernel
175
+ ZZ = Z-MU
176
+ ZZ_t = ZZ.transpose(0,1,3,2)
177
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
178
+
179
+ # shift the kernel so it will be centered
180
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
181
+
182
+ # Normalize the kernel and return
183
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
184
+ kernel = raw_kernel / np.sum(raw_kernel)
185
+ return kernel
186
+
187
+
188
+ def fspecial_gaussian(hsize, sigma):
189
+ hsize = [hsize, hsize]
190
+ siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
191
+ std = sigma
192
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
193
+ arg = -(x*x + y*y)/(2*std*std)
194
+ h = np.exp(arg)
195
+ h[h < scipy.finfo(float).eps * h.max()] = 0
196
+ sumh = h.sum()
197
+ if sumh != 0:
198
+ h = h/sumh
199
+ return h
200
+
201
+
202
+ def fspecial_laplacian(alpha):
203
+ alpha = max([0, min([alpha,1])])
204
+ h1 = alpha/(alpha+1)
205
+ h2 = (1-alpha)/(alpha+1)
206
+ h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
207
+ h = np.array(h)
208
+ return h
209
+
210
+
211
+ def fspecial(filter_type, *args, **kwargs):
212
+ '''
213
+ python code from:
214
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
215
+ '''
216
+ if filter_type == 'gaussian':
217
+ return fspecial_gaussian(*args, **kwargs)
218
+ if filter_type == 'laplacian':
219
+ return fspecial_laplacian(*args, **kwargs)
220
+
221
+ """
222
+ # --------------------------------------------
223
+ # degradation models
224
+ # --------------------------------------------
225
+ """
226
+
227
+
228
+ def bicubic_degradation(x, sf=3):
229
+ '''
230
+ Args:
231
+ x: HxWxC image, [0, 1]
232
+ sf: down-scale factor
233
+
234
+ Return:
235
+ bicubicly downsampled LR image
236
+ '''
237
+ x = util.imresize_np(x, scale=1/sf)
238
+ return x
239
+
240
+
241
+ def srmd_degradation(x, k, sf=3):
242
+ ''' blur + bicubic downsampling
243
+
244
+ Args:
245
+ x: HxWxC image, [0, 1]
246
+ k: hxw, double
247
+ sf: down-scale factor
248
+
249
+ Return:
250
+ downsampled LR image
251
+
252
+ Reference:
253
+ @inproceedings{zhang2018learning,
254
+ title={Learning a single convolutional super-resolution network for multiple degradations},
255
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
256
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
257
+ pages={3262--3271},
258
+ year={2018}
259
+ }
260
+ '''
261
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
262
+ x = bicubic_degradation(x, sf=sf)
263
+ return x
264
+
265
+
266
+ def dpsr_degradation(x, k, sf=3):
267
+
268
+ ''' bicubic downsampling + blur
269
+
270
+ Args:
271
+ x: HxWxC image, [0, 1]
272
+ k: hxw, double
273
+ sf: down-scale factor
274
+
275
+ Return:
276
+ downsampled LR image
277
+
278
+ Reference:
279
+ @inproceedings{zhang2019deep,
280
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
281
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
282
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
283
+ pages={1671--1681},
284
+ year={2019}
285
+ }
286
+ '''
287
+ x = bicubic_degradation(x, sf=sf)
288
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
289
+ return x
290
+
291
+
292
+ def classical_degradation(x, k, sf=3):
293
+ ''' blur + downsampling
294
+
295
+ Args:
296
+ x: HxWxC image, [0, 1]/[0, 255]
297
+ k: hxw, double
298
+ sf: down-scale factor
299
+
300
+ Return:
301
+ downsampled LR image
302
+ '''
303
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
304
+ #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
305
+ st = 0
306
+ return x[st::sf, st::sf, ...]
307
+
308
+
309
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
310
+ """USM sharpening. borrowed from real-ESRGAN
311
+ Input image: I; Blurry image: B.
312
+ 1. K = I + weight * (I - B)
313
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
314
+ 3. Blur mask:
315
+ 4. Out = Mask * K + (1 - Mask) * I
316
+ Args:
317
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
318
+ weight (float): Sharp weight. Default: 1.
319
+ radius (float): Kernel size of Gaussian blur. Default: 50.
320
+ threshold (int):
321
+ """
322
+ if radius % 2 == 0:
323
+ radius += 1
324
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
325
+ residual = img - blur
326
+ mask = np.abs(residual) * 255 > threshold
327
+ mask = mask.astype('float32')
328
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
329
+
330
+ K = img + weight * residual
331
+ K = np.clip(K, 0, 1)
332
+ return soft_mask * K + (1 - soft_mask) * img
333
+
334
+
335
+ def add_blur(img, sf=4):
336
+ wd2 = 4.0 + sf
337
+ wd = 2.0 + 0.2*sf
338
+ if random.random() < 0.5:
339
+ l1 = wd2*random.random()
340
+ l2 = wd2*random.random()
341
+ k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
342
+ else:
343
+ k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
344
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
345
+
346
+ return img
347
+
348
+
349
+ def add_resize(img, sf=4):
350
+ rnum = np.random.rand()
351
+ if rnum > 0.8: # up
352
+ sf1 = random.uniform(1, 2)
353
+ elif rnum < 0.7: # down
354
+ sf1 = random.uniform(0.5/sf, 1)
355
+ else:
356
+ sf1 = 1.0
357
+ img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
358
+ img = np.clip(img, 0.0, 1.0)
359
+
360
+ return img
361
+
362
+
363
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
364
+ noise_level = random.randint(noise_level1, noise_level2)
365
+ rnum = np.random.rand()
366
+ if rnum > 0.6: # add color Gaussian noise
367
+ img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
368
+ elif rnum < 0.4: # add grayscale Gaussian noise
369
+ img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
370
+ else: # add noise
371
+ L = noise_level2/255.
372
+ D = np.diag(np.random.rand(3))
373
+ U = orth(np.random.rand(3,3))
374
+ conv = np.dot(np.dot(np.transpose(U), D), U)
375
+ img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
376
+ img = np.clip(img, 0.0, 1.0)
377
+ return img
378
+
379
+
380
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
381
+ noise_level = random.randint(noise_level1, noise_level2)
382
+ img = np.clip(img, 0.0, 1.0)
383
+ rnum = random.random()
384
+ if rnum > 0.6:
385
+ img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
386
+ elif rnum < 0.4:
387
+ img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
388
+ else:
389
+ L = noise_level2/255.
390
+ D = np.diag(np.random.rand(3))
391
+ U = orth(np.random.rand(3,3))
392
+ conv = np.dot(np.dot(np.transpose(U), D), U)
393
+ img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
394
+ img = np.clip(img, 0.0, 1.0)
395
+ return img
396
+
397
+
398
+ def add_Poisson_noise(img):
399
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
400
+ vals = 10**(2*random.random()+2.0) # [2, 4]
401
+ if random.random() < 0.5:
402
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
403
+ else:
404
+ img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
405
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
406
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
407
+ img += noise_gray[:, :, np.newaxis]
408
+ img = np.clip(img, 0.0, 1.0)
409
+ return img
410
+
411
+
412
+ def add_JPEG_noise(img):
413
+ quality_factor = random.randint(30, 95)
414
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
415
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
416
+ img = cv2.imdecode(encimg, 1)
417
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
418
+ return img
419
+
420
+
421
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
422
+ h, w = lq.shape[:2]
423
+ rnd_h = random.randint(0, h-lq_patchsize)
424
+ rnd_w = random.randint(0, w-lq_patchsize)
425
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
426
+
427
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
428
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
429
+ return lq, hq
430
+
431
+
432
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
433
+ """
434
+ This is the degradation model of BSRGAN from the paper
435
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
436
+ ----------
437
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
438
+ sf: scale factor
439
+ isp_model: camera ISP model
440
+
441
+ Returns
442
+ -------
443
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
444
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
445
+ """
446
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
447
+ sf_ori = sf
448
+
449
+ h1, w1 = img.shape[:2]
450
+ img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
451
+ h, w = img.shape[:2]
452
+
453
+ if h < lq_patchsize*sf or w < lq_patchsize*sf:
454
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
455
+
456
+ hq = img.copy()
457
+
458
+ if sf == 4 and random.random() < scale2_prob: # downsample1
459
+ if np.random.rand() < 0.5:
460
+ img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
461
+ else:
462
+ img = util.imresize_np(img, 1/2, True)
463
+ img = np.clip(img, 0.0, 1.0)
464
+ sf = 2
465
+
466
+ shuffle_order = random.sample(range(7), 7)
467
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
468
+ if idx1 > idx2: # keep downsample3 last
469
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
470
+
471
+ for i in shuffle_order:
472
+
473
+ if i == 0:
474
+ img = add_blur(img, sf=sf)
475
+
476
+ elif i == 1:
477
+ img = add_blur(img, sf=sf)
478
+
479
+ elif i == 2:
480
+ a, b = img.shape[1], img.shape[0]
481
+ # downsample2
482
+ if random.random() < 0.75:
483
+ sf1 = random.uniform(1,2*sf)
484
+ img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
485
+ else:
486
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
487
+ k_shifted = shift_pixel(k, sf)
488
+ k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
489
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
490
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
491
+ img = np.clip(img, 0.0, 1.0)
492
+
493
+ elif i == 3:
494
+ # downsample3
495
+ img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
496
+ img = np.clip(img, 0.0, 1.0)
497
+
498
+ elif i == 4:
499
+ # add Gaussian noise
500
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
501
+
502
+ elif i == 5:
503
+ # add JPEG noise
504
+ if random.random() < jpeg_prob:
505
+ img = add_JPEG_noise(img)
506
+
507
+ elif i == 6:
508
+ # add processed camera sensor noise
509
+ if random.random() < isp_prob and isp_model is not None:
510
+ with torch.no_grad():
511
+ img, hq = isp_model.forward(img.copy(), hq)
512
+
513
+ # add final JPEG compression noise
514
+ img = add_JPEG_noise(img)
515
+
516
+ # random crop
517
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
518
+
519
+ return img, hq
520
+
521
+
522
+
523
+
524
+ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
525
+ """
526
+ This is an extended degradation model by combining
527
+ the degradation models of BSRGAN and Real-ESRGAN
528
+ ----------
529
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
530
+ sf: scale factor
531
+ use_shuffle: the degradation shuffle
532
+ use_sharp: sharpening the img
533
+
534
+ Returns
535
+ -------
536
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
537
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
538
+ """
539
+
540
+ h1, w1 = img.shape[:2]
541
+ img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
542
+ h, w = img.shape[:2]
543
+
544
+ if h < lq_patchsize*sf or w < lq_patchsize*sf:
545
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
546
+
547
+ if use_sharp:
548
+ img = add_sharpening(img)
549
+ hq = img.copy()
550
+
551
+ if random.random() < shuffle_prob:
552
+ shuffle_order = random.sample(range(13), 13)
553
+ else:
554
+ shuffle_order = list(range(13))
555
+ # local shuffle for noise, JPEG is always the last one
556
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
557
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
558
+
559
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
560
+
561
+ for i in shuffle_order:
562
+ if i == 0:
563
+ img = add_blur(img, sf=sf)
564
+ elif i == 1:
565
+ img = add_resize(img, sf=sf)
566
+ elif i == 2:
567
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
568
+ elif i == 3:
569
+ if random.random() < poisson_prob:
570
+ img = add_Poisson_noise(img)
571
+ elif i == 4:
572
+ if random.random() < speckle_prob:
573
+ img = add_speckle_noise(img)
574
+ elif i == 5:
575
+ if random.random() < isp_prob and isp_model is not None:
576
+ with torch.no_grad():
577
+ img, hq = isp_model.forward(img.copy(), hq)
578
+ elif i == 6:
579
+ img = add_JPEG_noise(img)
580
+ elif i == 7:
581
+ img = add_blur(img, sf=sf)
582
+ elif i == 8:
583
+ img = add_resize(img, sf=sf)
584
+ elif i == 9:
585
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
586
+ elif i == 10:
587
+ if random.random() < poisson_prob:
588
+ img = add_Poisson_noise(img)
589
+ elif i == 11:
590
+ if random.random() < speckle_prob:
591
+ img = add_speckle_noise(img)
592
+ elif i == 12:
593
+ if random.random() < isp_prob and isp_model is not None:
594
+ with torch.no_grad():
595
+ img, hq = isp_model.forward(img.copy(), hq)
596
+ else:
597
+ print('check the shuffle!')
598
+
599
+ # resize to desired size
600
+ img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
601
+
602
+ # add final JPEG compression noise
603
+ img = add_JPEG_noise(img)
604
+
605
+ # random crop
606
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
607
+
608
+ return img, hq
609
+
610
+
611
+
612
+ if __name__ == '__main__':
613
+ img = util.imread_uint('utils/test.png', 3)
614
+ img = util.uint2single(img)
615
+ sf = 4
616
+
617
+ for i in range(20):
618
+ img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
619
+ print(i)
620
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
621
+ img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
622
+ util.imsave(img_concat, str(i)+'.png')
623
+
624
+ # for i in range(10):
625
+ # img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
626
+ # print(i)
627
+ # lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
628
+ # img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
629
+ # util.imsave(img_concat, str(i)+'.png')
630
+
631
+ # run utils/utils_blindsr.py
core/data/deg_kair_utils/utils_bnorm.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ """
6
+ # --------------------------------------------
7
+ # Batch Normalization
8
+ # --------------------------------------------
9
+
10
+ # Kai Zhang ([email protected])
11
+ # https://github.com/cszn
12
+ # 01/Jan/2019
13
+ # --------------------------------------------
14
+ """
15
+
16
+
17
+ # --------------------------------------------
18
+ # remove/delete specified layer
19
+ # --------------------------------------------
20
+ def deleteLayer(model, layer_type=nn.BatchNorm2d):
21
+ ''' Kai Zhang, 11/Jan/2019.
22
+ '''
23
+ for k, m in list(model.named_children()):
24
+ if isinstance(m, layer_type):
25
+ del model._modules[k]
26
+ deleteLayer(m, layer_type)
27
+
28
+
29
+ # --------------------------------------------
30
+ # merge bn, "conv+bn" --> "conv"
31
+ # --------------------------------------------
32
+ def merge_bn(model):
33
+ ''' Kai Zhang, 11/Jan/2019.
34
+ merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
35
+ based on https://github.com/pytorch/pytorch/pull/901
36
+ '''
37
+ prev_m = None
38
+ for k, m in list(model.named_children()):
39
+ if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
40
+
41
+ w = prev_m.weight.data
42
+
43
+ if prev_m.bias is None:
44
+ zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
45
+ prev_m.bias = nn.Parameter(zeros)
46
+ b = prev_m.bias.data
47
+
48
+ invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
49
+ if isinstance(prev_m, nn.ConvTranspose2d):
50
+ w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
51
+ else:
52
+ w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
53
+ b.add_(-m.running_mean).mul_(invstd)
54
+ if m.affine:
55
+ if isinstance(prev_m, nn.ConvTranspose2d):
56
+ w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
57
+ else:
58
+ w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
59
+ b.mul_(m.weight.data).add_(m.bias.data)
60
+
61
+ del model._modules[k]
62
+ prev_m = m
63
+ merge_bn(m)
64
+
65
+
66
+ # --------------------------------------------
67
+ # add bn, "conv" --> "conv+bn"
68
+ # --------------------------------------------
69
+ def add_bn(model):
70
+ ''' Kai Zhang, 11/Jan/2019.
71
+ '''
72
+ for k, m in list(model.named_children()):
73
+ if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
74
+ b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
75
+ b.weight.data.fill_(1)
76
+ new_m = nn.Sequential(model._modules[k], b)
77
+ model._modules[k] = new_m
78
+ add_bn(m)
79
+
80
+
81
+ # --------------------------------------------
82
+ # tidy model after removing bn
83
+ # --------------------------------------------
84
+ def tidy_sequential(model):
85
+ ''' Kai Zhang, 11/Jan/2019.
86
+ '''
87
+ for k, m in list(model.named_children()):
88
+ if isinstance(m, nn.Sequential):
89
+ if m.__len__() == 1:
90
+ model._modules[k] = m.__getitem__(0)
91
+ tidy_sequential(m)
core/data/deg_kair_utils/utils_deblur.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import scipy
4
+ from scipy import fftpack
5
+ import torch
6
+
7
+ from math import cos, sin
8
+ from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
9
+ from numpy.random import randn, rand
10
+ from scipy.signal import convolve2d
11
+ import cv2
12
+ import random
13
+ # import utils_image as util
14
+
15
+ '''
16
+ modified by Kai Zhang (github: https://github.com/cszn)
17
+ 03/03/2019
18
+ '''
19
+
20
+
21
+ def get_uperleft_denominator(img, kernel):
22
+ '''
23
+ img: HxWxC
24
+ kernel: hxw
25
+ denominator: HxWx1
26
+ upperleft: HxWxC
27
+ '''
28
+ V = psf2otf(kernel, img.shape[:2])
29
+ denominator = np.expand_dims(np.abs(V)**2, axis=2)
30
+ upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
31
+ return upperleft, denominator
32
+
33
+
34
+ def get_uperleft_denominator_pytorch(img, kernel):
35
+ '''
36
+ img: NxCxHxW
37
+ kernel: Nx1xhxw
38
+ denominator: Nx1xHxW
39
+ upperleft: NxCxHxWx2
40
+ '''
41
+ V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
42
+ denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
43
+ upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
44
+ return upperleft, denominator
45
+
46
+
47
+ def c2c(x):
48
+ return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
49
+
50
+
51
+ def r2c(x):
52
+ return torch.stack([x, torch.zeros_like(x)], -1)
53
+
54
+
55
+ def cdiv(x, y):
56
+ a, b = x[..., 0], x[..., 1]
57
+ c, d = y[..., 0], y[..., 1]
58
+ cd2 = c**2 + d**2
59
+ return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
60
+
61
+
62
+ def cabs(x):
63
+ return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
64
+
65
+
66
+ def cmul(t1, t2):
67
+ '''
68
+ complex multiplication
69
+ t1: NxCxHxWx2
70
+ output: NxCxHxWx2
71
+ '''
72
+ real1, imag1 = t1[..., 0], t1[..., 1]
73
+ real2, imag2 = t2[..., 0], t2[..., 1]
74
+ return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
75
+
76
+
77
+ def cconj(t, inplace=False):
78
+ '''
79
+ # complex's conjugation
80
+ t: NxCxHxWx2
81
+ output: NxCxHxWx2
82
+ '''
83
+ c = t.clone() if not inplace else t
84
+ c[..., 1] *= -1
85
+ return c
86
+
87
+
88
+ def rfft(t):
89
+ return torch.rfft(t, 2, onesided=False)
90
+
91
+
92
+ def irfft(t):
93
+ return torch.irfft(t, 2, onesided=False)
94
+
95
+
96
+ def fft(t):
97
+ return torch.fft(t, 2)
98
+
99
+
100
+ def ifft(t):
101
+ return torch.ifft(t, 2)
102
+
103
+
104
+ def p2o(psf, shape):
105
+ '''
106
+ # psf: NxCxhxw
107
+ # shape: [H,W]
108
+ # otf: NxCxHxWx2
109
+ '''
110
+ otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
111
+ otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
112
+ for axis, axis_size in enumerate(psf.shape[2:]):
113
+ otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
114
+ otf = torch.rfft(otf, 2, onesided=False)
115
+ n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
116
+ otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
117
+ return otf
118
+
119
+
120
+
121
+ # otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math.
122
+ def otf2psf(otf, outsize=None):
123
+ insize = np.array(otf.shape)
124
+ psf = np.fft.ifftn(otf, axes=(0, 1))
125
+ for axis, axis_size in enumerate(insize):
126
+ psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
127
+ if type(outsize) != type(None):
128
+ insize = np.array(otf.shape)
129
+ outsize = np.array(outsize)
130
+ n = max(np.size(outsize), np.size(insize))
131
+ # outsize = postpad(outsize(:), n, 1);
132
+ # insize = postpad(insize(:) , n, 1);
133
+ colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
134
+ colvec_in = insize.flatten().reshape((np.size(insize), 1))
135
+ outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
136
+ insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
137
+
138
+ pad = (insize - outsize) / 2
139
+ if np.any(pad < 0):
140
+ print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
141
+ prepad = np.floor(pad)
142
+ postpad = np.ceil(pad)
143
+ dims_start = prepad.astype(int)
144
+ dims_end = (insize - postpad).astype(int)
145
+ for i in range(len(dims_start.shape)):
146
+ psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
147
+ n_ops = np.sum(otf.size * np.log2(otf.shape))
148
+ psf = np.real_if_close(psf, tol=n_ops)
149
+ return psf
150
+
151
+
152
+ # psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
153
+ def psf2otf(psf, shape=None):
154
+ """
155
+ Convert point-spread function to optical transfer function.
156
+ Compute the Fast Fourier Transform (FFT) of the point-spread
157
+ function (PSF) array and creates the optical transfer function (OTF)
158
+ array that is not influenced by the PSF off-centering.
159
+ By default, the OTF array is the same size as the PSF array.
160
+ To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
161
+ post-pads the PSF array (down or to the right) with zeros to match
162
+ dimensions specified in OUTSIZE, then circularly shifts the values of
163
+ the PSF array up (or to the left) until the central pixel reaches (1,1)
164
+ position.
165
+ Parameters
166
+ ----------
167
+ psf : `numpy.ndarray`
168
+ PSF array
169
+ shape : int
170
+ Output shape of the OTF array
171
+ Returns
172
+ -------
173
+ otf : `numpy.ndarray`
174
+ OTF array
175
+ Notes
176
+ -----
177
+ Adapted from MATLAB psf2otf function
178
+ """
179
+ if type(shape) == type(None):
180
+ shape = psf.shape
181
+ shape = np.array(shape)
182
+ if np.all(psf == 0):
183
+ # return np.zeros_like(psf)
184
+ return np.zeros(shape)
185
+ if len(psf.shape) == 1:
186
+ psf = psf.reshape((1, psf.shape[0]))
187
+ inshape = psf.shape
188
+ psf = zero_pad(psf, shape, position='corner')
189
+ for axis, axis_size in enumerate(inshape):
190
+ psf = np.roll(psf, -int(axis_size / 2), axis=axis)
191
+ # Compute the OTF
192
+ otf = np.fft.fft2(psf, axes=(0, 1))
193
+ # Estimate the rough number of operations involved in the FFT
194
+ # and discard the PSF imaginary part if within roundoff error
195
+ # roundoff error = machine epsilon = sys.float_info.epsilon
196
+ # or np.finfo().eps
197
+ n_ops = np.sum(psf.size * np.log2(psf.shape))
198
+ otf = np.real_if_close(otf, tol=n_ops)
199
+ return otf
200
+
201
+
202
+ def zero_pad(image, shape, position='corner'):
203
+ """
204
+ Extends image to a certain size with zeros
205
+ Parameters
206
+ ----------
207
+ image: real 2d `numpy.ndarray`
208
+ Input image
209
+ shape: tuple of int
210
+ Desired output shape of the image
211
+ position : str, optional
212
+ The position of the input image in the output one:
213
+ * 'corner'
214
+ top-left corner (default)
215
+ * 'center'
216
+ centered
217
+ Returns
218
+ -------
219
+ padded_img: real `numpy.ndarray`
220
+ The zero-padded image
221
+ """
222
+ shape = np.asarray(shape, dtype=int)
223
+ imshape = np.asarray(image.shape, dtype=int)
224
+ if np.alltrue(imshape == shape):
225
+ return image
226
+ if np.any(shape <= 0):
227
+ raise ValueError("ZERO_PAD: null or negative shape given")
228
+ dshape = shape - imshape
229
+ if np.any(dshape < 0):
230
+ raise ValueError("ZERO_PAD: target size smaller than source one")
231
+ pad_img = np.zeros(shape, dtype=image.dtype)
232
+ idx, idy = np.indices(imshape)
233
+ if position == 'center':
234
+ if np.any(dshape % 2 != 0):
235
+ raise ValueError("ZERO_PAD: source and target shapes "
236
+ "have different parity.")
237
+ offx, offy = dshape // 2
238
+ else:
239
+ offx, offy = (0, 0)
240
+ pad_img[idx + offx, idy + offy] = image
241
+ return pad_img
242
+
243
+
244
+ '''
245
+ Reducing boundary artifacts
246
+ '''
247
+
248
+
249
+ def opt_fft_size(n):
250
+ '''
251
+ Kai Zhang (github: https://github.com/cszn)
252
+ 03/03/2019
253
+ # opt_fft_size.m
254
+ # compute an optimal data length for Fourier transforms
255
+ # written by Sunghyun Cho ([email protected])
256
+ # persistent opt_fft_size_LUT;
257
+ '''
258
+
259
+ LUT_size = 2048
260
+ # print("generate opt_fft_size_LUT")
261
+ opt_fft_size_LUT = np.zeros(LUT_size)
262
+
263
+ e2 = 1
264
+ while e2 <= LUT_size:
265
+ e3 = e2
266
+ while e3 <= LUT_size:
267
+ e5 = e3
268
+ while e5 <= LUT_size:
269
+ e7 = e5
270
+ while e7 <= LUT_size:
271
+ if e7 <= LUT_size:
272
+ opt_fft_size_LUT[e7-1] = e7
273
+ if e7*11 <= LUT_size:
274
+ opt_fft_size_LUT[e7*11-1] = e7*11
275
+ if e7*13 <= LUT_size:
276
+ opt_fft_size_LUT[e7*13-1] = e7*13
277
+ e7 = e7 * 7
278
+ e5 = e5 * 5
279
+ e3 = e3 * 3
280
+ e2 = e2 * 2
281
+
282
+ nn = 0
283
+ for i in range(LUT_size, 0, -1):
284
+ if opt_fft_size_LUT[i-1] != 0:
285
+ nn = i-1
286
+ else:
287
+ opt_fft_size_LUT[i-1] = nn+1
288
+
289
+ m = np.zeros(len(n))
290
+ for c in range(len(n)):
291
+ nn = n[c]
292
+ if nn <= LUT_size:
293
+ m[c] = opt_fft_size_LUT[nn-1]
294
+ else:
295
+ m[c] = -1
296
+ return m
297
+
298
+
299
+ def wrap_boundary_liu(img, img_size):
300
+
301
+ """
302
+ Reducing boundary artifacts in image deconvolution
303
+ Renting Liu, Jiaya Jia
304
+ ICIP 2008
305
+ """
306
+ if img.ndim == 2:
307
+ ret = wrap_boundary(img, img_size)
308
+ elif img.ndim == 3:
309
+ ret = [wrap_boundary(img[:, :, i], img_size) for i in range(3)]
310
+ ret = np.stack(ret, 2)
311
+ return ret
312
+
313
+
314
+ def wrap_boundary(img, img_size):
315
+
316
+ """
317
+ python code from:
318
+ https://github.com/ys-koshelev/nla_deblur/blob/90fe0ab98c26c791dcbdf231fe6f938fca80e2a0/boundaries.py
319
+ Reducing boundary artifacts in image deconvolution
320
+ Renting Liu, Jiaya Jia
321
+ ICIP 2008
322
+ """
323
+ (H, W) = np.shape(img)
324
+ H_w = int(img_size[0]) - H
325
+ W_w = int(img_size[1]) - W
326
+
327
+ # ret = np.zeros((img_size[0], img_size[1]));
328
+ alpha = 1
329
+ HG = img[:, :]
330
+
331
+ r_A = np.zeros((alpha*2+H_w, W))
332
+ r_A[:alpha, :] = HG[-alpha:, :]
333
+ r_A[-alpha:, :] = HG[:alpha, :]
334
+ a = np.arange(H_w)/(H_w-1)
335
+ # r_A(alpha+1:end-alpha, 1) = (1-a)*r_A(alpha,1) + a*r_A(end-alpha+1,1)
336
+ r_A[alpha:-alpha, 0] = (1-a)*r_A[alpha-1, 0] + a*r_A[-alpha, 0]
337
+ # r_A(alpha+1:end-alpha, end) = (1-a)*r_A(alpha,end) + a*r_A(end-alpha+1,end)
338
+ r_A[alpha:-alpha, -1] = (1-a)*r_A[alpha-1, -1] + a*r_A[-alpha, -1]
339
+
340
+ r_B = np.zeros((H, alpha*2+W_w))
341
+ r_B[:, :alpha] = HG[:, -alpha:]
342
+ r_B[:, -alpha:] = HG[:, :alpha]
343
+ a = np.arange(W_w)/(W_w-1)
344
+ r_B[0, alpha:-alpha] = (1-a)*r_B[0, alpha-1] + a*r_B[0, -alpha]
345
+ r_B[-1, alpha:-alpha] = (1-a)*r_B[-1, alpha-1] + a*r_B[-1, -alpha]
346
+
347
+ if alpha == 1:
348
+ A2 = solve_min_laplacian(r_A[alpha-1:, :])
349
+ B2 = solve_min_laplacian(r_B[:, alpha-1:])
350
+ r_A[alpha-1:, :] = A2
351
+ r_B[:, alpha-1:] = B2
352
+ else:
353
+ A2 = solve_min_laplacian(r_A[alpha-1:-alpha+1, :])
354
+ r_A[alpha-1:-alpha+1, :] = A2
355
+ B2 = solve_min_laplacian(r_B[:, alpha-1:-alpha+1])
356
+ r_B[:, alpha-1:-alpha+1] = B2
357
+ A = r_A
358
+ B = r_B
359
+
360
+ r_C = np.zeros((alpha*2+H_w, alpha*2+W_w))
361
+ r_C[:alpha, :] = B[-alpha:, :]
362
+ r_C[-alpha:, :] = B[:alpha, :]
363
+ r_C[:, :alpha] = A[:, -alpha:]
364
+ r_C[:, -alpha:] = A[:, :alpha]
365
+
366
+ if alpha == 1:
367
+ C2 = C2 = solve_min_laplacian(r_C[alpha-1:, alpha-1:])
368
+ r_C[alpha-1:, alpha-1:] = C2
369
+ else:
370
+ C2 = solve_min_laplacian(r_C[alpha-1:-alpha+1, alpha-1:-alpha+1])
371
+ r_C[alpha-1:-alpha+1, alpha-1:-alpha+1] = C2
372
+ C = r_C
373
+ # return C
374
+ A = A[alpha-1:-alpha-1, :]
375
+ B = B[:, alpha:-alpha]
376
+ C = C[alpha:-alpha, alpha:-alpha]
377
+ ret = np.vstack((np.hstack((img, B)), np.hstack((A, C))))
378
+ return ret
379
+
380
+
381
+ def solve_min_laplacian(boundary_image):
382
+ (H, W) = np.shape(boundary_image)
383
+
384
+ # Laplacian
385
+ f = np.zeros((H, W))
386
+ # boundary image contains image intensities at boundaries
387
+ boundary_image[1:-1, 1:-1] = 0
388
+ j = np.arange(2, H)-1
389
+ k = np.arange(2, W)-1
390
+ f_bp = np.zeros((H, W))
391
+ f_bp[np.ix_(j, k)] = -4*boundary_image[np.ix_(j, k)] + boundary_image[np.ix_(j, k+1)] + boundary_image[np.ix_(j, k-1)] + boundary_image[np.ix_(j-1, k)] + boundary_image[np.ix_(j+1, k)]
392
+
393
+ del(j, k)
394
+ f1 = f - f_bp # subtract boundary points contribution
395
+ del(f_bp, f)
396
+
397
+ # DST Sine Transform algo starts here
398
+ f2 = f1[1:-1,1:-1]
399
+ del(f1)
400
+
401
+ # compute sine tranform
402
+ if f2.shape[1] == 1:
403
+ tt = fftpack.dst(f2, type=1, axis=0)/2
404
+ else:
405
+ tt = fftpack.dst(f2, type=1)/2
406
+
407
+ if tt.shape[0] == 1:
408
+ f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1, axis=0)/2)
409
+ else:
410
+ f2sin = np.transpose(fftpack.dst(np.transpose(tt), type=1)/2)
411
+ del(f2)
412
+
413
+ # compute Eigen Values
414
+ [x, y] = np.meshgrid(np.arange(1, W-1), np.arange(1, H-1))
415
+ denom = (2*np.cos(np.pi*x/(W-1))-2) + (2*np.cos(np.pi*y/(H-1)) - 2)
416
+
417
+ # divide
418
+ f3 = f2sin/denom
419
+ del(f2sin, x, y)
420
+
421
+ # compute Inverse Sine Transform
422
+ if f3.shape[0] == 1:
423
+ tt = fftpack.idst(f3*2, type=1, axis=1)/(2*(f3.shape[1]+1))
424
+ else:
425
+ tt = fftpack.idst(f3*2, type=1, axis=0)/(2*(f3.shape[0]+1))
426
+ del(f3)
427
+ if tt.shape[1] == 1:
428
+ img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1)/(2*(tt.shape[0]+1)))
429
+ else:
430
+ img_tt = np.transpose(fftpack.idst(np.transpose(tt)*2, type=1, axis=0)/(2*(tt.shape[1]+1)))
431
+ del(tt)
432
+
433
+ # put solution in inner points; outer points obtained from boundary image
434
+ img_direct = boundary_image
435
+ img_direct[1:-1, 1:-1] = 0
436
+ img_direct[1:-1, 1:-1] = img_tt
437
+ return img_direct
438
+
439
+
440
+ """
441
+ Created on Thu Jan 18 15:36:32 2018
442
+ @author: italo
443
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
444
+ """
445
+
446
+ """
447
+ Syntax
448
+ h = fspecial(type)
449
+ h = fspecial('average',hsize)
450
+ h = fspecial('disk',radius)
451
+ h = fspecial('gaussian',hsize,sigma)
452
+ h = fspecial('laplacian',alpha)
453
+ h = fspecial('log',hsize,sigma)
454
+ h = fspecial('motion',len,theta)
455
+ h = fspecial('prewitt')
456
+ h = fspecial('sobel')
457
+ """
458
+
459
+
460
+ def fspecial_average(hsize=3):
461
+ """Smoothing filter"""
462
+ return np.ones((hsize, hsize))/hsize**2
463
+
464
+
465
+ def fspecial_disk(radius):
466
+ """Disk filter"""
467
+ raise(NotImplemented)
468
+ rad = 0.6
469
+ crad = np.ceil(rad-0.5)
470
+ [x, y] = np.meshgrid(np.arange(-crad, crad+1), np.arange(-crad, crad+1))
471
+ maxxy = np.zeros(x.shape)
472
+ maxxy[abs(x) >= abs(y)] = abs(x)[abs(x) >= abs(y)]
473
+ maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
474
+ minxy = np.zeros(x.shape)
475
+ minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
476
+ minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
477
+ m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
478
+ (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
479
+ np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
480
+ m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
481
+ (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
482
+ np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
483
+ h = None
484
+ return h
485
+
486
+
487
+ def fspecial_gaussian(hsize, sigma):
488
+ hsize = [hsize, hsize]
489
+ siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
490
+ std = sigma
491
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
492
+ arg = -(x*x + y*y)/(2*std*std)
493
+ h = np.exp(arg)
494
+ h[h < scipy.finfo(float).eps * h.max()] = 0
495
+ sumh = h.sum()
496
+ if sumh != 0:
497
+ h = h/sumh
498
+ return h
499
+
500
+
501
+ def fspecial_laplacian(alpha):
502
+ alpha = max([0, min([alpha,1])])
503
+ h1 = alpha/(alpha+1)
504
+ h2 = (1-alpha)/(alpha+1)
505
+ h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
506
+ h = np.array(h)
507
+ return h
508
+
509
+
510
+ def fspecial_log(hsize, sigma):
511
+ raise(NotImplemented)
512
+
513
+
514
+ def fspecial_motion(motion_len, theta):
515
+ raise(NotImplemented)
516
+
517
+
518
+ def fspecial_prewitt():
519
+ return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
520
+
521
+
522
+ def fspecial_sobel():
523
+ return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
524
+
525
+
526
+ def fspecial(filter_type, *args, **kwargs):
527
+ '''
528
+ python code from:
529
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
530
+ '''
531
+ if filter_type == 'average':
532
+ return fspecial_average(*args, **kwargs)
533
+ if filter_type == 'disk':
534
+ return fspecial_disk(*args, **kwargs)
535
+ if filter_type == 'gaussian':
536
+ return fspecial_gaussian(*args, **kwargs)
537
+ if filter_type == 'laplacian':
538
+ return fspecial_laplacian(*args, **kwargs)
539
+ if filter_type == 'log':
540
+ return fspecial_log(*args, **kwargs)
541
+ if filter_type == 'motion':
542
+ return fspecial_motion(*args, **kwargs)
543
+ if filter_type == 'prewitt':
544
+ return fspecial_prewitt(*args, **kwargs)
545
+ if filter_type == 'sobel':
546
+ return fspecial_sobel(*args, **kwargs)
547
+
548
+
549
+ def fspecial_gauss(size, sigma):
550
+ x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
551
+ g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
552
+ return g / g.sum()
553
+
554
+
555
+ def blurkernel_synthesis(h=37, w=None):
556
+ # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
557
+ w = h if w is None else w
558
+ kdims = [h, w]
559
+ x = randomTrajectory(250)
560
+ k = None
561
+ while k is None:
562
+ k = kernelFromTrajectory(x)
563
+
564
+ # center pad to kdims
565
+ pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
566
+ pad_width = [(pad_width[0],), (pad_width[1],)]
567
+
568
+ if pad_width[0][0]<0 or pad_width[1][0]<0:
569
+ k = k[0:h, 0:h]
570
+ else:
571
+ k = pad(k, pad_width, "constant")
572
+ x1,x2 = k.shape
573
+ if np.random.randint(0, 4) == 1:
574
+ k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
575
+ y1, y2 = k.shape
576
+ k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
577
+
578
+ if sum(k)<0.1:
579
+ k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
580
+ k = k / sum(k)
581
+ # import matplotlib.pyplot as plt
582
+ # plt.imshow(k, interpolation="nearest", cmap="gray")
583
+ # plt.show()
584
+ return k
585
+
586
+
587
+ def kernelFromTrajectory(x):
588
+ h = 5 - log(rand()) / 0.15
589
+ h = round(min([h, 27])).astype(int)
590
+ h = h + 1 - h % 2
591
+ w = h
592
+ k = zeros((h, w))
593
+
594
+ xmin = min(x[0])
595
+ xmax = max(x[0])
596
+ ymin = min(x[1])
597
+ ymax = max(x[1])
598
+ xthr = arange(xmin, xmax, (xmax - xmin) / w)
599
+ ythr = arange(ymin, ymax, (ymax - ymin) / h)
600
+
601
+ for i in range(1, xthr.size):
602
+ for j in range(1, ythr.size):
603
+ idx = (
604
+ (x[0, :] >= xthr[i - 1])
605
+ & (x[0, :] < xthr[i])
606
+ & (x[1, :] >= ythr[j - 1])
607
+ & (x[1, :] < ythr[j])
608
+ )
609
+ k[i - 1, j - 1] = sum(idx)
610
+ if sum(k) == 0:
611
+ return
612
+ k = k / sum(k)
613
+ k = convolve2d(k, fspecial_gauss(3, 1), "same")
614
+ k = k / sum(k)
615
+ return k
616
+
617
+
618
+ def randomTrajectory(T):
619
+ x = zeros((3, T))
620
+ v = randn(3, T)
621
+ r = zeros((3, T))
622
+ trv = 1 / 1
623
+ trr = 2 * pi / T
624
+ for t in range(1, T):
625
+ F_rot = randn(3) / (t + 1) + r[:, t - 1]
626
+ F_trans = randn(3) / (t + 1)
627
+ r[:, t] = r[:, t - 1] + trr * F_rot
628
+ v[:, t] = v[:, t - 1] + trv * F_trans
629
+ st = v[:, t]
630
+ st = rot3D(st, r[:, t])
631
+ x[:, t] = x[:, t - 1] + st
632
+ return x
633
+
634
+
635
+ def rot3D(x, r):
636
+ Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
637
+ Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
638
+ Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
639
+ R = Rz @ Ry @ Rx
640
+ x = R @ x
641
+ return x
642
+
643
+
644
+ if __name__ == '__main__':
645
+ a = opt_fft_size([111])
646
+ print(a)
647
+
648
+ print(fspecial('gaussian', 5, 1))
649
+
650
+ print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
651
+
652
+ k = blurkernel_synthesis(11)
653
+ import matplotlib.pyplot as plt
654
+ plt.imshow(k, interpolation="nearest", cmap="gray")
655
+ plt.show()
core/data/deg_kair_utils/utils_dist.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+
9
+
10
+ # ----------------------------------
11
+ # init
12
+ # ----------------------------------
13
+ def init_dist(launcher, backend='nccl', **kwargs):
14
+ if mp.get_start_method(allow_none=True) is None:
15
+ mp.set_start_method('spawn')
16
+ if launcher == 'pytorch':
17
+ _init_dist_pytorch(backend, **kwargs)
18
+ elif launcher == 'slurm':
19
+ _init_dist_slurm(backend, **kwargs)
20
+ else:
21
+ raise ValueError(f'Invalid launcher type: {launcher}')
22
+
23
+
24
+ def _init_dist_pytorch(backend, **kwargs):
25
+ rank = int(os.environ['RANK'])
26
+ num_gpus = torch.cuda.device_count()
27
+ torch.cuda.set_device(rank % num_gpus)
28
+ dist.init_process_group(backend=backend, **kwargs)
29
+
30
+
31
+ def _init_dist_slurm(backend, port=None):
32
+ """Initialize slurm distributed training environment.
33
+ If argument ``port`` is not specified, then the master port will be system
34
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
35
+ environment variable, then a default port ``29500`` will be used.
36
+ Args:
37
+ backend (str): Backend of torch.distributed.
38
+ port (int, optional): Master port. Defaults to None.
39
+ """
40
+ proc_id = int(os.environ['SLURM_PROCID'])
41
+ ntasks = int(os.environ['SLURM_NTASKS'])
42
+ node_list = os.environ['SLURM_NODELIST']
43
+ num_gpus = torch.cuda.device_count()
44
+ torch.cuda.set_device(proc_id % num_gpus)
45
+ addr = subprocess.getoutput(
46
+ f'scontrol show hostname {node_list} | head -n1')
47
+ # specify master port
48
+ if port is not None:
49
+ os.environ['MASTER_PORT'] = str(port)
50
+ elif 'MASTER_PORT' in os.environ:
51
+ pass # use MASTER_PORT in the environment variable
52
+ else:
53
+ # 29500 is torch.distributed default port
54
+ os.environ['MASTER_PORT'] = '29500'
55
+ os.environ['MASTER_ADDR'] = addr
56
+ os.environ['WORLD_SIZE'] = str(ntasks)
57
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
58
+ os.environ['RANK'] = str(proc_id)
59
+ dist.init_process_group(backend=backend)
60
+
61
+
62
+
63
+ # ----------------------------------
64
+ # get rank and world_size
65
+ # ----------------------------------
66
+ def get_dist_info():
67
+ if dist.is_available():
68
+ initialized = dist.is_initialized()
69
+ else:
70
+ initialized = False
71
+ if initialized:
72
+ rank = dist.get_rank()
73
+ world_size = dist.get_world_size()
74
+ else:
75
+ rank = 0
76
+ world_size = 1
77
+ return rank, world_size
78
+
79
+
80
+ def get_rank():
81
+ if not dist.is_available():
82
+ return 0
83
+
84
+ if not dist.is_initialized():
85
+ return 0
86
+
87
+ return dist.get_rank()
88
+
89
+
90
+ def get_world_size():
91
+ if not dist.is_available():
92
+ return 1
93
+
94
+ if not dist.is_initialized():
95
+ return 1
96
+
97
+ return dist.get_world_size()
98
+
99
+
100
+ def master_only(func):
101
+
102
+ @functools.wraps(func)
103
+ def wrapper(*args, **kwargs):
104
+ rank, _ = get_dist_info()
105
+ if rank == 0:
106
+ return func(*args, **kwargs)
107
+
108
+ return wrapper
109
+
110
+
111
+
112
+
113
+
114
+
115
+ # ----------------------------------
116
+ # operation across ranks
117
+ # ----------------------------------
118
+ def reduce_sum(tensor):
119
+ if not dist.is_available():
120
+ return tensor
121
+
122
+ if not dist.is_initialized():
123
+ return tensor
124
+
125
+ tensor = tensor.clone()
126
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
127
+
128
+ return tensor
129
+
130
+
131
+ def gather_grad(params):
132
+ world_size = get_world_size()
133
+
134
+ if world_size == 1:
135
+ return
136
+
137
+ for param in params:
138
+ if param.grad is not None:
139
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
140
+ param.grad.data.div_(world_size)
141
+
142
+
143
+ def all_gather(data):
144
+ world_size = get_world_size()
145
+
146
+ if world_size == 1:
147
+ return [data]
148
+
149
+ buffer = pickle.dumps(data)
150
+ storage = torch.ByteStorage.from_buffer(buffer)
151
+ tensor = torch.ByteTensor(storage).to('cuda')
152
+
153
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
154
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
155
+ dist.all_gather(size_list, local_size)
156
+ size_list = [int(size.item()) for size in size_list]
157
+ max_size = max(size_list)
158
+
159
+ tensor_list = []
160
+ for _ in size_list:
161
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
162
+
163
+ if local_size != max_size:
164
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
165
+ tensor = torch.cat((tensor, padding), 0)
166
+
167
+ dist.all_gather(tensor_list, tensor)
168
+
169
+ data_list = []
170
+
171
+ for size, tensor in zip(size_list, tensor_list):
172
+ buffer = tensor.cpu().numpy().tobytes()[:size]
173
+ data_list.append(pickle.loads(buffer))
174
+
175
+ return data_list
176
+
177
+
178
+ def reduce_loss_dict(loss_dict):
179
+ world_size = get_world_size()
180
+
181
+ if world_size < 2:
182
+ return loss_dict
183
+
184
+ with torch.no_grad():
185
+ keys = []
186
+ losses = []
187
+
188
+ for k in sorted(loss_dict.keys()):
189
+ keys.append(k)
190
+ losses.append(loss_dict[k])
191
+
192
+ losses = torch.stack(losses, 0)
193
+ dist.reduce(losses, dst=0)
194
+
195
+ if dist.get_rank() == 0:
196
+ losses /= world_size
197
+
198
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
199
+
200
+ return reduced_losses
201
+
core/data/deg_kair_utils/utils_googledownload.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+
6
+ '''
7
+ borrowed from
8
+ https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
9
+ '''
10
+
11
+
12
+ def sizeof_fmt(size, suffix='B'):
13
+ """Get human readable file size.
14
+ Args:
15
+ size (int): File size.
16
+ suffix (str): Suffix. Default: 'B'.
17
+ Return:
18
+ str: Formated file siz.
19
+ """
20
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
21
+ if abs(size) < 1024.0:
22
+ return f'{size:3.1f} {unit}{suffix}'
23
+ size /= 1024.0
24
+ return f'{size:3.1f} Y{suffix}'
25
+
26
+
27
+ def download_file_from_google_drive(file_id, save_path):
28
+ """Download files from google drive.
29
+ Ref:
30
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
31
+ Args:
32
+ file_id (str): File id.
33
+ save_path (str): Save path.
34
+ """
35
+
36
+ session = requests.Session()
37
+ URL = 'https://docs.google.com/uc?export=download'
38
+ params = {'id': file_id}
39
+
40
+ response = session.get(URL, params=params, stream=True)
41
+ token = get_confirm_token(response)
42
+ if token:
43
+ params['confirm'] = token
44
+ response = session.get(URL, params=params, stream=True)
45
+
46
+ # get file size
47
+ response_file_size = session.get(
48
+ URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
49
+ if 'Content-Range' in response_file_size.headers:
50
+ file_size = int(
51
+ response_file_size.headers['Content-Range'].split('/')[1])
52
+ else:
53
+ file_size = None
54
+
55
+ save_response_content(response, save_path, file_size)
56
+
57
+
58
+ def get_confirm_token(response):
59
+ for key, value in response.cookies.items():
60
+ if key.startswith('download_warning'):
61
+ return value
62
+ return None
63
+
64
+
65
+ def save_response_content(response,
66
+ destination,
67
+ file_size=None,
68
+ chunk_size=32768):
69
+ if file_size is not None:
70
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
71
+
72
+ readable_file_size = sizeof_fmt(file_size)
73
+ else:
74
+ pbar = None
75
+
76
+ with open(destination, 'wb') as f:
77
+ downloaded_size = 0
78
+ for chunk in response.iter_content(chunk_size):
79
+ downloaded_size += chunk_size
80
+ if pbar is not None:
81
+ pbar.update(1)
82
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
83
+ f'/ {readable_file_size}')
84
+ if chunk: # filter out keep-alive new chunks
85
+ f.write(chunk)
86
+ if pbar is not None:
87
+ pbar.close()
88
+
89
+
90
+ if __name__ == "__main__":
91
+ file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
92
+ save_path = 'BSRGAN.pth'
93
+ download_file_from_google_drive(file_id, save_path)
core/data/deg_kair_utils/utils_image.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ from torchvision.utils import make_grid
8
+ from datetime import datetime
9
+ # import torchvision.transforms as transforms
10
+ import matplotlib.pyplot as plt
11
+ from mpl_toolkits.mplot3d import Axes3D
12
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
13
+
14
+
15
+ '''
16
+ # --------------------------------------------
17
+ # Kai Zhang (github: https://github.com/cszn)
18
+ # 03/Mar/2019
19
+ # --------------------------------------------
20
+ # https://github.com/twhui/SRGAN-pyTorch
21
+ # https://github.com/xinntao/BasicSR
22
+ # --------------------------------------------
23
+ '''
24
+
25
+
26
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
27
+
28
+
29
+ def is_image_file(filename):
30
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
31
+
32
+
33
+ def get_timestamp():
34
+ return datetime.now().strftime('%y%m%d-%H%M%S')
35
+
36
+
37
+ def imshow(x, title=None, cbar=False, figsize=None):
38
+ plt.figure(figsize=figsize)
39
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
40
+ if title:
41
+ plt.title(title)
42
+ if cbar:
43
+ plt.colorbar()
44
+ plt.show()
45
+
46
+
47
+ def surf(Z, cmap='rainbow', figsize=None):
48
+ plt.figure(figsize=figsize)
49
+ ax3 = plt.axes(projection='3d')
50
+
51
+ w, h = Z.shape[:2]
52
+ xx = np.arange(0,w,1)
53
+ yy = np.arange(0,h,1)
54
+ X, Y = np.meshgrid(xx, yy)
55
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
56
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
57
+ plt.show()
58
+
59
+
60
+ '''
61
+ # --------------------------------------------
62
+ # get image pathes
63
+ # --------------------------------------------
64
+ '''
65
+
66
+
67
+ def get_image_paths(dataroot):
68
+ paths = None # return None if dataroot is None
69
+ if isinstance(dataroot, str):
70
+ paths = sorted(_get_paths_from_images(dataroot))
71
+ elif isinstance(dataroot, list):
72
+ paths = []
73
+ for i in dataroot:
74
+ paths += sorted(_get_paths_from_images(i))
75
+ return paths
76
+
77
+
78
+ def _get_paths_from_images(path):
79
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
80
+ images = []
81
+ for dirpath, _, fnames in sorted(os.walk(path)):
82
+ for fname in sorted(fnames):
83
+ if is_image_file(fname):
84
+ img_path = os.path.join(dirpath, fname)
85
+ images.append(img_path)
86
+ assert images, '{:s} has no valid image file'.format(path)
87
+ return images
88
+
89
+
90
+ '''
91
+ # --------------------------------------------
92
+ # split large images into small images
93
+ # --------------------------------------------
94
+ '''
95
+
96
+
97
+ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
98
+ w, h = img.shape[:2]
99
+ patches = []
100
+ if w > p_max and h > p_max:
101
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
102
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
103
+ w1.append(w-p_size)
104
+ h1.append(h-p_size)
105
+ # print(w1)
106
+ # print(h1)
107
+ for i in w1:
108
+ for j in h1:
109
+ patches.append(img[i:i+p_size, j:j+p_size,:])
110
+ else:
111
+ patches.append(img)
112
+
113
+ return patches
114
+
115
+
116
+ def imssave(imgs, img_path):
117
+ """
118
+ imgs: list, N images of size WxHxC
119
+ """
120
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
121
+ for i, img in enumerate(imgs):
122
+ if img.ndim == 3:
123
+ img = img[:, :, [2, 1, 0]]
124
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
125
+ cv2.imwrite(new_path, img)
126
+
127
+
128
+ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
129
+ """
130
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
131
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
132
+ will be splitted.
133
+
134
+ Args:
135
+ original_dataroot:
136
+ taget_dataroot:
137
+ p_size: size of small images
138
+ p_overlap: patch size in training is a good choice
139
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
140
+ """
141
+ paths = get_image_paths(original_dataroot)
142
+ for img_path in paths:
143
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
144
+ img = imread_uint(img_path, n_channels=n_channels)
145
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
146
+ imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
147
+ #if original_dataroot == taget_dataroot:
148
+ #del img_path
149
+
150
+ '''
151
+ # --------------------------------------------
152
+ # makedir
153
+ # --------------------------------------------
154
+ '''
155
+
156
+
157
+ def mkdir(path):
158
+ if not os.path.exists(path):
159
+ os.makedirs(path)
160
+
161
+
162
+ def mkdirs(paths):
163
+ if isinstance(paths, str):
164
+ mkdir(paths)
165
+ else:
166
+ for path in paths:
167
+ mkdir(path)
168
+
169
+
170
+ def mkdir_and_rename(path):
171
+ if os.path.exists(path):
172
+ new_name = path + '_archived_' + get_timestamp()
173
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
174
+ os.rename(path, new_name)
175
+ os.makedirs(path)
176
+
177
+
178
+ '''
179
+ # --------------------------------------------
180
+ # read image from path
181
+ # opencv is fast, but read BGR numpy image
182
+ # --------------------------------------------
183
+ '''
184
+
185
+
186
+ # --------------------------------------------
187
+ # get uint8 image of size HxWxn_channles (RGB)
188
+ # --------------------------------------------
189
+ def imread_uint(path, n_channels=3):
190
+ # input: path
191
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
192
+ if n_channels == 1:
193
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
194
+ img = np.expand_dims(img, axis=2) # HxWx1
195
+ elif n_channels == 3:
196
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
197
+ if img.ndim == 2:
198
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
199
+ else:
200
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
201
+ return img
202
+
203
+
204
+ # --------------------------------------------
205
+ # matlab's imwrite
206
+ # --------------------------------------------
207
+ def imsave(img, img_path):
208
+ img = np.squeeze(img)
209
+ if img.ndim == 3:
210
+ img = img[:, :, [2, 1, 0]]
211
+ cv2.imwrite(img_path, img)
212
+
213
+ def imwrite(img, img_path):
214
+ img = np.squeeze(img)
215
+ if img.ndim == 3:
216
+ img = img[:, :, [2, 1, 0]]
217
+ cv2.imwrite(img_path, img)
218
+
219
+
220
+
221
+ # --------------------------------------------
222
+ # get single image of size HxWxn_channles (BGR)
223
+ # --------------------------------------------
224
+ def read_img(path):
225
+ # read image by cv2
226
+ # return: Numpy float32, HWC, BGR, [0,1]
227
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
228
+ img = img.astype(np.float32) / 255.
229
+ if img.ndim == 2:
230
+ img = np.expand_dims(img, axis=2)
231
+ # some images have 4 channels
232
+ if img.shape[2] > 3:
233
+ img = img[:, :, :3]
234
+ return img
235
+
236
+
237
+ '''
238
+ # --------------------------------------------
239
+ # image format conversion
240
+ # --------------------------------------------
241
+ # numpy(single) <---> numpy(uint)
242
+ # numpy(single) <---> tensor
243
+ # numpy(uint) <---> tensor
244
+ # --------------------------------------------
245
+ '''
246
+
247
+
248
+ # --------------------------------------------
249
+ # numpy(single) [0, 1] <---> numpy(uint)
250
+ # --------------------------------------------
251
+
252
+
253
+ def uint2single(img):
254
+
255
+ return np.float32(img/255.)
256
+
257
+
258
+ def single2uint(img):
259
+
260
+ return np.uint8((img.clip(0, 1)*255.).round())
261
+
262
+
263
+ def uint162single(img):
264
+
265
+ return np.float32(img/65535.)
266
+
267
+
268
+ def single2uint16(img):
269
+
270
+ return np.uint16((img.clip(0, 1)*65535.).round())
271
+
272
+
273
+ # --------------------------------------------
274
+ # numpy(uint) (HxWxC or HxW) <---> tensor
275
+ # --------------------------------------------
276
+
277
+
278
+ # convert uint to 4-dimensional torch tensor
279
+ def uint2tensor4(img):
280
+ if img.ndim == 2:
281
+ img = np.expand_dims(img, axis=2)
282
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
283
+
284
+
285
+ # convert uint to 3-dimensional torch tensor
286
+ def uint2tensor3(img):
287
+ if img.ndim == 2:
288
+ img = np.expand_dims(img, axis=2)
289
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
290
+
291
+
292
+ # convert 2/3/4-dimensional torch tensor to uint
293
+ def tensor2uint(img):
294
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
295
+ if img.ndim == 3:
296
+ img = np.transpose(img, (1, 2, 0))
297
+ return np.uint8((img*255.0).round())
298
+
299
+
300
+ # --------------------------------------------
301
+ # numpy(single) (HxWxC) <---> tensor
302
+ # --------------------------------------------
303
+
304
+
305
+ # convert single (HxWxC) to 3-dimensional torch tensor
306
+ def single2tensor3(img):
307
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
308
+
309
+
310
+ # convert single (HxWxC) to 4-dimensional torch tensor
311
+ def single2tensor4(img):
312
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
313
+
314
+
315
+ # convert torch tensor to single
316
+ def tensor2single(img):
317
+ img = img.data.squeeze().float().cpu().numpy()
318
+ if img.ndim == 3:
319
+ img = np.transpose(img, (1, 2, 0))
320
+
321
+ return img
322
+
323
+ # convert torch tensor to single
324
+ def tensor2single3(img):
325
+ img = img.data.squeeze().float().cpu().numpy()
326
+ if img.ndim == 3:
327
+ img = np.transpose(img, (1, 2, 0))
328
+ elif img.ndim == 2:
329
+ img = np.expand_dims(img, axis=2)
330
+ return img
331
+
332
+
333
+ def single2tensor5(img):
334
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
335
+
336
+
337
+ def single32tensor5(img):
338
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
339
+
340
+
341
+ def single42tensor4(img):
342
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
343
+
344
+
345
+ # from skimage.io import imread, imsave
346
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
347
+ '''
348
+ Converts a torch Tensor into an image Numpy array of BGR channel order
349
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
350
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
351
+ '''
352
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
353
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
354
+ n_dim = tensor.dim()
355
+ if n_dim == 4:
356
+ n_img = len(tensor)
357
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
358
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
359
+ elif n_dim == 3:
360
+ img_np = tensor.numpy()
361
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
362
+ elif n_dim == 2:
363
+ img_np = tensor.numpy()
364
+ else:
365
+ raise TypeError(
366
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
367
+ if out_type == np.uint8:
368
+ img_np = (img_np * 255.0).round()
369
+ # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
370
+ return img_np.astype(out_type)
371
+
372
+
373
+ '''
374
+ # --------------------------------------------
375
+ # Augmentation, flipe and/or rotate
376
+ # --------------------------------------------
377
+ # The following two are enough.
378
+ # (1) augmet_img: numpy image of WxHxC or WxH
379
+ # (2) augment_img_tensor4: tensor image 1xCxWxH
380
+ # --------------------------------------------
381
+ '''
382
+
383
+
384
+ def augment_img(img, mode=0):
385
+ '''Kai Zhang (github: https://github.com/cszn)
386
+ '''
387
+ if mode == 0:
388
+ return img
389
+ elif mode == 1:
390
+ return np.flipud(np.rot90(img))
391
+ elif mode == 2:
392
+ return np.flipud(img)
393
+ elif mode == 3:
394
+ return np.rot90(img, k=3)
395
+ elif mode == 4:
396
+ return np.flipud(np.rot90(img, k=2))
397
+ elif mode == 5:
398
+ return np.rot90(img)
399
+ elif mode == 6:
400
+ return np.rot90(img, k=2)
401
+ elif mode == 7:
402
+ return np.flipud(np.rot90(img, k=3))
403
+
404
+
405
+ def augment_img_tensor4(img, mode=0):
406
+ '''Kai Zhang (github: https://github.com/cszn)
407
+ '''
408
+ if mode == 0:
409
+ return img
410
+ elif mode == 1:
411
+ return img.rot90(1, [2, 3]).flip([2])
412
+ elif mode == 2:
413
+ return img.flip([2])
414
+ elif mode == 3:
415
+ return img.rot90(3, [2, 3])
416
+ elif mode == 4:
417
+ return img.rot90(2, [2, 3]).flip([2])
418
+ elif mode == 5:
419
+ return img.rot90(1, [2, 3])
420
+ elif mode == 6:
421
+ return img.rot90(2, [2, 3])
422
+ elif mode == 7:
423
+ return img.rot90(3, [2, 3]).flip([2])
424
+
425
+
426
+ def augment_img_tensor(img, mode=0):
427
+ '''Kai Zhang (github: https://github.com/cszn)
428
+ '''
429
+ img_size = img.size()
430
+ img_np = img.data.cpu().numpy()
431
+ if len(img_size) == 3:
432
+ img_np = np.transpose(img_np, (1, 2, 0))
433
+ elif len(img_size) == 4:
434
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
435
+ img_np = augment_img(img_np, mode=mode)
436
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
437
+ if len(img_size) == 3:
438
+ img_tensor = img_tensor.permute(2, 0, 1)
439
+ elif len(img_size) == 4:
440
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
441
+
442
+ return img_tensor.type_as(img)
443
+
444
+
445
+ def augment_img_np3(img, mode=0):
446
+ if mode == 0:
447
+ return img
448
+ elif mode == 1:
449
+ return img.transpose(1, 0, 2)
450
+ elif mode == 2:
451
+ return img[::-1, :, :]
452
+ elif mode == 3:
453
+ img = img[::-1, :, :]
454
+ img = img.transpose(1, 0, 2)
455
+ return img
456
+ elif mode == 4:
457
+ return img[:, ::-1, :]
458
+ elif mode == 5:
459
+ img = img[:, ::-1, :]
460
+ img = img.transpose(1, 0, 2)
461
+ return img
462
+ elif mode == 6:
463
+ img = img[:, ::-1, :]
464
+ img = img[::-1, :, :]
465
+ return img
466
+ elif mode == 7:
467
+ img = img[:, ::-1, :]
468
+ img = img[::-1, :, :]
469
+ img = img.transpose(1, 0, 2)
470
+ return img
471
+
472
+
473
+ def augment_imgs(img_list, hflip=True, rot=True):
474
+ # horizontal flip OR rotate
475
+ hflip = hflip and random.random() < 0.5
476
+ vflip = rot and random.random() < 0.5
477
+ rot90 = rot and random.random() < 0.5
478
+
479
+ def _augment(img):
480
+ if hflip:
481
+ img = img[:, ::-1, :]
482
+ if vflip:
483
+ img = img[::-1, :, :]
484
+ if rot90:
485
+ img = img.transpose(1, 0, 2)
486
+ return img
487
+
488
+ return [_augment(img) for img in img_list]
489
+
490
+
491
+ '''
492
+ # --------------------------------------------
493
+ # modcrop and shave
494
+ # --------------------------------------------
495
+ '''
496
+
497
+
498
+ def modcrop(img_in, scale):
499
+ # img_in: Numpy, HWC or HW
500
+ img = np.copy(img_in)
501
+ if img.ndim == 2:
502
+ H, W = img.shape
503
+ H_r, W_r = H % scale, W % scale
504
+ img = img[:H - H_r, :W - W_r]
505
+ elif img.ndim == 3:
506
+ H, W, C = img.shape
507
+ H_r, W_r = H % scale, W % scale
508
+ img = img[:H - H_r, :W - W_r, :]
509
+ else:
510
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
511
+ return img
512
+
513
+
514
+ def shave(img_in, border=0):
515
+ # img_in: Numpy, HWC or HW
516
+ img = np.copy(img_in)
517
+ h, w = img.shape[:2]
518
+ img = img[border:h-border, border:w-border]
519
+ return img
520
+
521
+
522
+ '''
523
+ # --------------------------------------------
524
+ # image processing process on numpy image
525
+ # channel_convert(in_c, tar_type, img_list):
526
+ # rgb2ycbcr(img, only_y=True):
527
+ # bgr2ycbcr(img, only_y=True):
528
+ # ycbcr2rgb(img):
529
+ # --------------------------------------------
530
+ '''
531
+
532
+
533
+ def rgb2ycbcr(img, only_y=True):
534
+ '''same as matlab rgb2ycbcr
535
+ only_y: only return Y channel
536
+ Input:
537
+ uint8, [0, 255]
538
+ float, [0, 1]
539
+ '''
540
+ in_img_type = img.dtype
541
+ img.astype(np.float32)
542
+ if in_img_type != np.uint8:
543
+ img *= 255.
544
+ # convert
545
+ if only_y:
546
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
547
+ else:
548
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
549
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
550
+ if in_img_type == np.uint8:
551
+ rlt = rlt.round()
552
+ else:
553
+ rlt /= 255.
554
+ return rlt.astype(in_img_type)
555
+
556
+
557
+ def ycbcr2rgb(img):
558
+ '''same as matlab ycbcr2rgb
559
+ Input:
560
+ uint8, [0, 255]
561
+ float, [0, 1]
562
+ '''
563
+ in_img_type = img.dtype
564
+ img.astype(np.float32)
565
+ if in_img_type != np.uint8:
566
+ img *= 255.
567
+ # convert
568
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
569
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
570
+ rlt = np.clip(rlt, 0, 255)
571
+ if in_img_type == np.uint8:
572
+ rlt = rlt.round()
573
+ else:
574
+ rlt /= 255.
575
+ return rlt.astype(in_img_type)
576
+
577
+
578
+ def bgr2ycbcr(img, only_y=True):
579
+ '''bgr version of rgb2ycbcr
580
+ only_y: only return Y channel
581
+ Input:
582
+ uint8, [0, 255]
583
+ float, [0, 1]
584
+ '''
585
+ in_img_type = img.dtype
586
+ img.astype(np.float32)
587
+ if in_img_type != np.uint8:
588
+ img *= 255.
589
+ # convert
590
+ if only_y:
591
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
592
+ else:
593
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
594
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
595
+ if in_img_type == np.uint8:
596
+ rlt = rlt.round()
597
+ else:
598
+ rlt /= 255.
599
+ return rlt.astype(in_img_type)
600
+
601
+
602
+ def channel_convert(in_c, tar_type, img_list):
603
+ # conversion among BGR, gray and y
604
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
605
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
606
+ return [np.expand_dims(img, axis=2) for img in gray_list]
607
+ elif in_c == 3 and tar_type == 'y': # BGR to y
608
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
609
+ return [np.expand_dims(img, axis=2) for img in y_list]
610
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
611
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
612
+ else:
613
+ return img_list
614
+
615
+
616
+ '''
617
+ # --------------------------------------------
618
+ # metric, PSNR, SSIM and PSNRB
619
+ # --------------------------------------------
620
+ '''
621
+
622
+
623
+ # --------------------------------------------
624
+ # PSNR
625
+ # --------------------------------------------
626
+ def calculate_psnr(img1, img2, border=0):
627
+ # img1 and img2 have range [0, 255]
628
+ #img1 = img1.squeeze()
629
+ #img2 = img2.squeeze()
630
+ if not img1.shape == img2.shape:
631
+ raise ValueError('Input images must have the same dimensions.')
632
+ h, w = img1.shape[:2]
633
+ img1 = img1[border:h-border, border:w-border]
634
+ img2 = img2[border:h-border, border:w-border]
635
+
636
+ img1 = img1.astype(np.float64)
637
+ img2 = img2.astype(np.float64)
638
+ mse = np.mean((img1 - img2)**2)
639
+ if mse == 0:
640
+ return float('inf')
641
+ return 20 * math.log10(255.0 / math.sqrt(mse))
642
+
643
+
644
+ # --------------------------------------------
645
+ # SSIM
646
+ # --------------------------------------------
647
+ def calculate_ssim(img1, img2, border=0):
648
+ '''calculate SSIM
649
+ the same outputs as MATLAB's
650
+ img1, img2: [0, 255]
651
+ '''
652
+ #img1 = img1.squeeze()
653
+ #img2 = img2.squeeze()
654
+ if not img1.shape == img2.shape:
655
+ raise ValueError('Input images must have the same dimensions.')
656
+ h, w = img1.shape[:2]
657
+ img1 = img1[border:h-border, border:w-border]
658
+ img2 = img2[border:h-border, border:w-border]
659
+
660
+ if img1.ndim == 2:
661
+ return ssim(img1, img2)
662
+ elif img1.ndim == 3:
663
+ if img1.shape[2] == 3:
664
+ ssims = []
665
+ for i in range(3):
666
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
667
+ return np.array(ssims).mean()
668
+ elif img1.shape[2] == 1:
669
+ return ssim(np.squeeze(img1), np.squeeze(img2))
670
+ else:
671
+ raise ValueError('Wrong input image dimensions.')
672
+
673
+
674
+ def ssim(img1, img2):
675
+ C1 = (0.01 * 255)**2
676
+ C2 = (0.03 * 255)**2
677
+
678
+ img1 = img1.astype(np.float64)
679
+ img2 = img2.astype(np.float64)
680
+ kernel = cv2.getGaussianKernel(11, 1.5)
681
+ window = np.outer(kernel, kernel.transpose())
682
+
683
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
684
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
685
+ mu1_sq = mu1**2
686
+ mu2_sq = mu2**2
687
+ mu1_mu2 = mu1 * mu2
688
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
689
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
690
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
691
+
692
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
693
+ (sigma1_sq + sigma2_sq + C2))
694
+ return ssim_map.mean()
695
+
696
+
697
+ def _blocking_effect_factor(im):
698
+ block_size = 8
699
+
700
+ block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
701
+ block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
702
+
703
+ horizontal_block_difference = (
704
+ (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
705
+ 3).sum(2).sum(1)
706
+ vertical_block_difference = (
707
+ (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
708
+ 2).sum(1)
709
+
710
+ nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
711
+ nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
712
+
713
+ horizontal_nonblock_difference = (
714
+ (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
715
+ 3).sum(2).sum(1)
716
+ vertical_nonblock_difference = (
717
+ (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
718
+ 3).sum(2).sum(1)
719
+
720
+ n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
721
+ n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
722
+ boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
723
+ n_boundary_horiz + n_boundary_vert)
724
+
725
+ n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
726
+ n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
727
+ nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
728
+ n_nonboundary_horiz + n_nonboundary_vert)
729
+
730
+ scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
731
+ bef = scaler * (boundary_difference - nonboundary_difference)
732
+
733
+ bef[boundary_difference <= nonboundary_difference] = 0
734
+ return bef
735
+
736
+
737
+ def calculate_psnrb(img1, img2, border=0):
738
+ """Calculate PSNR-B (Peak Signal-to-Noise Ratio).
739
+ Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
740
+ # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
741
+ Args:
742
+ img1 (ndarray): Images with range [0, 255].
743
+ img2 (ndarray): Images with range [0, 255].
744
+ border (int): Cropped pixels in each edge of an image. These
745
+ pixels are not involved in the PSNR calculation.
746
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
747
+ Returns:
748
+ float: psnr result.
749
+ """
750
+
751
+ if not img1.shape == img2.shape:
752
+ raise ValueError('Input images must have the same dimensions.')
753
+
754
+ if img1.ndim == 2:
755
+ img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
756
+
757
+ h, w = img1.shape[:2]
758
+ img1 = img1[border:h-border, border:w-border]
759
+ img2 = img2[border:h-border, border:w-border]
760
+
761
+ img1 = img1.astype(np.float64)
762
+ img2 = img2.astype(np.float64)
763
+
764
+ # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
765
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
766
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
767
+
768
+ total = 0
769
+ for c in range(img1.shape[1]):
770
+ mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
771
+ bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
772
+
773
+ mse = mse.view(mse.shape[0], -1).mean(1)
774
+ total += 10 * torch.log10(1 / (mse + bef))
775
+
776
+ return float(total) / img1.shape[1]
777
+
778
+ '''
779
+ # --------------------------------------------
780
+ # matlab's bicubic imresize (numpy and torch) [0, 1]
781
+ # --------------------------------------------
782
+ '''
783
+
784
+
785
+ # matlab 'imresize' function, now only support 'bicubic'
786
+ def cubic(x):
787
+ absx = torch.abs(x)
788
+ absx2 = absx**2
789
+ absx3 = absx**3
790
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
791
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
792
+
793
+
794
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
795
+ if (scale < 1) and (antialiasing):
796
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
797
+ kernel_width = kernel_width / scale
798
+
799
+ # Output-space coordinates
800
+ x = torch.linspace(1, out_length, out_length)
801
+
802
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
803
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
804
+ # space maps to 1.5 in input space.
805
+ u = x / scale + 0.5 * (1 - 1 / scale)
806
+
807
+ # What is the left-most pixel that can be involved in the computation?
808
+ left = torch.floor(u - kernel_width / 2)
809
+
810
+ # What is the maximum number of pixels that can be involved in the
811
+ # computation? Note: it's OK to use an extra pixel here; if the
812
+ # corresponding weights are all zero, it will be eliminated at the end
813
+ # of this function.
814
+ P = math.ceil(kernel_width) + 2
815
+
816
+ # The indices of the input pixels involved in computing the k-th output
817
+ # pixel are in row k of the indices matrix.
818
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
819
+ 1, P).expand(out_length, P)
820
+
821
+ # The weights used to compute the k-th output pixel are in row k of the
822
+ # weights matrix.
823
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
824
+ # apply cubic kernel
825
+ if (scale < 1) and (antialiasing):
826
+ weights = scale * cubic(distance_to_center * scale)
827
+ else:
828
+ weights = cubic(distance_to_center)
829
+ # Normalize the weights matrix so that each row sums to 1.
830
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
831
+ weights = weights / weights_sum.expand(out_length, P)
832
+
833
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
834
+ weights_zero_tmp = torch.sum((weights == 0), 0)
835
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
836
+ indices = indices.narrow(1, 1, P - 2)
837
+ weights = weights.narrow(1, 1, P - 2)
838
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
839
+ indices = indices.narrow(1, 0, P - 2)
840
+ weights = weights.narrow(1, 0, P - 2)
841
+ weights = weights.contiguous()
842
+ indices = indices.contiguous()
843
+ sym_len_s = -indices.min() + 1
844
+ sym_len_e = indices.max() - in_length
845
+ indices = indices + sym_len_s - 1
846
+ return weights, indices, int(sym_len_s), int(sym_len_e)
847
+
848
+
849
+ # --------------------------------------------
850
+ # imresize for tensor image [0, 1]
851
+ # --------------------------------------------
852
+ def imresize(img, scale, antialiasing=True):
853
+ # Now the scale should be the same for H and W
854
+ # input: img: pytorch tensor, CHW or HW [0,1]
855
+ # output: CHW or HW [0,1] w/o round
856
+ need_squeeze = True if img.dim() == 2 else False
857
+ if need_squeeze:
858
+ img.unsqueeze_(0)
859
+ in_C, in_H, in_W = img.size()
860
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
861
+ kernel_width = 4
862
+ kernel = 'cubic'
863
+
864
+ # Return the desired dimension order for performing the resize. The
865
+ # strategy is to perform the resize first along the dimension with the
866
+ # smallest scale factor.
867
+ # Now we do not support this.
868
+
869
+ # get weights and indices
870
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
871
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
872
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
873
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
874
+ # process H dimension
875
+ # symmetric copying
876
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
877
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
878
+
879
+ sym_patch = img[:, :sym_len_Hs, :]
880
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
881
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
882
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
883
+
884
+ sym_patch = img[:, -sym_len_He:, :]
885
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
886
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
887
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
888
+
889
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
890
+ kernel_width = weights_H.size(1)
891
+ for i in range(out_H):
892
+ idx = int(indices_H[i][0])
893
+ for j in range(out_C):
894
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
895
+
896
+ # process W dimension
897
+ # symmetric copying
898
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
899
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
900
+
901
+ sym_patch = out_1[:, :, :sym_len_Ws]
902
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
903
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
904
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
905
+
906
+ sym_patch = out_1[:, :, -sym_len_We:]
907
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
908
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
909
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
910
+
911
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
912
+ kernel_width = weights_W.size(1)
913
+ for i in range(out_W):
914
+ idx = int(indices_W[i][0])
915
+ for j in range(out_C):
916
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
917
+ if need_squeeze:
918
+ out_2.squeeze_()
919
+ return out_2
920
+
921
+
922
+ # --------------------------------------------
923
+ # imresize for numpy image [0, 1]
924
+ # --------------------------------------------
925
+ def imresize_np(img, scale, antialiasing=True):
926
+ # Now the scale should be the same for H and W
927
+ # input: img: Numpy, HWC or HW [0,1]
928
+ # output: HWC or HW [0,1] w/o round
929
+ img = torch.from_numpy(img)
930
+ need_squeeze = True if img.dim() == 2 else False
931
+ if need_squeeze:
932
+ img.unsqueeze_(2)
933
+
934
+ in_H, in_W, in_C = img.size()
935
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
936
+ kernel_width = 4
937
+ kernel = 'cubic'
938
+
939
+ # Return the desired dimension order for performing the resize. The
940
+ # strategy is to perform the resize first along the dimension with the
941
+ # smallest scale factor.
942
+ # Now we do not support this.
943
+
944
+ # get weights and indices
945
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
946
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
947
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
948
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
949
+ # process H dimension
950
+ # symmetric copying
951
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
952
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
953
+
954
+ sym_patch = img[:sym_len_Hs, :, :]
955
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
956
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
957
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
958
+
959
+ sym_patch = img[-sym_len_He:, :, :]
960
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
961
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
962
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
963
+
964
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
965
+ kernel_width = weights_H.size(1)
966
+ for i in range(out_H):
967
+ idx = int(indices_H[i][0])
968
+ for j in range(out_C):
969
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
970
+
971
+ # process W dimension
972
+ # symmetric copying
973
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
974
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
975
+
976
+ sym_patch = out_1[:, :sym_len_Ws, :]
977
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
978
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
979
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
980
+
981
+ sym_patch = out_1[:, -sym_len_We:, :]
982
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
983
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
984
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
985
+
986
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
987
+ kernel_width = weights_W.size(1)
988
+ for i in range(out_W):
989
+ idx = int(indices_W[i][0])
990
+ for j in range(out_C):
991
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
992
+ if need_squeeze:
993
+ out_2.squeeze_()
994
+
995
+ return out_2.numpy()
996
+
997
+
998
+ if __name__ == '__main__':
999
+ img = imread_uint('test.bmp', 3)
1000
+ # img = uint2single(img)
1001
+ # img_bicubic = imresize_np(img, 1/4)
1002
+ # imshow(single2uint(img_bicubic))
1003
+ #
1004
+ # img_tensor = single2tensor4(img)
1005
+ # for i in range(8):
1006
+ # imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
1007
+
1008
+ # patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
1009
+ # imssave(patches,'a.png')
1010
+
1011
+
1012
+
1013
+
1014
+
1015
+
1016
+
core/data/deg_kair_utils/utils_lmdb.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import lmdb
3
+ import sys
4
+ from multiprocessing import Pool
5
+ from os import path as osp
6
+ from tqdm import tqdm
7
+
8
+
9
+ def make_lmdb_from_imgs(data_path,
10
+ lmdb_path,
11
+ img_path_list,
12
+ keys,
13
+ batch=5000,
14
+ compress_level=1,
15
+ multiprocessing_read=False,
16
+ n_thread=40,
17
+ map_size=None):
18
+ """Make lmdb from images.
19
+
20
+ Contents of lmdb. The file structure is:
21
+ example.lmdb
22
+ ├── data.mdb
23
+ ├── lock.mdb
24
+ ├── meta_info.txt
25
+
26
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
27
+ https://lmdb.readthedocs.io/en/release/ for more details.
28
+
29
+ The meta_info.txt is a specified txt file to record the meta information
30
+ of our datasets. It will be automatically created when preparing
31
+ datasets by our provided dataset tools.
32
+ Each line in the txt file records 1)image name (with extension),
33
+ 2)image shape, and 3)compression level, separated by a white space.
34
+
35
+ For example, the meta information could be:
36
+ `000_00000000.png (720,1280,3) 1`, which means:
37
+ 1) image name (with extension): 000_00000000.png;
38
+ 2) image shape: (720,1280,3);
39
+ 3) compression level: 1
40
+
41
+ We use the image name without extension as the lmdb key.
42
+
43
+ If `multiprocessing_read` is True, it will read all the images to memory
44
+ using multiprocessing. Thus, your server needs to have enough memory.
45
+
46
+ Args:
47
+ data_path (str): Data path for reading images.
48
+ lmdb_path (str): Lmdb save path.
49
+ img_path_list (str): Image path list.
50
+ keys (str): Used for lmdb keys.
51
+ batch (int): After processing batch images, lmdb commits.
52
+ Default: 5000.
53
+ compress_level (int): Compress level when encoding images. Default: 1.
54
+ multiprocessing_read (bool): Whether use multiprocessing to read all
55
+ the images to memory. Default: False.
56
+ n_thread (int): For multiprocessing.
57
+ map_size (int | None): Map size for lmdb env. If None, use the
58
+ estimated size from images. Default: None
59
+ """
60
+
61
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
62
+ f'but got {len(img_path_list)} and {len(keys)}')
63
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
64
+ print(f'Totoal images: {len(img_path_list)}')
65
+ if not lmdb_path.endswith('.lmdb'):
66
+ raise ValueError("lmdb_path must end with '.lmdb'.")
67
+ if osp.exists(lmdb_path):
68
+ print(f'Folder {lmdb_path} already exists. Exit.')
69
+ sys.exit(1)
70
+
71
+ if multiprocessing_read:
72
+ # read all the images to memory (multiprocessing)
73
+ dataset = {} # use dict to keep the order for multiprocessing
74
+ shapes = {}
75
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
76
+ pbar = tqdm(total=len(img_path_list), unit='image')
77
+
78
+ def callback(arg):
79
+ """get the image data and update pbar."""
80
+ key, dataset[key], shapes[key] = arg
81
+ pbar.update(1)
82
+ pbar.set_description(f'Read {key}')
83
+
84
+ pool = Pool(n_thread)
85
+ for path, key in zip(img_path_list, keys):
86
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
87
+ pool.close()
88
+ pool.join()
89
+ pbar.close()
90
+ print(f'Finish reading {len(img_path_list)} images.')
91
+
92
+ # create lmdb environment
93
+ if map_size is None:
94
+ # obtain data size for one image
95
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
96
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
97
+ data_size_per_img = img_byte.nbytes
98
+ print('Data size per image is: ', data_size_per_img)
99
+ data_size = data_size_per_img * len(img_path_list)
100
+ map_size = data_size * 10
101
+
102
+ env = lmdb.open(lmdb_path, map_size=map_size)
103
+
104
+ # write data to lmdb
105
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
106
+ txn = env.begin(write=True)
107
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
108
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
109
+ pbar.update(1)
110
+ pbar.set_description(f'Write {key}')
111
+ key_byte = key.encode('ascii')
112
+ if multiprocessing_read:
113
+ img_byte = dataset[key]
114
+ h, w, c = shapes[key]
115
+ else:
116
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
117
+ h, w, c = img_shape
118
+
119
+ txn.put(key_byte, img_byte)
120
+ # write meta information
121
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
122
+ if idx % batch == 0:
123
+ txn.commit()
124
+ txn = env.begin(write=True)
125
+ pbar.close()
126
+ txn.commit()
127
+ env.close()
128
+ txt_file.close()
129
+ print('\nFinish writing lmdb.')
130
+
131
+
132
+ def read_img_worker(path, key, compress_level):
133
+ """Read image worker.
134
+
135
+ Args:
136
+ path (str): Image path.
137
+ key (str): Image key.
138
+ compress_level (int): Compress level when encoding images.
139
+
140
+ Returns:
141
+ str: Image key.
142
+ byte: Image byte.
143
+ tuple[int]: Image shape.
144
+ """
145
+
146
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
147
+ # deal with `libpng error: Read Error`
148
+ if img is None:
149
+ print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
150
+ from PIL import Image
151
+ import numpy as np
152
+ img = Image.open(path)
153
+ img = np.asanyarray(img)
154
+ img = img[:, :, [2, 1, 0]]
155
+
156
+ if img.ndim == 2:
157
+ h, w = img.shape
158
+ c = 1
159
+ else:
160
+ h, w, c = img.shape
161
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
162
+ return (key, img_byte, (h, w, c))
163
+
164
+
165
+ class LmdbMaker():
166
+ """LMDB Maker.
167
+
168
+ Args:
169
+ lmdb_path (str): Lmdb save path.
170
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
171
+ batch (int): After processing batch images, lmdb commits.
172
+ Default: 5000.
173
+ compress_level (int): Compress level when encoding images. Default: 1.
174
+ """
175
+
176
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
177
+ if not lmdb_path.endswith('.lmdb'):
178
+ raise ValueError("lmdb_path must end with '.lmdb'.")
179
+ if osp.exists(lmdb_path):
180
+ print(f'Folder {lmdb_path} already exists. Exit.')
181
+ sys.exit(1)
182
+
183
+ self.lmdb_path = lmdb_path
184
+ self.batch = batch
185
+ self.compress_level = compress_level
186
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
187
+ self.txn = self.env.begin(write=True)
188
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
189
+ self.counter = 0
190
+
191
+ def put(self, img_byte, key, img_shape):
192
+ self.counter += 1
193
+ key_byte = key.encode('ascii')
194
+ self.txn.put(key_byte, img_byte)
195
+ # write meta information
196
+ h, w, c = img_shape
197
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
198
+ if self.counter % self.batch == 0:
199
+ self.txn.commit()
200
+ self.txn = self.env.begin(write=True)
201
+
202
+ def close(self):
203
+ self.txn.commit()
204
+ self.env.close()
205
+ self.txt_file.close()
core/data/deg_kair_utils/utils_logger.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import datetime
3
+ import logging
4
+
5
+
6
+ '''
7
+ # --------------------------------------------
8
+ # Kai Zhang (github: https://github.com/cszn)
9
+ # 03/Mar/2019
10
+ # --------------------------------------------
11
+ # https://github.com/xinntao/BasicSR
12
+ # --------------------------------------------
13
+ '''
14
+
15
+
16
+ def log(*args, **kwargs):
17
+ print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
18
+
19
+
20
+ '''
21
+ # --------------------------------------------
22
+ # logger
23
+ # --------------------------------------------
24
+ '''
25
+
26
+
27
+ def logger_info(logger_name, log_path='default_logger.log'):
28
+ ''' set up logger
29
+ modified by Kai Zhang (github: https://github.com/cszn)
30
+ '''
31
+ log = logging.getLogger(logger_name)
32
+ if log.hasHandlers():
33
+ print('LogHandlers exist!')
34
+ else:
35
+ print('LogHandlers setup!')
36
+ level = logging.INFO
37
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
38
+ fh = logging.FileHandler(log_path, mode='a')
39
+ fh.setFormatter(formatter)
40
+ log.setLevel(level)
41
+ log.addHandler(fh)
42
+ # print(len(log.handlers))
43
+
44
+ sh = logging.StreamHandler()
45
+ sh.setFormatter(formatter)
46
+ log.addHandler(sh)
47
+
48
+
49
+ '''
50
+ # --------------------------------------------
51
+ # print to file and std_out simultaneously
52
+ # --------------------------------------------
53
+ '''
54
+
55
+
56
+ class logger_print(object):
57
+ def __init__(self, log_path="default.log"):
58
+ self.terminal = sys.stdout
59
+ self.log = open(log_path, 'a')
60
+
61
+ def write(self, message):
62
+ self.terminal.write(message)
63
+ self.log.write(message) # write the message
64
+
65
+ def flush(self):
66
+ pass
core/data/deg_kair_utils/utils_mat.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import scipy.io as spio
4
+ import pandas as pd
5
+
6
+
7
+ def loadmat(filename):
8
+ '''
9
+ this function should be called instead of direct spio.loadmat
10
+ as it cures the problem of not properly recovering python dictionaries
11
+ from mat files. It calls the function check keys to cure all entries
12
+ which are still mat-objects
13
+ '''
14
+ data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
15
+ return dict_to_nonedict(_check_keys(data))
16
+
17
+ def _check_keys(dict):
18
+ '''
19
+ checks if entries in dictionary are mat-objects. If yes
20
+ todict is called to change them to nested dictionaries
21
+ '''
22
+ for key in dict:
23
+ if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
24
+ dict[key] = _todict(dict[key])
25
+ return dict
26
+
27
+ def _todict(matobj):
28
+ '''
29
+ A recursive function which constructs from matobjects nested dictionaries
30
+ '''
31
+ dict = {}
32
+ for strg in matobj._fieldnames:
33
+ elem = matobj.__dict__[strg]
34
+ if isinstance(elem, spio.matlab.mio5_params.mat_struct):
35
+ dict[strg] = _todict(elem)
36
+ else:
37
+ dict[strg] = elem
38
+ return dict
39
+
40
+
41
+ def dict_to_nonedict(opt):
42
+ if isinstance(opt, dict):
43
+ new_opt = dict()
44
+ for key, sub_opt in opt.items():
45
+ new_opt[key] = dict_to_nonedict(sub_opt)
46
+ return NoneDict(**new_opt)
47
+ elif isinstance(opt, list):
48
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
49
+ else:
50
+ return opt
51
+
52
+
53
+ class NoneDict(dict):
54
+ def __missing__(self, key):
55
+ return None
56
+
57
+
58
+ def mat2json(mat_path=None, filepath = None):
59
+ """
60
+ Converts .mat file to .json and writes new file
61
+ Parameters
62
+ ----------
63
+ mat_path: Str
64
+ path/filename .mat存放路径
65
+ filepath: Str
66
+ 如果需要保存成json, 添加这一路径. 否则不保存
67
+ Returns
68
+ 返回转化的字典
69
+ -------
70
+ None
71
+ Examples
72
+ --------
73
+ >>> mat2json(blah blah)
74
+ """
75
+
76
+ matlabFile = loadmat(mat_path)
77
+ #pop all those dumb fields that don't let you jsonize file
78
+ matlabFile.pop('__header__')
79
+ matlabFile.pop('__version__')
80
+ matlabFile.pop('__globals__')
81
+ #jsonize the file - orientation is 'index'
82
+ matlabFile = pd.Series(matlabFile).to_json()
83
+
84
+ if filepath:
85
+ json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
86
+ with open(json_path, 'w') as f:
87
+ f.write(matlabFile)
88
+ return matlabFile
core/data/deg_kair_utils/utils_matconvnet.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ from collections import OrderedDict
5
+
6
+ # import scipy.io as io
7
+ import hdf5storage
8
+
9
+ """
10
+ # --------------------------------------------
11
+ # Convert matconvnet SimpleNN model into pytorch model
12
+ # --------------------------------------------
13
+ # Kai Zhang ([email protected])
14
+ # https://github.com/cszn
15
+ # 28/Nov/2019
16
+ # --------------------------------------------
17
+ """
18
+
19
+
20
+ def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
21
+ """Modified version of https://github.com/albanie/pytorch-mcn
22
+ Adjust memory layout and load weights as torch tensor
23
+ Args:
24
+ x (ndaray): a numpy array, corresponding to a set of network weights
25
+ stored in column major order
26
+ squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
27
+ singletons from the trailing dimensions. So after converting to
28
+ pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
29
+ it will be reshaped to a matrix with shape (A,B).
30
+ in_features (int :: None): used to reshape weights for a linear block.
31
+ out_features (int :: None): used to reshape weights for a linear block.
32
+ Returns:
33
+ torch.tensor: a permuted sets of weights, matching the pytorch layout
34
+ convention
35
+ """
36
+ if x.ndim == 4:
37
+ x = x.transpose((3, 2, 0, 1))
38
+ # for FFDNet, pixel-shuffle layer
39
+ # if x.shape[1]==13:
40
+ # x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
41
+ # if x.shape[0]==12:
42
+ # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
43
+ # if x.shape[1]==5:
44
+ # x=x[:,[0,2,1,3, 4],:,:]
45
+ # if x.shape[0]==4:
46
+ # x=x[[0,2,1,3],:,:,:]
47
+ ## for SRMD, pixel-shuffle layer
48
+ # if x.shape[0]==12:
49
+ # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
50
+ # if x.shape[0]==27:
51
+ # x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
52
+ # if x.shape[0]==48:
53
+ # x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
54
+
55
+ elif x.ndim == 3: # add by Kai
56
+ x = x[:,:,:,None]
57
+ x = x.transpose((3, 2, 0, 1))
58
+ elif x.ndim == 2:
59
+ if x.shape[1] == 1:
60
+ x = x.flatten()
61
+ if squeeze:
62
+ if in_features and out_features:
63
+ x = x.reshape((out_features, in_features))
64
+ x = np.squeeze(x)
65
+ return torch.from_numpy(np.ascontiguousarray(x))
66
+
67
+
68
+ def save_model(network, save_path):
69
+ state_dict = network.state_dict()
70
+ for key, param in state_dict.items():
71
+ state_dict[key] = param.cpu()
72
+ torch.save(state_dict, save_path)
73
+
74
+
75
+ if __name__ == '__main__':
76
+
77
+
78
+ # from utils import utils_logger
79
+ # import logging
80
+ # utils_logger.logger_info('a', 'a.log')
81
+ # logger = logging.getLogger('a')
82
+ #
83
+ # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
84
+ mcn = hdf5storage.loadmat('models/modelcolor.mat')
85
+
86
+
87
+ #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
88
+
89
+ mat_net = OrderedDict()
90
+ for idx in range(25):
91
+ mat_net[str(idx)] = OrderedDict()
92
+ count = -1
93
+
94
+ print(idx)
95
+ for i in range(13):
96
+
97
+ if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
98
+
99
+ count += 1
100
+ w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
101
+ # print(w.shape)
102
+ w = weights2tensor(w)
103
+ # print(w.shape)
104
+
105
+ b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
106
+ b = weights2tensor(b)
107
+ print(b.shape)
108
+
109
+ mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
110
+ mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
111
+
112
+ torch.save(mat_net, 'model_zoo/modelcolor.pth')
113
+
114
+
115
+
116
+ # from models.network_dncnn import IRCNN as net
117
+ # network = net(in_nc=3, out_nc=3, nc=64)
118
+ # state_dict = network.state_dict()
119
+ #
120
+ # #show_kv(state_dict)
121
+ #
122
+ # for i in range(len(mcn['net'][0][0][0])):
123
+ # print(mcn['net'][0][0][0][i][0][0][0][0])
124
+ #
125
+ # count = -1
126
+ # mat_net = OrderedDict()
127
+ # for i in range(len(mcn['net'][0][0][0])):
128
+ # if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
129
+ #
130
+ # count += 1
131
+ # w = mcn['net'][0][0][0][i][0][1][0][0]
132
+ # print(w.shape)
133
+ # w = weights2tensor(w)
134
+ # print(w.shape)
135
+ #
136
+ # b = mcn['net'][0][0][0][i][0][1][0][1]
137
+ # b = weights2tensor(b)
138
+ # print(b.shape)
139
+ #
140
+ # mat_net['model.{:d}.weight'.format(count*2)] = w
141
+ # mat_net['model.{:d}.bias'.format(count*2)] = b
142
+ #
143
+ # torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
144
+ #
145
+ #
146
+ #
147
+ # crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
148
+ # def show_kv(net):
149
+ # for k, v in net.items():
150
+ # print(k)
151
+ #
152
+ # show_kv(crt_net)
153
+
154
+
155
+ # from models.network_dncnn import DnCNN as net
156
+ # network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
157
+
158
+ # from models.network_srmd import SRMD as net
159
+ # #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
160
+ # network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
161
+ #
162
+ # from models.network_rrdb import RRDB as net
163
+ # network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
164
+ #
165
+ # state_dict = network.state_dict()
166
+ # for key, param in state_dict.items():
167
+ # print(key)
168
+ # from models.network_imdn import IMDN as net
169
+ # network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
170
+ # state_dict = network.state_dict()
171
+ # mat_net = OrderedDict()
172
+ # for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
173
+ # mat_net[key] = param2
174
+ # torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
175
+ #
176
+
177
+ # net_old = torch.load('net_old.pth')
178
+ # def show_kv(net):
179
+ # for k, v in net.items():
180
+ # print(k)
181
+ #
182
+ # show_kv(net_old)
183
+ # from models.network_dpsr import MSRResNet_prior as net
184
+ # model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
185
+ # state_dict = network.state_dict()
186
+ # net_new = OrderedDict()
187
+ # for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
188
+ # net_new[key] = param_old
189
+ # torch.save(net_new, 'net_new.pth')
190
+
191
+
192
+ # print(key)
193
+ # print(param.size())
194
+
195
+
196
+
197
+ # run utils/utils_matconvnet.py
core/data/deg_kair_utils/utils_model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ from utils import utils_image as util
5
+ import re
6
+ import glob
7
+ import os
8
+
9
+
10
+ '''
11
+ # --------------------------------------------
12
+ # Model
13
+ # --------------------------------------------
14
+ # Kai Zhang (github: https://github.com/cszn)
15
+ # 03/Mar/2019
16
+ # --------------------------------------------
17
+ '''
18
+
19
+
20
+ def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
21
+ """
22
+ # ---------------------------------------
23
+ # Kai Zhang (github: https://github.com/cszn)
24
+ # 03/Mar/2019
25
+ # ---------------------------------------
26
+ Args:
27
+ save_dir: model folder
28
+ net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
29
+ pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
30
+
31
+ Return:
32
+ init_iter: iteration number
33
+ init_path: model path
34
+ # ---------------------------------------
35
+ """
36
+
37
+ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
38
+ if file_list:
39
+ iter_exist = []
40
+ for file_ in file_list:
41
+ iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
42
+ iter_exist.append(int(iter_current[0]))
43
+ init_iter = max(iter_exist)
44
+ init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
45
+ else:
46
+ init_iter = 0
47
+ init_path = pretrained_path
48
+ return init_iter, init_path
49
+
50
+
51
+ def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
52
+ '''
53
+ # ---------------------------------------
54
+ # Kai Zhang (github: https://github.com/cszn)
55
+ # 03/Mar/2019
56
+ # ---------------------------------------
57
+ Args:
58
+ model: trained model
59
+ L: input Low-quality image
60
+ mode:
61
+ (0) normal: test(model, L)
62
+ (1) pad: test_pad(model, L, modulo=16)
63
+ (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
64
+ (3) x8: test_x8(model, L, modulo=1) ^_^
65
+ (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
66
+ refield: effective receptive filed of the network, 32 is enough
67
+ useful when split, i.e., mode=2, 4
68
+ min_size: min_sizeXmin_size image, e.g., 256X256 image
69
+ useful when split, i.e., mode=2, 4
70
+ sf: scale factor for super-resolution, otherwise 1
71
+ modulo: 1 if split
72
+ useful when pad, i.e., mode=1
73
+
74
+ Returns:
75
+ E: estimated image
76
+ # ---------------------------------------
77
+ '''
78
+ if mode == 0:
79
+ E = test(model, L)
80
+ elif mode == 1:
81
+ E = test_pad(model, L, modulo, sf)
82
+ elif mode == 2:
83
+ E = test_split(model, L, refield, min_size, sf, modulo)
84
+ elif mode == 3:
85
+ E = test_x8(model, L, modulo, sf)
86
+ elif mode == 4:
87
+ E = test_split_x8(model, L, refield, min_size, sf, modulo)
88
+ return E
89
+
90
+
91
+ '''
92
+ # --------------------------------------------
93
+ # normal (0)
94
+ # --------------------------------------------
95
+ '''
96
+
97
+
98
+ def test(model, L):
99
+ E = model(L)
100
+ return E
101
+
102
+
103
+ '''
104
+ # --------------------------------------------
105
+ # pad (1)
106
+ # --------------------------------------------
107
+ '''
108
+
109
+
110
+ def test_pad(model, L, modulo=16, sf=1):
111
+ h, w = L.size()[-2:]
112
+ paddingBottom = int(np.ceil(h/modulo)*modulo-h)
113
+ paddingRight = int(np.ceil(w/modulo)*modulo-w)
114
+ L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
115
+ E = model(L)
116
+ E = E[..., :h*sf, :w*sf]
117
+ return E
118
+
119
+
120
+ '''
121
+ # --------------------------------------------
122
+ # split (function)
123
+ # --------------------------------------------
124
+ '''
125
+
126
+
127
+ def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
128
+ """
129
+ Args:
130
+ model: trained model
131
+ L: input Low-quality image
132
+ refield: effective receptive filed of the network, 32 is enough
133
+ min_size: min_sizeXmin_size image, e.g., 256X256 image
134
+ sf: scale factor for super-resolution, otherwise 1
135
+ modulo: 1 if split
136
+
137
+ Returns:
138
+ E: estimated result
139
+ """
140
+ h, w = L.size()[-2:]
141
+ if h*w <= min_size**2:
142
+ L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
143
+ E = model(L)
144
+ E = E[..., :h*sf, :w*sf]
145
+ else:
146
+ top = slice(0, (h//2//refield+1)*refield)
147
+ bottom = slice(h - (h//2//refield+1)*refield, h)
148
+ left = slice(0, (w//2//refield+1)*refield)
149
+ right = slice(w - (w//2//refield+1)*refield, w)
150
+ Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
151
+
152
+ if h * w <= 4*(min_size**2):
153
+ Es = [model(Ls[i]) for i in range(4)]
154
+ else:
155
+ Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
156
+
157
+ b, c = Es[0].size()[:2]
158
+ E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
159
+
160
+ E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
161
+ E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
162
+ E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
163
+ E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
164
+ return E
165
+
166
+
167
+ '''
168
+ # --------------------------------------------
169
+ # split (2)
170
+ # --------------------------------------------
171
+ '''
172
+
173
+
174
+ def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
175
+ E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
176
+ return E
177
+
178
+
179
+ '''
180
+ # --------------------------------------------
181
+ # x8 (3)
182
+ # --------------------------------------------
183
+ '''
184
+
185
+
186
+ def test_x8(model, L, modulo=1, sf=1):
187
+ E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
188
+ for i in range(len(E_list)):
189
+ if i == 3 or i == 5:
190
+ E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
191
+ else:
192
+ E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
193
+ output_cat = torch.stack(E_list, dim=0)
194
+ E = output_cat.mean(dim=0, keepdim=False)
195
+ return E
196
+
197
+
198
+ '''
199
+ # --------------------------------------------
200
+ # split and x8 (4)
201
+ # --------------------------------------------
202
+ '''
203
+
204
+
205
+ def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
206
+ E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
207
+ for k, i in enumerate(range(len(E_list))):
208
+ if i==3 or i==5:
209
+ E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
210
+ else:
211
+ E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
212
+ output_cat = torch.stack(E_list, dim=0)
213
+ E = output_cat.mean(dim=0, keepdim=False)
214
+ return E
215
+
216
+
217
+ '''
218
+ # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
219
+ # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
220
+ # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
221
+ '''
222
+
223
+
224
+ '''
225
+ # --------------------------------------------
226
+ # print
227
+ # --------------------------------------------
228
+ '''
229
+
230
+
231
+ # --------------------------------------------
232
+ # print model
233
+ # --------------------------------------------
234
+ def print_model(model):
235
+ msg = describe_model(model)
236
+ print(msg)
237
+
238
+
239
+ # --------------------------------------------
240
+ # print params
241
+ # --------------------------------------------
242
+ def print_params(model):
243
+ msg = describe_params(model)
244
+ print(msg)
245
+
246
+
247
+ '''
248
+ # --------------------------------------------
249
+ # information
250
+ # --------------------------------------------
251
+ '''
252
+
253
+
254
+ # --------------------------------------------
255
+ # model inforation
256
+ # --------------------------------------------
257
+ def info_model(model):
258
+ msg = describe_model(model)
259
+ return msg
260
+
261
+
262
+ # --------------------------------------------
263
+ # params inforation
264
+ # --------------------------------------------
265
+ def info_params(model):
266
+ msg = describe_params(model)
267
+ return msg
268
+
269
+
270
+ '''
271
+ # --------------------------------------------
272
+ # description
273
+ # --------------------------------------------
274
+ '''
275
+
276
+
277
+ # --------------------------------------------
278
+ # model name and total number of parameters
279
+ # --------------------------------------------
280
+ def describe_model(model):
281
+ if isinstance(model, torch.nn.DataParallel):
282
+ model = model.module
283
+ msg = '\n'
284
+ msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
285
+ msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
286
+ msg += 'Net structure:\n{}'.format(str(model)) + '\n'
287
+ return msg
288
+
289
+
290
+ # --------------------------------------------
291
+ # parameters description
292
+ # --------------------------------------------
293
+ def describe_params(model):
294
+ if isinstance(model, torch.nn.DataParallel):
295
+ model = model.module
296
+ msg = '\n'
297
+ msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
298
+ for name, param in model.state_dict().items():
299
+ if not 'num_batches_tracked' in name:
300
+ v = param.data.clone().float()
301
+ msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
302
+ return msg
303
+
304
+
305
+ if __name__ == '__main__':
306
+
307
+ class Net(torch.nn.Module):
308
+ def __init__(self, in_channels=3, out_channels=3):
309
+ super(Net, self).__init__()
310
+ self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
311
+
312
+ def forward(self, x):
313
+ x = self.conv(x)
314
+ return x
315
+
316
+ start = torch.cuda.Event(enable_timing=True)
317
+ end = torch.cuda.Event(enable_timing=True)
318
+
319
+ model = Net()
320
+ model = model.eval()
321
+ print_model(model)
322
+ print_params(model)
323
+ x = torch.randn((2,3,401,401))
324
+ torch.cuda.empty_cache()
325
+ with torch.no_grad():
326
+ for mode in range(5):
327
+ y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
328
+ print(y.shape)
329
+
330
+ # run utils/utils_model.py
core/data/deg_kair_utils/utils_modelsummary.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ '''
6
+ ---- 1) FLOPs: floating point operations
7
+ ---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
8
+ ---- 3) #Conv2d: the number of ‘Conv2d’ layers
9
+ # --------------------------------------------
10
+ # Kai Zhang (github: https://github.com/cszn)
11
+ # 21/July/2020
12
+ # --------------------------------------------
13
+ # Reference
14
+ https://github.com/sovrasov/flops-counter.pytorch.git
15
+
16
+ # If you use this code, please consider the following citation:
17
+
18
+ @inproceedings{zhang2020aim, %
19
+ title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
20
+ author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
21
+ booktitle={European Conference on Computer Vision Workshops},
22
+ year={2020}
23
+ }
24
+ # --------------------------------------------
25
+ '''
26
+
27
+ def get_model_flops(model, input_res, print_per_layer_stat=True,
28
+ input_constructor=None):
29
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
30
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
31
+ flops_model = add_flops_counting_methods(model)
32
+ flops_model.eval().start_flops_count()
33
+ if input_constructor:
34
+ input = input_constructor(input_res)
35
+ _ = flops_model(**input)
36
+ else:
37
+ device = list(flops_model.parameters())[-1].device
38
+ batch = torch.FloatTensor(1, *input_res).to(device)
39
+ _ = flops_model(batch)
40
+
41
+ if print_per_layer_stat:
42
+ print_model_with_flops(flops_model)
43
+ flops_count = flops_model.compute_average_flops_cost()
44
+ flops_model.stop_flops_count()
45
+
46
+ return flops_count
47
+
48
+ def get_model_activation(model, input_res, input_constructor=None):
49
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
50
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
51
+ activation_model = add_activation_counting_methods(model)
52
+ activation_model.eval().start_activation_count()
53
+ if input_constructor:
54
+ input = input_constructor(input_res)
55
+ _ = activation_model(**input)
56
+ else:
57
+ device = list(activation_model.parameters())[-1].device
58
+ batch = torch.FloatTensor(1, *input_res).to(device)
59
+ _ = activation_model(batch)
60
+
61
+ activation_count, num_conv = activation_model.compute_average_activation_cost()
62
+ activation_model.stop_activation_count()
63
+
64
+ return activation_count, num_conv
65
+
66
+
67
+ def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
68
+ input_constructor=None):
69
+ assert type(input_res) is tuple
70
+ assert len(input_res) >= 3
71
+ flops_model = add_flops_counting_methods(model)
72
+ flops_model.eval().start_flops_count()
73
+ if input_constructor:
74
+ input = input_constructor(input_res)
75
+ _ = flops_model(**input)
76
+ else:
77
+ batch = torch.FloatTensor(1, *input_res)
78
+ _ = flops_model(batch)
79
+
80
+ if print_per_layer_stat:
81
+ print_model_with_flops(flops_model)
82
+ flops_count = flops_model.compute_average_flops_cost()
83
+ params_count = get_model_parameters_number(flops_model)
84
+ flops_model.stop_flops_count()
85
+
86
+ if as_strings:
87
+ return flops_to_string(flops_count), params_to_string(params_count)
88
+
89
+ return flops_count, params_count
90
+
91
+
92
+ def flops_to_string(flops, units='GMac', precision=2):
93
+ if units is None:
94
+ if flops // 10**9 > 0:
95
+ return str(round(flops / 10.**9, precision)) + ' GMac'
96
+ elif flops // 10**6 > 0:
97
+ return str(round(flops / 10.**6, precision)) + ' MMac'
98
+ elif flops // 10**3 > 0:
99
+ return str(round(flops / 10.**3, precision)) + ' KMac'
100
+ else:
101
+ return str(flops) + ' Mac'
102
+ else:
103
+ if units == 'GMac':
104
+ return str(round(flops / 10.**9, precision)) + ' ' + units
105
+ elif units == 'MMac':
106
+ return str(round(flops / 10.**6, precision)) + ' ' + units
107
+ elif units == 'KMac':
108
+ return str(round(flops / 10.**3, precision)) + ' ' + units
109
+ else:
110
+ return str(flops) + ' Mac'
111
+
112
+
113
+ def params_to_string(params_num):
114
+ if params_num // 10 ** 6 > 0:
115
+ return str(round(params_num / 10 ** 6, 2)) + ' M'
116
+ elif params_num // 10 ** 3:
117
+ return str(round(params_num / 10 ** 3, 2)) + ' k'
118
+ else:
119
+ return str(params_num)
120
+
121
+
122
+ def print_model_with_flops(model, units='GMac', precision=3):
123
+ total_flops = model.compute_average_flops_cost()
124
+
125
+ def accumulate_flops(self):
126
+ if is_supported_instance(self):
127
+ return self.__flops__ / model.__batch_counter__
128
+ else:
129
+ sum = 0
130
+ for m in self.children():
131
+ sum += m.accumulate_flops()
132
+ return sum
133
+
134
+ def flops_repr(self):
135
+ accumulated_flops_cost = self.accumulate_flops()
136
+ return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
137
+ '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
138
+ self.original_extra_repr()])
139
+
140
+ def add_extra_repr(m):
141
+ m.accumulate_flops = accumulate_flops.__get__(m)
142
+ flops_extra_repr = flops_repr.__get__(m)
143
+ if m.extra_repr != flops_extra_repr:
144
+ m.original_extra_repr = m.extra_repr
145
+ m.extra_repr = flops_extra_repr
146
+ assert m.extra_repr != m.original_extra_repr
147
+
148
+ def del_extra_repr(m):
149
+ if hasattr(m, 'original_extra_repr'):
150
+ m.extra_repr = m.original_extra_repr
151
+ del m.original_extra_repr
152
+ if hasattr(m, 'accumulate_flops'):
153
+ del m.accumulate_flops
154
+
155
+ model.apply(add_extra_repr)
156
+ print(model)
157
+ model.apply(del_extra_repr)
158
+
159
+
160
+ def get_model_parameters_number(model):
161
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
162
+ return params_num
163
+
164
+
165
+ def add_flops_counting_methods(net_main_module):
166
+ # adding additional methods to the existing module object,
167
+ # this is done this way so that each function has access to self object
168
+ # embed()
169
+ net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
170
+ net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
171
+ net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
172
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
173
+
174
+ net_main_module.reset_flops_count()
175
+ return net_main_module
176
+
177
+
178
+ def compute_average_flops_cost(self):
179
+ """
180
+ A method that will be available after add_flops_counting_methods() is called
181
+ on a desired net object.
182
+
183
+ Returns current mean flops consumption per image.
184
+
185
+ """
186
+
187
+ flops_sum = 0
188
+ for module in self.modules():
189
+ if is_supported_instance(module):
190
+ flops_sum += module.__flops__
191
+
192
+ return flops_sum
193
+
194
+
195
+ def start_flops_count(self):
196
+ """
197
+ A method that will be available after add_flops_counting_methods() is called
198
+ on a desired net object.
199
+
200
+ Activates the computation of mean flops consumption per image.
201
+ Call it before you run the network.
202
+
203
+ """
204
+ self.apply(add_flops_counter_hook_function)
205
+
206
+
207
+ def stop_flops_count(self):
208
+ """
209
+ A method that will be available after add_flops_counting_methods() is called
210
+ on a desired net object.
211
+
212
+ Stops computing the mean flops consumption per image.
213
+ Call whenever you want to pause the computation.
214
+
215
+ """
216
+ self.apply(remove_flops_counter_hook_function)
217
+
218
+
219
+ def reset_flops_count(self):
220
+ """
221
+ A method that will be available after add_flops_counting_methods() is called
222
+ on a desired net object.
223
+
224
+ Resets statistics computed so far.
225
+
226
+ """
227
+ self.apply(add_flops_counter_variable_or_reset)
228
+
229
+
230
+ def add_flops_counter_hook_function(module):
231
+ if is_supported_instance(module):
232
+ if hasattr(module, '__flops_handle__'):
233
+ return
234
+
235
+ if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
236
+ handle = module.register_forward_hook(conv_flops_counter_hook)
237
+ elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
238
+ handle = module.register_forward_hook(relu_flops_counter_hook)
239
+ elif isinstance(module, nn.Linear):
240
+ handle = module.register_forward_hook(linear_flops_counter_hook)
241
+ elif isinstance(module, (nn.BatchNorm2d)):
242
+ handle = module.register_forward_hook(bn_flops_counter_hook)
243
+ else:
244
+ handle = module.register_forward_hook(empty_flops_counter_hook)
245
+ module.__flops_handle__ = handle
246
+
247
+
248
+ def remove_flops_counter_hook_function(module):
249
+ if is_supported_instance(module):
250
+ if hasattr(module, '__flops_handle__'):
251
+ module.__flops_handle__.remove()
252
+ del module.__flops_handle__
253
+
254
+
255
+ def add_flops_counter_variable_or_reset(module):
256
+ if is_supported_instance(module):
257
+ module.__flops__ = 0
258
+
259
+
260
+ # ---- Internal functions
261
+ def is_supported_instance(module):
262
+ if isinstance(module,
263
+ (
264
+ nn.Conv2d, nn.ConvTranspose2d,
265
+ nn.BatchNorm2d,
266
+ nn.Linear,
267
+ nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
268
+ )):
269
+ return True
270
+
271
+ return False
272
+
273
+
274
+ def conv_flops_counter_hook(conv_module, input, output):
275
+ # Can have multiple inputs, getting the first one
276
+ # input = input[0]
277
+
278
+ batch_size = output.shape[0]
279
+ output_dims = list(output.shape[2:])
280
+
281
+ kernel_dims = list(conv_module.kernel_size)
282
+ in_channels = conv_module.in_channels
283
+ out_channels = conv_module.out_channels
284
+ groups = conv_module.groups
285
+
286
+ filters_per_channel = out_channels // groups
287
+ conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
288
+
289
+ active_elements_count = batch_size * np.prod(output_dims)
290
+ overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
291
+
292
+ # overall_flops = overall_conv_flops
293
+
294
+ conv_module.__flops__ += int(overall_conv_flops)
295
+ # conv_module.__output_dims__ = output_dims
296
+
297
+
298
+ def relu_flops_counter_hook(module, input, output):
299
+ active_elements_count = output.numel()
300
+ module.__flops__ += int(active_elements_count)
301
+ # print(module.__flops__, id(module))
302
+ # print(module)
303
+
304
+
305
+ def linear_flops_counter_hook(module, input, output):
306
+ input = input[0]
307
+ if len(input.shape) == 1:
308
+ batch_size = 1
309
+ module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
310
+ else:
311
+ batch_size = input.shape[0]
312
+ module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
313
+
314
+
315
+ def bn_flops_counter_hook(module, input, output):
316
+ # input = input[0]
317
+ # TODO: need to check here
318
+ # batch_flops = np.prod(input.shape)
319
+ # if module.affine:
320
+ # batch_flops *= 2
321
+ # module.__flops__ += int(batch_flops)
322
+ batch = output.shape[0]
323
+ output_dims = output.shape[2:]
324
+ channels = module.num_features
325
+ batch_flops = batch * channels * np.prod(output_dims)
326
+ if module.affine:
327
+ batch_flops *= 2
328
+ module.__flops__ += int(batch_flops)
329
+
330
+
331
+ # ---- Count the number of convolutional layers and the activation
332
+ def add_activation_counting_methods(net_main_module):
333
+ # adding additional methods to the existing module object,
334
+ # this is done this way so that each function has access to self object
335
+ # embed()
336
+ net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
337
+ net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
338
+ net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
339
+ net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
340
+
341
+ net_main_module.reset_activation_count()
342
+ return net_main_module
343
+
344
+
345
+ def compute_average_activation_cost(self):
346
+ """
347
+ A method that will be available after add_activation_counting_methods() is called
348
+ on a desired net object.
349
+
350
+ Returns current mean activation consumption per image.
351
+
352
+ """
353
+
354
+ activation_sum = 0
355
+ num_conv = 0
356
+ for module in self.modules():
357
+ if is_supported_instance_for_activation(module):
358
+ activation_sum += module.__activation__
359
+ num_conv += module.__num_conv__
360
+ return activation_sum, num_conv
361
+
362
+
363
+ def start_activation_count(self):
364
+ """
365
+ A method that will be available after add_activation_counting_methods() is called
366
+ on a desired net object.
367
+
368
+ Activates the computation of mean activation consumption per image.
369
+ Call it before you run the network.
370
+
371
+ """
372
+ self.apply(add_activation_counter_hook_function)
373
+
374
+
375
+ def stop_activation_count(self):
376
+ """
377
+ A method that will be available after add_activation_counting_methods() is called
378
+ on a desired net object.
379
+
380
+ Stops computing the mean activation consumption per image.
381
+ Call whenever you want to pause the computation.
382
+
383
+ """
384
+ self.apply(remove_activation_counter_hook_function)
385
+
386
+
387
+ def reset_activation_count(self):
388
+ """
389
+ A method that will be available after add_activation_counting_methods() is called
390
+ on a desired net object.
391
+
392
+ Resets statistics computed so far.
393
+
394
+ """
395
+ self.apply(add_activation_counter_variable_or_reset)
396
+
397
+
398
+ def add_activation_counter_hook_function(module):
399
+ if is_supported_instance_for_activation(module):
400
+ if hasattr(module, '__activation_handle__'):
401
+ return
402
+
403
+ if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
404
+ handle = module.register_forward_hook(conv_activation_counter_hook)
405
+ module.__activation_handle__ = handle
406
+
407
+
408
+ def remove_activation_counter_hook_function(module):
409
+ if is_supported_instance_for_activation(module):
410
+ if hasattr(module, '__activation_handle__'):
411
+ module.__activation_handle__.remove()
412
+ del module.__activation_handle__
413
+
414
+
415
+ def add_activation_counter_variable_or_reset(module):
416
+ if is_supported_instance_for_activation(module):
417
+ module.__activation__ = 0
418
+ module.__num_conv__ = 0
419
+
420
+
421
+ def is_supported_instance_for_activation(module):
422
+ if isinstance(module,
423
+ (
424
+ nn.Conv2d, nn.ConvTranspose2d,
425
+ )):
426
+ return True
427
+
428
+ return False
429
+
430
+ def conv_activation_counter_hook(module, input, output):
431
+ """
432
+ Calculate the activations in the convolutional operation.
433
+ Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
434
+ :param module:
435
+ :param input:
436
+ :param output:
437
+ :return:
438
+ """
439
+ module.__activation__ += output.numel()
440
+ module.__num_conv__ += 1
441
+
442
+
443
+ def empty_flops_counter_hook(module, input, output):
444
+ module.__flops__ += 0
445
+
446
+
447
+ def upsample_flops_counter_hook(module, input, output):
448
+ output_size = output[0]
449
+ batch_size = output_size.shape[0]
450
+ output_elements_count = batch_size
451
+ for val in output_size.shape[1:]:
452
+ output_elements_count *= val
453
+ module.__flops__ += int(output_elements_count)
454
+
455
+
456
+ def pool_flops_counter_hook(module, input, output):
457
+ input = input[0]
458
+ module.__flops__ += int(np.prod(input.shape))
459
+
460
+
461
+ def dconv_flops_counter_hook(dconv_module, input, output):
462
+ input = input[0]
463
+
464
+ batch_size = input.shape[0]
465
+ output_dims = list(output.shape[2:])
466
+
467
+ m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
468
+ out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
469
+ # groups = dconv_module.groups
470
+
471
+ # filters_per_channel = out_channels // groups
472
+ conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
473
+ conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
474
+ active_elements_count = batch_size * np.prod(output_dims)
475
+
476
+ overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
477
+ overall_flops = overall_conv_flops
478
+
479
+ dconv_module.__flops__ += int(overall_flops)
480
+ # dconv_module.__output_dims__ = output_dims
481
+
482
+
483
+
484
+
485
+
core/data/deg_kair_utils/utils_option.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ from datetime import datetime
4
+ import json
5
+ import re
6
+ import glob
7
+
8
+
9
+ '''
10
+ # --------------------------------------------
11
+ # Kai Zhang (github: https://github.com/cszn)
12
+ # 03/Mar/2019
13
+ # --------------------------------------------
14
+ # https://github.com/xinntao/BasicSR
15
+ # --------------------------------------------
16
+ '''
17
+
18
+
19
+ def get_timestamp():
20
+ return datetime.now().strftime('_%y%m%d_%H%M%S')
21
+
22
+
23
+ def parse(opt_path, is_train=True):
24
+
25
+ # ----------------------------------------
26
+ # remove comments starting with '//'
27
+ # ----------------------------------------
28
+ json_str = ''
29
+ with open(opt_path, 'r') as f:
30
+ for line in f:
31
+ line = line.split('//')[0] + '\n'
32
+ json_str += line
33
+
34
+ # ----------------------------------------
35
+ # initialize opt
36
+ # ----------------------------------------
37
+ opt = json.loads(json_str, object_pairs_hook=OrderedDict)
38
+
39
+ opt['opt_path'] = opt_path
40
+ opt['is_train'] = is_train
41
+
42
+ # ----------------------------------------
43
+ # set default
44
+ # ----------------------------------------
45
+ if 'merge_bn' not in opt:
46
+ opt['merge_bn'] = False
47
+ opt['merge_bn_startpoint'] = -1
48
+
49
+ if 'scale' not in opt:
50
+ opt['scale'] = 1
51
+
52
+ # ----------------------------------------
53
+ # datasets
54
+ # ----------------------------------------
55
+ for phase, dataset in opt['datasets'].items():
56
+ phase = phase.split('_')[0]
57
+ dataset['phase'] = phase
58
+ dataset['scale'] = opt['scale'] # broadcast
59
+ dataset['n_channels'] = opt['n_channels'] # broadcast
60
+ if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
61
+ dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
62
+ if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
63
+ dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
64
+
65
+ # ----------------------------------------
66
+ # path
67
+ # ----------------------------------------
68
+ for key, path in opt['path'].items():
69
+ if path and key in opt['path']:
70
+ opt['path'][key] = os.path.expanduser(path)
71
+
72
+ path_task = os.path.join(opt['path']['root'], opt['task'])
73
+ opt['path']['task'] = path_task
74
+ opt['path']['log'] = path_task
75
+ opt['path']['options'] = os.path.join(path_task, 'options')
76
+
77
+ if is_train:
78
+ opt['path']['models'] = os.path.join(path_task, 'models')
79
+ opt['path']['images'] = os.path.join(path_task, 'images')
80
+ else: # test
81
+ opt['path']['images'] = os.path.join(path_task, 'test_images')
82
+
83
+ # ----------------------------------------
84
+ # network
85
+ # ----------------------------------------
86
+ opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
87
+
88
+ # ----------------------------------------
89
+ # GPU devices
90
+ # ----------------------------------------
91
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
92
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
93
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
94
+
95
+ # ----------------------------------------
96
+ # default setting for distributeddataparallel
97
+ # ----------------------------------------
98
+ if 'find_unused_parameters' not in opt:
99
+ opt['find_unused_parameters'] = True
100
+ if 'use_static_graph' not in opt:
101
+ opt['use_static_graph'] = False
102
+ if 'dist' not in opt:
103
+ opt['dist'] = False
104
+ opt['num_gpu'] = len(opt['gpu_ids'])
105
+ print('number of GPUs is: ' + str(opt['num_gpu']))
106
+
107
+ # ----------------------------------------
108
+ # default setting for perceptual loss
109
+ # ----------------------------------------
110
+ if 'F_feature_layer' not in opt['train']:
111
+ opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
112
+ if 'F_weights' not in opt['train']:
113
+ opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
114
+ if 'F_lossfn_type' not in opt['train']:
115
+ opt['train']['F_lossfn_type'] = 'l1'
116
+ if 'F_use_input_norm' not in opt['train']:
117
+ opt['train']['F_use_input_norm'] = True
118
+ if 'F_use_range_norm' not in opt['train']:
119
+ opt['train']['F_use_range_norm'] = False
120
+
121
+ # ----------------------------------------
122
+ # default setting for optimizer
123
+ # ----------------------------------------
124
+ if 'G_optimizer_type' not in opt['train']:
125
+ opt['train']['G_optimizer_type'] = "adam"
126
+ if 'G_optimizer_betas' not in opt['train']:
127
+ opt['train']['G_optimizer_betas'] = [0.9,0.999]
128
+ if 'G_scheduler_restart_weights' not in opt['train']:
129
+ opt['train']['G_scheduler_restart_weights'] = 1
130
+ if 'G_optimizer_wd' not in opt['train']:
131
+ opt['train']['G_optimizer_wd'] = 0
132
+ if 'G_optimizer_reuse' not in opt['train']:
133
+ opt['train']['G_optimizer_reuse'] = False
134
+ if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
135
+ opt['train']['D_optimizer_reuse'] = False
136
+
137
+ # ----------------------------------------
138
+ # default setting of strict for model loading
139
+ # ----------------------------------------
140
+ if 'G_param_strict' not in opt['train']:
141
+ opt['train']['G_param_strict'] = True
142
+ if 'netD' in opt and 'D_param_strict' not in opt['path']:
143
+ opt['train']['D_param_strict'] = True
144
+ if 'E_param_strict' not in opt['path']:
145
+ opt['train']['E_param_strict'] = True
146
+
147
+ # ----------------------------------------
148
+ # Exponential Moving Average
149
+ # ----------------------------------------
150
+ if 'E_decay' not in opt['train']:
151
+ opt['train']['E_decay'] = 0
152
+
153
+ # ----------------------------------------
154
+ # default setting for discriminator
155
+ # ----------------------------------------
156
+ if 'netD' in opt:
157
+ if 'net_type' not in opt['netD']:
158
+ opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
159
+ if 'in_nc' not in opt['netD']:
160
+ opt['netD']['in_nc'] = 3
161
+ if 'base_nc' not in opt['netD']:
162
+ opt['netD']['base_nc'] = 64
163
+ if 'n_layers' not in opt['netD']:
164
+ opt['netD']['n_layers'] = 3
165
+ if 'norm_type' not in opt['netD']:
166
+ opt['netD']['norm_type'] = 'spectral'
167
+
168
+
169
+ return opt
170
+
171
+
172
+ def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
173
+ """
174
+ Args:
175
+ save_dir: model folder
176
+ net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
177
+ pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
178
+
179
+ Return:
180
+ init_iter: iteration number
181
+ init_path: model path
182
+ """
183
+ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
184
+ if file_list:
185
+ iter_exist = []
186
+ for file_ in file_list:
187
+ iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
188
+ iter_exist.append(int(iter_current[0]))
189
+ init_iter = max(iter_exist)
190
+ init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
191
+ else:
192
+ init_iter = 0
193
+ init_path = pretrained_path
194
+ return init_iter, init_path
195
+
196
+
197
+ '''
198
+ # --------------------------------------------
199
+ # convert the opt into json file
200
+ # --------------------------------------------
201
+ '''
202
+
203
+
204
+ def save(opt):
205
+ opt_path = opt['opt_path']
206
+ opt_path_copy = opt['path']['options']
207
+ dirname, filename_ext = os.path.split(opt_path)
208
+ filename, ext = os.path.splitext(filename_ext)
209
+ dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
210
+ with open(dump_path, 'w') as dump_file:
211
+ json.dump(opt, dump_file, indent=2)
212
+
213
+
214
+ '''
215
+ # --------------------------------------------
216
+ # dict to string for logger
217
+ # --------------------------------------------
218
+ '''
219
+
220
+
221
+ def dict2str(opt, indent_l=1):
222
+ msg = ''
223
+ for k, v in opt.items():
224
+ if isinstance(v, dict):
225
+ msg += ' ' * (indent_l * 2) + k + ':[\n'
226
+ msg += dict2str(v, indent_l + 1)
227
+ msg += ' ' * (indent_l * 2) + ']\n'
228
+ else:
229
+ msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
230
+ return msg
231
+
232
+
233
+ '''
234
+ # --------------------------------------------
235
+ # convert OrderedDict to NoneDict,
236
+ # return None for missing key
237
+ # --------------------------------------------
238
+ '''
239
+
240
+
241
+ def dict_to_nonedict(opt):
242
+ if isinstance(opt, dict):
243
+ new_opt = dict()
244
+ for key, sub_opt in opt.items():
245
+ new_opt[key] = dict_to_nonedict(sub_opt)
246
+ return NoneDict(**new_opt)
247
+ elif isinstance(opt, list):
248
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
249
+ else:
250
+ return opt
251
+
252
+
253
+ class NoneDict(dict):
254
+ def __missing__(self, key):
255
+ return None
core/data/deg_kair_utils/utils_params.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torchvision
4
+
5
+ from models import basicblock as B
6
+
7
+ def show_kv(net):
8
+ for k, v in net.items():
9
+ print(k)
10
+
11
+ # should run train debug mode first to get an initial model
12
+ #crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
13
+ #
14
+ #for k, v in crt_net.items():
15
+ # print(k)
16
+ #for k, v in crt_net.items():
17
+ # if k in pretrained_net:
18
+ # crt_net[k] = pretrained_net[k]
19
+ # print('replace ... ', k)
20
+
21
+ # x2 -> x4
22
+ #crt_net['model.5.weight'] = pretrained_net['model.2.weight']
23
+ #crt_net['model.5.bias'] = pretrained_net['model.2.bias']
24
+ #crt_net['model.8.weight'] = pretrained_net['model.5.weight']
25
+ #crt_net['model.8.bias'] = pretrained_net['model.5.bias']
26
+ #crt_net['model.10.weight'] = pretrained_net['model.7.weight']
27
+ #crt_net['model.10.bias'] = pretrained_net['model.7.bias']
28
+ #torch.save(crt_net, '../pretrained_tmp.pth')
29
+
30
+ # x2 -> x3
31
+ '''
32
+ in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
33
+ new_filter = torch.Tensor(576, 64, 3, 3)
34
+ new_filter[0:256, :, :, :] = in_filter
35
+ new_filter[256:512, :, :, :] = in_filter
36
+ new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
37
+ crt_net['model.2.weight'] = new_filter
38
+
39
+ in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
40
+ new_bias = torch.Tensor(576)
41
+ new_bias[0:256] = in_bias
42
+ new_bias[256:512] = in_bias
43
+ new_bias[512:] = in_bias[0:576 - 512]
44
+ crt_net['model.2.bias'] = new_bias
45
+
46
+ torch.save(crt_net, '../pretrained_tmp.pth')
47
+ '''
48
+
49
+ # x2 -> x8
50
+ '''
51
+ crt_net['model.5.weight'] = pretrained_net['model.2.weight']
52
+ crt_net['model.5.bias'] = pretrained_net['model.2.bias']
53
+ crt_net['model.8.weight'] = pretrained_net['model.2.weight']
54
+ crt_net['model.8.bias'] = pretrained_net['model.2.bias']
55
+ crt_net['model.11.weight'] = pretrained_net['model.5.weight']
56
+ crt_net['model.11.bias'] = pretrained_net['model.5.bias']
57
+ crt_net['model.13.weight'] = pretrained_net['model.7.weight']
58
+ crt_net['model.13.bias'] = pretrained_net['model.7.bias']
59
+ torch.save(crt_net, '../pretrained_tmp.pth')
60
+ '''
61
+
62
+ # x3/4/8 RGB -> Y
63
+
64
+ def rgb2gray_net(net, only_input=True):
65
+
66
+ if only_input:
67
+ in_filter = net['0.weight']
68
+ in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
69
+ in_new_filter.unsqueeze_(1)
70
+ net['0.weight'] = in_new_filter
71
+
72
+ # out_filter = pretrained_net['model.13.weight']
73
+ # out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
74
+ # out_filter[2, :, :, :] * 0.114
75
+ # out_new_filter.unsqueeze_(0)
76
+ # crt_net['model.13.weight'] = out_new_filter
77
+ # out_bias = pretrained_net['model.13.bias']
78
+ # out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
79
+ # out_new_bias = torch.Tensor(1).fill_(out_new_bias)
80
+ # crt_net['model.13.bias'] = out_new_bias
81
+
82
+ # torch.save(crt_net, '../pretrained_tmp.pth')
83
+
84
+ return net
85
+
86
+
87
+
88
+ if __name__ == '__main__':
89
+
90
+ net = torchvision.models.vgg19(pretrained=True)
91
+ for k,v in net.features.named_parameters():
92
+ if k=='0.weight':
93
+ in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
94
+ in_new_filter.unsqueeze_(1)
95
+ v = in_new_filter
96
+ print(v.shape)
97
+ print(v[0,0,0,0])
98
+ if k=='0.bias':
99
+ in_new_bias = v
100
+ print(v[0])
101
+
102
+ print(net.features[0])
103
+
104
+ net.features[0] = B.conv(1, 64, mode='C')
105
+
106
+ print(net.features[0])
107
+ net.features[0].weight.data=in_new_filter
108
+ net.features[0].bias.data=in_new_bias
109
+
110
+ for k,v in net.features.named_parameters():
111
+ if k=='0.weight':
112
+ print(v[0,0,0,0])
113
+ if k=='0.bias':
114
+ print(v[0])
115
+
116
+ # transfer parameters of old model to new one
117
+ model_old = torch.load(model_path)
118
+ state_dict = model.state_dict()
119
+ for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
120
+ state_dict[key2] = param
121
+ print([key, key2])
122
+ # print([param.size(), param2.size()])
123
+ torch.save(state_dict, 'model_new.pth')
124
+
125
+
126
+ # rgb2gray_net(net)
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
core/data/deg_kair_utils/utils_receptivefield.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # online calculation: https://fomoro.com/research/article/receptive-field-calculator#
4
+
5
+ # [filter size, stride, padding]
6
+ #Assume the two dimensions are the same
7
+ #Each kernel requires the following parameters:
8
+ # - k_i: kernel size
9
+ # - s_i: stride
10
+ # - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
11
+ #
12
+ #Each layer i requires the following parameters to be fully represented:
13
+ # - n_i: number of feature (data layer has n_1 = imagesize )
14
+ # - j_i: distance (projected to image pixel distance) between center of two adjacent features
15
+ # - r_i: receptive field of a feature in layer i
16
+ # - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
17
+
18
+ import math
19
+
20
+ def outFromIn(conv, layerIn):
21
+ n_in = layerIn[0]
22
+ j_in = layerIn[1]
23
+ r_in = layerIn[2]
24
+ start_in = layerIn[3]
25
+ k = conv[0]
26
+ s = conv[1]
27
+ p = conv[2]
28
+
29
+ n_out = math.floor((n_in - k + 2*p)/s) + 1
30
+ actualP = (n_out-1)*s - n_in + k
31
+ pR = math.ceil(actualP/2)
32
+ pL = math.floor(actualP/2)
33
+
34
+ j_out = j_in * s
35
+ r_out = r_in + (k - 1)*j_in
36
+ start_out = start_in + ((k-1)/2 - pL)*j_in
37
+ return n_out, j_out, r_out, start_out
38
+
39
+ def printLayer(layer, layer_name):
40
+ print(layer_name + ":")
41
+ print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
42
+
43
+
44
+
45
+ layerInfos = []
46
+ if __name__ == '__main__':
47
+
48
+ convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
49
+ layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
50
+ imsize = 128
51
+
52
+ print ("-------Net summary------")
53
+ currentLayer = [imsize, 1, 1, 0.5]
54
+ printLayer(currentLayer, "input image")
55
+ for i in range(len(convnet)):
56
+ currentLayer = outFromIn(convnet[i], currentLayer)
57
+ layerInfos.append(currentLayer)
58
+ printLayer(currentLayer, layer_names[i])
59
+
60
+
61
+ # run utils/utils_receptivefield.py
62
+
core/data/deg_kair_utils/utils_regularizers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ '''
6
+ # --------------------------------------------
7
+ # Kai Zhang (github: https://github.com/cszn)
8
+ # 03/Mar/2019
9
+ # --------------------------------------------
10
+ '''
11
+
12
+
13
+ # --------------------------------------------
14
+ # SVD Orthogonal Regularization
15
+ # --------------------------------------------
16
+ def regularizer_orth(m):
17
+ """
18
+ # ----------------------------------------
19
+ # SVD Orthogonal Regularization
20
+ # ----------------------------------------
21
+ # Applies regularization to the training by performing the
22
+ # orthogonalization technique described in the paper
23
+ # This function is to be called by the torch.nn.Module.apply() method,
24
+ # which applies svd_orthogonalization() to every layer of the model.
25
+ # usage: net.apply(regularizer_orth)
26
+ # ----------------------------------------
27
+ """
28
+ classname = m.__class__.__name__
29
+ if classname.find('Conv') != -1:
30
+ w = m.weight.data.clone()
31
+ c_out, c_in, f1, f2 = w.size()
32
+ # dtype = m.weight.data.type()
33
+ w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
34
+ # self.netG.apply(svd_orthogonalization)
35
+ u, s, v = torch.svd(w)
36
+ s[s > 1.5] = s[s > 1.5] - 1e-4
37
+ s[s < 0.5] = s[s < 0.5] + 1e-4
38
+ w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
39
+ m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
40
+ else:
41
+ pass
42
+
43
+
44
+ # --------------------------------------------
45
+ # SVD Orthogonal Regularization
46
+ # --------------------------------------------
47
+ def regularizer_orth2(m):
48
+ """
49
+ # ----------------------------------------
50
+ # Applies regularization to the training by performing the
51
+ # orthogonalization technique described in the paper
52
+ # This function is to be called by the torch.nn.Module.apply() method,
53
+ # which applies svd_orthogonalization() to every layer of the model.
54
+ # usage: net.apply(regularizer_orth2)
55
+ # ----------------------------------------
56
+ """
57
+ classname = m.__class__.__name__
58
+ if classname.find('Conv') != -1:
59
+ w = m.weight.data.clone()
60
+ c_out, c_in, f1, f2 = w.size()
61
+ # dtype = m.weight.data.type()
62
+ w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
63
+ u, s, v = torch.svd(w)
64
+ s_mean = s.mean()
65
+ s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
66
+ s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
67
+ w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
68
+ m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
69
+ else:
70
+ pass
71
+
72
+
73
+
74
+ def regularizer_clip(m):
75
+ """
76
+ # ----------------------------------------
77
+ # usage: net.apply(regularizer_clip)
78
+ # ----------------------------------------
79
+ """
80
+ eps = 1e-4
81
+ c_min = -1.5
82
+ c_max = 1.5
83
+
84
+ classname = m.__class__.__name__
85
+ if classname.find('Conv') != -1 or classname.find('Linear') != -1:
86
+ w = m.weight.data.clone()
87
+ w[w > c_max] -= eps
88
+ w[w < c_min] += eps
89
+ m.weight.data = w
90
+
91
+ if m.bias is not None:
92
+ b = m.bias.data.clone()
93
+ b[b > c_max] -= eps
94
+ b[b < c_min] += eps
95
+ m.bias.data = b
96
+
97
+ # elif classname.find('BatchNorm2d') != -1:
98
+ #
99
+ # rv = m.running_var.data.clone()
100
+ # rm = m.running_mean.data.clone()
101
+ #
102
+ # if m.affine:
103
+ # m.weight.data
104
+ # m.bias.data
core/data/deg_kair_utils/utils_sisr.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from utils import utils_image as util
3
+ import random
4
+
5
+ import scipy
6
+ import scipy.stats as ss
7
+ import scipy.io as io
8
+ from scipy import ndimage
9
+ from scipy.interpolate import interp2d
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ """
16
+ # --------------------------------------------
17
+ # Super-Resolution
18
+ # --------------------------------------------
19
+ #
20
+ # Kai Zhang ([email protected])
21
+ # https://github.com/cszn
22
+ # modified by Kai Zhang (github: https://github.com/cszn)
23
+ # 03/03/2020
24
+ # --------------------------------------------
25
+ """
26
+
27
+
28
+ """
29
+ # --------------------------------------------
30
+ # anisotropic Gaussian kernels
31
+ # --------------------------------------------
32
+ """
33
+
34
+
35
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
36
+ """ generate an anisotropic Gaussian kernel
37
+ Args:
38
+ ksize : e.g., 15, kernel size
39
+ theta : [0, pi], rotation angle range
40
+ l1 : [0.1,50], scaling of eigenvalues
41
+ l2 : [0.1,l1], scaling of eigenvalues
42
+ If l1 = l2, will get an isotropic Gaussian kernel.
43
+ Returns:
44
+ k : kernel
45
+ """
46
+
47
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
48
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
49
+ D = np.array([[l1, 0], [0, l2]])
50
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
51
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
52
+
53
+ return k
54
+
55
+
56
+ def gm_blur_kernel(mean, cov, size=15):
57
+ center = size / 2.0 + 0.5
58
+ k = np.zeros([size, size])
59
+ for y in range(size):
60
+ for x in range(size):
61
+ cy = y - center + 1
62
+ cx = x - center + 1
63
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
64
+
65
+ k = k / np.sum(k)
66
+ return k
67
+
68
+
69
+ """
70
+ # --------------------------------------------
71
+ # calculate PCA projection matrix
72
+ # --------------------------------------------
73
+ """
74
+
75
+
76
+ def get_pca_matrix(x, dim_pca=15):
77
+ """
78
+ Args:
79
+ x: 225x10000 matrix
80
+ dim_pca: 15
81
+ Returns:
82
+ pca_matrix: 15x225
83
+ """
84
+ C = np.dot(x, x.T)
85
+ w, v = scipy.linalg.eigh(C)
86
+ pca_matrix = v[:, -dim_pca:].T
87
+
88
+ return pca_matrix
89
+
90
+
91
+ def show_pca(x):
92
+ """
93
+ x: PCA projection matrix, e.g., 15x225
94
+ """
95
+ for i in range(x.shape[0]):
96
+ xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
97
+ util.surf(xc)
98
+
99
+
100
+ def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
101
+ kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
102
+ for i in range(num_samples):
103
+
104
+ theta = np.pi*np.random.rand(1)
105
+ l1 = 0.1+l_max*np.random.rand(1)
106
+ l2 = 0.1+(l1-0.1)*np.random.rand(1)
107
+
108
+ k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
109
+
110
+ # util.imshow(k)
111
+
112
+ kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
113
+
114
+ # io.savemat('k.mat', {'k': kernels})
115
+
116
+ pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
117
+
118
+ io.savemat(path, {'p': pca_matrix})
119
+
120
+ return pca_matrix
121
+
122
+
123
+ """
124
+ # --------------------------------------------
125
+ # shifted anisotropic Gaussian kernels
126
+ # --------------------------------------------
127
+ """
128
+
129
+
130
+ def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
131
+ """"
132
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
133
+ # Kai Zhang
134
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
135
+ # max_var = 2.5 * sf
136
+ """
137
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
138
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
139
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
140
+ theta = np.random.rand() * np.pi # random theta
141
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
142
+
143
+ # Set COV matrix using Lambdas and Theta
144
+ LAMBDA = np.diag([lambda_1, lambda_2])
145
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
146
+ [np.sin(theta), np.cos(theta)]])
147
+ SIGMA = Q @ LAMBDA @ Q.T
148
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
149
+
150
+ # Set expectation position (shifting kernel for aligned image)
151
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
152
+ MU = MU[None, None, :, None]
153
+
154
+ # Create meshgrid for Gaussian
155
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
156
+ Z = np.stack([X, Y], 2)[:, :, :, None]
157
+
158
+ # Calcualte Gaussian for every pixel of the kernel
159
+ ZZ = Z-MU
160
+ ZZ_t = ZZ.transpose(0,1,3,2)
161
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
162
+
163
+ # shift the kernel so it will be centered
164
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
165
+
166
+ # Normalize the kernel and return
167
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
168
+ kernel = raw_kernel / np.sum(raw_kernel)
169
+ return kernel
170
+
171
+
172
+ def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
173
+ """"
174
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
175
+ # Kai Zhang
176
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
177
+ # max_var = 2.5 * sf
178
+ """
179
+ sf = random.choice([1, 2, 3, 4])
180
+ scale_factor = np.array([sf, sf])
181
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
182
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
183
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
184
+ theta = np.random.rand() * np.pi # random theta
185
+ noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
186
+
187
+ # Set COV matrix using Lambdas and Theta
188
+ LAMBDA = np.diag([lambda_1, lambda_2])
189
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
190
+ [np.sin(theta), np.cos(theta)]])
191
+ SIGMA = Q @ LAMBDA @ Q.T
192
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
193
+
194
+ # Set expectation position (shifting kernel for aligned image)
195
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
196
+ MU = MU[None, None, :, None]
197
+
198
+ # Create meshgrid for Gaussian
199
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
200
+ Z = np.stack([X, Y], 2)[:, :, :, None]
201
+
202
+ # Calcualte Gaussian for every pixel of the kernel
203
+ ZZ = Z-MU
204
+ ZZ_t = ZZ.transpose(0,1,3,2)
205
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
206
+
207
+ # shift the kernel so it will be centered
208
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
209
+
210
+ # Normalize the kernel and return
211
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
212
+ kernel = raw_kernel / np.sum(raw_kernel)
213
+ return kernel
214
+
215
+
216
+ """
217
+ # --------------------------------------------
218
+ # degradation models
219
+ # --------------------------------------------
220
+ """
221
+
222
+
223
+ def bicubic_degradation(x, sf=3):
224
+ '''
225
+ Args:
226
+ x: HxWxC image, [0, 1]
227
+ sf: down-scale factor
228
+ Return:
229
+ bicubicly downsampled LR image
230
+ '''
231
+ x = util.imresize_np(x, scale=1/sf)
232
+ return x
233
+
234
+
235
+ def srmd_degradation(x, k, sf=3):
236
+ ''' blur + bicubic downsampling
237
+ Args:
238
+ x: HxWxC image, [0, 1]
239
+ k: hxw, double
240
+ sf: down-scale factor
241
+ Return:
242
+ downsampled LR image
243
+ Reference:
244
+ @inproceedings{zhang2018learning,
245
+ title={Learning a single convolutional super-resolution network for multiple degradations},
246
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
247
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
248
+ pages={3262--3271},
249
+ year={2018}
250
+ }
251
+ '''
252
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
253
+ x = bicubic_degradation(x, sf=sf)
254
+ return x
255
+
256
+
257
+ def dpsr_degradation(x, k, sf=3):
258
+
259
+ ''' bicubic downsampling + blur
260
+ Args:
261
+ x: HxWxC image, [0, 1]
262
+ k: hxw, double
263
+ sf: down-scale factor
264
+ Return:
265
+ downsampled LR image
266
+ Reference:
267
+ @inproceedings{zhang2019deep,
268
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
269
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
270
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
271
+ pages={1671--1681},
272
+ year={2019}
273
+ }
274
+ '''
275
+ x = bicubic_degradation(x, sf=sf)
276
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
277
+ return x
278
+
279
+
280
+ def classical_degradation(x, k, sf=3):
281
+ ''' blur + downsampling
282
+
283
+ Args:
284
+ x: HxWxC image, [0, 1]/[0, 255]
285
+ k: hxw, double
286
+ sf: down-scale factor
287
+
288
+ Return:
289
+ downsampled LR image
290
+ '''
291
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
292
+ #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
293
+ st = 0
294
+ return x[st::sf, st::sf, ...]
295
+
296
+
297
+ def modcrop_np(img, sf):
298
+ '''
299
+ Args:
300
+ img: numpy image, WxH or WxHxC
301
+ sf: scale factor
302
+ Return:
303
+ cropped image
304
+ '''
305
+ w, h = img.shape[:2]
306
+ im = np.copy(img)
307
+ return im[:w - w % sf, :h - h % sf, ...]
308
+
309
+
310
+ '''
311
+ # =================
312
+ # Numpy
313
+ # =================
314
+ '''
315
+
316
+
317
+ def shift_pixel(x, sf, upper_left=True):
318
+ """shift pixel for super-resolution with different scale factors
319
+ Args:
320
+ x: WxHxC or WxH, image or kernel
321
+ sf: scale factor
322
+ upper_left: shift direction
323
+ """
324
+ h, w = x.shape[:2]
325
+ shift = (sf-1)*0.5
326
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
327
+ if upper_left:
328
+ x1 = xv + shift
329
+ y1 = yv + shift
330
+ else:
331
+ x1 = xv - shift
332
+ y1 = yv - shift
333
+
334
+ x1 = np.clip(x1, 0, w-1)
335
+ y1 = np.clip(y1, 0, h-1)
336
+
337
+ if x.ndim == 2:
338
+ x = interp2d(xv, yv, x)(x1, y1)
339
+ if x.ndim == 3:
340
+ for i in range(x.shape[-1]):
341
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
342
+
343
+ return x
344
+
345
+
346
+ '''
347
+ # =================
348
+ # pytorch
349
+ # =================
350
+ '''
351
+
352
+
353
+ def splits(a, sf):
354
+ '''
355
+ a: tensor NxCxWxHx2
356
+ sf: scale factor
357
+ out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
358
+ '''
359
+ b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
360
+ b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
361
+ return b
362
+
363
+
364
+ def c2c(x):
365
+ return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
366
+
367
+
368
+ def r2c(x):
369
+ return torch.stack([x, torch.zeros_like(x)], -1)
370
+
371
+
372
+ def cdiv(x, y):
373
+ a, b = x[..., 0], x[..., 1]
374
+ c, d = y[..., 0], y[..., 1]
375
+ cd2 = c**2 + d**2
376
+ return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
377
+
378
+
379
+ def csum(x, y):
380
+ return torch.stack([x[..., 0] + y, x[..., 1]], -1)
381
+
382
+
383
+ def cabs(x):
384
+ return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
385
+
386
+
387
+ def cmul(t1, t2):
388
+ '''
389
+ complex multiplication
390
+ t1: NxCxHxWx2
391
+ output: NxCxHxWx2
392
+ '''
393
+ real1, imag1 = t1[..., 0], t1[..., 1]
394
+ real2, imag2 = t2[..., 0], t2[..., 1]
395
+ return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
396
+
397
+
398
+ def cconj(t, inplace=False):
399
+ '''
400
+ # complex's conjugation
401
+ t: NxCxHxWx2
402
+ output: NxCxHxWx2
403
+ '''
404
+ c = t.clone() if not inplace else t
405
+ c[..., 1] *= -1
406
+ return c
407
+
408
+
409
+ def rfft(t):
410
+ return torch.rfft(t, 2, onesided=False)
411
+
412
+
413
+ def irfft(t):
414
+ return torch.irfft(t, 2, onesided=False)
415
+
416
+
417
+ def fft(t):
418
+ return torch.fft(t, 2)
419
+
420
+
421
+ def ifft(t):
422
+ return torch.ifft(t, 2)
423
+
424
+
425
+ def p2o(psf, shape):
426
+ '''
427
+ Args:
428
+ psf: NxCxhxw
429
+ shape: [H,W]
430
+
431
+ Returns:
432
+ otf: NxCxHxWx2
433
+ '''
434
+ otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
435
+ otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
436
+ for axis, axis_size in enumerate(psf.shape[2:]):
437
+ otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
438
+ otf = torch.rfft(otf, 2, onesided=False)
439
+ n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
440
+ otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
441
+ return otf
442
+
443
+
444
+ '''
445
+ # =================
446
+ PyTorch
447
+ # =================
448
+ '''
449
+
450
+ def INVLS_pytorch(FB, FBC, F2B, FR, tau, sf=2):
451
+ '''
452
+ FB: NxCxWxHx2
453
+ F2B: NxCxWxHx2
454
+
455
+ x1 = FB.*FR;
456
+ FBR = BlockMM(nr,nc,Nb,m,x1);
457
+ invW = BlockMM(nr,nc,Nb,m,F2B);
458
+ invWBR = FBR./(invW + tau*Nb);
459
+ fun = @(block_struct) block_struct.data.*invWBR;
460
+ FCBinvWBR = blockproc(FBC,[nr,nc],fun);
461
+ FX = (FR-FCBinvWBR)/tau;
462
+ Xest = real(ifft2(FX));
463
+ '''
464
+ x1 = cmul(FB, FR)
465
+ FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
466
+ invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
467
+ invWBR = cdiv(FBR, csum(invW, tau))
468
+ FCBinvWBR = cmul(FBC, invWBR.repeat(1,1,sf,sf,1))
469
+ FX = (FR-FCBinvWBR)/tau
470
+ Xest = torch.irfft(FX, 2, onesided=False)
471
+ return Xest
472
+
473
+
474
+ def real2complex(x):
475
+ return torch.stack([x, torch.zeros_like(x)], -1)
476
+
477
+
478
+ def modcrop(img, sf):
479
+ '''
480
+ img: tensor image, NxCxWxH or CxWxH or WxH
481
+ sf: scale factor
482
+ '''
483
+ w, h = img.shape[-2:]
484
+ im = img.clone()
485
+ return im[..., :w - w % sf, :h - h % sf]
486
+
487
+
488
+ def upsample(x, sf=3, center=False):
489
+ '''
490
+ x: tensor image, NxCxWxH
491
+ '''
492
+ st = (sf-1)//2 if center else 0
493
+ z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x)
494
+ z[..., st::sf, st::sf].copy_(x)
495
+ return z
496
+
497
+
498
+ def downsample(x, sf=3, center=False):
499
+ st = (sf-1)//2 if center else 0
500
+ return x[..., st::sf, st::sf]
501
+
502
+
503
+ def circular_pad(x, pad):
504
+ '''
505
+ # x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
506
+ '''
507
+ x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
508
+ x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
509
+ x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
510
+ x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
511
+ return x
512
+
513
+
514
+ def pad_circular(input, padding):
515
+ # type: (Tensor, List[int]) -> Tensor
516
+ """
517
+ Arguments
518
+ :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
519
+ :param padding: (tuple): m-elem tuple where m is the degree of convolution
520
+ Returns
521
+ :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
522
+ H + 2 * padding[1]], W + 2 * padding[2]))`
523
+ """
524
+ offset = 3
525
+ for dimension in range(input.dim() - offset + 1):
526
+ input = dim_pad_circular(input, padding[dimension], dimension + offset)
527
+ return input
528
+
529
+
530
+ def dim_pad_circular(input, padding, dimension):
531
+ # type: (Tensor, int, int) -> Tensor
532
+ input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
533
+ [slice(0, padding)]]], dim=dimension - 1)
534
+ input = torch.cat([input[[slice(None)] * (dimension - 1) +
535
+ [slice(-2 * padding, -padding)]], input], dim=dimension - 1)
536
+ return input
537
+
538
+
539
+ def imfilter(x, k):
540
+ '''
541
+ x: image, NxcxHxW
542
+ k: kernel, cx1xhxw
543
+ '''
544
+ x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
545
+ x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
546
+ return x
547
+
548
+
549
+ def G(x, k, sf=3, center=False):
550
+ '''
551
+ x: image, NxcxHxW
552
+ k: kernel, cx1xhxw
553
+ sf: scale factor
554
+ center: the first one or the moddle one
555
+
556
+ Matlab function:
557
+ tmp = imfilter(x,h,'circular');
558
+ y = downsample2(tmp,K);
559
+ '''
560
+ x = downsample(imfilter(x, k), sf=sf, center=center)
561
+ return x
562
+
563
+
564
+ def Gt(x, k, sf=3, center=False):
565
+ '''
566
+ x: image, NxcxHxW
567
+ k: kernel, cx1xhxw
568
+ sf: scale factor
569
+ center: the first one or the moddle one
570
+
571
+ Matlab function:
572
+ tmp = upsample2(x,K);
573
+ y = imfilter(tmp,h,'circular');
574
+ '''
575
+ x = imfilter(upsample(x, sf=sf, center=center), k)
576
+ return x
577
+
578
+
579
+ def interpolation_down(x, sf, center=False):
580
+ mask = torch.zeros_like(x)
581
+ if center:
582
+ start = torch.tensor((sf-1)//2)
583
+ mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
584
+ LR = x[..., start::sf, start::sf]
585
+ else:
586
+ mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
587
+ LR = x[..., ::sf, ::sf]
588
+ y = x.mul(mask)
589
+
590
+ return LR, y, mask
591
+
592
+
593
+ '''
594
+ # =================
595
+ Numpy
596
+ # =================
597
+ '''
598
+
599
+
600
+ def blockproc(im, blocksize, fun):
601
+ xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
602
+ xblocks_proc = []
603
+ for xb in xblocks:
604
+ yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
605
+ yblocks_proc = []
606
+ for yb in yblocks:
607
+ yb_proc = fun(yb)
608
+ yblocks_proc.append(yb_proc)
609
+ xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
610
+
611
+ proc = np.concatenate(xblocks_proc, axis=0)
612
+
613
+ return proc
614
+
615
+
616
+ def fun_reshape(a):
617
+ return np.reshape(a, (-1,1,a.shape[-1]), order='F')
618
+
619
+
620
+ def fun_mul(a, b):
621
+ return a*b
622
+
623
+
624
+ def BlockMM(nr, nc, Nb, m, x1):
625
+ '''
626
+ myfun = @(block_struct) reshape(block_struct.data,m,1);
627
+ x1 = blockproc(x1,[nr nc],myfun);
628
+ x1 = reshape(x1,m,Nb);
629
+ x1 = sum(x1,2);
630
+ x = reshape(x1,nr,nc);
631
+ '''
632
+ fun = fun_reshape
633
+ x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
634
+ x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
635
+ x1 = np.sum(x1, 1)
636
+ x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
637
+ return x
638
+
639
+
640
+ def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
641
+ '''
642
+ x1 = FB.*FR;
643
+ FBR = BlockMM(nr,nc,Nb,m,x1);
644
+ invW = BlockMM(nr,nc,Nb,m,F2B);
645
+ invWBR = FBR./(invW + tau*Nb);
646
+ fun = @(block_struct) block_struct.data.*invWBR;
647
+ FCBinvWBR = blockproc(FBC,[nr,nc],fun);
648
+ FX = (FR-FCBinvWBR)/tau;
649
+ Xest = real(ifft2(FX));
650
+ '''
651
+ x1 = FB*FR
652
+ FBR = BlockMM(nr, nc, Nb, m, x1)
653
+ invW = BlockMM(nr, nc, Nb, m, F2B)
654
+ invWBR = FBR/(invW + tau*Nb)
655
+ FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
656
+ FX = (FR-FCBinvWBR)/tau
657
+ Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
658
+ return Xest
659
+
660
+
661
+ def psf2otf(psf, shape=None):
662
+ """
663
+ Convert point-spread function to optical transfer function.
664
+ Compute the Fast Fourier Transform (FFT) of the point-spread
665
+ function (PSF) array and creates the optical transfer function (OTF)
666
+ array that is not influenced by the PSF off-centering.
667
+ By default, the OTF array is the same size as the PSF array.
668
+ To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
669
+ post-pads the PSF array (down or to the right) with zeros to match
670
+ dimensions specified in OUTSIZE, then circularly shifts the values of
671
+ the PSF array up (or to the left) until the central pixel reaches (1,1)
672
+ position.
673
+ Parameters
674
+ ----------
675
+ psf : `numpy.ndarray`
676
+ PSF array
677
+ shape : int
678
+ Output shape of the OTF array
679
+ Returns
680
+ -------
681
+ otf : `numpy.ndarray`
682
+ OTF array
683
+ Notes
684
+ -----
685
+ Adapted from MATLAB psf2otf function
686
+ """
687
+ if type(shape) == type(None):
688
+ shape = psf.shape
689
+ shape = np.array(shape)
690
+ if np.all(psf == 0):
691
+ # return np.zeros_like(psf)
692
+ return np.zeros(shape)
693
+ if len(psf.shape) == 1:
694
+ psf = psf.reshape((1, psf.shape[0]))
695
+ inshape = psf.shape
696
+ psf = zero_pad(psf, shape, position='corner')
697
+ for axis, axis_size in enumerate(inshape):
698
+ psf = np.roll(psf, -int(axis_size / 2), axis=axis)
699
+ # Compute the OTF
700
+ otf = np.fft.fft2(psf, axes=(0, 1))
701
+ # Estimate the rough number of operations involved in the FFT
702
+ # and discard the PSF imaginary part if within roundoff error
703
+ # roundoff error = machine epsilon = sys.float_info.epsilon
704
+ # or np.finfo().eps
705
+ n_ops = np.sum(psf.size * np.log2(psf.shape))
706
+ otf = np.real_if_close(otf, tol=n_ops)
707
+ return otf
708
+
709
+
710
+ def zero_pad(image, shape, position='corner'):
711
+ """
712
+ Extends image to a certain size with zeros
713
+ Parameters
714
+ ----------
715
+ image: real 2d `numpy.ndarray`
716
+ Input image
717
+ shape: tuple of int
718
+ Desired output shape of the image
719
+ position : str, optional
720
+ The position of the input image in the output one:
721
+ * 'corner'
722
+ top-left corner (default)
723
+ * 'center'
724
+ centered
725
+ Returns
726
+ -------
727
+ padded_img: real `numpy.ndarray`
728
+ The zero-padded image
729
+ """
730
+ shape = np.asarray(shape, dtype=int)
731
+ imshape = np.asarray(image.shape, dtype=int)
732
+ if np.alltrue(imshape == shape):
733
+ return image
734
+ if np.any(shape <= 0):
735
+ raise ValueError("ZERO_PAD: null or negative shape given")
736
+ dshape = shape - imshape
737
+ if np.any(dshape < 0):
738
+ raise ValueError("ZERO_PAD: target size smaller than source one")
739
+ pad_img = np.zeros(shape, dtype=image.dtype)
740
+ idx, idy = np.indices(imshape)
741
+ if position == 'center':
742
+ if np.any(dshape % 2 != 0):
743
+ raise ValueError("ZERO_PAD: source and target shapes "
744
+ "have different parity.")
745
+ offx, offy = dshape // 2
746
+ else:
747
+ offx, offy = (0, 0)
748
+ pad_img[idx + offx, idy + offy] = image
749
+ return pad_img
750
+
751
+
752
+ def upsample_np(x, sf=3, center=False):
753
+ st = (sf-1)//2 if center else 0
754
+ z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
755
+ z[st::sf, st::sf, ...] = x
756
+ return z
757
+
758
+
759
+ def downsample_np(x, sf=3, center=False):
760
+ st = (sf-1)//2 if center else 0
761
+ return x[st::sf, st::sf, ...]
762
+
763
+
764
+ def imfilter_np(x, k):
765
+ '''
766
+ x: image, NxcxHxW
767
+ k: kernel, cx1xhxw
768
+ '''
769
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
770
+ return x
771
+
772
+
773
+ def G_np(x, k, sf=3, center=False):
774
+ '''
775
+ x: image, NxcxHxW
776
+ k: kernel, cx1xhxw
777
+
778
+ Matlab function:
779
+ tmp = imfilter(x,h,'circular');
780
+ y = downsample2(tmp,K);
781
+ '''
782
+ x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
783
+ return x
784
+
785
+
786
+ def Gt_np(x, k, sf=3, center=False):
787
+ '''
788
+ x: image, NxcxHxW
789
+ k: kernel, cx1xhxw
790
+
791
+ Matlab function:
792
+ tmp = upsample2(x,K);
793
+ y = imfilter(tmp,h,'circular');
794
+ '''
795
+ x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
796
+ return x
797
+
798
+
799
+ if __name__ == '__main__':
800
+ img = util.imread_uint('test.bmp', 3)
801
+
802
+ img = util.uint2single(img)
803
+ k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
804
+ util.imshow(k*10)
805
+
806
+
807
+ for sf in [2, 3, 4]:
808
+
809
+ # modcrop
810
+ img = modcrop_np(img, sf=sf)
811
+
812
+ # 1) bicubic degradation
813
+ img_b = bicubic_degradation(img, sf=sf)
814
+ print(img_b.shape)
815
+
816
+ # 2) srmd degradation
817
+ img_s = srmd_degradation(img, k, sf=sf)
818
+ print(img_s.shape)
819
+
820
+ # 3) dpsr degradation
821
+ img_d = dpsr_degradation(img, k, sf=sf)
822
+ print(img_d.shape)
823
+
824
+ # 4) classical degradation
825
+ img_d = classical_degradation(img, k, sf=sf)
826
+ print(img_d.shape)
827
+
828
+ k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
829
+ #print(k)
830
+ # util.imshow(k*10)
831
+
832
+ k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
833
+ # util.imshow(k*10)
834
+
835
+
836
+ # PCA
837
+ # pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
838
+ # print(pca_matrix.shape)
839
+ # show_pca(pca_matrix)
840
+ # run utils/utils_sisr.py
841
+ # run utils_sisr.py
842
+
843
+
844
+
845
+
846
+
847
+
848
+
core/data/deg_kair_utils/utils_video.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import random
6
+ from os import path as osp
7
+ from torch.nn import functional as F
8
+ from abc import ABCMeta, abstractmethod
9
+
10
+
11
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
12
+ """Scan a directory to find the interested files.
13
+
14
+ Args:
15
+ dir_path (str): Path of the directory.
16
+ suffix (str | tuple(str), optional): File suffix that we are
17
+ interested in. Default: None.
18
+ recursive (bool, optional): If set to True, recursively scan the
19
+ directory. Default: False.
20
+ full_path (bool, optional): If set to True, include the dir_path.
21
+ Default: False.
22
+
23
+ Returns:
24
+ A generator for all the interested files with relative paths.
25
+ """
26
+
27
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
28
+ raise TypeError('"suffix" must be a string or tuple of strings')
29
+
30
+ root = dir_path
31
+
32
+ def _scandir(dir_path, suffix, recursive):
33
+ for entry in os.scandir(dir_path):
34
+ if not entry.name.startswith('.') and entry.is_file():
35
+ if full_path:
36
+ return_path = entry.path
37
+ else:
38
+ return_path = osp.relpath(entry.path, root)
39
+
40
+ if suffix is None:
41
+ yield return_path
42
+ elif return_path.endswith(suffix):
43
+ yield return_path
44
+ else:
45
+ if recursive:
46
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
47
+ else:
48
+ continue
49
+
50
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
51
+
52
+
53
+ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
54
+ """Read a sequence of images from a given folder path.
55
+
56
+ Args:
57
+ path (list[str] | str): List of image paths or image folder path.
58
+ require_mod_crop (bool): Require mod crop for each image.
59
+ Default: False.
60
+ scale (int): Scale factor for mod_crop. Default: 1.
61
+ return_imgname(bool): Whether return image names. Default False.
62
+
63
+ Returns:
64
+ Tensor: size (t, c, h, w), RGB, [0, 1].
65
+ list[str]: Returned image name list.
66
+ """
67
+ if isinstance(path, list):
68
+ img_paths = path
69
+ else:
70
+ img_paths = sorted(list(scandir(path, full_path=True)))
71
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
72
+
73
+ if require_mod_crop:
74
+ imgs = [mod_crop(img, scale) for img in imgs]
75
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
76
+ imgs = torch.stack(imgs, dim=0)
77
+
78
+ if return_imgname:
79
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
80
+ return imgs, imgnames
81
+ else:
82
+ return imgs
83
+
84
+
85
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
86
+ """Numpy array to tensor.
87
+
88
+ Args:
89
+ imgs (list[ndarray] | ndarray): Input images.
90
+ bgr2rgb (bool): Whether to change bgr to rgb.
91
+ float32 (bool): Whether to change to float32.
92
+
93
+ Returns:
94
+ list[tensor] | tensor: Tensor images. If returned results only have
95
+ one element, just return tensor.
96
+ """
97
+
98
+ def _totensor(img, bgr2rgb, float32):
99
+ if img.shape[2] == 3 and bgr2rgb:
100
+ if img.dtype == 'float64':
101
+ img = img.astype('float32')
102
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
+ img = torch.from_numpy(img.transpose(2, 0, 1))
104
+ if float32:
105
+ img = img.float()
106
+ return img
107
+
108
+ if isinstance(imgs, list):
109
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
110
+ else:
111
+ return _totensor(imgs, bgr2rgb, float32)
112
+
113
+
114
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
115
+ """Convert torch Tensors into image numpy arrays.
116
+
117
+ After clamping to [min, max], values will be normalized to [0, 1].
118
+
119
+ Args:
120
+ tensor (Tensor or list[Tensor]): Accept shapes:
121
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
122
+ 2) 3D Tensor of shape (3/1 x H x W);
123
+ 3) 2D Tensor of shape (H x W).
124
+ Tensor channel should be in RGB order.
125
+ rgb2bgr (bool): Whether to change rgb to bgr.
126
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
127
+ to uint8 type with range [0, 255]; otherwise, float type with
128
+ range [0, 1]. Default: ``np.uint8``.
129
+ min_max (tuple[int]): min and max values for clamp.
130
+
131
+ Returns:
132
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
133
+ shape (H x W). The channel order is BGR.
134
+ """
135
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
136
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
137
+
138
+ if torch.is_tensor(tensor):
139
+ tensor = [tensor]
140
+ result = []
141
+ for _tensor in tensor:
142
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
143
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
144
+
145
+ n_dim = _tensor.dim()
146
+ if n_dim == 4:
147
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
148
+ img_np = img_np.transpose(1, 2, 0)
149
+ if rgb2bgr:
150
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
151
+ elif n_dim == 3:
152
+ img_np = _tensor.numpy()
153
+ img_np = img_np.transpose(1, 2, 0)
154
+ if img_np.shape[2] == 1: # gray image
155
+ img_np = np.squeeze(img_np, axis=2)
156
+ else:
157
+ if rgb2bgr:
158
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
159
+ elif n_dim == 2:
160
+ img_np = _tensor.numpy()
161
+ else:
162
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
163
+ if out_type == np.uint8:
164
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
165
+ img_np = (img_np * 255.0).round()
166
+ img_np = img_np.astype(out_type)
167
+ result.append(img_np)
168
+ if len(result) == 1:
169
+ result = result[0]
170
+ return result
171
+
172
+
173
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
174
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
175
+
176
+ We use vertical flip and transpose for rotation implementation.
177
+ All the images in the list use the same augmentation.
178
+
179
+ Args:
180
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
181
+ is an ndarray, it will be transformed to a list.
182
+ hflip (bool): Horizontal flip. Default: True.
183
+ rotation (bool): Ratotation. Default: True.
184
+ flows (list[ndarray]: Flows to be augmented. If the input is an
185
+ ndarray, it will be transformed to a list.
186
+ Dimension is (h, w, 2). Default: None.
187
+ return_status (bool): Return the status of flip and rotation.
188
+ Default: False.
189
+
190
+ Returns:
191
+ list[ndarray] | ndarray: Augmented images and flows. If returned
192
+ results only have one element, just return ndarray.
193
+
194
+ """
195
+ hflip = hflip and random.random() < 0.5
196
+ vflip = rotation and random.random() < 0.5
197
+ rot90 = rotation and random.random() < 0.5
198
+
199
+ def _augment(img):
200
+ if hflip: # horizontal
201
+ cv2.flip(img, 1, img)
202
+ if vflip: # vertical
203
+ cv2.flip(img, 0, img)
204
+ if rot90:
205
+ img = img.transpose(1, 0, 2)
206
+ return img
207
+
208
+ def _augment_flow(flow):
209
+ if hflip: # horizontal
210
+ cv2.flip(flow, 1, flow)
211
+ flow[:, :, 0] *= -1
212
+ if vflip: # vertical
213
+ cv2.flip(flow, 0, flow)
214
+ flow[:, :, 1] *= -1
215
+ if rot90:
216
+ flow = flow.transpose(1, 0, 2)
217
+ flow = flow[:, :, [1, 0]]
218
+ return flow
219
+
220
+ if not isinstance(imgs, list):
221
+ imgs = [imgs]
222
+ imgs = [_augment(img) for img in imgs]
223
+ if len(imgs) == 1:
224
+ imgs = imgs[0]
225
+
226
+ if flows is not None:
227
+ if not isinstance(flows, list):
228
+ flows = [flows]
229
+ flows = [_augment_flow(flow) for flow in flows]
230
+ if len(flows) == 1:
231
+ flows = flows[0]
232
+ return imgs, flows
233
+ else:
234
+ if return_status:
235
+ return imgs, (hflip, vflip, rot90)
236
+ else:
237
+ return imgs
238
+
239
+
240
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
241
+ """Paired random crop. Support Numpy array and Tensor inputs.
242
+
243
+ It crops lists of lq and gt images with corresponding locations.
244
+
245
+ Args:
246
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
247
+ should have the same shape. If the input is an ndarray, it will
248
+ be transformed to a list containing itself.
249
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
250
+ should have the same shape. If the input is an ndarray, it will
251
+ be transformed to a list containing itself.
252
+ gt_patch_size (int): GT patch size.
253
+ scale (int): Scale factor.
254
+ gt_path (str): Path to ground-truth. Default: None.
255
+
256
+ Returns:
257
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
258
+ only have one element, just return ndarray.
259
+ """
260
+
261
+ if not isinstance(img_gts, list):
262
+ img_gts = [img_gts]
263
+ if not isinstance(img_lqs, list):
264
+ img_lqs = [img_lqs]
265
+
266
+ # determine input type: Numpy array or Tensor
267
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
268
+
269
+ if input_type == 'Tensor':
270
+ h_lq, w_lq = img_lqs[0].size()[-2:]
271
+ h_gt, w_gt = img_gts[0].size()[-2:]
272
+ else:
273
+ h_lq, w_lq = img_lqs[0].shape[0:2]
274
+ h_gt, w_gt = img_gts[0].shape[0:2]
275
+ lq_patch_size = gt_patch_size // scale
276
+
277
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
278
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
279
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
280
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
281
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
282
+ f'({lq_patch_size}, {lq_patch_size}). '
283
+ f'Please remove {gt_path}.')
284
+
285
+ # randomly choose top and left coordinates for lq patch
286
+ top = random.randint(0, h_lq - lq_patch_size)
287
+ left = random.randint(0, w_lq - lq_patch_size)
288
+
289
+ # crop lq patch
290
+ if input_type == 'Tensor':
291
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
292
+ else:
293
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
294
+
295
+ # crop corresponding gt patch
296
+ top_gt, left_gt = int(top * scale), int(left * scale)
297
+ if input_type == 'Tensor':
298
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
299
+ else:
300
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
301
+ if len(img_gts) == 1:
302
+ img_gts = img_gts[0]
303
+ if len(img_lqs) == 1:
304
+ img_lqs = img_lqs[0]
305
+ return img_gts, img_lqs
306
+
307
+
308
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
309
+ class BaseStorageBackend(metaclass=ABCMeta):
310
+ """Abstract class of storage backends.
311
+
312
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
313
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
314
+ as texts.
315
+ """
316
+
317
+ @abstractmethod
318
+ def get(self, filepath):
319
+ pass
320
+
321
+ @abstractmethod
322
+ def get_text(self, filepath):
323
+ pass
324
+
325
+
326
+ class MemcachedBackend(BaseStorageBackend):
327
+ """Memcached storage backend.
328
+
329
+ Attributes:
330
+ server_list_cfg (str): Config file for memcached server list.
331
+ client_cfg (str): Config file for memcached client.
332
+ sys_path (str | None): Additional path to be appended to `sys.path`.
333
+ Default: None.
334
+ """
335
+
336
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
337
+ if sys_path is not None:
338
+ import sys
339
+ sys.path.append(sys_path)
340
+ try:
341
+ import mc
342
+ except ImportError:
343
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
344
+
345
+ self.server_list_cfg = server_list_cfg
346
+ self.client_cfg = client_cfg
347
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
348
+ # mc.pyvector servers as a point which points to a memory cache
349
+ self._mc_buffer = mc.pyvector()
350
+
351
+ def get(self, filepath):
352
+ filepath = str(filepath)
353
+ import mc
354
+ self._client.Get(filepath, self._mc_buffer)
355
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
356
+ return value_buf
357
+
358
+ def get_text(self, filepath):
359
+ raise NotImplementedError
360
+
361
+
362
+ class HardDiskBackend(BaseStorageBackend):
363
+ """Raw hard disks storage backend."""
364
+
365
+ def get(self, filepath):
366
+ filepath = str(filepath)
367
+ with open(filepath, 'rb') as f:
368
+ value_buf = f.read()
369
+ return value_buf
370
+
371
+ def get_text(self, filepath):
372
+ filepath = str(filepath)
373
+ with open(filepath, 'r') as f:
374
+ value_buf = f.read()
375
+ return value_buf
376
+
377
+
378
+ class LmdbBackend(BaseStorageBackend):
379
+ """Lmdb storage backend.
380
+
381
+ Args:
382
+ db_paths (str | list[str]): Lmdb database paths.
383
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
384
+ readonly (bool, optional): Lmdb environment parameter. If True,
385
+ disallow any write operations. Default: True.
386
+ lock (bool, optional): Lmdb environment parameter. If False, when
387
+ concurrent access occurs, do not lock the database. Default: False.
388
+ readahead (bool, optional): Lmdb environment parameter. If False,
389
+ disable the OS filesystem readahead mechanism, which may improve
390
+ random read performance when a database is larger than RAM.
391
+ Default: False.
392
+
393
+ Attributes:
394
+ db_paths (list): Lmdb database path.
395
+ _client (list): A list of several lmdb envs.
396
+ """
397
+
398
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
399
+ try:
400
+ import lmdb
401
+ except ImportError:
402
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
403
+
404
+ if isinstance(client_keys, str):
405
+ client_keys = [client_keys]
406
+
407
+ if isinstance(db_paths, list):
408
+ self.db_paths = [str(v) for v in db_paths]
409
+ elif isinstance(db_paths, str):
410
+ self.db_paths = [str(db_paths)]
411
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
412
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
413
+
414
+ self._client = {}
415
+ for client, path in zip(client_keys, self.db_paths):
416
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
417
+
418
+ def get(self, filepath, client_key):
419
+ """Get values according to the filepath from one lmdb named client_key.
420
+
421
+ Args:
422
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
423
+ client_key (str): Used for distinguishing different lmdb envs.
424
+ """
425
+ filepath = str(filepath)
426
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
427
+ client = self._client[client_key]
428
+ with client.begin(write=False) as txn:
429
+ value_buf = txn.get(filepath.encode('ascii'))
430
+ return value_buf
431
+
432
+ def get_text(self, filepath):
433
+ raise NotImplementedError
434
+
435
+
436
+ class FileClient(object):
437
+ """A general file client to access files in different backend.
438
+
439
+ The client loads a file or text in a specified backend from its path
440
+ and return it as a binary file. it can also register other backend
441
+ accessor with a given name and backend class.
442
+
443
+ Attributes:
444
+ backend (str): The storage backend type. Options are "disk",
445
+ "memcached" and "lmdb".
446
+ client (:obj:`BaseStorageBackend`): The backend object.
447
+ """
448
+
449
+ _backends = {
450
+ 'disk': HardDiskBackend,
451
+ 'memcached': MemcachedBackend,
452
+ 'lmdb': LmdbBackend,
453
+ }
454
+
455
+ def __init__(self, backend='disk', **kwargs):
456
+ if backend not in self._backends:
457
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
458
+ f' are {list(self._backends.keys())}')
459
+ self.backend = backend
460
+ self.client = self._backends[backend](**kwargs)
461
+
462
+ def get(self, filepath, client_key='default'):
463
+ # client_key is used only for lmdb, where different fileclients have
464
+ # different lmdb environments.
465
+ if self.backend == 'lmdb':
466
+ return self.client.get(filepath, client_key)
467
+ else:
468
+ return self.client.get(filepath)
469
+
470
+ def get_text(self, filepath):
471
+ return self.client.get_text(filepath)
472
+
473
+
474
+ def imfrombytes(content, flag='color', float32=False):
475
+ """Read an image from bytes.
476
+
477
+ Args:
478
+ content (bytes): Image bytes got from files or other streams.
479
+ flag (str): Flags specifying the color type of a loaded image,
480
+ candidates are `color`, `grayscale` and `unchanged`.
481
+ float32 (bool): Whether to change to float32., If True, will also norm
482
+ to [0, 1]. Default: False.
483
+
484
+ Returns:
485
+ ndarray: Loaded image array.
486
+ """
487
+ img_np = np.frombuffer(content, np.uint8)
488
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
489
+ img = cv2.imdecode(img_np, imread_flags[flag])
490
+ if float32:
491
+ img = img.astype(np.float32) / 255.
492
+ return img
493
+
core/data/deg_kair_utils/utils_videoio.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import random
6
+ from os import path as osp
7
+ from torchvision.utils import make_grid
8
+ import sys
9
+ from pathlib import Path
10
+ import six
11
+ from collections import OrderedDict
12
+ import math
13
+ import glob
14
+ import av
15
+ import io
16
+ from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
17
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
18
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
19
+
20
+ if sys.version_info <= (3, 3):
21
+ FileNotFoundError = IOError
22
+ else:
23
+ FileNotFoundError = FileNotFoundError
24
+
25
+
26
+ def is_str(x):
27
+ """Whether the input is an string instance."""
28
+ return isinstance(x, six.string_types)
29
+
30
+
31
+ def is_filepath(x):
32
+ return is_str(x) or isinstance(x, Path)
33
+
34
+
35
+ def fopen(filepath, *args, **kwargs):
36
+ if is_str(filepath):
37
+ return open(filepath, *args, **kwargs)
38
+ elif isinstance(filepath, Path):
39
+ return filepath.open(*args, **kwargs)
40
+ raise ValueError('`filepath` should be a string or a Path')
41
+
42
+
43
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
44
+ if not osp.isfile(filename):
45
+ raise FileNotFoundError(msg_tmpl.format(filename))
46
+
47
+
48
+ def mkdir_or_exist(dir_name, mode=0o777):
49
+ if dir_name == '':
50
+ return
51
+ dir_name = osp.expanduser(dir_name)
52
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
53
+
54
+
55
+ def symlink(src, dst, overwrite=True, **kwargs):
56
+ if os.path.lexists(dst) and overwrite:
57
+ os.remove(dst)
58
+ os.symlink(src, dst, **kwargs)
59
+
60
+
61
+ def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
62
+ """Scan a directory to find the interested files.
63
+ Args:
64
+ dir_path (str | :obj:`Path`): Path of the directory.
65
+ suffix (str | tuple(str), optional): File suffix that we are
66
+ interested in. Default: None.
67
+ recursive (bool, optional): If set to True, recursively scan the
68
+ directory. Default: False.
69
+ case_sensitive (bool, optional) : If set to False, ignore the case of
70
+ suffix. Default: True.
71
+ Returns:
72
+ A generator for all the interested files with relative paths.
73
+ """
74
+ if isinstance(dir_path, (str, Path)):
75
+ dir_path = str(dir_path)
76
+ else:
77
+ raise TypeError('"dir_path" must be a string or Path object')
78
+
79
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
80
+ raise TypeError('"suffix" must be a string or tuple of strings')
81
+
82
+ if suffix is not None and not case_sensitive:
83
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
84
+ item.lower() for item in suffix)
85
+
86
+ root = dir_path
87
+
88
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
89
+ for entry in os.scandir(dir_path):
90
+ if not entry.name.startswith('.') and entry.is_file():
91
+ rel_path = osp.relpath(entry.path, root)
92
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
93
+ if suffix is None or _rel_path.endswith(suffix):
94
+ yield rel_path
95
+ elif recursive and os.path.isdir(entry.path):
96
+ # scan recursively if entry.path is a directory
97
+ yield from _scandir(entry.path, suffix, recursive,
98
+ case_sensitive)
99
+
100
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
101
+
102
+
103
+ class Cache:
104
+
105
+ def __init__(self, capacity):
106
+ self._cache = OrderedDict()
107
+ self._capacity = int(capacity)
108
+ if capacity <= 0:
109
+ raise ValueError('capacity must be a positive integer')
110
+
111
+ @property
112
+ def capacity(self):
113
+ return self._capacity
114
+
115
+ @property
116
+ def size(self):
117
+ return len(self._cache)
118
+
119
+ def put(self, key, val):
120
+ if key in self._cache:
121
+ return
122
+ if len(self._cache) >= self.capacity:
123
+ self._cache.popitem(last=False)
124
+ self._cache[key] = val
125
+
126
+ def get(self, key, default=None):
127
+ val = self._cache[key] if key in self._cache else default
128
+ return val
129
+
130
+
131
+ class VideoReader:
132
+ """Video class with similar usage to a list object.
133
+
134
+ This video warpper class provides convenient apis to access frames.
135
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
136
+ certain frame may be inaccurate. It is fixed in this class by checking
137
+ the position after jumping each time.
138
+ Cache is used when decoding videos. So if the same frame is visited for
139
+ the second time, there is no need to decode again if it is stored in the
140
+ cache.
141
+
142
+ """
143
+
144
+ def __init__(self, filename, cache_capacity=10):
145
+ # Check whether the video path is a url
146
+ if not filename.startswith(('https://', 'http://')):
147
+ check_file_exist(filename, 'Video file not found: ' + filename)
148
+ self._vcap = cv2.VideoCapture(filename)
149
+ assert cache_capacity > 0
150
+ self._cache = Cache(cache_capacity)
151
+ self._position = 0
152
+ # get basic info
153
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
154
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
155
+ self._fps = self._vcap.get(CAP_PROP_FPS)
156
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
157
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
158
+
159
+ @property
160
+ def vcap(self):
161
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
162
+ return self._vcap
163
+
164
+ @property
165
+ def opened(self):
166
+ """bool: Indicate whether the video is opened."""
167
+ return self._vcap.isOpened()
168
+
169
+ @property
170
+ def width(self):
171
+ """int: Width of video frames."""
172
+ return self._width
173
+
174
+ @property
175
+ def height(self):
176
+ """int: Height of video frames."""
177
+ return self._height
178
+
179
+ @property
180
+ def resolution(self):
181
+ """tuple: Video resolution (width, height)."""
182
+ return (self._width, self._height)
183
+
184
+ @property
185
+ def fps(self):
186
+ """float: FPS of the video."""
187
+ return self._fps
188
+
189
+ @property
190
+ def frame_cnt(self):
191
+ """int: Total frames of the video."""
192
+ return self._frame_cnt
193
+
194
+ @property
195
+ def fourcc(self):
196
+ """str: "Four character code" of the video."""
197
+ return self._fourcc
198
+
199
+ @property
200
+ def position(self):
201
+ """int: Current cursor position, indicating frame decoded."""
202
+ return self._position
203
+
204
+ def _get_real_position(self):
205
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
206
+
207
+ def _set_real_position(self, frame_id):
208
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
209
+ pos = self._get_real_position()
210
+ for _ in range(frame_id - pos):
211
+ self._vcap.read()
212
+ self._position = frame_id
213
+
214
+ def read(self):
215
+ """Read the next frame.
216
+
217
+ If the next frame have been decoded before and in the cache, then
218
+ return it directly, otherwise decode, cache and return it.
219
+
220
+ Returns:
221
+ ndarray or None: Return the frame if successful, otherwise None.
222
+ """
223
+ # pos = self._position
224
+ if self._cache:
225
+ img = self._cache.get(self._position)
226
+ if img is not None:
227
+ ret = True
228
+ else:
229
+ if self._position != self._get_real_position():
230
+ self._set_real_position(self._position)
231
+ ret, img = self._vcap.read()
232
+ if ret:
233
+ self._cache.put(self._position, img)
234
+ else:
235
+ ret, img = self._vcap.read()
236
+ if ret:
237
+ self._position += 1
238
+ return img
239
+
240
+ def get_frame(self, frame_id):
241
+ """Get frame by index.
242
+
243
+ Args:
244
+ frame_id (int): Index of the expected frame, 0-based.
245
+
246
+ Returns:
247
+ ndarray or None: Return the frame if successful, otherwise None.
248
+ """
249
+ if frame_id < 0 or frame_id >= self._frame_cnt:
250
+ raise IndexError(
251
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
252
+ if frame_id == self._position:
253
+ return self.read()
254
+ if self._cache:
255
+ img = self._cache.get(frame_id)
256
+ if img is not None:
257
+ self._position = frame_id + 1
258
+ return img
259
+ self._set_real_position(frame_id)
260
+ ret, img = self._vcap.read()
261
+ if ret:
262
+ if self._cache:
263
+ self._cache.put(self._position, img)
264
+ self._position += 1
265
+ return img
266
+
267
+ def current_frame(self):
268
+ """Get the current frame (frame that is just visited).
269
+
270
+ Returns:
271
+ ndarray or None: If the video is fresh, return None, otherwise
272
+ return the frame.
273
+ """
274
+ if self._position == 0:
275
+ return None
276
+ return self._cache.get(self._position - 1)
277
+
278
+ def cvt2frames(self,
279
+ frame_dir,
280
+ file_start=0,
281
+ filename_tmpl='{:06d}.jpg',
282
+ start=0,
283
+ max_num=0,
284
+ show_progress=False):
285
+ """Convert a video to frame images.
286
+
287
+ Args:
288
+ frame_dir (str): Output directory to store all the frame images.
289
+ file_start (int): Filenames will start from the specified number.
290
+ filename_tmpl (str): Filename template with the index as the
291
+ placeholder.
292
+ start (int): The starting frame index.
293
+ max_num (int): Maximum number of frames to be written.
294
+ show_progress (bool): Whether to show a progress bar.
295
+ """
296
+ mkdir_or_exist(frame_dir)
297
+ if max_num == 0:
298
+ task_num = self.frame_cnt - start
299
+ else:
300
+ task_num = min(self.frame_cnt - start, max_num)
301
+ if task_num <= 0:
302
+ raise ValueError('start must be less than total frame number')
303
+ if start > 0:
304
+ self._set_real_position(start)
305
+
306
+ def write_frame(file_idx):
307
+ img = self.read()
308
+ if img is None:
309
+ return
310
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
311
+ cv2.imwrite(filename, img)
312
+
313
+ if show_progress:
314
+ pass
315
+ #track_progress(write_frame, range(file_start,file_start + task_num))
316
+ else:
317
+ for i in range(task_num):
318
+ write_frame(file_start + i)
319
+
320
+ def __len__(self):
321
+ return self.frame_cnt
322
+
323
+ def __getitem__(self, index):
324
+ if isinstance(index, slice):
325
+ return [
326
+ self.get_frame(i)
327
+ for i in range(*index.indices(self.frame_cnt))
328
+ ]
329
+ # support negative indexing
330
+ if index < 0:
331
+ index += self.frame_cnt
332
+ if index < 0:
333
+ raise IndexError('index out of range')
334
+ return self.get_frame(index)
335
+
336
+ def __iter__(self):
337
+ self._set_real_position(0)
338
+ return self
339
+
340
+ def __next__(self):
341
+ img = self.read()
342
+ if img is not None:
343
+ return img
344
+ else:
345
+ raise StopIteration
346
+
347
+ next = __next__
348
+
349
+ def __enter__(self):
350
+ return self
351
+
352
+ def __exit__(self, exc_type, exc_value, traceback):
353
+ self._vcap.release()
354
+
355
+
356
+ def frames2video(frame_dir,
357
+ video_file,
358
+ fps=30,
359
+ fourcc='XVID',
360
+ filename_tmpl='{:06d}.jpg',
361
+ start=0,
362
+ end=0,
363
+ show_progress=False):
364
+ """Read the frame images from a directory and join them as a video.
365
+
366
+ Args:
367
+ frame_dir (str): The directory containing video frames.
368
+ video_file (str): Output filename.
369
+ fps (float): FPS of the output video.
370
+ fourcc (str): Fourcc of the output video, this should be compatible
371
+ with the output file type.
372
+ filename_tmpl (str): Filename template with the index as the variable.
373
+ start (int): Starting frame index.
374
+ end (int): Ending frame index.
375
+ show_progress (bool): Whether to show a progress bar.
376
+ """
377
+ if end == 0:
378
+ ext = filename_tmpl.split('.')[-1]
379
+ end = len([name for name in scandir(frame_dir, ext)])
380
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
381
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
382
+ img = cv2.imread(first_file)
383
+ height, width = img.shape[:2]
384
+ resolution = (width, height)
385
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
386
+ resolution)
387
+
388
+ def write_frame(file_idx):
389
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
390
+ img = cv2.imread(filename)
391
+ vwriter.write(img)
392
+
393
+ if show_progress:
394
+ pass
395
+ # track_progress(write_frame, range(start, end))
396
+ else:
397
+ for i in range(start, end):
398
+ write_frame(i)
399
+ vwriter.release()
400
+
401
+
402
+ def video2images(video_path, output_dir):
403
+ vidcap = cv2.VideoCapture(video_path)
404
+ in_fps = vidcap.get(cv2.CAP_PROP_FPS)
405
+ print('video fps:', in_fps)
406
+ if not os.path.isdir(output_dir):
407
+ os.makedirs(output_dir)
408
+ loaded, frame = vidcap.read()
409
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
410
+ print(f'number of total frames is: {total_frames:06}')
411
+ for i_frame in range(total_frames):
412
+ if i_frame % 100 == 0:
413
+ print(f'{i_frame:06} / {total_frames:06}')
414
+ frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
415
+ cv2.imwrite(frame_name, frame)
416
+ loaded, frame = vidcap.read()
417
+
418
+
419
+ def images2video(image_dir, video_path, fps=24, image_ext='png'):
420
+ '''
421
+ #codec = cv2.VideoWriter_fourcc(*'XVID')
422
+ #codec = cv2.VideoWriter_fourcc('A','V','C','1')
423
+ #codec = cv2.VideoWriter_fourcc('Y','U','V','1')
424
+ #codec = cv2.VideoWriter_fourcc('P','I','M','1')
425
+ #codec = cv2.VideoWriter_fourcc('M','J','P','G')
426
+ codec = cv2.VideoWriter_fourcc('M','P','4','2')
427
+ #codec = cv2.VideoWriter_fourcc('D','I','V','3')
428
+ #codec = cv2.VideoWriter_fourcc('D','I','V','X')
429
+ #codec = cv2.VideoWriter_fourcc('U','2','6','3')
430
+ #codec = cv2.VideoWriter_fourcc('I','2','6','3')
431
+ #codec = cv2.VideoWriter_fourcc('F','L','V','1')
432
+ #codec = cv2.VideoWriter_fourcc('H','2','6','4')
433
+ #codec = cv2.VideoWriter_fourcc('A','Y','U','V')
434
+ #codec = cv2.VideoWriter_fourcc('I','U','Y','V')
435
+ 编码器常用的几种:
436
+ cv2.VideoWriter_fourcc("I", "4", "2", "0")
437
+ 压缩的yuv颜色编码器,4:2:0色彩度子采样 兼容性好,产生很大的视频 avi
438
+ cv2.VideoWriter_fourcc("P", I", "M", "1")
439
+ 采用mpeg-1编码,文件为avi
440
+ cv2.VideoWriter_fourcc("X", "V", "T", "D")
441
+ 采用mpeg-4编码,得到视频大小平均 拓展名avi
442
+ cv2.VideoWriter_fourcc("T", "H", "E", "O")
443
+ Ogg Vorbis, 拓展名为ogv
444
+ cv2.VideoWriter_fourcc("F", "L", "V", "1")
445
+ FLASH视频,拓展名为.flv
446
+ '''
447
+ image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
448
+ print(len(image_files))
449
+ height, width, _ = cv2.imread(image_files[0]).shape
450
+ out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
451
+ out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
452
+
453
+ for image_file in image_files:
454
+ img = cv2.imread(image_file)
455
+ img = cv2.resize(img, (width, height), interpolation=3)
456
+ out_video.write(img)
457
+ out_video.release()
458
+
459
+
460
+ def add_video_compression(imgs):
461
+ codec_type = ['libx264', 'h264', 'mpeg4']
462
+ codec_prob = [1 / 3., 1 / 3., 1 / 3.]
463
+ codec = random.choices(codec_type, codec_prob)[0]
464
+ # codec = 'mpeg4'
465
+ bitrate = [1e4, 1e5]
466
+ bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
467
+
468
+ buf = io.BytesIO()
469
+ with av.open(buf, 'w', 'mp4') as container:
470
+ stream = container.add_stream(codec, rate=1)
471
+ stream.height = imgs[0].shape[0]
472
+ stream.width = imgs[0].shape[1]
473
+ stream.pix_fmt = 'yuv420p'
474
+ stream.bit_rate = bitrate
475
+
476
+ for img in imgs:
477
+ img = np.uint8((img.clip(0, 1)*255.).round())
478
+ frame = av.VideoFrame.from_ndarray(img, format='rgb24')
479
+ frame.pict_type = 'NONE'
480
+ # pdb.set_trace()
481
+ for packet in stream.encode(frame):
482
+ container.mux(packet)
483
+
484
+ # Flush stream
485
+ for packet in stream.encode():
486
+ container.mux(packet)
487
+
488
+ outputs = []
489
+ with av.open(buf, 'r', 'mp4') as container:
490
+ if container.streams.video:
491
+ for frame in container.decode(**{'video': 0}):
492
+ outputs.append(
493
+ frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
494
+
495
+ #outputs = np.stack(outputs, axis=0)
496
+ return outputs
497
+
498
+
499
+ if __name__ == '__main__':
500
+
501
+ # -----------------------------------
502
+ # test VideoReader(filename, cache_capacity=10)
503
+ # -----------------------------------
504
+ # video_reader = VideoReader('utils/test.mp4')
505
+ # from utils import utils_image as util
506
+ # inputs = []
507
+ # for frame in video_reader:
508
+ # print(frame.dtype)
509
+ # util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
510
+ # #util.imshow(np.flip(frame, axis=2))
511
+
512
+ # -----------------------------------
513
+ # test video2images(video_path, output_dir)
514
+ # -----------------------------------
515
+ # video2images('utils/test.mp4', 'frames')
516
+
517
+ # -----------------------------------
518
+ # test images2video(image_dir, video_path, fps=24, image_ext='png')
519
+ # -----------------------------------
520
+ # images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
521
+
522
+
523
+ # -----------------------------------
524
+ # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
525
+ # -----------------------------------
526
+ # frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
527
+
528
+
529
+ # -----------------------------------
530
+ # test add_video_compression(imgs)
531
+ # -----------------------------------
532
+ # imgs = []
533
+ # image_ext = 'png'
534
+ # frames = 'frames'
535
+ # from utils import utils_image as util
536
+ # image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
537
+ # for i, image_file in enumerate(image_files):
538
+ # if i < 7:
539
+ # img = util.imread_uint(image_file, 3)
540
+ # img = util.uint2single(img)
541
+ # imgs.append(img)
542
+ #
543
+ # results = add_video_compression(imgs)
544
+ # for i, img in enumerate(results):
545
+ # util.imshow(util.single2uint(img))
546
+ # util.imsave(util.single2uint(img),f'{i:05}.png')
547
+
548
+ # run utils/utils_video.py
549
+
550
+
551
+
552
+
553
+
554
+
555
+
core/scripts/__init__.py ADDED
File without changes
core/scripts/cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ from .. import WarpCore
4
+ from .. import templates
5
+
6
+
7
+ def template_init(args):
8
+ return ''''
9
+
10
+
11
+ '''.strip()
12
+
13
+
14
+ def init_template(args):
15
+ parser = argparse.ArgumentParser(description='WarpCore template init tool')
16
+ parser.add_argument('-t', '--template', type=str, default='WarpCore')
17
+ args = parser.parse_args(args)
18
+
19
+ if args.template == 'WarpCore':
20
+ template_cls = WarpCore
21
+ else:
22
+ try:
23
+ template_cls = __import__(args.template)
24
+ except ModuleNotFoundError:
25
+ template_cls = getattr(templates, args.template)
26
+ print(template_cls)
27
+
28
+
29
+ def main():
30
+ if len(sys.argv) < 2:
31
+ print('Usage: core <command>')
32
+ sys.exit(1)
33
+ if sys.argv[1] == 'init':
34
+ init_template(sys.argv[2:])
35
+ else:
36
+ print('Unknown command')
37
+ sys.exit(1)
38
+
39
+
40
+ if __name__ == '__main__':
41
+ main()
core/templates/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .diffusion import DiffusionCore
core/templates/diffusion.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import WarpCore
2
+ from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
3
+ from abc import abstractmethod
4
+ from dataclasses import dataclass
5
+ import torch
6
+ from torch import nn
7
+ from torch.utils.data import DataLoader
8
+ from gdf import GDF
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import wandb
12
+
13
+ import webdataset as wds
14
+ from webdataset.handlers import warn_and_continue
15
+ from torch.distributed import barrier
16
+ from enum import Enum
17
+
18
+ class TargetReparametrization(Enum):
19
+ EPSILON = 'epsilon'
20
+ X0 = 'x0'
21
+
22
+ class DiffusionCore(WarpCore):
23
+ @dataclass(frozen=True)
24
+ class Config(WarpCore.Config):
25
+ # TRAINING PARAMS
26
+ lr: float = EXPECTED_TRAIN
27
+ grad_accum_steps: int = EXPECTED_TRAIN
28
+ batch_size: int = EXPECTED_TRAIN
29
+ updates: int = EXPECTED_TRAIN
30
+ warmup_updates: int = EXPECTED_TRAIN
31
+ save_every: int = 500
32
+ backup_every: int = 20000
33
+ use_fsdp: bool = True
34
+
35
+ # EMA UPDATE
36
+ ema_start_iters: int = None
37
+ ema_iters: int = None
38
+ ema_beta: float = None
39
+
40
+ # GDF setting
41
+ gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
42
+
43
+ @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
44
+ class Info(WarpCore.Info):
45
+ ema_loss: float = None
46
+
47
+ @dataclass(frozen=True)
48
+ class Models(WarpCore.Models):
49
+ generator : nn.Module = EXPECTED
50
+ generator_ema : nn.Module = None # optional
51
+
52
+ @dataclass(frozen=True)
53
+ class Optimizers(WarpCore.Optimizers):
54
+ generator : any = EXPECTED
55
+
56
+ @dataclass(frozen=True)
57
+ class Schedulers(WarpCore.Schedulers):
58
+ generator: any = None
59
+
60
+ @dataclass(frozen=True)
61
+ class Extras(WarpCore.Extras):
62
+ gdf: GDF = EXPECTED
63
+ sampling_configs: dict = EXPECTED
64
+
65
+ # --------------------------------------------
66
+ info: Info
67
+ config: Config
68
+
69
+ @abstractmethod
70
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
71
+ raise NotImplementedError("This method needs to be overriden")
72
+
73
+ @abstractmethod
74
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
75
+ raise NotImplementedError("This method needs to be overriden")
76
+
77
+ @abstractmethod
78
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
79
+ raise NotImplementedError("This method needs to be overriden")
80
+
81
+ @abstractmethod
82
+ def webdataset_path(self, extras: Extras):
83
+ raise NotImplementedError("This method needs to be overriden")
84
+
85
+ @abstractmethod
86
+ def webdataset_filters(self, extras: Extras):
87
+ raise NotImplementedError("This method needs to be overriden")
88
+
89
+ @abstractmethod
90
+ def webdataset_preprocessors(self, extras: Extras):
91
+ raise NotImplementedError("This method needs to be overriden")
92
+
93
+ @abstractmethod
94
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
95
+ raise NotImplementedError("This method needs to be overriden")
96
+ # -------------
97
+
98
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
99
+ # SETUP DATASET
100
+ dataset_path = self.webdataset_path(extras)
101
+ preprocessors = self.webdataset_preprocessors(extras)
102
+ filters = self.webdataset_filters(extras)
103
+
104
+ handler = warn_and_continue # None
105
+ # handler = None
106
+ dataset = wds.WebDataset(
107
+ dataset_path, resampled=True, handler=handler
108
+ ).select(filters).shuffle(690, handler=handler).decode(
109
+ "pilrgb", handler=handler
110
+ ).to_tuple(
111
+ *[p[0] for p in preprocessors], handler=handler
112
+ ).map_tuple(
113
+ *[p[1] for p in preprocessors], handler=handler
114
+ ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
115
+
116
+ # SETUP DATALOADER
117
+ real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
118
+ dataloader = DataLoader(
119
+ dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
120
+ )
121
+
122
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
123
+
124
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
125
+ batch = next(data.iterator)
126
+
127
+ with torch.no_grad():
128
+ conditions = self.get_conditions(batch, models, extras)
129
+ latents = self.encode_latents(batch, models, extras)
130
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
131
+
132
+ # FORWARD PASS
133
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
134
+ pred = models.generator(noised, noise_cond, **conditions)
135
+ if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
136
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
137
+ target = noise
138
+ elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
139
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
140
+ target = latents
141
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
142
+ loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
143
+
144
+ return loss, loss_adjusted
145
+
146
+ def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
147
+ start_iter = self.info.iter+1
148
+ max_iters = self.config.updates * self.config.grad_accum_steps
149
+ if self.is_main_node:
150
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
151
+
152
+ pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
153
+ models.generator.train()
154
+ for i in pbar:
155
+ # FORWARD PASS
156
+ loss, loss_adjusted = self.forward_pass(data, extras, models)
157
+
158
+ # BACKWARD PASS
159
+ if i % self.config.grad_accum_steps == 0 or i == max_iters:
160
+ loss_adjusted.backward()
161
+ grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
162
+ optimizers_dict = optimizers.to_dict()
163
+ for k in optimizers_dict:
164
+ optimizers_dict[k].step()
165
+ schedulers_dict = schedulers.to_dict()
166
+ for k in schedulers_dict:
167
+ schedulers_dict[k].step()
168
+ models.generator.zero_grad(set_to_none=True)
169
+ self.info.total_steps += 1
170
+ else:
171
+ with models.generator.no_sync():
172
+ loss_adjusted.backward()
173
+ self.info.iter = i
174
+
175
+ # UPDATE EMA
176
+ if models.generator_ema is not None and i % self.config.ema_iters == 0:
177
+ update_weights_ema(
178
+ models.generator_ema, models.generator,
179
+ beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
180
+ )
181
+
182
+ # UPDATE LOSS METRICS
183
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
184
+
185
+ if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
186
+ wandb.alert(
187
+ title=f"NaN value encountered in training run {self.info.wandb_run_id}",
188
+ text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
189
+ wait_duration=60*30
190
+ )
191
+
192
+ if self.is_main_node:
193
+ logs = {
194
+ 'loss': self.info.ema_loss,
195
+ 'raw_loss': loss.mean().item(),
196
+ 'grad_norm': grad_norm.item(),
197
+ 'lr': optimizers.generator.param_groups[0]['lr'],
198
+ 'total_steps': self.info.total_steps,
199
+ }
200
+
201
+ pbar.set_postfix(logs)
202
+ if self.config.wandb_project is not None:
203
+ wandb.log(logs)
204
+
205
+ if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
206
+ # SAVE AND CHECKPOINT STUFF
207
+ if np.isnan(loss.mean().item()):
208
+ if self.is_main_node and self.config.wandb_project is not None:
209
+ tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
210
+ wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
211
+ else:
212
+ self.save_checkpoints(models, optimizers)
213
+ if self.is_main_node:
214
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
215
+ self.sample(models, data, extras)
216
+
217
+ def models_to_save(self):
218
+ return ['generator', 'generator_ema']
219
+
220
+ def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
221
+ barrier()
222
+ suffix = '' if suffix is None else suffix
223
+ self.save_info(self.info, suffix=suffix)
224
+ models_dict = models.to_dict()
225
+ optimizers_dict = optimizers.to_dict()
226
+ for key in self.models_to_save():
227
+ model = models_dict[key]
228
+ if model is not None:
229
+ self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
230
+ for key in optimizers_dict:
231
+ optimizer = optimizers_dict[key]
232
+ if optimizer is not None:
233
+ self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
234
+ if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
235
+ self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
236
+ torch.cuda.empty_cache()
core/utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
2
+ from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
3
+
4
+ # MOVE IT SOMERWHERE ELSE
5
+ def update_weights_ema(tgt_model, src_model, beta=0.999):
6
+ for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
7
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
8
+ for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
9
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
core/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (763 Bytes). View file
 
core/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (804 Bytes). View file
 
core/utils/__pycache__/base_dto.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
core/utils/__pycache__/base_dto.cpython-39.pyc ADDED
Binary file (3.11 kB). View file
 
core/utils/__pycache__/save_and_load.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
core/utils/__pycache__/save_and_load.cpython-39.pyc ADDED
Binary file (2.2 kB). View file
 
core/utils/base_dto.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from dataclasses import dataclass, _MISSING_TYPE
3
+ from munch import Munch
4
+
5
+ EXPECTED = "___REQUIRED___"
6
+ EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
7
+
8
+ # pylint: disable=invalid-field-call
9
+ def nested_dto(x, raw=False):
10
+ return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
11
+
12
+ @dataclass(frozen=True)
13
+ class Base:
14
+ training: bool = None
15
+ def __new__(cls, **kwargs):
16
+ training = kwargs.get('training', True)
17
+ setteable_fields = cls.setteable_fields(**kwargs)
18
+ mandatory_fields = cls.mandatory_fields(**kwargs)
19
+ invalid_kwargs = [
20
+ {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
21
+ ]
22
+ print(mandatory_fields)
23
+ assert (
24
+ len(invalid_kwargs) == 0
25
+ ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
26
+ missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
27
+ assert (
28
+ len(missing_kwargs) == 0
29
+ ), f"Required fields missing initializing this DTO: {missing_kwargs}."
30
+ return object.__new__(cls)
31
+
32
+
33
+ @classmethod
34
+ def setteable_fields(cls, **kwargs):
35
+ return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
36
+
37
+ @classmethod
38
+ def mandatory_fields(cls, **kwargs):
39
+ training = kwargs.get('training', True)
40
+ return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
41
+
42
+ @classmethod
43
+ def from_dict(cls, kwargs):
44
+ for k in kwargs:
45
+ if isinstance(kwargs[k], (dict, list, tuple)):
46
+ kwargs[k] = Munch.fromDict(kwargs[k])
47
+ return cls(**kwargs)
48
+
49
+ def to_dict(self):
50
+ # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
51
+ selfdict = {}
52
+ for k in dataclasses.fields(self):
53
+ selfdict[k.name] = getattr(self, k.name)
54
+ if isinstance(selfdict[k.name], Munch):
55
+ selfdict[k.name] = selfdict[k.name].toDict()
56
+ return selfdict
core/utils/save_and_load.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ from pathlib import Path
5
+ import safetensors
6
+ import wandb
7
+
8
+
9
+ def create_folder_if_necessary(path):
10
+ path = "/".join(path.split("/")[:-1])
11
+ Path(path).mkdir(parents=True, exist_ok=True)
12
+
13
+
14
+ def safe_save(ckpt, path):
15
+ try:
16
+ os.remove(f"{path}.bak")
17
+ except OSError:
18
+ pass
19
+ try:
20
+ os.rename(path, f"{path}.bak")
21
+ except OSError:
22
+ pass
23
+ if path.endswith(".pt") or path.endswith(".ckpt"):
24
+ torch.save(ckpt, path)
25
+ elif path.endswith(".json"):
26
+ with open(path, "w", encoding="utf-8") as f:
27
+ json.dump(ckpt, f, indent=4)
28
+ elif path.endswith(".safetensors"):
29
+ safetensors.torch.save_file(ckpt, path)
30
+ else:
31
+ raise ValueError(f"File extension not supported: {path}")
32
+
33
+
34
+ def load_or_fail(path, wandb_run_id=None):
35
+ accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
36
+ try:
37
+ assert any(
38
+ [path.endswith(ext) for ext in accepted_extensions]
39
+ ), f"Automatic loading not supported for this extension: {path}"
40
+ if not os.path.exists(path):
41
+ checkpoint = None
42
+ elif path.endswith(".pt") or path.endswith(".ckpt"):
43
+ checkpoint = torch.load(path, map_location="cpu")
44
+ elif path.endswith(".json"):
45
+ with open(path, "r", encoding="utf-8") as f:
46
+ checkpoint = json.load(f)
47
+ elif path.endswith(".safetensors"):
48
+ checkpoint = {}
49
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
50
+ for key in f.keys():
51
+ checkpoint[key] = f.get_tensor(key)
52
+ return checkpoint
53
+ except Exception as e:
54
+ if wandb_run_id is not None:
55
+ wandb.alert(
56
+ title=f"Corrupt checkpoint for run {wandb_run_id}",
57
+ text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
58
+ )
59
+ raise e