diff --git a/app.py b/app.py
index bf08cc43b90980c3b7c915b94c841af31e36963e..3c4ac8cf1a8d4a201305effc5566332bf0acf82d 100644
--- a/app.py
+++ b/app.py
@@ -1,213 +1,158 @@
import spaces
-import json
-import subprocess
import os
-import sys
-
-def run_command(command):
- process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
- output, error = process.communicate()
- if process.returncode != 0:
- print(f"Error executing command: {command}")
- print(error.decode('utf-8'))
- exit(1)
- return output.decode('utf-8')
-
-# Download CUDA installer
-download_command = "wget https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
-result = run_command(download_command)
-if result is None:
- print("Failed to download CUDA installer.")
- exit(1)
-
-# Run CUDA installer in silent mode
-install_command = "sh cuda_12.2.0_535.54.03_linux.run --silent --toolkit --samples --override"
-result = run_command(install_command)
-if result is None:
- print("Failed to run CUDA installer.")
- exit(1)
-
-print("CUDA installation process completed.")
-
-def install_packages():
-
- # Clone the repository with submodules
- run_command("git clone --recurse-submodules https://github.com/abetlen/llama-cpp-python.git")
-
- # Change to the cloned directory
- os.chdir("llama-cpp-python")
-
- # Checkout the specific commit in the llama.cpp submodule
- os.chdir("vendor/llama.cpp")
- run_command("git checkout 50e0535")
- os.chdir("../..")
-
- # Upgrade pip
- run_command("pip install --upgrade pip")
-
-
-
- # Install all optional dependencies with CUDA support
- run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
-
- run_command("make clean && GGML_OPENBLAS=1 make -j")
-
- # Reinstall the package with CUDA support
- run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .')
-
- # Install llama-cpp-agent
- run_command("pip install llama-cpp-agent")
-
- run_command("export PYTHONPATH=$PYTHONPATH:$(pwd)")
-
- print("Installation complete!")
-
-try:
- install_packages()
-
- # Add a delay to allow for package registration
- import time
- time.sleep(5)
-
- # Force Python to reload the site packages
- import site
- import importlib
- importlib.reload(site)
-
- # Now try to import the libraries
- from llama_cpp import Llama
- from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
- from llama_cpp_agent.providers import LlamaCppPythonProvider
- from llama_cpp_agent.chat_history import BasicChatHistory
- from llama_cpp_agent.chat_history.messages import Roles
-
- print("Libraries imported successfully!")
-except Exception as e:
- print(f"Installation failed or libraries couldn't be imported: {str(e)}")
- sys.exit(1)
-
+import requests
+import yaml
+import torch
import gradio as gr
+from PIL import Image
+import sys
+sys.path.append(os.path.abspath('./'))
+from inference.utils import *
+from core.utils import load_or_fail
+from train import WurstCoreB
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from train import WurstCore_t2i as WurstCoreC
+import torch.nn.functional as F
+from core.utils import load_or_fail
+import numpy as np
+import random
+import math
+from einops import rearrange
from huggingface_hub import hf_hub_download
-hf_hub_download(
- repo_id="MaziyarPanahi/Mistral-Nemo-Instruct-2407-GGUF",
- filename="Mistral-Nemo-Instruct-2407.Q5_K_M.gguf",
- local_dir="./models"
-)
-# Initialize LLM outside the respond function
-llm = Llama(
- model_path="models/Mistral-Nemo-Instruct-2407.Q5_K_M.gguf",
- flash_attn=True,
- n_gpu_layers=81,
- n_batch=1024,
- n_ctx=32768,
+def download_file(url, folder_path, filename):
+ if not os.path.exists(folder_path):
+ os.makedirs(folder_path)
+ file_path = os.path.join(folder_path, filename)
+
+ if os.path.isfile(file_path):
+ print(f"File already exists: {file_path}")
+ else:
+ response = requests.get(url, stream=True)
+ if response.status_code == 200:
+ with open(file_path, 'wb') as file:
+ for chunk in response.iter_content(chunk_size=1024):
+ file.write(chunk)
+ print(f"File successfully downloaded and saved: {file_path}")
+ else:
+ print(f"Error downloading the file. Status code: {response.status_code}")
+
+def download_models():
+ models = {
+ "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/StableWurst", "stage_a.safetensors"),
+ "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/StableWurst", "previewer.safetensors"),
+ "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/StableWurst", "effnet_encoder.safetensors"),
+ "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/StableWurst", "stage_b_lite_bf16.safetensors"),
+ "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/StableWurst", "stage_c_bf16.safetensors"),
+ "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/UltraPixel", "ultrapixel_t2i.safetensors"),
+ "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/UltraPixel", "lora_cat.safetensors"),
+ }
+
+ for model, (url, folder, filename) in models.items():
+ download_file(url, folder, filename)
+
+download_models()
+
+# Global variables
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+dtype = torch.bfloat16
+
+# Load configs and setup models
+with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file:
+ config_c = yaml.safe_load(file)
+
+with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file:
+ config_b = yaml.safe_load(file)
+
+core = WurstCoreC(config_dict=config_c, device=device, training=False)
+core_b = WurstCoreB(config_dict=config_b, device=device, training=False)
+
+extras = core.setup_extras_pre()
+models = core.setup_models(extras)
+models.generator.eval().requires_grad_(False)
+
+extras_b = core_b.setup_extras_pre()
+models_b = core_b.setup_models(extras_b, skip_clip=True)
+models_b = WurstCoreB.Models(
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
)
-
-provider = LlamaCppPythonProvider(llm)
-
-@spaces.GPU(duration=120)
-def respond(
- message,
- history: list[tuple[str, str]],
- system_message,
- max_tokens,
- temperature,
- top_p,
- top_k,
- repeat_penalty,
-):
- chat_template = MessagesFormatterType.MISTRAL
-
- agent = LlamaCppAgent(
- provider,
- system_prompt=f"{system_message}",
- predefined_messages_formatter_type=chat_template,
- debug_output=True
- )
+models_b.generator.bfloat16().eval().requires_grad_(False)
+
+# Load pretrained model
+pretrained_path = "models/ultrapixel_t2i.safetensors"
+sdd = torch.load(pretrained_path, map_location='cpu')
+collect_sd = {k[7:]: v for k, v in sdd.items()}
+models.train_norm.load_state_dict(collect_sd)
+models.generator.eval()
+models.train_norm.eval()
+
+# Set up sampling configurations
+extras.sampling_configs.update({
+ 'cfg': 4,
+ 'shift': 1,
+ 'timesteps': 20,
+ 't_start': 1.0,
+ 'sampler': DDPMSampler(extras.gdf)
+})
+
+extras_b.sampling_configs.update({
+ 'cfg': 1.1,
+ 'shift': 1,
+ 'timesteps': 10,
+ 't_start': 1.0
+})
+
+@spaces.GPU
+def generate_image(prompt, height, width, seed):
+ torch.manual_seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+ batch_size = 1
+ height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
+ stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
+
+ batch = {'captions': [prompt] * batch_size}
+ conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
- settings = provider.get_provider_default_settings()
- settings.temperature = temperature
- settings.top_k = top_k
- settings.top_p = top_p
- settings.max_tokens = max_tokens
- settings.repeat_penalty = repeat_penalty
- settings.stream = True
-
- messages = BasicChatHistory()
-
- for msn in history:
- user = {
- 'role': Roles.user,
- 'content': msn[0]
- }
- assistant = {
- 'role': Roles.assistant,
- 'content': msn[1]
- }
- messages.add_message(user)
- messages.add_message(assistant)
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
- stream = agent.get_chat_response(
- message,
- llm_sampling_settings=settings,
- chat_history=messages,
- returns_streaming_generator=True,
- print_output=False
- )
-
- outputs = ""
- for output in stream:
- outputs += output
- yield outputs
-
-description = """
-[Instruct Model]
-[Base Model]
-[GGUF Version]
-
-"""
-
-demo = gr.ChatInterface(
- respond,
- additional_inputs=[
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
- gr.Slider(
- minimum=0.1,
- maximum=1.0,
- value=0.95,
- step=0.05,
- label="Top-p",
- ),
- gr.Slider(
- minimum=0,
- maximum=100,
- value=40,
- step=1,
- label="Top-k",
- ),
- gr.Slider(
- minimum=0.0,
- maximum=2.0,
- value=1.1,
- step=0.1,
- label="Repetition penalty",
- ),
+ with torch.no_grad():
+ models.generator.cuda()
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
+
+ models.generator.cpu()
+ torch.cuda.empty_cache()
+
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
+ conditions_b['effnet'] = sampled_c
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
+
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True)
+
+ torch.cuda.empty_cache()
+ imgs = show_images(sampled)
+ return imgs[0]
+
+iface = gr.Interface(
+ fn=generate_image,
+ inputs=[
+ gr.Textbox(label="Prompt"),
+ gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024),
+ gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024),
+ gr.Number(label="Seed", value=42)
],
- retry_btn="Retry",
- undo_btn="Undo",
- clear_btn="Clear",
- submit_btn="Send",
- title="Chat with Mistral-NeMo using llama.cpp",
- description=description,
- chatbot=gr.Chatbot(
- scale=1,
- likeable=False,
- show_copy_button=True
- )
+ outputs=gr.Image(type="pil"),
+ title="UltraPixel Image Generation",
+ description="Generate high-resolution images using UltraPixel model.",
+ theme='bethecloud/storj_theme'
)
-if __name__ == "__main__":
- demo.launch(debug=True)
\ No newline at end of file
+iface.launch()
\ No newline at end of file
diff --git a/configs/inference/controlnet_c_3b_canny.yaml b/configs/inference/controlnet_c_3b_canny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..286d7a6c8017e922a020d6ae5633cc3e27f9b702
--- /dev/null
+++ b/configs/inference/controlnet_c_3b_canny.yaml
@@ -0,0 +1,14 @@
+# GLOBAL STUFF
+model_version: 3.6B
+dtype: bfloat16
+
+# ControlNet specific
+controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
+controlnet_filter: CannyFilter
+controlnet_filter_params:
+ resize: 224
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
+controlnet_checkpoint_path: models/canny.safetensors
diff --git a/configs/inference/controlnet_c_3b_identity.yaml b/configs/inference/controlnet_c_3b_identity.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a20fa860fed5f6eea1d33113535c2633205e327
--- /dev/null
+++ b/configs/inference/controlnet_c_3b_identity.yaml
@@ -0,0 +1,17 @@
+# GLOBAL STUFF
+model_version: 3.6B
+dtype: bfloat16
+
+# ControlNet specific
+controlnet_bottleneck_mode: 'simple'
+controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
+controlnet_filter: IdentityFilter
+controlnet_filter_params:
+ max_faces: 4
+ p_drop: 0.00
+ p_full: 0.0
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
+controlnet_checkpoint_path:
diff --git a/configs/inference/controlnet_c_3b_inpainting.yaml b/configs/inference/controlnet_c_3b_inpainting.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a94bd7953dfa407184d9094b481a56cdbbb73549
--- /dev/null
+++ b/configs/inference/controlnet_c_3b_inpainting.yaml
@@ -0,0 +1,15 @@
+# GLOBAL STUFF
+model_version: 3.6B
+dtype: bfloat16
+
+# ControlNet specific
+controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
+controlnet_filter: InpaintFilter
+controlnet_filter_params:
+ thresold: [0.04, 0.4]
+ p_outpaint: 0.4
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
+controlnet_checkpoint_path: models/inpainting.safetensors
diff --git a/configs/inference/controlnet_c_3b_sr.yaml b/configs/inference/controlnet_c_3b_sr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..13c4a2cd2dcd2a3cf87fb32bd6e34269e796a747
--- /dev/null
+++ b/configs/inference/controlnet_c_3b_sr.yaml
@@ -0,0 +1,15 @@
+# GLOBAL STUFF
+model_version: 3.6B
+dtype: bfloat16
+
+# ControlNet specific
+controlnet_bottleneck_mode: 'large'
+controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
+controlnet_filter: SREffnetFilter
+controlnet_filter_params:
+ scale_factor: 0.5
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
+controlnet_checkpoint_path: models/super_resolution.safetensors
diff --git a/configs/inference/lora_c_3b.yaml b/configs/inference/lora_c_3b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7468078c657c1f569c6c052a14b265d69082ab25
--- /dev/null
+++ b/configs/inference/lora_c_3b.yaml
@@ -0,0 +1,15 @@
+# GLOBAL STUFF
+model_version: 3.6B
+dtype: bfloat16
+
+# LoRA specific
+module_filters: ['.attn']
+rank: 4
+train_tokens:
+ # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
+ - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
+lora_checkpoint_path: models/lora_fernando_10k.safetensors
diff --git a/configs/inference/stage_b_1b.yaml b/configs/inference/stage_b_1b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0811cae75622614e91de6532262acb2c062bf344
--- /dev/null
+++ b/configs/inference/stage_b_1b.yaml
@@ -0,0 +1,13 @@
+# GLOBAL STUFF
+model_version: 700M
+dtype: bfloat16
+
+# For demonstration purposes in reconstruct_images.ipynb
+webdataset_path: path to your dataset
+batch_size: 1
+image_size: 2048
+grad_accum_steps: 1
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+stage_a_checkpoint_path: models/stage_a.safetensors
+generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
\ No newline at end of file
diff --git a/configs/inference/stage_b_3b.yaml b/configs/inference/stage_b_3b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..840268980103e0c629599b966705043d6a616578
--- /dev/null
+++ b/configs/inference/stage_b_3b.yaml
@@ -0,0 +1,13 @@
+# GLOBAL STUFF
+model_version: 3B
+dtype: bfloat16
+
+# For demonstration purposes in reconstruct_images.ipynb
+webdataset_path: path to your dataset
+batch_size: 4
+image_size: 1024
+grad_accum_steps: 1
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+stage_a_checkpoint_path: models/stage_a.safetensors
+generator_checkpoint_path: models/stage_b_lite_bf16.safetensors
\ No newline at end of file
diff --git a/configs/inference/stage_c_1b.yaml b/configs/inference/stage_c_1b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..781886e515d80e7870abb89bf8fd0ce7c7c8d4b6
--- /dev/null
+++ b/configs/inference/stage_c_1b.yaml
@@ -0,0 +1,7 @@
+# GLOBAL STUFF
+model_version: 1B
+dtype: bfloat16
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
\ No newline at end of file
diff --git a/configs/inference/stage_c_3b.yaml b/configs/inference/stage_c_3b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b22897e71996ad78f3832af78f5bc44ca06d206d
--- /dev/null
+++ b/configs/inference/stage_c_3b.yaml
@@ -0,0 +1,7 @@
+# GLOBAL STUFF
+model_version: 3.6B
+dtype: bfloat16
+
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
\ No newline at end of file
diff --git a/configs/training/cfg_control_lr.yaml b/configs/training/cfg_control_lr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2955b6a925504525b981e7004b65a33573c08aef
--- /dev/null
+++ b/configs/training/cfg_control_lr.yaml
@@ -0,0 +1,47 @@
+# GLOBAL STUFF
+experiment_id: Ultrapixel_controlnet
+
+checkpoint_path: checkpoint output path
+output_path: visual results output path
+model_version: 3.6B
+dtype: float32
+# # WandB
+# wandb_project: StableCascade
+# wandb_entity: wandb_username
+#module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ]
+#rank: 32
+# TRAINING PARAMS
+lr: 1.0e-4
+batch_size: 12
+#image_size: [1536, 2048, 2560, 3072, 4096]
+image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
+#image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
+#image_size: [ 1024, 1280]
+multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
+grad_accum_steps: 2
+updates: 40000
+backup_every: 5000
+save_every: 256
+warmup_updates: 1
+use_fsdp: True
+
+# ControlNet specific
+controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63]
+controlnet_filter: CannyFilter
+controlnet_filter_params:
+ resize: 224
+# offset_noise: 0.1
+
+# GDF
+adaptive_loss_weight: True
+
+ema_start_iters: 10
+ema_iters: 50
+ema_beta: 0.9
+
+webdataset_path: path to your training dataset
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
+controlnet_checkpoint_path: pretrained controlnet path
+
diff --git a/configs/training/lora_personalization.yaml b/configs/training/lora_personalization.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..857795e6d37e9cb61bd76aa588f432978ed90ad2
--- /dev/null
+++ b/configs/training/lora_personalization.yaml
@@ -0,0 +1,37 @@
+# GLOBAL STUFF
+experiment_id: roubao_cat_personalized
+
+checkpoint_path: checkpoint output path
+output_path: visual results output path
+model_version: 3.6B
+dtype: float32
+
+module_filters: [ '.attn']
+rank: 4
+train_tokens:
+ # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
+ - ['[roubaobao]', '^cat'] # custom token [snail], initialize as avg of snail & snails
+# TRAINING PARAMS
+lr: 1.0e-4
+batch_size: 4
+
+image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608]
+multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
+grad_accum_steps: 2
+updates: 40000
+backup_every: 5000
+save_every: 512
+warmup_updates: 1
+use_ddp: True
+
+# GDF
+adaptive_loss_weight: True
+
+
+tmp_prompt: a photo of a cat [roubaobao]
+webdataset_path: path to your personalized training dataset
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
+ultrapixel_path: models/ultrapixel_t2i.safetensors
+
diff --git a/configs/training/t2i.yaml b/configs/training/t2i.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a0ceaca0ad8813e3c9b998661ac3e9b3c0937fd
--- /dev/null
+++ b/configs/training/t2i.yaml
@@ -0,0 +1,29 @@
+# GLOBAL STUFF
+experiment_id: ultrapixel_t2i
+#strc_fixlrt_norm3_lite_1024_hrft_newdata
+checkpoint_path: checkpoint output path #output model directory
+output_path: visual results output path #experiment output directory
+model_version: 3.6B # finetune large stage c model of stablecascade
+dtype: float32
+
+
+# TRAINING PARAMS
+lr: 1.0e-4
+batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps
+image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution
+multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
+grad_accum_steps: 2
+updates: 40000
+backup_every: 5000
+save_every: 256
+warmup_updates: 1
+use_ddp: True
+
+# GDF
+adaptive_loss_weight: True
+
+
+webdataset_path: path to your personalized training dataset
+effnet_checkpoint_path: models/effnet_encoder.safetensors
+previewer_checkpoint_path: models/previewer.safetensors
+generator_checkpoint_path: models/stage_c_bf16.safetensors
\ No newline at end of file
diff --git a/core/__init__.py b/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed382f1907ddc86c7e9a9618c21441755a6221a9
--- /dev/null
+++ b/core/__init__.py
@@ -0,0 +1,372 @@
+import os
+import yaml
+import torch
+from torch import nn
+import wandb
+import json
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from torch.utils.data import Dataset, DataLoader
+
+from torch.distributed import init_process_group, destroy_process_group, barrier
+from torch.distributed.fsdp import (
+ FullyShardedDataParallel as FSDP,
+ FullStateDictConfig,
+ MixedPrecision,
+ ShardingStrategy,
+ StateDictType
+)
+
+from .utils import Base, EXPECTED, EXPECTED_TRAIN
+from .utils import create_folder_if_necessary, safe_save, load_or_fail
+
+# pylint: disable=unused-argument
+class WarpCore(ABC):
+ @dataclass(frozen=True)
+ class Config(Base):
+ experiment_id: str = EXPECTED_TRAIN
+ checkpoint_path: str = EXPECTED_TRAIN
+ output_path: str = EXPECTED_TRAIN
+ checkpoint_extension: str = "safetensors"
+ dist_file_subfolder: str = ""
+ allow_tf32: bool = True
+
+ wandb_project: str = None
+ wandb_entity: str = None
+
+ @dataclass() # not frozen, means that fields are mutable
+ class Info(): # not inheriting from Base, because we don't want to enforce the default fields
+ wandb_run_id: str = None
+ total_steps: int = 0
+ iter: int = 0
+
+ @dataclass(frozen=True)
+ class Data(Base):
+ dataset: Dataset = EXPECTED
+ dataloader: DataLoader = EXPECTED
+ iterator: any = EXPECTED
+
+ @dataclass(frozen=True)
+ class Models(Base):
+ pass
+
+ @dataclass(frozen=True)
+ class Optimizers(Base):
+ pass
+
+ @dataclass(frozen=True)
+ class Schedulers(Base):
+ pass
+
+ @dataclass(frozen=True)
+ class Extras(Base):
+ pass
+ # ---------------------------------------
+ info: Info
+ config: Config
+
+ # FSDP stuff
+ fsdp_defaults = {
+ "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
+ "cpu_offload": None,
+ "mixed_precision": MixedPrecision(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.bfloat16,
+ buffer_dtype=torch.bfloat16,
+ ),
+ "limit_all_gathers": True,
+ }
+ fsdp_fullstate_save_policy = FullStateDictConfig(
+ offload_to_cpu=True, rank0_only=True
+ )
+ # ------------
+
+ # OVERRIDEABLE METHODS
+
+ # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
+ def setup_extras_pre(self) -> Extras:
+ return self.Extras()
+
+ # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
+ @abstractmethod
+ def setup_data(self, extras: Extras) -> Data:
+ raise NotImplementedError("This method needs to be overriden")
+
+ # return a dict with all models that are going to be used in the training
+ @abstractmethod
+ def setup_models(self, extras: Extras) -> Models:
+ raise NotImplementedError("This method needs to be overriden")
+
+ # return a dict with all optimizers that are going to be used in the training
+ @abstractmethod
+ def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
+ raise NotImplementedError("This method needs to be overriden")
+
+ # [optionally] return a dict with all schedulers that are going to be used in the training
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
+ return self.Schedulers()
+
+ # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
+ def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
+ return self.Extras.from_dict(extras.to_dict())
+
+ # perform the training here
+ @abstractmethod
+ def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
+ raise NotImplementedError("This method needs to be overriden")
+ # ------------
+
+ def setup_info(self, full_path=None) -> Info:
+ if full_path is None:
+ full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
+ info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
+ info_dto = self.Info(**info_dict)
+ if info_dto.total_steps > 0 and self.is_main_node:
+ print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
+ return info_dto
+
+ def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
+ if config_file_path is not None:
+ if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
+ with open(config_file_path, "r", encoding="utf-8") as file:
+ loaded_config = yaml.safe_load(file)
+ elif config_file_path.endswith(".json"):
+ with open(config_file_path, "r", encoding="utf-8") as file:
+ loaded_config = json.load(file)
+ else:
+ raise ValueError("Config file must be either a .yml|.yaml or .json file")
+ return self.Config.from_dict({**loaded_config, 'training': training})
+ if config_dict is not None:
+ return self.Config.from_dict({**config_dict, 'training': training})
+ return self.Config(training=training)
+
+ def setup_ddp(self, experiment_id, single_gpu=False):
+ if not single_gpu:
+ local_rank = int(os.environ.get("SLURM_LOCALID"))
+ process_id = int(os.environ.get("SLURM_PROCID"))
+ world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
+
+ self.process_id = process_id
+ self.is_main_node = process_id == 0
+ self.device = torch.device(local_rank)
+ self.world_size = world_size
+
+ dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
+ # if os.path.exists(dist_file_path) and self.is_main_node:
+ # os.remove(dist_file_path)
+
+ torch.cuda.set_device(local_rank)
+ init_process_group(
+ backend="nccl",
+ rank=process_id,
+ world_size=world_size,
+ init_method=f"file://{dist_file_path}",
+ )
+ print(f"[GPU {process_id}] READY")
+ else:
+ print("Running in single thread, DDP not enabled.")
+
+ def setup_wandb(self):
+ if self.is_main_node and self.config.wandb_project is not None:
+ self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
+ wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
+
+ if self.info.total_steps > 0:
+ wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
+ else:
+ wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
+
+ # LOAD UTILITIES ----------
+ def load_model(self, model, model_id=None, full_path=None, strict=True):
+ print('in line 181 load model', type(model), model_id, full_path, strict)
+ if model_id is not None and full_path is None:
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
+ elif full_path is None and model_id is None:
+ raise ValueError(
+ "This method expects either 'model_id' or 'full_path' to be defined"
+ )
+
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
+ if checkpoint is not None:
+ model.load_state_dict(checkpoint, strict=strict)
+ del checkpoint
+
+ return model
+
+ def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
+ if optim_id is not None and full_path is None:
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
+ elif full_path is None and optim_id is None:
+ raise ValueError(
+ "This method expects either 'optim_id' or 'full_path' to be defined"
+ )
+
+ checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
+ if checkpoint is not None:
+ try:
+ if fsdp_model is not None:
+ sharded_optimizer_state_dict = (
+ FSDP.scatter_full_optim_state_dict( # <---- FSDP
+ checkpoint
+ if (
+ self.is_main_node
+ or self.fsdp_defaults["sharding_strategy"]
+ == ShardingStrategy.NO_SHARD
+ )
+ else None,
+ fsdp_model,
+ )
+ )
+ optim.load_state_dict(sharded_optimizer_state_dict)
+ del checkpoint, sharded_optimizer_state_dict
+ else:
+ optim.load_state_dict(checkpoint)
+ # pylint: disable=broad-except
+ except Exception as e:
+ print("!!! Failed loading optimizer, skipping... Exception:", e)
+
+ return optim
+
+ # SAVE UTILITIES ----------
+ def save_info(self, info, suffix=""):
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
+ create_folder_if_necessary(full_path)
+ if self.is_main_node:
+ safe_save(vars(self.info), full_path)
+
+ def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
+ if model_id is not None and full_path is None:
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
+ elif full_path is None and model_id is None:
+ raise ValueError(
+ "This method expects either 'model_id' or 'full_path' to be defined"
+ )
+ create_folder_if_necessary(full_path)
+ if is_fsdp:
+ with FSDP.summon_full_params(model):
+ pass
+ with FSDP.state_dict_type(
+ model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
+ ):
+ checkpoint = model.state_dict()
+ if self.is_main_node:
+ safe_save(checkpoint, full_path)
+ del checkpoint
+ else:
+ if self.is_main_node:
+ checkpoint = model.state_dict()
+ safe_save(checkpoint, full_path)
+ del checkpoint
+
+ def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
+ if optim_id is not None and full_path is None:
+ full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
+ elif full_path is None and optim_id is None:
+ raise ValueError(
+ "This method expects either 'optim_id' or 'full_path' to be defined"
+ )
+ create_folder_if_necessary(full_path)
+ if fsdp_model is not None:
+ optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
+ if self.is_main_node:
+ safe_save(optim_statedict, full_path)
+ del optim_statedict
+ else:
+ if self.is_main_node:
+ checkpoint = optim.state_dict()
+ safe_save(checkpoint, full_path)
+ del checkpoint
+ # -----
+
+ def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
+ # Temporary setup, will be overriden by setup_ddp if required
+ self.device = device
+ self.process_id = 0
+ self.is_main_node = True
+ self.world_size = 1
+ # ----
+
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
+ self.info: self.Info = self.setup_info()
+
+ def __call__(self, single_gpu=False):
+ self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
+ self.setup_wandb()
+ if self.config.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ if self.is_main_node:
+ print()
+ print("**STARTIG JOB WITH CONFIG:**")
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
+ print("------------------------------------")
+ print()
+ print("**INFO:**")
+ print(yaml.dump(vars(self.info), default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ # SETUP STUFF
+ extras = self.setup_extras_pre()
+ assert extras is not None, "setup_extras_pre() must return a DTO"
+
+ data = self.setup_data(extras)
+ assert data is not None, "setup_data() must return a DTO"
+ if self.is_main_node:
+ print("**DATA:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ models = self.setup_models(extras)
+ assert models is not None, "setup_models() must return a DTO"
+ if self.is_main_node:
+ print("**MODELS:**")
+ print(yaml.dump({
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
+ }, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ optimizers = self.setup_optimizers(extras, models)
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
+ if self.is_main_node:
+ print("**OPTIMIZERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ schedulers = self.setup_schedulers(extras, models, optimizers)
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
+ if self.is_main_node:
+ print("**SCHEDULERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
+ if self.is_main_node:
+ print("**EXTRAS:**")
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+ # -------
+
+ # TRAIN
+ if self.is_main_node:
+ print("**TRAINING STARTING...**")
+ self.train(data, extras, models, optimizers, schedulers)
+
+ if single_gpu is False:
+ barrier()
+ destroy_process_group()
+ if self.is_main_node:
+ print()
+ print("------------------------------------")
+ print()
+ print("**TRAINING COMPLETE**")
+ if self.config.wandb_project is not None:
+ wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
diff --git a/core/data/__init__.py b/core/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b687719914b2e303909f7c280347e4bdee607d13
--- /dev/null
+++ b/core/data/__init__.py
@@ -0,0 +1,69 @@
+import json
+import subprocess
+import yaml
+import os
+from .bucketeer import Bucketeer
+
+class MultiFilter():
+ def __init__(self, rules, default=False):
+ self.rules = rules
+ self.default = default
+
+ def __call__(self, x):
+ try:
+ x_json = x['json']
+ if isinstance(x_json, bytes):
+ x_json = json.loads(x_json)
+ validations = []
+ for k, r in self.rules.items():
+ if isinstance(k, tuple):
+ v = r(*[x_json[kv] for kv in k])
+ else:
+ v = r(x_json[k])
+ validations.append(v)
+ return all(validations)
+ except Exception:
+ return False
+
+class MultiGetter():
+ def __init__(self, rules):
+ self.rules = rules
+
+ def __call__(self, x_json):
+ if isinstance(x_json, bytes):
+ x_json = json.loads(x_json)
+ outputs = []
+ for k, r in self.rules.items():
+ if isinstance(k, tuple):
+ v = r(*[x_json[kv] for kv in k])
+ else:
+ v = r(x_json[k])
+ outputs.append(v)
+ if len(outputs) == 1:
+ outputs = outputs[0]
+ return outputs
+
+def setup_webdataset_path(paths, cache_path=None):
+ if cache_path is None or not os.path.exists(cache_path):
+ tar_paths = []
+ if isinstance(paths, str):
+ paths = [paths]
+ for path in paths:
+ if path.strip().endswith(".tar"):
+ # Avoid looking up s3 if we already have a tar file
+ tar_paths.append(path)
+ continue
+ bucket = "/".join(path.split("/")[:3])
+ result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True)
+ files = result.stdout.decode('utf-8').split()
+ files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")]
+ tar_paths += files
+
+ with open(cache_path, 'w', encoding='utf-8') as outfile:
+ yaml.dump(tar_paths, outfile, default_flow_style=False)
+ else:
+ with open(cache_path, 'r', encoding='utf-8') as file:
+ tar_paths = yaml.safe_load(file)
+
+ tar_paths_str = ",".join([f"{p}" for p in tar_paths])
+ return f"pipe:aws s3 cp {{ {tar_paths_str} }} -"
diff --git a/core/data/bucketeer.py b/core/data/bucketeer.py
new file mode 100644
index 0000000000000000000000000000000000000000..131e6ba4293bd7c00399f08609aba184b712d5e8
--- /dev/null
+++ b/core/data/bucketeer.py
@@ -0,0 +1,88 @@
+import torch
+import torchvision
+import numpy as np
+from torchtools.transforms import SmartCrop
+import math
+
+class Bucketeer():
+ def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
+ assert crop_mode in ['center', 'random', 'smart']
+ self.crop_mode = crop_mode
+ self.ratios = ratios
+ if reverse_list:
+ for r in list(ratios):
+ if 1/r not in self.ratios:
+ self.ratios.append(1/r)
+ self.sizes = {}
+ for dd in density:
+ self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
+
+ self.batch_size = dataloader.batch_size
+ self.iterator = iter(dataloader)
+ all_sizes = []
+ for k, vs in self.sizes.items():
+ all_sizes += vs
+ self.buckets = {s: [] for s in all_sizes}
+ self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
+ self.p_random_ratio = p_random_ratio
+ self.interpolate_nearest = interpolate_nearest
+
+ def get_available_batch(self):
+ for b in self.buckets:
+ if len(self.buckets[b]) >= self.batch_size:
+ batch = self.buckets[b][:self.batch_size]
+ self.buckets[b] = self.buckets[b][self.batch_size:]
+ return batch
+ return None
+
+ def get_closest_size(self, x):
+ w, h = x.size(-1), x.size(-2)
+
+
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
+ find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
+ min_ = find_dict[list(find_dict.keys())[0]]
+ find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
+ for dd, val in find_dict.items():
+ if val < min_:
+ min_ = val
+ find_size = self.sizes[dd][best_size_idx]
+
+ return find_size
+
+ def get_resize_size(self, orig_size, tgt_size):
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
+ resize_size = max(alt_min, min(tgt_size))
+ else:
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
+ resize_size = max(alt_max, max(tgt_size))
+
+ return resize_size
+
+ def __next__(self):
+ batch = self.get_available_batch()
+ while batch is None:
+ elements = next(self.iterator)
+ for dct in elements:
+ img = dct['images']
+ size = self.get_closest_size(img)
+ resize_size = self.get_resize_size(img.shape[-2:], size)
+
+ if self.interpolate_nearest:
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
+ else:
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
+ if self.crop_mode == 'center':
+ img = torchvision.transforms.functional.center_crop(img, size)
+ elif self.crop_mode == 'random':
+ img = torchvision.transforms.RandomCrop(size)(img)
+ elif self.crop_mode == 'smart':
+ self.smartcrop.output_size = size
+ img = self.smartcrop(img)
+
+ self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
+ batch = self.get_available_batch()
+
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
diff --git a/core/data/bucketeer_deg.py b/core/data/bucketeer_deg.py
new file mode 100644
index 0000000000000000000000000000000000000000..7206ccf08932f617abb811221cc7bbe1d126f184
--- /dev/null
+++ b/core/data/bucketeer_deg.py
@@ -0,0 +1,91 @@
+import torch
+import torchvision
+import numpy as np
+from torchtools.transforms import SmartCrop
+import math
+
+class Bucketeer():
+ def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
+ assert crop_mode in ['center', 'random', 'smart']
+ self.crop_mode = crop_mode
+ self.ratios = ratios
+ if reverse_list:
+ for r in list(ratios):
+ if 1/r not in self.ratios:
+ self.ratios.append(1/r)
+ self.sizes = {}
+ for dd in density:
+ self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios]
+ print('in line 17 buckteer', self.sizes)
+ self.batch_size = dataloader.batch_size
+ self.iterator = iter(dataloader)
+ all_sizes = []
+ for k, vs in self.sizes.items():
+ all_sizes += vs
+ self.buckets = {s: [] for s in all_sizes}
+ self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
+ self.p_random_ratio = p_random_ratio
+ self.interpolate_nearest = interpolate_nearest
+
+ def get_available_batch(self):
+ for b in self.buckets:
+ if len(self.buckets[b]) >= self.batch_size:
+ batch = self.buckets[b][:self.batch_size]
+ self.buckets[b] = self.buckets[b][self.batch_size:]
+ return batch
+ return None
+
+ def get_closest_size(self, x):
+ w, h = x.size(-1), x.size(-2)
+ #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
+ # best_size_idx = np.random.randint(len(self.ratios))
+ #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio)
+ #else:
+
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
+ find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()}
+ min_ = find_dict[list(find_dict.keys())[0]]
+ find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx]
+ for dd, val in find_dict.items():
+ if val < min_:
+ min_ = val
+ find_size = self.sizes[dd][best_size_idx]
+
+ return find_size
+
+ def get_resize_size(self, orig_size, tgt_size):
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
+ resize_size = max(alt_min, min(tgt_size))
+ else:
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
+ resize_size = max(alt_max, max(tgt_size))
+ #print('in line 50', orig_size, tgt_size, resize_size)
+ return resize_size
+
+ def __next__(self):
+ batch = self.get_available_batch()
+ while batch is None:
+ elements = next(self.iterator)
+ for dct in elements:
+ img = dct['images']
+ size = self.get_closest_size(img)
+ resize_size = self.get_resize_size(img.shape[-2:], size)
+ #print('in line 74', img.size(), resize_size)
+ if self.interpolate_nearest:
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
+ else:
+ img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
+ if self.crop_mode == 'center':
+ img = torchvision.transforms.functional.center_crop(img, size)
+ elif self.crop_mode == 'random':
+ img = torchvision.transforms.RandomCrop(size)(img)
+ elif self.crop_mode == 'smart':
+ self.smartcrop.output_size = size
+ img = self.smartcrop(img)
+ print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img))
+ self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
+ batch = self.get_available_batch()
+
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
diff --git a/core/data/deg_kair_utils/utils_alignfaces.py b/core/data/deg_kair_utils/utils_alignfaces.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa74e8a2e8984f5075d0cbd06afd494c9661a015
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_alignfaces.py
@@ -0,0 +1,263 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Apr 24 15:43:29 2017
+@author: zhaoy
+"""
+import cv2
+import numpy as np
+from skimage import transform as trans
+
+# reference facial points, a list of coordinates (x,y)
+REFERENCE_FACIAL_POINTS = [
+ [30.29459953, 51.69630051],
+ [65.53179932, 51.50139999],
+ [48.02519989, 71.73660278],
+ [33.54930115, 92.3655014],
+ [62.72990036, 92.20410156]
+]
+
+DEFAULT_CROP_SIZE = (96, 112)
+
+
+def _umeyama(src, dst, estimate_scale=True, scale=1.0):
+ """Estimate N-D similarity transformation with or without scaling.
+ Parameters
+ ----------
+ src : (M, N) array
+ Source coordinates.
+ dst : (M, N) array
+ Destination coordinates.
+ estimate_scale : bool
+ Whether to estimate scaling factor.
+ Returns
+ -------
+ T : (N + 1, N + 1)
+ The homogeneous similarity transformation matrix. The matrix contains
+ NaN values only if the problem is not well-conditioned.
+ References
+ ----------
+ .. [1] "Least-squares estimation of transformation parameters between two
+ point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
+ """
+
+ num = src.shape[0]
+ dim = src.shape[1]
+
+ # Compute mean of src and dst.
+ src_mean = src.mean(axis=0)
+ dst_mean = dst.mean(axis=0)
+
+ # Subtract mean from src and dst.
+ src_demean = src - src_mean
+ dst_demean = dst - dst_mean
+
+ # Eq. (38).
+ A = dst_demean.T @ src_demean / num
+
+ # Eq. (39).
+ d = np.ones((dim,), dtype=np.double)
+ if np.linalg.det(A) < 0:
+ d[dim - 1] = -1
+
+ T = np.eye(dim + 1, dtype=np.double)
+
+ U, S, V = np.linalg.svd(A)
+
+ # Eq. (40) and (43).
+ rank = np.linalg.matrix_rank(A)
+ if rank == 0:
+ return np.nan * T
+ elif rank == dim - 1:
+ if np.linalg.det(U) * np.linalg.det(V) > 0:
+ T[:dim, :dim] = U @ V
+ else:
+ s = d[dim - 1]
+ d[dim - 1] = -1
+ T[:dim, :dim] = U @ np.diag(d) @ V
+ d[dim - 1] = s
+ else:
+ T[:dim, :dim] = U @ np.diag(d) @ V
+
+ if estimate_scale:
+ # Eq. (41) and (42).
+ scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
+ else:
+ scale = scale
+
+ T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
+ T[:dim, :dim] *= scale
+
+ return T, scale
+
+
+class FaceWarpException(Exception):
+ def __str__(self):
+ return 'In File {}:{}'.format(
+ __file__, super.__str__(self))
+
+
+def get_reference_facial_points(output_size=None,
+ inner_padding_factor=0.0,
+ outer_padding=(0, 0),
+ default_square=False):
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
+
+ # 0) make the inner region a square
+ if default_square:
+ size_diff = max(tmp_crop_size) - tmp_crop_size
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += size_diff
+
+ if (output_size and
+ output_size[0] == tmp_crop_size[0] and
+ output_size[1] == tmp_crop_size[1]):
+ print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
+ return tmp_5pts
+
+ if (inner_padding_factor == 0 and
+ outer_padding == (0, 0)):
+ if output_size is None:
+ print('No paddings to do: return default reference points')
+ return tmp_5pts
+ else:
+ raise FaceWarpException(
+ 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
+
+ # check output size
+ if not (0 <= inner_padding_factor <= 1.0):
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
+
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
+ and output_size is None):
+ output_size = tmp_crop_size * \
+ (1 + inner_padding_factor * 2).astype(np.int32)
+ output_size += np.array(outer_padding)
+ print(' deduced from paddings, output_size = ', output_size)
+
+ if not (outer_padding[0] < output_size[0]
+ and outer_padding[1] < output_size[1]):
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
+ 'and outer_padding[1] < output_size[1])')
+
+ # 1) pad the inner region according inner_padding_factor
+ # print('---> STEP1: pad the inner region according inner_padding_factor')
+ if inner_padding_factor > 0:
+ size_diff = tmp_crop_size * inner_padding_factor * 2
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
+
+ # print(' crop_size = ', tmp_crop_size)
+ # print(' reference_5pts = ', tmp_5pts)
+
+ # 2) resize the padded inner region
+ # print('---> STEP2: resize the padded inner region')
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
+ # print(' crop_size = ', tmp_crop_size)
+ # print(' size_bf_outer_pad = ', size_bf_outer_pad)
+
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
+ raise FaceWarpException('Must have (output_size - outer_padding)'
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
+
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
+ # print(' resize scale_factor = ', scale_factor)
+ tmp_5pts = tmp_5pts * scale_factor
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
+ # tmp_5pts = tmp_5pts + size_diff / 2
+ tmp_crop_size = size_bf_outer_pad
+ # print(' crop_size = ', tmp_crop_size)
+ # print(' reference_5pts = ', tmp_5pts)
+
+ # 3) add outer_padding to make output_size
+ reference_5point = tmp_5pts + np.array(outer_padding)
+ tmp_crop_size = output_size
+ # print('---> STEP3: add outer_padding to make output_size')
+ # print(' crop_size = ', tmp_crop_size)
+ # print(' reference_5pts = ', tmp_5pts)
+ #
+ # print('===> end get_reference_facial_points\n')
+
+ return reference_5point
+
+
+def get_affine_transform_matrix(src_pts, dst_pts):
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
+ n_pts = src_pts.shape[0]
+ ones = np.ones((n_pts, 1), src_pts.dtype)
+ src_pts_ = np.hstack([src_pts, ones])
+ dst_pts_ = np.hstack([dst_pts, ones])
+
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
+
+ if rank == 3:
+ tfm = np.float32([
+ [A[0, 0], A[1, 0], A[2, 0]],
+ [A[0, 1], A[1, 1], A[2, 1]]
+ ])
+ elif rank == 2:
+ tfm = np.float32([
+ [A[0, 0], A[1, 0], 0],
+ [A[0, 1], A[1, 1], 0]
+ ])
+
+ return tfm
+
+
+def warp_and_crop_face(src_img,
+ facial_pts,
+ reference_pts=None,
+ crop_size=(96, 112),
+ align_type='smilarity'): #smilarity cv2_affine affine
+ if reference_pts is None:
+ if crop_size[0] == 96 and crop_size[1] == 112:
+ reference_pts = REFERENCE_FACIAL_POINTS
+ else:
+ default_square = False
+ inner_padding_factor = 0
+ outer_padding = (0, 0)
+ output_size = crop_size
+
+ reference_pts = get_reference_facial_points(output_size,
+ inner_padding_factor,
+ outer_padding,
+ default_square)
+
+ ref_pts = np.float32(reference_pts)
+ ref_pts_shp = ref_pts.shape
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
+ raise FaceWarpException(
+ 'reference_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if ref_pts_shp[0] == 2:
+ ref_pts = ref_pts.T
+
+ src_pts = np.float32(facial_pts)
+ src_pts_shp = src_pts.shape
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
+ raise FaceWarpException(
+ 'facial_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if src_pts_shp[0] == 2:
+ src_pts = src_pts.T
+
+ if src_pts.shape != ref_pts.shape:
+ raise FaceWarpException(
+ 'facial_pts and reference_pts must have the same shape')
+
+ if align_type is 'cv2_affine':
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
+ tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
+ elif align_type is 'affine':
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
+ tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
+ else:
+ params, scale = _umeyama(src_pts, ref_pts)
+ tfm = params[:2, :]
+
+ params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
+ tfm_inv = params[:2, :]
+
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)
+
+ return face_img, tfm_inv
diff --git a/core/data/deg_kair_utils/utils_blindsr.py b/core/data/deg_kair_utils/utils_blindsr.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a1a7baf99473043e216c16f464f4e168cbd94ab
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_blindsr.py
@@ -0,0 +1,631 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from core.data.deg_kair_utils import utils_image as util
+
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+
+
+
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf-1)*0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w-1)
+ y1 = np.clip(y1, 0, h-1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1,c,1,1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z-MU
+ ZZ_t = ZZ.transpose(0,1,3,2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
+ arg = -(x*x + y*y)/(2*std*std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h/sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha,1])])
+ h1 = alpha/(alpha+1)
+ h2 = (1-alpha)/(alpha+1)
+ h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1/sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+
+ Return:
+ downsampled LR image
+
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+
+ ''' bicubic downsampling + blur
+
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+
+ Return:
+ downsampled LR image
+
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2*sf
+ if random.random() < 0.5:
+ l1 = wd2*random.random()
+ l2 = wd2*random.random()
+ k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5/sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2/255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3,3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2/255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3,3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10**(2*random.random()+2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h-lq_patchsize)
+ rnd_w = random.randint(0, w-lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize*sf or w < lq_patchsize*sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3]))
+ else:
+ img = util.imresize_np(img, 1/2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1,2*sf)
+ img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+
+
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize*sf or w < lq_patchsize*sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+
+if __name__ == '__main__':
+ img = util.imread_uint('utils/test.png', 3)
+ img = util.uint2single(img)
+ sf = 4
+
+ for i in range(20):
+ img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72)
+ print(i)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
+ img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i)+'.png')
+
+# for i in range(10):
+# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64)
+# print(i)
+# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0)
+# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1)
+# util.imsave(img_concat, str(i)+'.png')
+
+# run utils/utils_blindsr.py
diff --git a/core/data/deg_kair_utils/utils_bnorm.py b/core/data/deg_kair_utils/utils_bnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bd346e05b66efd074f81f1961068e2de45ac5da
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_bnorm.py
@@ -0,0 +1,91 @@
+import torch
+import torch.nn as nn
+
+
+"""
+# --------------------------------------------
+# Batch Normalization
+# --------------------------------------------
+
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# 01/Jan/2019
+# --------------------------------------------
+"""
+
+
+# --------------------------------------------
+# remove/delete specified layer
+# --------------------------------------------
+def deleteLayer(model, layer_type=nn.BatchNorm2d):
+ ''' Kai Zhang, 11/Jan/2019.
+ '''
+ for k, m in list(model.named_children()):
+ if isinstance(m, layer_type):
+ del model._modules[k]
+ deleteLayer(m, layer_type)
+
+
+# --------------------------------------------
+# merge bn, "conv+bn" --> "conv"
+# --------------------------------------------
+def merge_bn(model):
+ ''' Kai Zhang, 11/Jan/2019.
+ merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
+ based on https://github.com/pytorch/pytorch/pull/901
+ '''
+ prev_m = None
+ for k, m in list(model.named_children()):
+ if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
+
+ w = prev_m.weight.data
+
+ if prev_m.bias is None:
+ zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
+ prev_m.bias = nn.Parameter(zeros)
+ b = prev_m.bias.data
+
+ invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
+ if isinstance(prev_m, nn.ConvTranspose2d):
+ w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
+ else:
+ w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
+ b.add_(-m.running_mean).mul_(invstd)
+ if m.affine:
+ if isinstance(prev_m, nn.ConvTranspose2d):
+ w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
+ else:
+ w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
+ b.mul_(m.weight.data).add_(m.bias.data)
+
+ del model._modules[k]
+ prev_m = m
+ merge_bn(m)
+
+
+# --------------------------------------------
+# add bn, "conv" --> "conv+bn"
+# --------------------------------------------
+def add_bn(model):
+ ''' Kai Zhang, 11/Jan/2019.
+ '''
+ for k, m in list(model.named_children()):
+ if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
+ b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
+ b.weight.data.fill_(1)
+ new_m = nn.Sequential(model._modules[k], b)
+ model._modules[k] = new_m
+ add_bn(m)
+
+
+# --------------------------------------------
+# tidy model after removing bn
+# --------------------------------------------
+def tidy_sequential(model):
+ ''' Kai Zhang, 11/Jan/2019.
+ '''
+ for k, m in list(model.named_children()):
+ if isinstance(m, nn.Sequential):
+ if m.__len__() == 1:
+ model._modules[k] = m.__getitem__(0)
+ tidy_sequential(m)
diff --git a/core/data/deg_kair_utils/utils_deblur.py b/core/data/deg_kair_utils/utils_deblur.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ab5852d0cb334627abcd9476409d632740be389
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_deblur.py
@@ -0,0 +1,655 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import scipy
+from scipy import fftpack
+import torch
+
+from math import cos, sin
+from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round
+from numpy.random import randn, rand
+from scipy.signal import convolve2d
+import cv2
+import random
+# import utils_image as util
+
+'''
+modified by Kai Zhang (github: https://github.com/cszn)
+03/03/2019
+'''
+
+
+def get_uperleft_denominator(img, kernel):
+ '''
+ img: HxWxC
+ kernel: hxw
+ denominator: HxWx1
+ upperleft: HxWxC
+ '''
+ V = psf2otf(kernel, img.shape[:2])
+ denominator = np.expand_dims(np.abs(V)**2, axis=2)
+ upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
+ return upperleft, denominator
+
+
+def get_uperleft_denominator_pytorch(img, kernel):
+ '''
+ img: NxCxHxW
+ kernel: Nx1xhxw
+ denominator: Nx1xHxW
+ upperleft: NxCxHxWx2
+ '''
+ V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2
+ denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW
+ upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2
+ return upperleft, denominator
+
+
+def c2c(x):
+ return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
+
+
+def r2c(x):
+ return torch.stack([x, torch.zeros_like(x)], -1)
+
+
+def cdiv(x, y):
+ a, b = x[..., 0], x[..., 1]
+ c, d = y[..., 0], y[..., 1]
+ cd2 = c**2 + d**2
+ return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
+
+
+def cabs(x):
+ return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
+
+
+def cmul(t1, t2):
+ '''
+ complex multiplication
+ t1: NxCxHxWx2
+ output: NxCxHxWx2
+ '''
+ real1, imag1 = t1[..., 0], t1[..., 1]
+ real2, imag2 = t2[..., 0], t2[..., 1]
+ return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
+
+
+def cconj(t, inplace=False):
+ '''
+ # complex's conjugation
+ t: NxCxHxWx2
+ output: NxCxHxWx2
+ '''
+ c = t.clone() if not inplace else t
+ c[..., 1] *= -1
+ return c
+
+
+def rfft(t):
+ return torch.rfft(t, 2, onesided=False)
+
+
+def irfft(t):
+ return torch.irfft(t, 2, onesided=False)
+
+
+def fft(t):
+ return torch.fft(t, 2)
+
+
+def ifft(t):
+ return torch.ifft(t, 2)
+
+
+def p2o(psf, shape):
+ '''
+ # psf: NxCxhxw
+ # shape: [H,W]
+ # otf: NxCxHxWx2
+ '''
+ otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
+ otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
+ for axis, axis_size in enumerate(psf.shape[2:]):
+ otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
+ otf = torch.rfft(otf, 2, onesided=False)
+ n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
+ otf[...,1][torch.abs(otf[...,1])= abs(y)] = abs(x)[abs(x) >= abs(y)]
+ maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)]
+ minxy = np.zeros(x.shape)
+ minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)]
+ minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)]
+ m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\
+ (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\
+ np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2)
+ m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\
+ (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\
+ np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2)
+ h = None
+ return h
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1))
+ arg = -(x*x + y*y)/(2*std*std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h/sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha,1])])
+ h1 = alpha/(alpha+1)
+ h2 = (1-alpha)/(alpha+1)
+ h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial_log(hsize, sigma):
+ raise(NotImplemented)
+
+
+def fspecial_motion(motion_len, theta):
+ raise(NotImplemented)
+
+
+def fspecial_prewitt():
+ return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]])
+
+
+def fspecial_sobel():
+ return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'average':
+ return fspecial_average(*args, **kwargs)
+ if filter_type == 'disk':
+ return fspecial_disk(*args, **kwargs)
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+ if filter_type == 'log':
+ return fspecial_log(*args, **kwargs)
+ if filter_type == 'motion':
+ return fspecial_motion(*args, **kwargs)
+ if filter_type == 'prewitt':
+ return fspecial_prewitt(*args, **kwargs)
+ if filter_type == 'sobel':
+ return fspecial_sobel(*args, **kwargs)
+
+
+def fspecial_gauss(size, sigma):
+ x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1]
+ g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
+ return g / g.sum()
+
+
+def blurkernel_synthesis(h=37, w=None):
+ # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py
+ w = h if w is None else w
+ kdims = [h, w]
+ x = randomTrajectory(250)
+ k = None
+ while k is None:
+ k = kernelFromTrajectory(x)
+
+ # center pad to kdims
+ pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2)
+ pad_width = [(pad_width[0],), (pad_width[1],)]
+
+ if pad_width[0][0]<0 or pad_width[1][0]<0:
+ k = k[0:h, 0:h]
+ else:
+ k = pad(k, pad_width, "constant")
+ x1,x2 = k.shape
+ if np.random.randint(0, 4) == 1:
+ k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR)
+ y1, y2 = k.shape
+ k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2]
+
+ if sum(k)<0.1:
+ k = fspecial_gaussian(h, 0.1+6*np.random.rand(1))
+ k = k / sum(k)
+ # import matplotlib.pyplot as plt
+ # plt.imshow(k, interpolation="nearest", cmap="gray")
+ # plt.show()
+ return k
+
+
+def kernelFromTrajectory(x):
+ h = 5 - log(rand()) / 0.15
+ h = round(min([h, 27])).astype(int)
+ h = h + 1 - h % 2
+ w = h
+ k = zeros((h, w))
+
+ xmin = min(x[0])
+ xmax = max(x[0])
+ ymin = min(x[1])
+ ymax = max(x[1])
+ xthr = arange(xmin, xmax, (xmax - xmin) / w)
+ ythr = arange(ymin, ymax, (ymax - ymin) / h)
+
+ for i in range(1, xthr.size):
+ for j in range(1, ythr.size):
+ idx = (
+ (x[0, :] >= xthr[i - 1])
+ & (x[0, :] < xthr[i])
+ & (x[1, :] >= ythr[j - 1])
+ & (x[1, :] < ythr[j])
+ )
+ k[i - 1, j - 1] = sum(idx)
+ if sum(k) == 0:
+ return
+ k = k / sum(k)
+ k = convolve2d(k, fspecial_gauss(3, 1), "same")
+ k = k / sum(k)
+ return k
+
+
+def randomTrajectory(T):
+ x = zeros((3, T))
+ v = randn(3, T)
+ r = zeros((3, T))
+ trv = 1 / 1
+ trr = 2 * pi / T
+ for t in range(1, T):
+ F_rot = randn(3) / (t + 1) + r[:, t - 1]
+ F_trans = randn(3) / (t + 1)
+ r[:, t] = r[:, t - 1] + trr * F_rot
+ v[:, t] = v[:, t - 1] + trv * F_trans
+ st = v[:, t]
+ st = rot3D(st, r[:, t])
+ x[:, t] = x[:, t - 1] + st
+ return x
+
+
+def rot3D(x, r):
+ Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]])
+ Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]])
+ Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]])
+ R = Rz @ Ry @ Rx
+ x = R @ x
+ return x
+
+
+if __name__ == '__main__':
+ a = opt_fft_size([111])
+ print(a)
+
+ print(fspecial('gaussian', 5, 1))
+
+ print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape)
+
+ k = blurkernel_synthesis(11)
+ import matplotlib.pyplot as plt
+ plt.imshow(k, interpolation="nearest", cmap="gray")
+ plt.show()
diff --git a/core/data/deg_kair_utils/utils_dist.py b/core/data/deg_kair_utils/utils_dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..88811737a8fc7cb6e12d9226a9242dbf8391d86b
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_dist.py
@@ -0,0 +1,201 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+# ----------------------------------
+# init
+# ----------------------------------
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+
+# ----------------------------------
+# get rank and world_size
+# ----------------------------------
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+
+ if not dist.is_initialized():
+ return 0
+
+ return dist.get_rank()
+
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+
+ if not dist.is_initialized():
+ return 1
+
+ return dist.get_world_size()
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+
+
+
+
+# ----------------------------------
+# operation across ranks
+# ----------------------------------
+def reduce_sum(tensor):
+ if not dist.is_available():
+ return tensor
+
+ if not dist.is_initialized():
+ return tensor
+
+ tensor = tensor.clone()
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+
+ return tensor
+
+
+def gather_grad(params):
+ world_size = get_world_size()
+
+ if world_size == 1:
+ return
+
+ for param in params:
+ if param.grad is not None:
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
+ param.grad.data.div_(world_size)
+
+
+def all_gather(data):
+ world_size = get_world_size()
+
+ if world_size == 1:
+ return [data]
+
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to('cuda')
+
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
+
+ if local_size != max_size:
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
+ tensor = torch.cat((tensor, padding), 0)
+
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_loss_dict(loss_dict):
+ world_size = get_world_size()
+
+ if world_size < 2:
+ return loss_dict
+
+ with torch.no_grad():
+ keys = []
+ losses = []
+
+ for k in sorted(loss_dict.keys()):
+ keys.append(k)
+ losses.append(loss_dict[k])
+
+ losses = torch.stack(losses, 0)
+ dist.reduce(losses, dst=0)
+
+ if dist.get_rank() == 0:
+ losses /= world_size
+
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
+
+ return reduced_losses
+
diff --git a/core/data/deg_kair_utils/utils_googledownload.py b/core/data/deg_kair_utils/utils_googledownload.py
new file mode 100644
index 0000000000000000000000000000000000000000..25533d4e0d90bac7519874a654ffd833d16ae289
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_googledownload.py
@@ -0,0 +1,93 @@
+import math
+import requests
+from tqdm import tqdm
+
+
+'''
+borrowed from
+https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py
+'''
+
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(
+ URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(
+ response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response,
+ destination,
+ file_size=None,
+ chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
+ f'/ {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+if __name__ == "__main__":
+ file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv'
+ save_path = 'BSRGAN.pth'
+ download_file_from_google_drive(file_id, save_path)
diff --git a/core/data/deg_kair_utils/utils_image.py b/core/data/deg_kair_utils/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e513a8bc1594c9ce2ba47ce3fe3b497269b7f16
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_image.py
@@ -0,0 +1,1016 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+# import torchvision.transforms as transforms
+import matplotlib.pyplot as plt
+from mpl_toolkits.mplot3d import Axes3D
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if isinstance(dataroot, str):
+ paths = sorted(_get_paths_from_images(dataroot))
+ elif isinstance(dataroot, list):
+ paths = []
+ for i in dataroot:
+ paths += sorted(_get_paths_from_images(i))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+ # print(w1)
+ # print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(uint)
+# numpy(single) <---> tensor
+# numpy(uint) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(uint)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(uint) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ rlt = np.clip(rlt, 0, 255)
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR, SSIM and PSNRB
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+def _blocking_effect_factor(im):
+ block_size = 8
+
+ block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
+ block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
+
+ horizontal_block_difference = (
+ (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
+ 3).sum(2).sum(1)
+ vertical_block_difference = (
+ (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
+ 2).sum(1)
+
+ nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
+ nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
+
+ horizontal_nonblock_difference = (
+ (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
+ 3).sum(2).sum(1)
+ vertical_nonblock_difference = (
+ (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
+ 3).sum(2).sum(1)
+
+ n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
+ n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
+ boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
+ n_boundary_horiz + n_boundary_vert)
+
+ n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
+ n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
+ nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
+ n_nonboundary_horiz + n_nonboundary_vert)
+
+ scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
+ bef = scaler * (boundary_difference - nonboundary_difference)
+
+ bef[boundary_difference <= nonboundary_difference] = 0
+ return bef
+
+
+def calculate_psnrb(img1, img2, border=0):
+ """Calculate PSNR-B (Peak Signal-to-Noise Ratio).
+ Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
+ # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the PSNR calculation.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+ Returns:
+ float: psnr result.
+ """
+
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+
+ if img1.ndim == 2:
+ img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
+
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
+
+ total = 0
+ for c in range(img1.shape[1]):
+ mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
+ bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
+
+ mse = mse.view(mse.shape[0], -1).mean(1)
+ total += 10 * torch.log10(1 / (mse + bef))
+
+ return float(total) / img1.shape[1]
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
+# imshow(single2uint(img_bicubic))
+#
+# img_tensor = single2tensor4(img)
+# for i in range(8):
+# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1))
+
+# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200)
+# imssave(patches,'a.png')
+
+
+
+
+
+
+
diff --git a/core/data/deg_kair_utils/utils_lmdb.py b/core/data/deg_kair_utils/utils_lmdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..75192c346bb9c0b96f8b09635ed548bd6e797d89
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_lmdb.py
@@ -0,0 +1,205 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None):
+ """Make lmdb from images.
+
+ Contents of lmdb. The file structure is:
+ example.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records 1)image name (with extension),
+ 2)image shape, and 3)compression level, separated by a white space.
+
+ For example, the meta information could be:
+ `000_00000000.png (720,1280,3) 1`, which means:
+ 1) image name (with extension): 000_00000000.png;
+ 2) image shape: (720,1280,3);
+ 3) compression level: 1
+
+ We use the image name without extension as the lmdb key.
+
+ If `multiprocessing_read` is True, it will read all the images to memory
+ using multiprocessing. Thus, your server needs to have enough memory.
+
+ Args:
+ data_path (str): Data path for reading images.
+ lmdb_path (str): Lmdb save path.
+ img_path_list (str): Image path list.
+ keys (str): Used for lmdb keys.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ multiprocessing_read (bool): Whether use multiprocessing to read all
+ the images to memory. Default: False.
+ n_thread (int): For multiprocessing.
+ map_size (int | None): Map size for lmdb env. If None, use the
+ estimated size from images. Default: None
+ """
+
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+ f'but got {len(img_path_list)} and {len(keys)}')
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+ print(f'Totoal images: {len(img_path_list)}')
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ if multiprocessing_read:
+ # read all the images to memory (multiprocessing)
+ dataset = {} # use dict to keep the order for multiprocessing
+ shapes = {}
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+ pbar = tqdm(total=len(img_path_list), unit='image')
+
+ def callback(arg):
+ """get the image data and update pbar."""
+ key, dataset[key], shapes[key] = arg
+ pbar.update(1)
+ pbar.set_description(f'Read {key}')
+
+ pool = Pool(n_thread)
+ for path, key in zip(img_path_list, keys):
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+ pool.close()
+ pool.join()
+ pbar.close()
+ print(f'Finish reading {len(img_path_list)} images.')
+
+ # create lmdb environment
+ if map_size is None:
+ # obtain data size for one image
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ data_size_per_img = img_byte.nbytes
+ print('Data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(img_path_list)
+ map_size = data_size * 10
+
+ env = lmdb.open(lmdb_path, map_size=map_size)
+
+ # write data to lmdb
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
+ txn = env.begin(write=True)
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+ pbar.update(1)
+ pbar.set_description(f'Write {key}')
+ key_byte = key.encode('ascii')
+ if multiprocessing_read:
+ img_byte = dataset[key]
+ h, w, c = shapes[key]
+ else:
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+ h, w, c = img_shape
+
+ txn.put(key_byte, img_byte)
+ # write meta information
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+ if idx % batch == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ pbar.close()
+ txn.commit()
+ env.close()
+ txt_file.close()
+ print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+ """Read image worker.
+
+ Args:
+ path (str): Image path.
+ key (str): Image key.
+ compress_level (int): Compress level when encoding images.
+
+ Returns:
+ str: Image key.
+ byte: Image byte.
+ tuple[int]: Image shape.
+ """
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ # deal with `libpng error: Read Error`
+ if img is None:
+ print(f'To deal with `libpng error: Read Error`, use PIL to load {path}')
+ from PIL import Image
+ import numpy as np
+ img = Image.open(path)
+ img = np.asanyarray(img)
+ img = img[:, :, [2, 1, 0]]
+
+ if img.ndim == 2:
+ h, w = img.shape
+ c = 1
+ else:
+ h, w, c = img.shape
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+ """LMDB Maker.
+
+ Args:
+ lmdb_path (str): Lmdb save path.
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ self.lmdb_path = lmdb_path
+ self.batch = batch
+ self.compress_level = compress_level
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
+ self.txn = self.env.begin(write=True)
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ self.counter = 0
+
+ def put(self, img_byte, key, img_shape):
+ self.counter += 1
+ key_byte = key.encode('ascii')
+ self.txn.put(key_byte, img_byte)
+ # write meta information
+ h, w, c = img_shape
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+ if self.counter % self.batch == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+
+ def close(self):
+ self.txn.commit()
+ self.env.close()
+ self.txt_file.close()
diff --git a/core/data/deg_kair_utils/utils_logger.py b/core/data/deg_kair_utils/utils_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..3067190e1b09b244814e0ccc4496b18f06e22b54
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_logger.py
@@ -0,0 +1,66 @@
+import sys
+import datetime
+import logging
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+def log(*args, **kwargs):
+ print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
+
+
+'''
+# --------------------------------------------
+# logger
+# --------------------------------------------
+'''
+
+
+def logger_info(logger_name, log_path='default_logger.log'):
+ ''' set up logger
+ modified by Kai Zhang (github: https://github.com/cszn)
+ '''
+ log = logging.getLogger(logger_name)
+ if log.hasHandlers():
+ print('LogHandlers exist!')
+ else:
+ print('LogHandlers setup!')
+ level = logging.INFO
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
+ fh = logging.FileHandler(log_path, mode='a')
+ fh.setFormatter(formatter)
+ log.setLevel(level)
+ log.addHandler(fh)
+ # print(len(log.handlers))
+
+ sh = logging.StreamHandler()
+ sh.setFormatter(formatter)
+ log.addHandler(sh)
+
+
+'''
+# --------------------------------------------
+# print to file and std_out simultaneously
+# --------------------------------------------
+'''
+
+
+class logger_print(object):
+ def __init__(self, log_path="default.log"):
+ self.terminal = sys.stdout
+ self.log = open(log_path, 'a')
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.log.write(message) # write the message
+
+ def flush(self):
+ pass
diff --git a/core/data/deg_kair_utils/utils_mat.py b/core/data/deg_kair_utils/utils_mat.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd25d500c0eae77a3b815b8e956205b737ee43d4
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_mat.py
@@ -0,0 +1,88 @@
+import os
+import json
+import scipy.io as spio
+import pandas as pd
+
+
+def loadmat(filename):
+ '''
+ this function should be called instead of direct spio.loadmat
+ as it cures the problem of not properly recovering python dictionaries
+ from mat files. It calls the function check keys to cure all entries
+ which are still mat-objects
+ '''
+ data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True)
+ return dict_to_nonedict(_check_keys(data))
+
+def _check_keys(dict):
+ '''
+ checks if entries in dictionary are mat-objects. If yes
+ todict is called to change them to nested dictionaries
+ '''
+ for key in dict:
+ if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
+ dict[key] = _todict(dict[key])
+ return dict
+
+def _todict(matobj):
+ '''
+ A recursive function which constructs from matobjects nested dictionaries
+ '''
+ dict = {}
+ for strg in matobj._fieldnames:
+ elem = matobj.__dict__[strg]
+ if isinstance(elem, spio.matlab.mio5_params.mat_struct):
+ dict[strg] = _todict(elem)
+ else:
+ dict[strg] = elem
+ return dict
+
+
+def dict_to_nonedict(opt):
+ if isinstance(opt, dict):
+ new_opt = dict()
+ for key, sub_opt in opt.items():
+ new_opt[key] = dict_to_nonedict(sub_opt)
+ return NoneDict(**new_opt)
+ elif isinstance(opt, list):
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
+ else:
+ return opt
+
+
+class NoneDict(dict):
+ def __missing__(self, key):
+ return None
+
+
+def mat2json(mat_path=None, filepath = None):
+ """
+ Converts .mat file to .json and writes new file
+ Parameters
+ ----------
+ mat_path: Str
+ path/filename .matå˜æ”¾è·¯å¾„
+ filepath: Str
+ 如果需è¦ä¿å˜æˆjson, æ·»åŠ è¿™ä¸€è·¯å¾„. å¦åˆ™ä¸ä¿å˜
+ Returns
+ 返回转化的å—å…¸
+ -------
+ None
+ Examples
+ --------
+ >>> mat2json(blah blah)
+ """
+
+ matlabFile = loadmat(mat_path)
+ #pop all those dumb fields that don't let you jsonize file
+ matlabFile.pop('__header__')
+ matlabFile.pop('__version__')
+ matlabFile.pop('__globals__')
+ #jsonize the file - orientation is 'index'
+ matlabFile = pd.Series(matlabFile).to_json()
+
+ if filepath:
+ json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json'
+ with open(json_path, 'w') as f:
+ f.write(matlabFile)
+ return matlabFile
\ No newline at end of file
diff --git a/core/data/deg_kair_utils/utils_matconvnet.py b/core/data/deg_kair_utils/utils_matconvnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..506dc47805ae07976022b236ca64c98e9a6f78b3
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_matconvnet.py
@@ -0,0 +1,197 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import torch
+from collections import OrderedDict
+
+# import scipy.io as io
+import hdf5storage
+
+"""
+# --------------------------------------------
+# Convert matconvnet SimpleNN model into pytorch model
+# --------------------------------------------
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# 28/Nov/2019
+# --------------------------------------------
+"""
+
+
+def weights2tensor(x, squeeze=False, in_features=None, out_features=None):
+ """Modified version of https://github.com/albanie/pytorch-mcn
+ Adjust memory layout and load weights as torch tensor
+ Args:
+ x (ndaray): a numpy array, corresponding to a set of network weights
+ stored in column major order
+ squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove
+ singletons from the trailing dimensions. So after converting to
+ pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1)
+ it will be reshaped to a matrix with shape (A,B).
+ in_features (int :: None): used to reshape weights for a linear block.
+ out_features (int :: None): used to reshape weights for a linear block.
+ Returns:
+ torch.tensor: a permuted sets of weights, matching the pytorch layout
+ convention
+ """
+ if x.ndim == 4:
+ x = x.transpose((3, 2, 0, 1))
+# for FFDNet, pixel-shuffle layer
+# if x.shape[1]==13:
+# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:]
+# if x.shape[0]==12:
+# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
+# if x.shape[1]==5:
+# x=x[:,[0,2,1,3, 4],:,:]
+# if x.shape[0]==4:
+# x=x[[0,2,1,3],:,:,:]
+## for SRMD, pixel-shuffle layer
+# if x.shape[0]==12:
+# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:]
+# if x.shape[0]==27:
+# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:]
+# if x.shape[0]==48:
+# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:]
+
+ elif x.ndim == 3: # add by Kai
+ x = x[:,:,:,None]
+ x = x.transpose((3, 2, 0, 1))
+ elif x.ndim == 2:
+ if x.shape[1] == 1:
+ x = x.flatten()
+ if squeeze:
+ if in_features and out_features:
+ x = x.reshape((out_features, in_features))
+ x = np.squeeze(x)
+ return torch.from_numpy(np.ascontiguousarray(x))
+
+
+def save_model(network, save_path):
+ state_dict = network.state_dict()
+ for key, param in state_dict.items():
+ state_dict[key] = param.cpu()
+ torch.save(state_dict, save_path)
+
+
+if __name__ == '__main__':
+
+
+# from utils import utils_logger
+# import logging
+# utils_logger.logger_info('a', 'a.log')
+# logger = logging.getLogger('a')
+#
+ # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat')
+ mcn = hdf5storage.loadmat('models/modelcolor.mat')
+
+
+ #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0])
+
+ mat_net = OrderedDict()
+ for idx in range(25):
+ mat_net[str(idx)] = OrderedDict()
+ count = -1
+
+ print(idx)
+ for i in range(13):
+
+ if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv':
+
+ count += 1
+ w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0]
+ # print(w.shape)
+ w = weights2tensor(w)
+ # print(w.shape)
+
+ b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1]
+ b = weights2tensor(b)
+ print(b.shape)
+
+ mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w
+ mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b
+
+ torch.save(mat_net, 'model_zoo/modelcolor.pth')
+
+
+
+# from models.network_dncnn import IRCNN as net
+# network = net(in_nc=3, out_nc=3, nc=64)
+# state_dict = network.state_dict()
+#
+# #show_kv(state_dict)
+#
+# for i in range(len(mcn['net'][0][0][0])):
+# print(mcn['net'][0][0][0][i][0][0][0][0])
+#
+# count = -1
+# mat_net = OrderedDict()
+# for i in range(len(mcn['net'][0][0][0])):
+# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv':
+#
+# count += 1
+# w = mcn['net'][0][0][0][i][0][1][0][0]
+# print(w.shape)
+# w = weights2tensor(w)
+# print(w.shape)
+#
+# b = mcn['net'][0][0][0][i][0][1][0][1]
+# b = weights2tensor(b)
+# print(b.shape)
+#
+# mat_net['model.{:d}.weight'.format(count*2)] = w
+# mat_net['model.{:d}.bias'.format(count*2)] = b
+#
+# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth')
+#
+#
+#
+# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth')
+# def show_kv(net):
+# for k, v in net.items():
+# print(k)
+#
+# show_kv(crt_net)
+
+
+# from models.network_dncnn import DnCNN as net
+# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
+
+# from models.network_srmd import SRMD as net
+# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
+# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
+#
+# from models.network_rrdb import RRDB as net
+# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv')
+#
+# state_dict = network.state_dict()
+# for key, param in state_dict.items():
+# print(key)
+# from models.network_imdn import IMDN as net
+# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle')
+# state_dict = network.state_dict()
+# mat_net = OrderedDict()
+# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()):
+# mat_net[key] = param2
+# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth')
+#
+
+# net_old = torch.load('net_old.pth')
+# def show_kv(net):
+# for k, v in net.items():
+# print(k)
+#
+# show_kv(net_old)
+# from models.network_dpsr import MSRResNet_prior as net
+# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
+# state_dict = network.state_dict()
+# net_new = OrderedDict()
+# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()):
+# net_new[key] = param_old
+# torch.save(net_new, 'net_new.pth')
+
+
+ # print(key)
+ # print(param.size())
+
+
+
+ # run utils/utils_matconvnet.py
diff --git a/core/data/deg_kair_utils/utils_model.py b/core/data/deg_kair_utils/utils_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4d9e6ac651784c7ed36e623c3a6175883123c2b
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_model.py
@@ -0,0 +1,330 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import torch
+from utils import utils_image as util
+import re
+import glob
+import os
+
+
+'''
+# --------------------------------------------
+# Model
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+'''
+
+
+def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
+ """
+ # ---------------------------------------
+ # Kai Zhang (github: https://github.com/cszn)
+ # 03/Mar/2019
+ # ---------------------------------------
+ Args:
+ save_dir: model folder
+ net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
+ pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
+
+ Return:
+ init_iter: iteration number
+ init_path: model path
+ # ---------------------------------------
+ """
+
+ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
+ if file_list:
+ iter_exist = []
+ for file_ in file_list:
+ iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
+ iter_exist.append(int(iter_current[0]))
+ init_iter = max(iter_exist)
+ init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
+ else:
+ init_iter = 0
+ init_path = pretrained_path
+ return init_iter, init_path
+
+
+def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1):
+ '''
+ # ---------------------------------------
+ # Kai Zhang (github: https://github.com/cszn)
+ # 03/Mar/2019
+ # ---------------------------------------
+ Args:
+ model: trained model
+ L: input Low-quality image
+ mode:
+ (0) normal: test(model, L)
+ (1) pad: test_pad(model, L, modulo=16)
+ (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1)
+ (3) x8: test_x8(model, L, modulo=1) ^_^
+ (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1)
+ refield: effective receptive filed of the network, 32 is enough
+ useful when split, i.e., mode=2, 4
+ min_size: min_sizeXmin_size image, e.g., 256X256 image
+ useful when split, i.e., mode=2, 4
+ sf: scale factor for super-resolution, otherwise 1
+ modulo: 1 if split
+ useful when pad, i.e., mode=1
+
+ Returns:
+ E: estimated image
+ # ---------------------------------------
+ '''
+ if mode == 0:
+ E = test(model, L)
+ elif mode == 1:
+ E = test_pad(model, L, modulo, sf)
+ elif mode == 2:
+ E = test_split(model, L, refield, min_size, sf, modulo)
+ elif mode == 3:
+ E = test_x8(model, L, modulo, sf)
+ elif mode == 4:
+ E = test_split_x8(model, L, refield, min_size, sf, modulo)
+ return E
+
+
+'''
+# --------------------------------------------
+# normal (0)
+# --------------------------------------------
+'''
+
+
+def test(model, L):
+ E = model(L)
+ return E
+
+
+'''
+# --------------------------------------------
+# pad (1)
+# --------------------------------------------
+'''
+
+
+def test_pad(model, L, modulo=16, sf=1):
+ h, w = L.size()[-2:]
+ paddingBottom = int(np.ceil(h/modulo)*modulo-h)
+ paddingRight = int(np.ceil(w/modulo)*modulo-w)
+ L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
+ E = model(L)
+ E = E[..., :h*sf, :w*sf]
+ return E
+
+
+'''
+# --------------------------------------------
+# split (function)
+# --------------------------------------------
+'''
+
+
+def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1):
+ """
+ Args:
+ model: trained model
+ L: input Low-quality image
+ refield: effective receptive filed of the network, 32 is enough
+ min_size: min_sizeXmin_size image, e.g., 256X256 image
+ sf: scale factor for super-resolution, otherwise 1
+ modulo: 1 if split
+
+ Returns:
+ E: estimated result
+ """
+ h, w = L.size()[-2:]
+ if h*w <= min_size**2:
+ L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
+ E = model(L)
+ E = E[..., :h*sf, :w*sf]
+ else:
+ top = slice(0, (h//2//refield+1)*refield)
+ bottom = slice(h - (h//2//refield+1)*refield, h)
+ left = slice(0, (w//2//refield+1)*refield)
+ right = slice(w - (w//2//refield+1)*refield, w)
+ Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
+
+ if h * w <= 4*(min_size**2):
+ Es = [model(Ls[i]) for i in range(4)]
+ else:
+ Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
+
+ b, c = Es[0].size()[:2]
+ E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
+
+ E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
+ E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
+ E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
+ E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
+ return E
+
+
+'''
+# --------------------------------------------
+# split (2)
+# --------------------------------------------
+'''
+
+
+def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1):
+ E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
+ return E
+
+
+'''
+# --------------------------------------------
+# x8 (3)
+# --------------------------------------------
+'''
+
+
+def test_x8(model, L, modulo=1, sf=1):
+ E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)]
+ for i in range(len(E_list)):
+ if i == 3 or i == 5:
+ E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i)
+ else:
+ E_list[i] = util.augment_img_tensor4(E_list[i], mode=i)
+ output_cat = torch.stack(E_list, dim=0)
+ E = output_cat.mean(dim=0, keepdim=False)
+ return E
+
+
+'''
+# --------------------------------------------
+# split and x8 (4)
+# --------------------------------------------
+'''
+
+
+def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1):
+ E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)]
+ for k, i in enumerate(range(len(E_list))):
+ if i==3 or i==5:
+ E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i)
+ else:
+ E_list[k] = util.augment_img_tensor4(E_list[k], mode=i)
+ output_cat = torch.stack(E_list, dim=0)
+ E = output_cat.mean(dim=0, keepdim=False)
+ return E
+
+
+'''
+# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
+# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^
+# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-
+'''
+
+
+'''
+# --------------------------------------------
+# print
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# print model
+# --------------------------------------------
+def print_model(model):
+ msg = describe_model(model)
+ print(msg)
+
+
+# --------------------------------------------
+# print params
+# --------------------------------------------
+def print_params(model):
+ msg = describe_params(model)
+ print(msg)
+
+
+'''
+# --------------------------------------------
+# information
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# model inforation
+# --------------------------------------------
+def info_model(model):
+ msg = describe_model(model)
+ return msg
+
+
+# --------------------------------------------
+# params inforation
+# --------------------------------------------
+def info_params(model):
+ msg = describe_params(model)
+ return msg
+
+
+'''
+# --------------------------------------------
+# description
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# model name and total number of parameters
+# --------------------------------------------
+def describe_model(model):
+ if isinstance(model, torch.nn.DataParallel):
+ model = model.module
+ msg = '\n'
+ msg += 'models name: {}'.format(model.__class__.__name__) + '\n'
+ msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n'
+ msg += 'Net structure:\n{}'.format(str(model)) + '\n'
+ return msg
+
+
+# --------------------------------------------
+# parameters description
+# --------------------------------------------
+def describe_params(model):
+ if isinstance(model, torch.nn.DataParallel):
+ model = model.module
+ msg = '\n'
+ msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n'
+ for name, param in model.state_dict().items():
+ if not 'num_batches_tracked' in name:
+ v = param.data.clone().float()
+ msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n'
+ return msg
+
+
+if __name__ == '__main__':
+
+ class Net(torch.nn.Module):
+ def __init__(self, in_channels=3, out_channels=3):
+ super(Net, self).__init__()
+ self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ model = Net()
+ model = model.eval()
+ print_model(model)
+ print_params(model)
+ x = torch.randn((2,3,401,401))
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ for mode in range(5):
+ y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1)
+ print(y.shape)
+
+ # run utils/utils_model.py
diff --git a/core/data/deg_kair_utils/utils_modelsummary.py b/core/data/deg_kair_utils/utils_modelsummary.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e040e31d8ddffbb8b7b2e2dc4ddf0b9cdca6a23
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_modelsummary.py
@@ -0,0 +1,485 @@
+import torch.nn as nn
+import torch
+import numpy as np
+
+'''
+---- 1) FLOPs: floating point operations
+---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
+---- 3) #Conv2d: the number of ‘Conv2d’ layers
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 21/July/2020
+# --------------------------------------------
+# Reference
+https://github.com/sovrasov/flops-counter.pytorch.git
+
+# If you use this code, please consider the following citation:
+
+@inproceedings{zhang2020aim, %
+ title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results},
+ author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others},
+ booktitle={European Conference on Computer Vision Workshops},
+ year={2020}
+}
+# --------------------------------------------
+'''
+
+def get_model_flops(model, input_res, print_per_layer_stat=True,
+ input_constructor=None):
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval().start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_res)
+ _ = flops_model(**input)
+ else:
+ device = list(flops_model.parameters())[-1].device
+ batch = torch.FloatTensor(1, *input_res).to(device)
+ _ = flops_model(batch)
+
+ if print_per_layer_stat:
+ print_model_with_flops(flops_model)
+ flops_count = flops_model.compute_average_flops_cost()
+ flops_model.stop_flops_count()
+
+ return flops_count
+
+def get_model_activation(model, input_res, input_constructor=None):
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
+ activation_model = add_activation_counting_methods(model)
+ activation_model.eval().start_activation_count()
+ if input_constructor:
+ input = input_constructor(input_res)
+ _ = activation_model(**input)
+ else:
+ device = list(activation_model.parameters())[-1].device
+ batch = torch.FloatTensor(1, *input_res).to(device)
+ _ = activation_model(batch)
+
+ activation_count, num_conv = activation_model.compute_average_activation_cost()
+ activation_model.stop_activation_count()
+
+ return activation_count, num_conv
+
+
+def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
+ input_constructor=None):
+ assert type(input_res) is tuple
+ assert len(input_res) >= 3
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval().start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_res)
+ _ = flops_model(**input)
+ else:
+ batch = torch.FloatTensor(1, *input_res)
+ _ = flops_model(batch)
+
+ if print_per_layer_stat:
+ print_model_with_flops(flops_model)
+ flops_count = flops_model.compute_average_flops_cost()
+ params_count = get_model_parameters_number(flops_model)
+ flops_model.stop_flops_count()
+
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+
+ return flops_count, params_count
+
+
+def flops_to_string(flops, units='GMac', precision=2):
+ if units is None:
+ if flops // 10**9 > 0:
+ return str(round(flops / 10.**9, precision)) + ' GMac'
+ elif flops // 10**6 > 0:
+ return str(round(flops / 10.**6, precision)) + ' MMac'
+ elif flops // 10**3 > 0:
+ return str(round(flops / 10.**3, precision)) + ' KMac'
+ else:
+ return str(flops) + ' Mac'
+ else:
+ if units == 'GMac':
+ return str(round(flops / 10.**9, precision)) + ' ' + units
+ elif units == 'MMac':
+ return str(round(flops / 10.**6, precision)) + ' ' + units
+ elif units == 'KMac':
+ return str(round(flops / 10.**3, precision)) + ' ' + units
+ else:
+ return str(flops) + ' Mac'
+
+
+def params_to_string(params_num):
+ if params_num // 10 ** 6 > 0:
+ return str(round(params_num / 10 ** 6, 2)) + ' M'
+ elif params_num // 10 ** 3:
+ return str(round(params_num / 10 ** 3, 2)) + ' k'
+ else:
+ return str(params_num)
+
+
+def print_model_with_flops(model, units='GMac', precision=3):
+ total_flops = model.compute_average_flops_cost()
+
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+
+ def flops_repr(self):
+ accumulated_flops_cost = self.accumulate_flops()
+ return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
+ '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
+ self.original_extra_repr()])
+
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+
+ def del_extra_repr(m):
+ if hasattr(m, 'original_extra_repr'):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, 'accumulate_flops'):
+ del m.accumulate_flops
+
+ model.apply(add_extra_repr)
+ print(model)
+ model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model):
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return params_num
+
+
+def add_flops_counting_methods(net_main_module):
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ # embed()
+ net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
+
+ net_main_module.reset_flops_count()
+ return net_main_module
+
+
+def compute_average_flops_cost(self):
+ """
+ A method that will be available after add_flops_counting_methods() is called
+ on a desired net object.
+
+ Returns current mean flops consumption per image.
+
+ """
+
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+
+ return flops_sum
+
+
+def start_flops_count(self):
+ """
+ A method that will be available after add_flops_counting_methods() is called
+ on a desired net object.
+
+ Activates the computation of mean flops consumption per image.
+ Call it before you run the network.
+
+ """
+ self.apply(add_flops_counter_hook_function)
+
+
+def stop_flops_count(self):
+ """
+ A method that will be available after add_flops_counting_methods() is called
+ on a desired net object.
+
+ Stops computing the mean flops consumption per image.
+ Call whenever you want to pause the computation.
+
+ """
+ self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self):
+ """
+ A method that will be available after add_flops_counting_methods() is called
+ on a desired net object.
+
+ Resets statistics computed so far.
+
+ """
+ self.apply(add_flops_counter_variable_or_reset)
+
+
+def add_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ return
+
+ if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
+ handle = module.register_forward_hook(conv_flops_counter_hook)
+ elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
+ handle = module.register_forward_hook(relu_flops_counter_hook)
+ elif isinstance(module, nn.Linear):
+ handle = module.register_forward_hook(linear_flops_counter_hook)
+ elif isinstance(module, (nn.BatchNorm2d)):
+ handle = module.register_forward_hook(bn_flops_counter_hook)
+ else:
+ handle = module.register_forward_hook(empty_flops_counter_hook)
+ module.__flops_handle__ = handle
+
+
+def remove_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+
+
+def add_flops_counter_variable_or_reset(module):
+ if is_supported_instance(module):
+ module.__flops__ = 0
+
+
+# ---- Internal functions
+def is_supported_instance(module):
+ if isinstance(module,
+ (
+ nn.Conv2d, nn.ConvTranspose2d,
+ nn.BatchNorm2d,
+ nn.Linear,
+ nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
+ )):
+ return True
+
+ return False
+
+
+def conv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ # input = input[0]
+
+ batch_size = output.shape[0]
+ output_dims = list(output.shape[2:])
+
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
+
+ active_elements_count = batch_size * np.prod(output_dims)
+ overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
+
+ # overall_flops = overall_conv_flops
+
+ conv_module.__flops__ += int(overall_conv_flops)
+ # conv_module.__output_dims__ = output_dims
+
+
+def relu_flops_counter_hook(module, input, output):
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+ # print(module.__flops__, id(module))
+ # print(module)
+
+
+def linear_flops_counter_hook(module, input, output):
+ input = input[0]
+ if len(input.shape) == 1:
+ batch_size = 1
+ module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
+ else:
+ batch_size = input.shape[0]
+ module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
+
+
+def bn_flops_counter_hook(module, input, output):
+ # input = input[0]
+ # TODO: need to check here
+ # batch_flops = np.prod(input.shape)
+ # if module.affine:
+ # batch_flops *= 2
+ # module.__flops__ += int(batch_flops)
+ batch = output.shape[0]
+ output_dims = output.shape[2:]
+ channels = module.num_features
+ batch_flops = batch * channels * np.prod(output_dims)
+ if module.affine:
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+
+
+# ---- Count the number of convolutional layers and the activation
+def add_activation_counting_methods(net_main_module):
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ # embed()
+ net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
+ net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
+ net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
+ net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
+
+ net_main_module.reset_activation_count()
+ return net_main_module
+
+
+def compute_average_activation_cost(self):
+ """
+ A method that will be available after add_activation_counting_methods() is called
+ on a desired net object.
+
+ Returns current mean activation consumption per image.
+
+ """
+
+ activation_sum = 0
+ num_conv = 0
+ for module in self.modules():
+ if is_supported_instance_for_activation(module):
+ activation_sum += module.__activation__
+ num_conv += module.__num_conv__
+ return activation_sum, num_conv
+
+
+def start_activation_count(self):
+ """
+ A method that will be available after add_activation_counting_methods() is called
+ on a desired net object.
+
+ Activates the computation of mean activation consumption per image.
+ Call it before you run the network.
+
+ """
+ self.apply(add_activation_counter_hook_function)
+
+
+def stop_activation_count(self):
+ """
+ A method that will be available after add_activation_counting_methods() is called
+ on a desired net object.
+
+ Stops computing the mean activation consumption per image.
+ Call whenever you want to pause the computation.
+
+ """
+ self.apply(remove_activation_counter_hook_function)
+
+
+def reset_activation_count(self):
+ """
+ A method that will be available after add_activation_counting_methods() is called
+ on a desired net object.
+
+ Resets statistics computed so far.
+
+ """
+ self.apply(add_activation_counter_variable_or_reset)
+
+
+def add_activation_counter_hook_function(module):
+ if is_supported_instance_for_activation(module):
+ if hasattr(module, '__activation_handle__'):
+ return
+
+ if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
+ handle = module.register_forward_hook(conv_activation_counter_hook)
+ module.__activation_handle__ = handle
+
+
+def remove_activation_counter_hook_function(module):
+ if is_supported_instance_for_activation(module):
+ if hasattr(module, '__activation_handle__'):
+ module.__activation_handle__.remove()
+ del module.__activation_handle__
+
+
+def add_activation_counter_variable_or_reset(module):
+ if is_supported_instance_for_activation(module):
+ module.__activation__ = 0
+ module.__num_conv__ = 0
+
+
+def is_supported_instance_for_activation(module):
+ if isinstance(module,
+ (
+ nn.Conv2d, nn.ConvTranspose2d,
+ )):
+ return True
+
+ return False
+
+def conv_activation_counter_hook(module, input, output):
+ """
+ Calculate the activations in the convolutional operation.
+ Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
+ :param module:
+ :param input:
+ :param output:
+ :return:
+ """
+ module.__activation__ += output.numel()
+ module.__num_conv__ += 1
+
+
+def empty_flops_counter_hook(module, input, output):
+ module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module, input, output):
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+
+
+def pool_flops_counter_hook(module, input, output):
+ input = input[0]
+ module.__flops__ += int(np.prod(input.shape))
+
+
+def dconv_flops_counter_hook(dconv_module, input, output):
+ input = input[0]
+
+ batch_size = input.shape[0]
+ output_dims = list(output.shape[2:])
+
+ m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
+ out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
+ # groups = dconv_module.groups
+
+ # filters_per_channel = out_channels // groups
+ conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
+ conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
+ active_elements_count = batch_size * np.prod(output_dims)
+
+ overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
+ overall_flops = overall_conv_flops
+
+ dconv_module.__flops__ += int(overall_flops)
+ # dconv_module.__output_dims__ = output_dims
+
+
+
+
+
diff --git a/core/data/deg_kair_utils/utils_option.py b/core/data/deg_kair_utils/utils_option.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf096210e2d8ea553b06a91ac5cdaa21127d837c
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_option.py
@@ -0,0 +1,255 @@
+import os
+from collections import OrderedDict
+from datetime import datetime
+import json
+import re
+import glob
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+def get_timestamp():
+ return datetime.now().strftime('_%y%m%d_%H%M%S')
+
+
+def parse(opt_path, is_train=True):
+
+ # ----------------------------------------
+ # remove comments starting with '//'
+ # ----------------------------------------
+ json_str = ''
+ with open(opt_path, 'r') as f:
+ for line in f:
+ line = line.split('//')[0] + '\n'
+ json_str += line
+
+ # ----------------------------------------
+ # initialize opt
+ # ----------------------------------------
+ opt = json.loads(json_str, object_pairs_hook=OrderedDict)
+
+ opt['opt_path'] = opt_path
+ opt['is_train'] = is_train
+
+ # ----------------------------------------
+ # set default
+ # ----------------------------------------
+ if 'merge_bn' not in opt:
+ opt['merge_bn'] = False
+ opt['merge_bn_startpoint'] = -1
+
+ if 'scale' not in opt:
+ opt['scale'] = 1
+
+ # ----------------------------------------
+ # datasets
+ # ----------------------------------------
+ for phase, dataset in opt['datasets'].items():
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ dataset['scale'] = opt['scale'] # broadcast
+ dataset['n_channels'] = opt['n_channels'] # broadcast
+ if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None:
+ dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H'])
+ if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None:
+ dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L'])
+
+ # ----------------------------------------
+ # path
+ # ----------------------------------------
+ for key, path in opt['path'].items():
+ if path and key in opt['path']:
+ opt['path'][key] = os.path.expanduser(path)
+
+ path_task = os.path.join(opt['path']['root'], opt['task'])
+ opt['path']['task'] = path_task
+ opt['path']['log'] = path_task
+ opt['path']['options'] = os.path.join(path_task, 'options')
+
+ if is_train:
+ opt['path']['models'] = os.path.join(path_task, 'models')
+ opt['path']['images'] = os.path.join(path_task, 'images')
+ else: # test
+ opt['path']['images'] = os.path.join(path_task, 'test_images')
+
+ # ----------------------------------------
+ # network
+ # ----------------------------------------
+ opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1
+
+ # ----------------------------------------
+ # GPU devices
+ # ----------------------------------------
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
+
+ # ----------------------------------------
+ # default setting for distributeddataparallel
+ # ----------------------------------------
+ if 'find_unused_parameters' not in opt:
+ opt['find_unused_parameters'] = True
+ if 'use_static_graph' not in opt:
+ opt['use_static_graph'] = False
+ if 'dist' not in opt:
+ opt['dist'] = False
+ opt['num_gpu'] = len(opt['gpu_ids'])
+ print('number of GPUs is: ' + str(opt['num_gpu']))
+
+ # ----------------------------------------
+ # default setting for perceptual loss
+ # ----------------------------------------
+ if 'F_feature_layer' not in opt['train']:
+ opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34]
+ if 'F_weights' not in opt['train']:
+ opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0]
+ if 'F_lossfn_type' not in opt['train']:
+ opt['train']['F_lossfn_type'] = 'l1'
+ if 'F_use_input_norm' not in opt['train']:
+ opt['train']['F_use_input_norm'] = True
+ if 'F_use_range_norm' not in opt['train']:
+ opt['train']['F_use_range_norm'] = False
+
+ # ----------------------------------------
+ # default setting for optimizer
+ # ----------------------------------------
+ if 'G_optimizer_type' not in opt['train']:
+ opt['train']['G_optimizer_type'] = "adam"
+ if 'G_optimizer_betas' not in opt['train']:
+ opt['train']['G_optimizer_betas'] = [0.9,0.999]
+ if 'G_scheduler_restart_weights' not in opt['train']:
+ opt['train']['G_scheduler_restart_weights'] = 1
+ if 'G_optimizer_wd' not in opt['train']:
+ opt['train']['G_optimizer_wd'] = 0
+ if 'G_optimizer_reuse' not in opt['train']:
+ opt['train']['G_optimizer_reuse'] = False
+ if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']:
+ opt['train']['D_optimizer_reuse'] = False
+
+ # ----------------------------------------
+ # default setting of strict for model loading
+ # ----------------------------------------
+ if 'G_param_strict' not in opt['train']:
+ opt['train']['G_param_strict'] = True
+ if 'netD' in opt and 'D_param_strict' not in opt['path']:
+ opt['train']['D_param_strict'] = True
+ if 'E_param_strict' not in opt['path']:
+ opt['train']['E_param_strict'] = True
+
+ # ----------------------------------------
+ # Exponential Moving Average
+ # ----------------------------------------
+ if 'E_decay' not in opt['train']:
+ opt['train']['E_decay'] = 0
+
+ # ----------------------------------------
+ # default setting for discriminator
+ # ----------------------------------------
+ if 'netD' in opt:
+ if 'net_type' not in opt['netD']:
+ opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet
+ if 'in_nc' not in opt['netD']:
+ opt['netD']['in_nc'] = 3
+ if 'base_nc' not in opt['netD']:
+ opt['netD']['base_nc'] = 64
+ if 'n_layers' not in opt['netD']:
+ opt['netD']['n_layers'] = 3
+ if 'norm_type' not in opt['netD']:
+ opt['netD']['norm_type'] = 'spectral'
+
+
+ return opt
+
+
+def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None):
+ """
+ Args:
+ save_dir: model folder
+ net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD'
+ pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path
+
+ Return:
+ init_iter: iteration number
+ init_path: model path
+ """
+ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type)))
+ if file_list:
+ iter_exist = []
+ for file_ in file_list:
+ iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_)
+ iter_exist.append(int(iter_current[0]))
+ init_iter = max(iter_exist)
+ init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type))
+ else:
+ init_iter = 0
+ init_path = pretrained_path
+ return init_iter, init_path
+
+
+'''
+# --------------------------------------------
+# convert the opt into json file
+# --------------------------------------------
+'''
+
+
+def save(opt):
+ opt_path = opt['opt_path']
+ opt_path_copy = opt['path']['options']
+ dirname, filename_ext = os.path.split(opt_path)
+ filename, ext = os.path.splitext(filename_ext)
+ dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext)
+ with open(dump_path, 'w') as dump_file:
+ json.dump(opt, dump_file, indent=2)
+
+
+'''
+# --------------------------------------------
+# dict to string for logger
+# --------------------------------------------
+'''
+
+
+def dict2str(opt, indent_l=1):
+ msg = ''
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_l * 2) + k + ':[\n'
+ msg += dict2str(v, indent_l + 1)
+ msg += ' ' * (indent_l * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
+ return msg
+
+
+'''
+# --------------------------------------------
+# convert OrderedDict to NoneDict,
+# return None for missing key
+# --------------------------------------------
+'''
+
+
+def dict_to_nonedict(opt):
+ if isinstance(opt, dict):
+ new_opt = dict()
+ for key, sub_opt in opt.items():
+ new_opt[key] = dict_to_nonedict(sub_opt)
+ return NoneDict(**new_opt)
+ elif isinstance(opt, list):
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
+ else:
+ return opt
+
+
+class NoneDict(dict):
+ def __missing__(self, key):
+ return None
diff --git a/core/data/deg_kair_utils/utils_params.py b/core/data/deg_kair_utils/utils_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..def1cb79e11472b9b8ebbaae4bd83e7216af2ccb
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_params.py
@@ -0,0 +1,135 @@
+import torch
+
+import torchvision
+
+from models import basicblock as B
+
+def show_kv(net):
+ for k, v in net.items():
+ print(k)
+
+# should run train debug mode first to get an initial model
+#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
+#
+#for k, v in crt_net.items():
+# print(k)
+#for k, v in crt_net.items():
+# if k in pretrained_net:
+# crt_net[k] = pretrained_net[k]
+# print('replace ... ', k)
+
+# x2 -> x4
+#crt_net['model.5.weight'] = pretrained_net['model.2.weight']
+#crt_net['model.5.bias'] = pretrained_net['model.2.bias']
+#crt_net['model.8.weight'] = pretrained_net['model.5.weight']
+#crt_net['model.8.bias'] = pretrained_net['model.5.bias']
+#crt_net['model.10.weight'] = pretrained_net['model.7.weight']
+#crt_net['model.10.bias'] = pretrained_net['model.7.bias']
+#torch.save(crt_net, '../pretrained_tmp.pth')
+
+# x2 -> x3
+'''
+in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
+new_filter = torch.Tensor(576, 64, 3, 3)
+new_filter[0:256, :, :, :] = in_filter
+new_filter[256:512, :, :, :] = in_filter
+new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
+crt_net['model.2.weight'] = new_filter
+
+in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
+new_bias = torch.Tensor(576)
+new_bias[0:256] = in_bias
+new_bias[256:512] = in_bias
+new_bias[512:] = in_bias[0:576 - 512]
+crt_net['model.2.bias'] = new_bias
+
+torch.save(crt_net, '../pretrained_tmp.pth')
+'''
+
+# x2 -> x8
+'''
+crt_net['model.5.weight'] = pretrained_net['model.2.weight']
+crt_net['model.5.bias'] = pretrained_net['model.2.bias']
+crt_net['model.8.weight'] = pretrained_net['model.2.weight']
+crt_net['model.8.bias'] = pretrained_net['model.2.bias']
+crt_net['model.11.weight'] = pretrained_net['model.5.weight']
+crt_net['model.11.bias'] = pretrained_net['model.5.bias']
+crt_net['model.13.weight'] = pretrained_net['model.7.weight']
+crt_net['model.13.bias'] = pretrained_net['model.7.bias']
+torch.save(crt_net, '../pretrained_tmp.pth')
+'''
+
+# x3/4/8 RGB -> Y
+
+def rgb2gray_net(net, only_input=True):
+
+ if only_input:
+ in_filter = net['0.weight']
+ in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
+ in_new_filter.unsqueeze_(1)
+ net['0.weight'] = in_new_filter
+
+# out_filter = pretrained_net['model.13.weight']
+# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
+# out_filter[2, :, :, :] * 0.114
+# out_new_filter.unsqueeze_(0)
+# crt_net['model.13.weight'] = out_new_filter
+# out_bias = pretrained_net['model.13.bias']
+# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
+# out_new_bias = torch.Tensor(1).fill_(out_new_bias)
+# crt_net['model.13.bias'] = out_new_bias
+
+# torch.save(crt_net, '../pretrained_tmp.pth')
+
+ return net
+
+
+
+if __name__ == '__main__':
+
+ net = torchvision.models.vgg19(pretrained=True)
+ for k,v in net.features.named_parameters():
+ if k=='0.weight':
+ in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
+ in_new_filter.unsqueeze_(1)
+ v = in_new_filter
+ print(v.shape)
+ print(v[0,0,0,0])
+ if k=='0.bias':
+ in_new_bias = v
+ print(v[0])
+
+ print(net.features[0])
+
+ net.features[0] = B.conv(1, 64, mode='C')
+
+ print(net.features[0])
+ net.features[0].weight.data=in_new_filter
+ net.features[0].bias.data=in_new_bias
+
+ for k,v in net.features.named_parameters():
+ if k=='0.weight':
+ print(v[0,0,0,0])
+ if k=='0.bias':
+ print(v[0])
+
+ # transfer parameters of old model to new one
+ model_old = torch.load(model_path)
+ state_dict = model.state_dict()
+ for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
+ state_dict[key2] = param
+ print([key, key2])
+ # print([param.size(), param2.size()])
+ torch.save(state_dict, 'model_new.pth')
+
+
+ # rgb2gray_net(net)
+
+
+
+
+
+
+
+
+
diff --git a/core/data/deg_kair_utils/utils_receptivefield.py b/core/data/deg_kair_utils/utils_receptivefield.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ad613b9e744189e13b721a558dbc0f42c57b30
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_receptivefield.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+
+# online calculation: https://fomoro.com/research/article/receptive-field-calculator#
+
+# [filter size, stride, padding]
+#Assume the two dimensions are the same
+#Each kernel requires the following parameters:
+# - k_i: kernel size
+# - s_i: stride
+# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow)
+#
+#Each layer i requires the following parameters to be fully represented:
+# - n_i: number of feature (data layer has n_1 = imagesize )
+# - j_i: distance (projected to image pixel distance) between center of two adjacent features
+# - r_i: receptive field of a feature in layer i
+# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding)
+
+import math
+
+def outFromIn(conv, layerIn):
+ n_in = layerIn[0]
+ j_in = layerIn[1]
+ r_in = layerIn[2]
+ start_in = layerIn[3]
+ k = conv[0]
+ s = conv[1]
+ p = conv[2]
+
+ n_out = math.floor((n_in - k + 2*p)/s) + 1
+ actualP = (n_out-1)*s - n_in + k
+ pR = math.ceil(actualP/2)
+ pL = math.floor(actualP/2)
+
+ j_out = j_in * s
+ r_out = r_in + (k - 1)*j_in
+ start_out = start_in + ((k-1)/2 - pL)*j_in
+ return n_out, j_out, r_out, start_out
+
+def printLayer(layer, layer_name):
+ print(layer_name + ":")
+ print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3]))
+
+
+
+layerInfos = []
+if __name__ == '__main__':
+
+ convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]]
+ layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12']
+ imsize = 128
+
+ print ("-------Net summary------")
+ currentLayer = [imsize, 1, 1, 0.5]
+ printLayer(currentLayer, "input image")
+ for i in range(len(convnet)):
+ currentLayer = outFromIn(convnet[i], currentLayer)
+ layerInfos.append(currentLayer)
+ printLayer(currentLayer, layer_names[i])
+
+
+# run utils/utils_receptivefield.py
+
\ No newline at end of file
diff --git a/core/data/deg_kair_utils/utils_regularizers.py b/core/data/deg_kair_utils/utils_regularizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..17e7c8524b716f36e10b41d72fee2e375af69454
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_regularizers.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# SVD Orthogonal Regularization
+# --------------------------------------------
+def regularizer_orth(m):
+ """
+ # ----------------------------------------
+ # SVD Orthogonal Regularization
+ # ----------------------------------------
+ # Applies regularization to the training by performing the
+ # orthogonalization technique described in the paper
+ # This function is to be called by the torch.nn.Module.apply() method,
+ # which applies svd_orthogonalization() to every layer of the model.
+ # usage: net.apply(regularizer_orth)
+ # ----------------------------------------
+ """
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ w = m.weight.data.clone()
+ c_out, c_in, f1, f2 = w.size()
+ # dtype = m.weight.data.type()
+ w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
+ # self.netG.apply(svd_orthogonalization)
+ u, s, v = torch.svd(w)
+ s[s > 1.5] = s[s > 1.5] - 1e-4
+ s[s < 0.5] = s[s < 0.5] + 1e-4
+ w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
+ m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
+ else:
+ pass
+
+
+# --------------------------------------------
+# SVD Orthogonal Regularization
+# --------------------------------------------
+def regularizer_orth2(m):
+ """
+ # ----------------------------------------
+ # Applies regularization to the training by performing the
+ # orthogonalization technique described in the paper
+ # This function is to be called by the torch.nn.Module.apply() method,
+ # which applies svd_orthogonalization() to every layer of the model.
+ # usage: net.apply(regularizer_orth2)
+ # ----------------------------------------
+ """
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ w = m.weight.data.clone()
+ c_out, c_in, f1, f2 = w.size()
+ # dtype = m.weight.data.type()
+ w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
+ u, s, v = torch.svd(w)
+ s_mean = s.mean()
+ s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
+ s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
+ w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
+ m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
+ else:
+ pass
+
+
+
+def regularizer_clip(m):
+ """
+ # ----------------------------------------
+ # usage: net.apply(regularizer_clip)
+ # ----------------------------------------
+ """
+ eps = 1e-4
+ c_min = -1.5
+ c_max = 1.5
+
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1 or classname.find('Linear') != -1:
+ w = m.weight.data.clone()
+ w[w > c_max] -= eps
+ w[w < c_min] += eps
+ m.weight.data = w
+
+ if m.bias is not None:
+ b = m.bias.data.clone()
+ b[b > c_max] -= eps
+ b[b < c_min] += eps
+ m.bias.data = b
+
+# elif classname.find('BatchNorm2d') != -1:
+#
+# rv = m.running_var.data.clone()
+# rm = m.running_mean.data.clone()
+#
+# if m.affine:
+# m.weight.data
+# m.bias.data
diff --git a/core/data/deg_kair_utils/utils_sisr.py b/core/data/deg_kair_utils/utils_sisr.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9edbd72ce53351d9e306c9774073a0e2eb0bdb3
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_sisr.py
@@ -0,0 +1,848 @@
+# -*- coding: utf-8 -*-
+from utils import utils_image as util
+import random
+
+import scipy
+import scipy.stats as ss
+import scipy.io as io
+from scipy import ndimage
+from scipy.interpolate import interp2d
+
+import numpy as np
+import torch
+
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# modified by Kai Zhang (github: https://github.com/cszn)
+# 03/03/2020
+# --------------------------------------------
+"""
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+"""
+# --------------------------------------------
+# calculate PCA projection matrix
+# --------------------------------------------
+"""
+
+
+def get_pca_matrix(x, dim_pca=15):
+ """
+ Args:
+ x: 225x10000 matrix
+ dim_pca: 15
+ Returns:
+ pca_matrix: 15x225
+ """
+ C = np.dot(x, x.T)
+ w, v = scipy.linalg.eigh(C)
+ pca_matrix = v[:, -dim_pca:].T
+
+ return pca_matrix
+
+
+def show_pca(x):
+ """
+ x: PCA projection matrix, e.g., 15x225
+ """
+ for i in range(x.shape[0]):
+ xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F")
+ util.surf(xc)
+
+
+def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500):
+ kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32)
+ for i in range(num_samples):
+
+ theta = np.pi*np.random.rand(1)
+ l1 = 0.1+l_max*np.random.rand(1)
+ l2 = 0.1+(l1-0.1)*np.random.rand(1)
+
+ k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0])
+
+ # util.imshow(k)
+
+ kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F')
+
+ # io.savemat('k.mat', {'k': kernels})
+
+ pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca)
+
+ io.savemat(path, {'p': pca_matrix})
+
+ return pca_matrix
+
+
+"""
+# --------------------------------------------
+# shifted anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z-MU
+ ZZ_t = ZZ.transpose(0,1,3,2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ sf = random.choice([1, 2, 3, 4])
+ scale_factor = np.array([sf, sf])
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z-MU
+ ZZ_t = ZZ.transpose(0,1,3,2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ #kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1/sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+'''
+# =================
+# Numpy
+# =================
+'''
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH, image or kernel
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf-1)*0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w-1)
+ y1 = np.clip(y1, 0, h-1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+'''
+# =================
+# pytorch
+# =================
+'''
+
+
+def splits(a, sf):
+ '''
+ a: tensor NxCxWxHx2
+ sf: scale factor
+ out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2)
+ '''
+ b = torch.stack(torch.chunk(a, sf, dim=2), dim=5)
+ b = torch.cat(torch.chunk(b, sf, dim=3), dim=5)
+ return b
+
+
+def c2c(x):
+ return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1))
+
+
+def r2c(x):
+ return torch.stack([x, torch.zeros_like(x)], -1)
+
+
+def cdiv(x, y):
+ a, b = x[..., 0], x[..., 1]
+ c, d = y[..., 0], y[..., 1]
+ cd2 = c**2 + d**2
+ return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1)
+
+
+def csum(x, y):
+ return torch.stack([x[..., 0] + y, x[..., 1]], -1)
+
+
+def cabs(x):
+ return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5)
+
+
+def cmul(t1, t2):
+ '''
+ complex multiplication
+ t1: NxCxHxWx2
+ output: NxCxHxWx2
+ '''
+ real1, imag1 = t1[..., 0], t1[..., 1]
+ real2, imag2 = t2[..., 0], t2[..., 1]
+ return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1)
+
+
+def cconj(t, inplace=False):
+ '''
+ # complex's conjugation
+ t: NxCxHxWx2
+ output: NxCxHxWx2
+ '''
+ c = t.clone() if not inplace else t
+ c[..., 1] *= -1
+ return c
+
+
+def rfft(t):
+ return torch.rfft(t, 2, onesided=False)
+
+
+def irfft(t):
+ return torch.irfft(t, 2, onesided=False)
+
+
+def fft(t):
+ return torch.fft(t, 2)
+
+
+def ifft(t):
+ return torch.ifft(t, 2)
+
+
+def p2o(psf, shape):
+ '''
+ Args:
+ psf: NxCxhxw
+ shape: [H,W]
+
+ Returns:
+ otf: NxCxHxWx2
+ '''
+ otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
+ otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
+ for axis, axis_size in enumerate(psf.shape[2:]):
+ otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
+ otf = torch.rfft(otf, 2, onesided=False)
+ n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
+ otf[...,1][torch.abs(otf[...,1]) x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
+ '''
+ x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
+ x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
+ x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
+ x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
+ return x
+
+
+def pad_circular(input, padding):
+ # type: (Tensor, List[int]) -> Tensor
+ """
+ Arguments
+ :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
+ :param padding: (tuple): m-elem tuple where m is the degree of convolution
+ Returns
+ :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
+ H + 2 * padding[1]], W + 2 * padding[2]))`
+ """
+ offset = 3
+ for dimension in range(input.dim() - offset + 1):
+ input = dim_pad_circular(input, padding[dimension], dimension + offset)
+ return input
+
+
+def dim_pad_circular(input, padding, dimension):
+ # type: (Tensor, int, int) -> Tensor
+ input = torch.cat([input, input[[slice(None)] * (dimension - 1) +
+ [slice(0, padding)]]], dim=dimension - 1)
+ input = torch.cat([input[[slice(None)] * (dimension - 1) +
+ [slice(-2 * padding, -padding)]], input], dim=dimension - 1)
+ return input
+
+
+def imfilter(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, cx1xhxw
+ '''
+ x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2))
+ x = torch.nn.functional.conv2d(x, k, groups=x.shape[1])
+ return x
+
+
+def G(x, k, sf=3, center=False):
+ '''
+ x: image, NxcxHxW
+ k: kernel, cx1xhxw
+ sf: scale factor
+ center: the first one or the moddle one
+
+ Matlab function:
+ tmp = imfilter(x,h,'circular');
+ y = downsample2(tmp,K);
+ '''
+ x = downsample(imfilter(x, k), sf=sf, center=center)
+ return x
+
+
+def Gt(x, k, sf=3, center=False):
+ '''
+ x: image, NxcxHxW
+ k: kernel, cx1xhxw
+ sf: scale factor
+ center: the first one or the moddle one
+
+ Matlab function:
+ tmp = upsample2(x,K);
+ y = imfilter(tmp,h,'circular');
+ '''
+ x = imfilter(upsample(x, sf=sf, center=center), k)
+ return x
+
+
+def interpolation_down(x, sf, center=False):
+ mask = torch.zeros_like(x)
+ if center:
+ start = torch.tensor((sf-1)//2)
+ mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x)
+ LR = x[..., start::sf, start::sf]
+ else:
+ mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x)
+ LR = x[..., ::sf, ::sf]
+ y = x.mul(mask)
+
+ return LR, y, mask
+
+
+'''
+# =================
+Numpy
+# =================
+'''
+
+
+def blockproc(im, blocksize, fun):
+ xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0)
+ xblocks_proc = []
+ for xb in xblocks:
+ yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1)
+ yblocks_proc = []
+ for yb in yblocks:
+ yb_proc = fun(yb)
+ yblocks_proc.append(yb_proc)
+ xblocks_proc.append(np.concatenate(yblocks_proc, axis=1))
+
+ proc = np.concatenate(xblocks_proc, axis=0)
+
+ return proc
+
+
+def fun_reshape(a):
+ return np.reshape(a, (-1,1,a.shape[-1]), order='F')
+
+
+def fun_mul(a, b):
+ return a*b
+
+
+def BlockMM(nr, nc, Nb, m, x1):
+ '''
+ myfun = @(block_struct) reshape(block_struct.data,m,1);
+ x1 = blockproc(x1,[nr nc],myfun);
+ x1 = reshape(x1,m,Nb);
+ x1 = sum(x1,2);
+ x = reshape(x1,nr,nc);
+ '''
+ fun = fun_reshape
+ x1 = blockproc(x1, blocksize=(nr, nc), fun=fun)
+ x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F')
+ x1 = np.sum(x1, 1)
+ x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F')
+ return x
+
+
+def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m):
+ '''
+ x1 = FB.*FR;
+ FBR = BlockMM(nr,nc,Nb,m,x1);
+ invW = BlockMM(nr,nc,Nb,m,F2B);
+ invWBR = FBR./(invW + tau*Nb);
+ fun = @(block_struct) block_struct.data.*invWBR;
+ FCBinvWBR = blockproc(FBC,[nr,nc],fun);
+ FX = (FR-FCBinvWBR)/tau;
+ Xest = real(ifft2(FX));
+ '''
+ x1 = FB*FR
+ FBR = BlockMM(nr, nc, Nb, m, x1)
+ invW = BlockMM(nr, nc, Nb, m, F2B)
+ invWBR = FBR/(invW + tau*Nb)
+ FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR))
+ FX = (FR-FCBinvWBR)/tau
+ Xest = np.real(np.fft.ifft2(FX, axes=(0, 1)))
+ return Xest
+
+
+def psf2otf(psf, shape=None):
+ """
+ Convert point-spread function to optical transfer function.
+ Compute the Fast Fourier Transform (FFT) of the point-spread
+ function (PSF) array and creates the optical transfer function (OTF)
+ array that is not influenced by the PSF off-centering.
+ By default, the OTF array is the same size as the PSF array.
+ To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
+ post-pads the PSF array (down or to the right) with zeros to match
+ dimensions specified in OUTSIZE, then circularly shifts the values of
+ the PSF array up (or to the left) until the central pixel reaches (1,1)
+ position.
+ Parameters
+ ----------
+ psf : `numpy.ndarray`
+ PSF array
+ shape : int
+ Output shape of the OTF array
+ Returns
+ -------
+ otf : `numpy.ndarray`
+ OTF array
+ Notes
+ -----
+ Adapted from MATLAB psf2otf function
+ """
+ if type(shape) == type(None):
+ shape = psf.shape
+ shape = np.array(shape)
+ if np.all(psf == 0):
+ # return np.zeros_like(psf)
+ return np.zeros(shape)
+ if len(psf.shape) == 1:
+ psf = psf.reshape((1, psf.shape[0]))
+ inshape = psf.shape
+ psf = zero_pad(psf, shape, position='corner')
+ for axis, axis_size in enumerate(inshape):
+ psf = np.roll(psf, -int(axis_size / 2), axis=axis)
+ # Compute the OTF
+ otf = np.fft.fft2(psf, axes=(0, 1))
+ # Estimate the rough number of operations involved in the FFT
+ # and discard the PSF imaginary part if within roundoff error
+ # roundoff error = machine epsilon = sys.float_info.epsilon
+ # or np.finfo().eps
+ n_ops = np.sum(psf.size * np.log2(psf.shape))
+ otf = np.real_if_close(otf, tol=n_ops)
+ return otf
+
+
+def zero_pad(image, shape, position='corner'):
+ """
+ Extends image to a certain size with zeros
+ Parameters
+ ----------
+ image: real 2d `numpy.ndarray`
+ Input image
+ shape: tuple of int
+ Desired output shape of the image
+ position : str, optional
+ The position of the input image in the output one:
+ * 'corner'
+ top-left corner (default)
+ * 'center'
+ centered
+ Returns
+ -------
+ padded_img: real `numpy.ndarray`
+ The zero-padded image
+ """
+ shape = np.asarray(shape, dtype=int)
+ imshape = np.asarray(image.shape, dtype=int)
+ if np.alltrue(imshape == shape):
+ return image
+ if np.any(shape <= 0):
+ raise ValueError("ZERO_PAD: null or negative shape given")
+ dshape = shape - imshape
+ if np.any(dshape < 0):
+ raise ValueError("ZERO_PAD: target size smaller than source one")
+ pad_img = np.zeros(shape, dtype=image.dtype)
+ idx, idy = np.indices(imshape)
+ if position == 'center':
+ if np.any(dshape % 2 != 0):
+ raise ValueError("ZERO_PAD: source and target shapes "
+ "have different parity.")
+ offx, offy = dshape // 2
+ else:
+ offx, offy = (0, 0)
+ pad_img[idx + offx, idy + offy] = image
+ return pad_img
+
+
+def upsample_np(x, sf=3, center=False):
+ st = (sf-1)//2 if center else 0
+ z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2]))
+ z[st::sf, st::sf, ...] = x
+ return z
+
+
+def downsample_np(x, sf=3, center=False):
+ st = (sf-1)//2 if center else 0
+ return x[st::sf, st::sf, ...]
+
+
+def imfilter_np(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, cx1xhxw
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def G_np(x, k, sf=3, center=False):
+ '''
+ x: image, NxcxHxW
+ k: kernel, cx1xhxw
+
+ Matlab function:
+ tmp = imfilter(x,h,'circular');
+ y = downsample2(tmp,K);
+ '''
+ x = downsample_np(imfilter_np(x, k), sf=sf, center=center)
+ return x
+
+
+def Gt_np(x, k, sf=3, center=False):
+ '''
+ x: image, NxcxHxW
+ k: kernel, cx1xhxw
+
+ Matlab function:
+ tmp = upsample2(x,K);
+ y = imfilter(tmp,h,'circular');
+ '''
+ x = imfilter_np(upsample_np(x, sf=sf, center=center), k)
+ return x
+
+
+if __name__ == '__main__':
+ img = util.imread_uint('test.bmp', 3)
+
+ img = util.uint2single(img)
+ k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6)
+ util.imshow(k*10)
+
+
+ for sf in [2, 3, 4]:
+
+ # modcrop
+ img = modcrop_np(img, sf=sf)
+
+ # 1) bicubic degradation
+ img_b = bicubic_degradation(img, sf=sf)
+ print(img_b.shape)
+
+ # 2) srmd degradation
+ img_s = srmd_degradation(img, k, sf=sf)
+ print(img_s.shape)
+
+ # 3) dpsr degradation
+ img_d = dpsr_degradation(img, k, sf=sf)
+ print(img_d.shape)
+
+ # 4) classical degradation
+ img_d = classical_degradation(img, k, sf=sf)
+ print(img_d.shape)
+
+ k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01)
+ #print(k)
+# util.imshow(k*10)
+
+ k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0)
+# util.imshow(k*10)
+
+
+ # PCA
+# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500)
+# print(pca_matrix.shape)
+# show_pca(pca_matrix)
+ # run utils/utils_sisr.py
+ # run utils_sisr.py
+
+
+
+
+
+
+
diff --git a/core/data/deg_kair_utils/utils_video.py b/core/data/deg_kair_utils/utils_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..596dd4203098cf7b36f3d8499ccbf299623381ae
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_video.py
@@ -0,0 +1,493 @@
+import os
+import cv2
+import numpy as np
+import torch
+import random
+from os import path as osp
+from torch.nn import functional as F
+from abc import ABCMeta, abstractmethod
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+ return_imgname(bool): Whether return image names. Default False.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ list[str]: Returned image name list.
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+
+ if return_imgname:
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
+ return imgs, imgnames
+ else:
+ return imgs
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
+ """Paired random crop. Support Numpy array and Tensor inputs.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth. Default: None.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ # determine input type: Numpy array or Tensor
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
+
+ if input_type == 'Tensor':
+ h_lq, w_lq = img_lqs[0].size()[-2:]
+ h_gt, w_gt = img_gts[0].size()[-2:]
+ else:
+ h_lq, w_lq = img_lqs[0].shape[0:2]
+ h_gt, w_gt = img_gts[0].shape[0:2]
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ if input_type == 'Tensor':
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
+ else:
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ if input_type == 'Tensor':
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
+ else:
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing different lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
diff --git a/core/data/deg_kair_utils/utils_videoio.py b/core/data/deg_kair_utils/utils_videoio.py
new file mode 100644
index 0000000000000000000000000000000000000000..5be8c7f06802d5aaa7155a1cdcb27d2838a0882c
--- /dev/null
+++ b/core/data/deg_kair_utils/utils_videoio.py
@@ -0,0 +1,555 @@
+import os
+import cv2
+import numpy as np
+import torch
+import random
+from os import path as osp
+from torchvision.utils import make_grid
+import sys
+from pathlib import Path
+import six
+from collections import OrderedDict
+import math
+import glob
+import av
+import io
+from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
+
+if sys.version_info <= (3, 3):
+ FileNotFoundError = IOError
+else:
+ FileNotFoundError = FileNotFoundError
+
+
+def is_str(x):
+ """Whether the input is an string instance."""
+ return isinstance(x, six.string_types)
+
+
+def is_filepath(x):
+ return is_str(x) or isinstance(x, Path)
+
+
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+ raise ValueError('`filepath` should be a string or a Path')
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == '':
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+
+
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+ """Scan a directory to find the interested files.
+ Args:
+ dir_path (str | :obj:`Path`): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ case_sensitive (bool, optional) : If set to False, ignore the case of
+ suffix. Default: True.
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+ if isinstance(dir_path, (str, Path)):
+ dir_path = str(dir_path)
+ else:
+ raise TypeError('"dir_path" must be a string or Path object')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ if suffix is not None and not case_sensitive:
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
+ item.lower() for item in suffix)
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
+ if suffix is None or _rel_path.endswith(suffix):
+ yield rel_path
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(entry.path, suffix, recursive,
+ case_sensitive)
+
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
+
+
+class Cache:
+
+ def __init__(self, capacity):
+ self._cache = OrderedDict()
+ self._capacity = int(capacity)
+ if capacity <= 0:
+ raise ValueError('capacity must be a positive integer')
+
+ @property
+ def capacity(self):
+ return self._capacity
+
+ @property
+ def size(self):
+ return len(self._cache)
+
+ def put(self, key, val):
+ if key in self._cache:
+ return
+ if len(self._cache) >= self.capacity:
+ self._cache.popitem(last=False)
+ self._cache[key] = val
+
+ def get(self, key, default=None):
+ val = self._cache[key] if key in self._cache else default
+ return val
+
+
+class VideoReader:
+ """Video class with similar usage to a list object.
+
+ This video warpper class provides convenient apis to access frames.
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
+ certain frame may be inaccurate. It is fixed in this class by checking
+ the position after jumping each time.
+ Cache is used when decoding videos. So if the same frame is visited for
+ the second time, there is no need to decode again if it is stored in the
+ cache.
+
+ """
+
+ def __init__(self, filename, cache_capacity=10):
+ # Check whether the video path is a url
+ if not filename.startswith(('https://', 'http://')):
+ check_file_exist(filename, 'Video file not found: ' + filename)
+ self._vcap = cv2.VideoCapture(filename)
+ assert cache_capacity > 0
+ self._cache = Cache(cache_capacity)
+ self._position = 0
+ # get basic info
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
+ self._fps = self._vcap.get(CAP_PROP_FPS)
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
+
+ @property
+ def vcap(self):
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
+ return self._vcap
+
+ @property
+ def opened(self):
+ """bool: Indicate whether the video is opened."""
+ return self._vcap.isOpened()
+
+ @property
+ def width(self):
+ """int: Width of video frames."""
+ return self._width
+
+ @property
+ def height(self):
+ """int: Height of video frames."""
+ return self._height
+
+ @property
+ def resolution(self):
+ """tuple: Video resolution (width, height)."""
+ return (self._width, self._height)
+
+ @property
+ def fps(self):
+ """float: FPS of the video."""
+ return self._fps
+
+ @property
+ def frame_cnt(self):
+ """int: Total frames of the video."""
+ return self._frame_cnt
+
+ @property
+ def fourcc(self):
+ """str: "Four character code" of the video."""
+ return self._fourcc
+
+ @property
+ def position(self):
+ """int: Current cursor position, indicating frame decoded."""
+ return self._position
+
+ def _get_real_position(self):
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
+
+ def _set_real_position(self, frame_id):
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
+ pos = self._get_real_position()
+ for _ in range(frame_id - pos):
+ self._vcap.read()
+ self._position = frame_id
+
+ def read(self):
+ """Read the next frame.
+
+ If the next frame have been decoded before and in the cache, then
+ return it directly, otherwise decode, cache and return it.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ # pos = self._position
+ if self._cache:
+ img = self._cache.get(self._position)
+ if img is not None:
+ ret = True
+ else:
+ if self._position != self._get_real_position():
+ self._set_real_position(self._position)
+ ret, img = self._vcap.read()
+ if ret:
+ self._cache.put(self._position, img)
+ else:
+ ret, img = self._vcap.read()
+ if ret:
+ self._position += 1
+ return img
+
+ def get_frame(self, frame_id):
+ """Get frame by index.
+
+ Args:
+ frame_id (int): Index of the expected frame, 0-based.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ if frame_id < 0 or frame_id >= self._frame_cnt:
+ raise IndexError(
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
+ if frame_id == self._position:
+ return self.read()
+ if self._cache:
+ img = self._cache.get(frame_id)
+ if img is not None:
+ self._position = frame_id + 1
+ return img
+ self._set_real_position(frame_id)
+ ret, img = self._vcap.read()
+ if ret:
+ if self._cache:
+ self._cache.put(self._position, img)
+ self._position += 1
+ return img
+
+ def current_frame(self):
+ """Get the current frame (frame that is just visited).
+
+ Returns:
+ ndarray or None: If the video is fresh, return None, otherwise
+ return the frame.
+ """
+ if self._position == 0:
+ return None
+ return self._cache.get(self._position - 1)
+
+ def cvt2frames(self,
+ frame_dir,
+ file_start=0,
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ max_num=0,
+ show_progress=False):
+ """Convert a video to frame images.
+
+ Args:
+ frame_dir (str): Output directory to store all the frame images.
+ file_start (int): Filenames will start from the specified number.
+ filename_tmpl (str): Filename template with the index as the
+ placeholder.
+ start (int): The starting frame index.
+ max_num (int): Maximum number of frames to be written.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ mkdir_or_exist(frame_dir)
+ if max_num == 0:
+ task_num = self.frame_cnt - start
+ else:
+ task_num = min(self.frame_cnt - start, max_num)
+ if task_num <= 0:
+ raise ValueError('start must be less than total frame number')
+ if start > 0:
+ self._set_real_position(start)
+
+ def write_frame(file_idx):
+ img = self.read()
+ if img is None:
+ return
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ cv2.imwrite(filename, img)
+
+ if show_progress:
+ pass
+ #track_progress(write_frame, range(file_start,file_start + task_num))
+ else:
+ for i in range(task_num):
+ write_frame(file_start + i)
+
+ def __len__(self):
+ return self.frame_cnt
+
+ def __getitem__(self, index):
+ if isinstance(index, slice):
+ return [
+ self.get_frame(i)
+ for i in range(*index.indices(self.frame_cnt))
+ ]
+ # support negative indexing
+ if index < 0:
+ index += self.frame_cnt
+ if index < 0:
+ raise IndexError('index out of range')
+ return self.get_frame(index)
+
+ def __iter__(self):
+ self._set_real_position(0)
+ return self
+
+ def __next__(self):
+ img = self.read()
+ if img is not None:
+ return img
+ else:
+ raise StopIteration
+
+ next = __next__
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._vcap.release()
+
+
+def frames2video(frame_dir,
+ video_file,
+ fps=30,
+ fourcc='XVID',
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ end=0,
+ show_progress=False):
+ """Read the frame images from a directory and join them as a video.
+
+ Args:
+ frame_dir (str): The directory containing video frames.
+ video_file (str): Output filename.
+ fps (float): FPS of the output video.
+ fourcc (str): Fourcc of the output video, this should be compatible
+ with the output file type.
+ filename_tmpl (str): Filename template with the index as the variable.
+ start (int): Starting frame index.
+ end (int): Ending frame index.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ if end == 0:
+ ext = filename_tmpl.split('.')[-1]
+ end = len([name for name in scandir(frame_dir, ext)])
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
+ img = cv2.imread(first_file)
+ height, width = img.shape[:2]
+ resolution = (width, height)
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
+ resolution)
+
+ def write_frame(file_idx):
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ img = cv2.imread(filename)
+ vwriter.write(img)
+
+ if show_progress:
+ pass
+ # track_progress(write_frame, range(start, end))
+ else:
+ for i in range(start, end):
+ write_frame(i)
+ vwriter.release()
+
+
+def video2images(video_path, output_dir):
+ vidcap = cv2.VideoCapture(video_path)
+ in_fps = vidcap.get(cv2.CAP_PROP_FPS)
+ print('video fps:', in_fps)
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+ loaded, frame = vidcap.read()
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
+ print(f'number of total frames is: {total_frames:06}')
+ for i_frame in range(total_frames):
+ if i_frame % 100 == 0:
+ print(f'{i_frame:06} / {total_frames:06}')
+ frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png')
+ cv2.imwrite(frame_name, frame)
+ loaded, frame = vidcap.read()
+
+
+def images2video(image_dir, video_path, fps=24, image_ext='png'):
+ '''
+ #codec = cv2.VideoWriter_fourcc(*'XVID')
+ #codec = cv2.VideoWriter_fourcc('A','V','C','1')
+ #codec = cv2.VideoWriter_fourcc('Y','U','V','1')
+ #codec = cv2.VideoWriter_fourcc('P','I','M','1')
+ #codec = cv2.VideoWriter_fourcc('M','J','P','G')
+ codec = cv2.VideoWriter_fourcc('M','P','4','2')
+ #codec = cv2.VideoWriter_fourcc('D','I','V','3')
+ #codec = cv2.VideoWriter_fourcc('D','I','V','X')
+ #codec = cv2.VideoWriter_fourcc('U','2','6','3')
+ #codec = cv2.VideoWriter_fourcc('I','2','6','3')
+ #codec = cv2.VideoWriter_fourcc('F','L','V','1')
+ #codec = cv2.VideoWriter_fourcc('H','2','6','4')
+ #codec = cv2.VideoWriter_fourcc('A','Y','U','V')
+ #codec = cv2.VideoWriter_fourcc('I','U','Y','V')
+ ç¼–ç å™¨å¸¸ç”¨çš„å‡ ç§ï¼š
+ cv2.VideoWriter_fourcc("I", "4", "2", "0")
+ 压缩的yuv颜色编ç 器,4:2:0色彩度åé‡‡æ · 兼容性好,产生很大的视频 avi
+ cv2.VideoWriter_fourcc("P", I", "M", "1")
+ 采用mpeg-1ç¼–ç ,文件为avi
+ cv2.VideoWriter_fourcc("X", "V", "T", "D")
+ 采用mpeg-4ç¼–ç ,得到视频大å°å¹³å‡ 拓展åavi
+ cv2.VideoWriter_fourcc("T", "H", "E", "O")
+ Ogg Vorbis, 拓展å为ogv
+ cv2.VideoWriter_fourcc("F", "L", "V", "1")
+ FLASH视频,拓展å为.flv
+ '''
+ image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext))))
+ print(len(image_files))
+ height, width, _ = cv2.imread(image_files[0]).shape
+ out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V')
+ out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height))
+
+ for image_file in image_files:
+ img = cv2.imread(image_file)
+ img = cv2.resize(img, (width, height), interpolation=3)
+ out_video.write(img)
+ out_video.release()
+
+
+def add_video_compression(imgs):
+ codec_type = ['libx264', 'h264', 'mpeg4']
+ codec_prob = [1 / 3., 1 / 3., 1 / 3.]
+ codec = random.choices(codec_type, codec_prob)[0]
+ # codec = 'mpeg4'
+ bitrate = [1e4, 1e5]
+ bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
+
+ buf = io.BytesIO()
+ with av.open(buf, 'w', 'mp4') as container:
+ stream = container.add_stream(codec, rate=1)
+ stream.height = imgs[0].shape[0]
+ stream.width = imgs[0].shape[1]
+ stream.pix_fmt = 'yuv420p'
+ stream.bit_rate = bitrate
+
+ for img in imgs:
+ img = np.uint8((img.clip(0, 1)*255.).round())
+ frame = av.VideoFrame.from_ndarray(img, format='rgb24')
+ frame.pict_type = 'NONE'
+ # pdb.set_trace()
+ for packet in stream.encode(frame):
+ container.mux(packet)
+
+ # Flush stream
+ for packet in stream.encode():
+ container.mux(packet)
+
+ outputs = []
+ with av.open(buf, 'r', 'mp4') as container:
+ if container.streams.video:
+ for frame in container.decode(**{'video': 0}):
+ outputs.append(
+ frame.to_rgb().to_ndarray().astype(np.float32) / 255.)
+
+ #outputs = np.stack(outputs, axis=0)
+ return outputs
+
+
+if __name__ == '__main__':
+
+ # -----------------------------------
+ # test VideoReader(filename, cache_capacity=10)
+ # -----------------------------------
+# video_reader = VideoReader('utils/test.mp4')
+# from utils import utils_image as util
+# inputs = []
+# for frame in video_reader:
+# print(frame.dtype)
+# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+# #util.imshow(np.flip(frame, axis=2))
+
+ # -----------------------------------
+ # test video2images(video_path, output_dir)
+ # -----------------------------------
+# video2images('utils/test.mp4', 'frames')
+
+ # -----------------------------------
+ # test images2video(image_dir, video_path, fps=24, image_ext='png')
+ # -----------------------------------
+# images2video('frames', 'video_02.mp4', fps=30, image_ext='png')
+
+
+ # -----------------------------------
+ # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png')
+ # -----------------------------------
+# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png')
+
+
+ # -----------------------------------
+ # test add_video_compression(imgs)
+ # -----------------------------------
+# imgs = []
+# image_ext = 'png'
+# frames = 'frames'
+# from utils import utils_image as util
+# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext))))
+# for i, image_file in enumerate(image_files):
+# if i < 7:
+# img = util.imread_uint(image_file, 3)
+# img = util.uint2single(img)
+# imgs.append(img)
+#
+# results = add_video_compression(imgs)
+# for i, img in enumerate(results):
+# util.imshow(util.single2uint(img))
+# util.imsave(util.single2uint(img),f'{i:05}.png')
+
+ # run utils/utils_video.py
+
+
+
+
+
+
+
diff --git a/core/scripts/__init__.py b/core/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/core/scripts/cli.py b/core/scripts/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfe3ecc330ecf9f0b3af1e7dc6b3758673712cc7
--- /dev/null
+++ b/core/scripts/cli.py
@@ -0,0 +1,41 @@
+import sys
+import argparse
+from .. import WarpCore
+from .. import templates
+
+
+def template_init(args):
+ return ''''
+
+
+ '''.strip()
+
+
+def init_template(args):
+ parser = argparse.ArgumentParser(description='WarpCore template init tool')
+ parser.add_argument('-t', '--template', type=str, default='WarpCore')
+ args = parser.parse_args(args)
+
+ if args.template == 'WarpCore':
+ template_cls = WarpCore
+ else:
+ try:
+ template_cls = __import__(args.template)
+ except ModuleNotFoundError:
+ template_cls = getattr(templates, args.template)
+ print(template_cls)
+
+
+def main():
+ if len(sys.argv) < 2:
+ print('Usage: core ')
+ sys.exit(1)
+ if sys.argv[1] == 'init':
+ init_template(sys.argv[2:])
+ else:
+ print('Unknown command')
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/core/templates/__init__.py b/core/templates/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..570f16de78bcce68aa49ff0a5d0fad63284f6948
--- /dev/null
+++ b/core/templates/__init__.py
@@ -0,0 +1 @@
+from .diffusion import DiffusionCore
\ No newline at end of file
diff --git a/core/templates/diffusion.py b/core/templates/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f36dc3f5efa14669cc36cc3c0cffcc8def037289
--- /dev/null
+++ b/core/templates/diffusion.py
@@ -0,0 +1,236 @@
+from .. import WarpCore
+from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
+from abc import abstractmethod
+from dataclasses import dataclass
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+from gdf import GDF
+import numpy as np
+from tqdm import tqdm
+import wandb
+
+import webdataset as wds
+from webdataset.handlers import warn_and_continue
+from torch.distributed import barrier
+from enum import Enum
+
+class TargetReparametrization(Enum):
+ EPSILON = 'epsilon'
+ X0 = 'x0'
+
+class DiffusionCore(WarpCore):
+ @dataclass(frozen=True)
+ class Config(WarpCore.Config):
+ # TRAINING PARAMS
+ lr: float = EXPECTED_TRAIN
+ grad_accum_steps: int = EXPECTED_TRAIN
+ batch_size: int = EXPECTED_TRAIN
+ updates: int = EXPECTED_TRAIN
+ warmup_updates: int = EXPECTED_TRAIN
+ save_every: int = 500
+ backup_every: int = 20000
+ use_fsdp: bool = True
+
+ # EMA UPDATE
+ ema_start_iters: int = None
+ ema_iters: int = None
+ ema_beta: float = None
+
+ # GDF setting
+ gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
+
+ @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
+ class Info(WarpCore.Info):
+ ema_loss: float = None
+
+ @dataclass(frozen=True)
+ class Models(WarpCore.Models):
+ generator : nn.Module = EXPECTED
+ generator_ema : nn.Module = None # optional
+
+ @dataclass(frozen=True)
+ class Optimizers(WarpCore.Optimizers):
+ generator : any = EXPECTED
+
+ @dataclass(frozen=True)
+ class Schedulers(WarpCore.Schedulers):
+ generator: any = None
+
+ @dataclass(frozen=True)
+ class Extras(WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+
+ # --------------------------------------------
+ info: Info
+ config: Config
+
+ @abstractmethod
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def webdataset_path(self, extras: Extras):
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def webdataset_filters(self, extras: Extras):
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def webdataset_preprocessors(self, extras: Extras):
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
+ raise NotImplementedError("This method needs to be overriden")
+ # -------------
+
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
+ # SETUP DATASET
+ dataset_path = self.webdataset_path(extras)
+ preprocessors = self.webdataset_preprocessors(extras)
+ filters = self.webdataset_filters(extras)
+
+ handler = warn_and_continue # None
+ # handler = None
+ dataset = wds.WebDataset(
+ dataset_path, resampled=True, handler=handler
+ ).select(filters).shuffle(690, handler=handler).decode(
+ "pilrgb", handler=handler
+ ).to_tuple(
+ *[p[0] for p in preprocessors], handler=handler
+ ).map_tuple(
+ *[p[1] for p in preprocessors], handler=handler
+ ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})
+
+ # SETUP DATALOADER
+ real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
+ dataloader = DataLoader(
+ dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
+ )
+
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))
+
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
+ batch = next(data.iterator)
+
+ with torch.no_grad():
+ conditions = self.get_conditions(batch, models, extras)
+ latents = self.encode_latents(batch, models, extras)
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
+
+ # FORWARD PASS
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ pred = models.generator(noised, noise_cond, **conditions)
+ if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
+ target = noise
+ elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
+ target = latents
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+ loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
+
+ return loss, loss_adjusted
+
+ def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
+ start_iter = self.info.iter+1
+ max_iters = self.config.updates * self.config.grad_accum_steps
+ if self.is_main_node:
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
+
+ pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
+ models.generator.train()
+ for i in pbar:
+ # FORWARD PASS
+ loss, loss_adjusted = self.forward_pass(data, extras, models)
+
+ # BACKWARD PASS
+ if i % self.config.grad_accum_steps == 0 or i == max_iters:
+ loss_adjusted.backward()
+ grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
+ optimizers_dict = optimizers.to_dict()
+ for k in optimizers_dict:
+ optimizers_dict[k].step()
+ schedulers_dict = schedulers.to_dict()
+ for k in schedulers_dict:
+ schedulers_dict[k].step()
+ models.generator.zero_grad(set_to_none=True)
+ self.info.total_steps += 1
+ else:
+ with models.generator.no_sync():
+ loss_adjusted.backward()
+ self.info.iter = i
+
+ # UPDATE EMA
+ if models.generator_ema is not None and i % self.config.ema_iters == 0:
+ update_weights_ema(
+ models.generator_ema, models.generator,
+ beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
+ )
+
+ # UPDATE LOSS METRICS
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
+
+ if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
+ wandb.alert(
+ title=f"NaN value encountered in training run {self.info.wandb_run_id}",
+ text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
+ wait_duration=60*30
+ )
+
+ if self.is_main_node:
+ logs = {
+ 'loss': self.info.ema_loss,
+ 'raw_loss': loss.mean().item(),
+ 'grad_norm': grad_norm.item(),
+ 'lr': optimizers.generator.param_groups[0]['lr'],
+ 'total_steps': self.info.total_steps,
+ }
+
+ pbar.set_postfix(logs)
+ if self.config.wandb_project is not None:
+ wandb.log(logs)
+
+ if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
+ # SAVE AND CHECKPOINT STUFF
+ if np.isnan(loss.mean().item()):
+ if self.is_main_node and self.config.wandb_project is not None:
+ tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
+ wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
+ else:
+ self.save_checkpoints(models, optimizers)
+ if self.is_main_node:
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
+ self.sample(models, data, extras)
+
+ def models_to_save(self):
+ return ['generator', 'generator_ema']
+
+ def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
+ barrier()
+ suffix = '' if suffix is None else suffix
+ self.save_info(self.info, suffix=suffix)
+ models_dict = models.to_dict()
+ optimizers_dict = optimizers.to_dict()
+ for key in self.models_to_save():
+ model = models_dict[key]
+ if model is not None:
+ self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
+ for key in optimizers_dict:
+ optimizer = optimizers_dict[key]
+ if optimizer is not None:
+ self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
+ if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
+ self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
+ torch.cuda.empty_cache()
diff --git a/core/utils/__init__.py b/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e71b37e8d1690a00ab1e0958320775bc822b6f5
--- /dev/null
+++ b/core/utils/__init__.py
@@ -0,0 +1,9 @@
+from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN
+from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail
+
+# MOVE IT SOMERWHERE ELSE
+def update_weights_ema(tgt_model, src_model, beta=0.999):
+ for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta)
+ for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta)
\ No newline at end of file
diff --git a/core/utils/__pycache__/__init__.cpython-310.pyc b/core/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..63c0a7e0fbf358f557d6bea755a0f550b4010a48
Binary files /dev/null and b/core/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/core/utils/__pycache__/__init__.cpython-39.pyc b/core/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f18d6921da3c9d93087c1b6d8eacd7a5e46a8e5
Binary files /dev/null and b/core/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/core/utils/__pycache__/base_dto.cpython-310.pyc b/core/utils/__pycache__/base_dto.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de093eb65813d4abf69edfbb6923f2cabab21ad7
Binary files /dev/null and b/core/utils/__pycache__/base_dto.cpython-310.pyc differ
diff --git a/core/utils/__pycache__/base_dto.cpython-39.pyc b/core/utils/__pycache__/base_dto.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b80d348c7959338709ec24c3ac24dfc4f6dab3dc
Binary files /dev/null and b/core/utils/__pycache__/base_dto.cpython-39.pyc differ
diff --git a/core/utils/__pycache__/save_and_load.cpython-310.pyc b/core/utils/__pycache__/save_and_load.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7a0f63ac8bbaf073dcd8a046ed112cec181d33a
Binary files /dev/null and b/core/utils/__pycache__/save_and_load.cpython-310.pyc differ
diff --git a/core/utils/__pycache__/save_and_load.cpython-39.pyc b/core/utils/__pycache__/save_and_load.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec04e9aba6f83ab76f0bbc243bb95fda07ad8d16
Binary files /dev/null and b/core/utils/__pycache__/save_and_load.cpython-39.pyc differ
diff --git a/core/utils/base_dto.py b/core/utils/base_dto.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cf185f00e5c6f56d23774cec8591b8d4554971e
--- /dev/null
+++ b/core/utils/base_dto.py
@@ -0,0 +1,56 @@
+import dataclasses
+from dataclasses import dataclass, _MISSING_TYPE
+from munch import Munch
+
+EXPECTED = "___REQUIRED___"
+EXPECTED_TRAIN = "___REQUIRED_TRAIN___"
+
+# pylint: disable=invalid-field-call
+def nested_dto(x, raw=False):
+ return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x))
+
+@dataclass(frozen=True)
+class Base:
+ training: bool = None
+ def __new__(cls, **kwargs):
+ training = kwargs.get('training', True)
+ setteable_fields = cls.setteable_fields(**kwargs)
+ mandatory_fields = cls.mandatory_fields(**kwargs)
+ invalid_kwargs = [
+ {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False)
+ ]
+ print(mandatory_fields)
+ assert (
+ len(invalid_kwargs) == 0
+ ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable."
+ missing_kwargs = [f for f in mandatory_fields if f not in kwargs]
+ assert (
+ len(missing_kwargs) == 0
+ ), f"Required fields missing initializing this DTO: {missing_kwargs}."
+ return object.__new__(cls)
+
+
+ @classmethod
+ def setteable_fields(cls, **kwargs):
+ return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN]
+
+ @classmethod
+ def mandatory_fields(cls, **kwargs):
+ training = kwargs.get('training', True)
+ return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)]
+
+ @classmethod
+ def from_dict(cls, kwargs):
+ for k in kwargs:
+ if isinstance(kwargs[k], (dict, list, tuple)):
+ kwargs[k] = Munch.fromDict(kwargs[k])
+ return cls(**kwargs)
+
+ def to_dict(self):
+ # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes
+ selfdict = {}
+ for k in dataclasses.fields(self):
+ selfdict[k.name] = getattr(self, k.name)
+ if isinstance(selfdict[k.name], Munch):
+ selfdict[k.name] = selfdict[k.name].toDict()
+ return selfdict
diff --git a/core/utils/save_and_load.py b/core/utils/save_and_load.py
new file mode 100644
index 0000000000000000000000000000000000000000..0215f664f5a8e738147d0828b6a7e65b9c3a8507
--- /dev/null
+++ b/core/utils/save_and_load.py
@@ -0,0 +1,59 @@
+import os
+import torch
+import json
+from pathlib import Path
+import safetensors
+import wandb
+
+
+def create_folder_if_necessary(path):
+ path = "/".join(path.split("/")[:-1])
+ Path(path).mkdir(parents=True, exist_ok=True)
+
+
+def safe_save(ckpt, path):
+ try:
+ os.remove(f"{path}.bak")
+ except OSError:
+ pass
+ try:
+ os.rename(path, f"{path}.bak")
+ except OSError:
+ pass
+ if path.endswith(".pt") or path.endswith(".ckpt"):
+ torch.save(ckpt, path)
+ elif path.endswith(".json"):
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(ckpt, f, indent=4)
+ elif path.endswith(".safetensors"):
+ safetensors.torch.save_file(ckpt, path)
+ else:
+ raise ValueError(f"File extension not supported: {path}")
+
+
+def load_or_fail(path, wandb_run_id=None):
+ accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
+ try:
+ assert any(
+ [path.endswith(ext) for ext in accepted_extensions]
+ ), f"Automatic loading not supported for this extension: {path}"
+ if not os.path.exists(path):
+ checkpoint = None
+ elif path.endswith(".pt") or path.endswith(".ckpt"):
+ checkpoint = torch.load(path, map_location="cpu")
+ elif path.endswith(".json"):
+ with open(path, "r", encoding="utf-8") as f:
+ checkpoint = json.load(f)
+ elif path.endswith(".safetensors"):
+ checkpoint = {}
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ checkpoint[key] = f.get_tensor(key)
+ return checkpoint
+ except Exception as e:
+ if wandb_run_id is not None:
+ wandb.alert(
+ title=f"Corrupt checkpoint for run {wandb_run_id}",
+ text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
+ )
+ raise e
diff --git a/gdf/__init__.py b/gdf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..753b52e2e07e2540385594627a6faf4f6091b0a0
--- /dev/null
+++ b/gdf/__init__.py
@@ -0,0 +1,205 @@
+import torch
+from .scalers import *
+from .targets import *
+from .schedulers import *
+from .noise_conditions import *
+from .loss_weights import *
+from .samplers import *
+import torch.nn.functional as F
+import math
+class GDF():
+ def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
+ self.schedule = schedule
+ self.input_scaler = input_scaler
+ self.target = target
+ self.noise_cond = noise_cond
+ self.loss_weight = loss_weight
+ self.offset_noise = offset_noise
+
+ def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
+ stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
+ return stretched_limits
+
+ def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
+ if epsilon is None:
+ epsilon = torch.randn_like(x0)
+ if self.offset_noise > 0:
+ if offset is None:
+ offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device)
+ epsilon = epsilon + offset * self.offset_noise
+ logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device)
+ a, b = self.input_scaler(logSNR) # B
+ if len(a.shape) == 1:
+ a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW
+ #print('in line 33 a b', a.shape, b.shape, x0.shape, logSNR.shape, logSNR, self.noise_cond(logSNR))
+ target = self.target(x0, epsilon, logSNR, a, b)
+
+ # noised, noise, logSNR, t_cond
+ #noised, noise, target, logSNR, noise_cond, loss_weight
+ return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
+
+ def undiffuse(self, x, logSNR, pred):
+ a, b = self.input_scaler(logSNR)
+ if len(a.shape) == 1:
+ a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1))
+ return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b)
+
+ def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
+ sampler_params = {} if sampler_params is None else sampler_params
+ if sampler is None:
+ sampler = DDPMSampler(self)
+ r_range = torch.linspace(t_start, t_end, timesteps+1)
+ schedule = self.schedule if schedule is None else schedule
+ logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
+ -1, shape[0] if x_init is None else x_init.size(0)
+ ).to(device)
+
+ x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
+
+ if cfg is not None:
+ if unconditional_inputs is None:
+ unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
+ model_inputs = {
+ k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
+ else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
+ else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
+ else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
+ }
+
+ for i in range(0, timesteps):
+ noise_cond = self.noise_cond(logSNR_range[i])
+ if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
+ cfg_val = cfg
+ if isinstance(cfg_val, (list, tuple)):
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
+
+ pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
+
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
+ if cfg_rho > 0:
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
+ pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
+ else:
+ pred = pred_cfg
+ else:
+ pred = model(x, noise_cond, **model_inputs)
+ x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
+ x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
+ #print('in line 86', x0.shape, x.shape, i, )
+ altered_vars = yield (x0, x, pred)
+
+ # Update some running variables if the user wants
+ if altered_vars is not None:
+ cfg = altered_vars.get('cfg', cfg)
+ cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
+ sampler = altered_vars.get('sampler', sampler)
+ model_inputs = altered_vars.get('model_inputs', model_inputs)
+ x = altered_vars.get('x', x)
+ x_init = altered_vars.get('x_init', x_init)
+
+class GDF_dual_fixlrt(GDF):
+ def ref_noise(self, noised, x0, logSNR):
+ a, b = self.input_scaler(logSNR)
+ if len(a.shape) == 1:
+ a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1))
+ #print('in line 210', a.shape, b.shape, x0.shape, noised.shape)
+ return self.target.noise_givenx0_noised(x0, noised, logSNR, a, b)
+
+ def sample(self, model, model_inputs, shape, shape_lr, unconditional_inputs=None, sampler=None,
+ schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None,
+ cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"):
+ sampler_params = {} if sampler_params is None else sampler_params
+ if sampler is None:
+ sampler = DDPMSampler(self)
+ r_range = torch.linspace(t_start, t_end, timesteps+1)
+ schedule = self.schedule if schedule is None else schedule
+ logSNR_range = schedule(r_range, shift=shift)[:, None].expand(
+ -1, shape[0] if x_init is None else x_init.size(0)
+ ).to(device)
+
+ x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
+ x_lr = sampler.init_x(shape_lr).to(device) if x_init is None else x_init.clone()
+ if cfg is not None:
+ if unconditional_inputs is None:
+ unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
+ model_inputs = {
+ k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor)
+ else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list)
+ else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict)
+ else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
+ }
+
+ ###############################################lr sampling
+
+ guide_feas = [None] * timesteps
+
+ for i in range(0, timesteps):
+ noise_cond = self.noise_cond(logSNR_range[i])
+ if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
+ cfg_val = cfg
+ if isinstance(cfg_val, (list, tuple)):
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
+
+
+
+ if i == timesteps -1 :
+ output, guide_lr_enc, guide_lr_dec = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
+ guide_feas[i] = ([f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_enc], [f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_dec])
+ else:
+ output, _, _ = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs)
+
+ pred, pred_unconditional = output.chunk(2)
+
+
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
+ if cfg_rho > 0:
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
+ pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
+ else:
+ pred = pred_cfg
+ else:
+ pred = model(x_lr, noise_cond, **model_inputs)
+ x0_lr, epsilon_lr = self.undiffuse(x_lr, logSNR_range[i], pred)
+ x_lr = sampler(x_lr, x0_lr, epsilon_lr, logSNR_range[i], logSNR_range[i+1], **sampler_params)
+
+ ###############################################hr HR sampling
+ for i in range(0, timesteps):
+ noise_cond = self.noise_cond(logSNR_range[i])
+ if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start):
+ cfg_val = cfg
+ if isinstance(cfg_val, (list, tuple)):
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item())
+
+ out_pred, t_emb = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), \
+ lr_guide=guide_feas[timesteps -1] if i <=19 else None , **model_inputs, require_t=True, guide_weight=1 - i/timesteps)
+ pred, pred_unconditional = out_pred.chunk(2)
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
+ if cfg_rho > 0:
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
+ pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
+ else:
+ pred = pred_cfg
+ else:
+ pred = model(x, noise_cond, guide_lr=(guide_lr_enc, guide_lr_dec), **model_inputs)
+ x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
+
+ x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params)
+ altered_vars = yield (x0, x, pred, x_lr)
+
+
+
+ # Update some running variables if the user wants
+ if altered_vars is not None:
+ cfg = altered_vars.get('cfg', cfg)
+ cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
+ sampler = altered_vars.get('sampler', sampler)
+ model_inputs = altered_vars.get('model_inputs', model_inputs)
+ x = altered_vars.get('x', x)
+ x_init = altered_vars.get('x_init', x_init)
+
+
+
+
diff --git a/gdf/loss_weights.py b/gdf/loss_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..d14ddaadeeb3f8de6c68aea4c364d9b852f2f15c
--- /dev/null
+++ b/gdf/loss_weights.py
@@ -0,0 +1,101 @@
+import torch
+import numpy as np
+
+# --- Loss Weighting
+class BaseLossWeight():
+ def weight(self, logSNR):
+ raise NotImplementedError("this method needs to be overridden")
+
+ def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
+ if shift != 1:
+ logSNR = logSNR.clone() + 2 * np.log(shift)
+ return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
+
+class ComposedLossWeight(BaseLossWeight):
+ def __init__(self, div, mul):
+ self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
+ self.div = [div] if isinstance(div, BaseLossWeight) else div
+
+ def weight(self, logSNR):
+ prod, div = 1, 1
+ for m in self.mul:
+ prod *= m.weight(logSNR)
+ for d in self.div:
+ div *= d.weight(logSNR)
+ return prod/div
+
+class ConstantLossWeight(BaseLossWeight):
+ def __init__(self, v=1):
+ self.v = v
+
+ def weight(self, logSNR):
+ return torch.ones_like(logSNR) * self.v
+
+class SNRLossWeight(BaseLossWeight):
+ def weight(self, logSNR):
+ return logSNR.exp()
+
+class P2LossWeight(BaseLossWeight):
+ def __init__(self, k=1.0, gamma=1.0, s=1.0):
+ self.k, self.gamma, self.s = k, gamma, s
+
+ def weight(self, logSNR):
+ return (self.k + (logSNR * self.s).exp()) ** -self.gamma
+
+class SNRPlusOneLossWeight(BaseLossWeight):
+ def weight(self, logSNR):
+ return logSNR.exp() + 1
+
+class MinSNRLossWeight(BaseLossWeight):
+ def __init__(self, max_snr=5):
+ self.max_snr = max_snr
+
+ def weight(self, logSNR):
+ return logSNR.exp().clamp(max=self.max_snr)
+
+class MinSNRPlusOneLossWeight(BaseLossWeight):
+ def __init__(self, max_snr=5):
+ self.max_snr = max_snr
+
+ def weight(self, logSNR):
+ return (logSNR.exp() + 1).clamp(max=self.max_snr)
+
+class TruncatedSNRLossWeight(BaseLossWeight):
+ def __init__(self, min_snr=1):
+ self.min_snr = min_snr
+
+ def weight(self, logSNR):
+ return logSNR.exp().clamp(min=self.min_snr)
+
+class SechLossWeight(BaseLossWeight):
+ def __init__(self, div=2):
+ self.div = div
+
+ def weight(self, logSNR):
+ return 1/(logSNR/self.div).cosh()
+
+class DebiasedLossWeight(BaseLossWeight):
+ def weight(self, logSNR):
+ return 1/logSNR.exp().sqrt()
+
+class SigmoidLossWeight(BaseLossWeight):
+ def __init__(self, s=1):
+ self.s = s
+
+ def weight(self, logSNR):
+ return (logSNR * self.s).sigmoid()
+
+class AdaptiveLossWeight(BaseLossWeight):
+ def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
+ self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1)
+ self.bucket_losses = torch.ones(buckets)
+ self.weight_range = weight_range
+
+ def weight(self, logSNR):
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
+ return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
+
+ def update_buckets(self, logSNR, loss, beta=0.99):
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
+ self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
diff --git a/gdf/noise_conditions.py b/gdf/noise_conditions.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc2791f50a6f63eff8f9bed9b827f87517cc0be8
--- /dev/null
+++ b/gdf/noise_conditions.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+
+class BaseNoiseCond():
+ def __init__(self, *args, shift=1, clamp_range=None, **kwargs):
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
+ self.shift = shift
+ self.clamp_range = clamp_range
+ self.setup(*args, **kwargs)
+
+ def setup(self, *args, **kwargs):
+ pass # this method is optional, override it if required
+
+ def cond(self, logSNR):
+ raise NotImplementedError("this method needs to be overriden")
+
+ def __call__(self, logSNR):
+ if self.shift != 1:
+ logSNR = logSNR.clone() + 2 * np.log(self.shift)
+ return self.cond(logSNR).clamp(*self.clamp_range)
+
+class CosineTNoiseCond(BaseNoiseCond):
+ def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999]
+ self.s = torch.tensor([s])
+ self.clamp_range = clamp_range
+ self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
+
+ def cond(self, logSNR):
+ var = logSNR.sigmoid()
+ var = var.clamp(*self.clamp_range)
+ s, min_var = self.s.to(var.device), self.min_var.to(var.device)
+ t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
+ return t
+
+class EDMNoiseCond(BaseNoiseCond):
+ def cond(self, logSNR):
+ return -logSNR/8
+
+class SigmoidNoiseCond(BaseNoiseCond):
+ def cond(self, logSNR):
+ return (-logSNR).sigmoid()
+
+class LogSNRNoiseCond(BaseNoiseCond):
+ def cond(self, logSNR):
+ return logSNR
+
+class EDMSigmaNoiseCond(BaseNoiseCond):
+ def setup(self, sigma_data=1):
+ self.sigma_data = sigma_data
+
+ def cond(self, logSNR):
+ return torch.exp(-logSNR / 2) * self.sigma_data
+
+class RectifiedFlowsNoiseCond(BaseNoiseCond):
+ def cond(self, logSNR):
+ _a = logSNR.exp() - 1
+ _a[_a == 0] = 1e-3 # Avoid division by zero
+ a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a)
+ return a
+
+# Any NoiseCond that cannot be described easily as a continuous function of t
+# It needs to define self.x and self.y in the setup() method
+class PiecewiseLinearNoiseCond(BaseNoiseCond):
+ def setup(self):
+ self.x = None
+ self.y = None
+
+ def piecewise_linear(self, y, xs, ys):
+ indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y)
+ x_min, x_max = xs[indices], xs[indices+1]
+ y_min, y_max = ys[indices], ys[indices+1]
+ x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min)
+ return x
+
+ def cond(self, logSNR):
+ var = logSNR.sigmoid()
+ t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0)
+ return t
+
+class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond):
+ def setup(self, linear_range=[0.00085, 0.012], total_steps=1000):
+ self.total_steps = total_steps
+ linear_range_sqrt = [r**0.5 for r in linear_range]
+ self.x = torch.linspace(0, 1, total_steps+1)
+
+ alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2
+ self.y = alphas.cumprod(dim=-1)
+
+ def cond(self, logSNR):
+ return super().cond(logSNR).clamp(0, 1)
+
+class DiscreteNoiseCond(BaseNoiseCond):
+ def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]):
+ self.noise_cond = noise_cond
+ self.steps = steps
+ self.continuous_range = continuous_range
+
+ def cond(self, logSNR):
+ cond = self.noise_cond(logSNR)
+ cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0])
+ return cond.mul(self.steps).long()
+
\ No newline at end of file
diff --git a/gdf/readme.md b/gdf/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..9a63691513c9da6804fba53e36acc8e0cd7f5d7f
--- /dev/null
+++ b/gdf/readme.md
@@ -0,0 +1,86 @@
+# Generic Diffusion Framework (GDF)
+
+# Basic usage
+GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM
+, EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different
+frameworks
+
+Using GDF is very straighforward, first of all just define an instance of the GDF class:
+
+```python
+from gdf import GDF
+from gdf import CosineSchedule
+from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight
+
+gdf = GDF(
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ input_scaler=VPScaler(), target=EpsilonTarget(),
+ noise_cond=CosineTNoiseCond(),
+ loss_weight=P2LossWeight(),
+)
+```
+
+You need to define the following components:
+* **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution.
+* **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule.
+* **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows)
+* **Target**: What the target is during training, usually: epsilon, x0 or v
+* **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8`
+* **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use
+
+All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just:
+```python
+class VPScaler():
+ def __call__(self, logSNR):
+ a_squared = logSNR.sigmoid()
+ a = a_squared.sqrt()
+ b = (1-a_squared).sqrt()
+ return a, b
+
+```
+
+So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc...
+
+### Training
+
+When you define your training loop you can get all you need by just doing:
+```python
+shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution
+for inputs, extra_conditions in dataloader_iterator:
+ noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift)
+ pred = diffusion_model(noised, noise_cond, extra_conditions)
+
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+ loss_adjusted = (loss * loss_weight).mean()
+
+ loss_adjusted.backward()
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+```
+
+And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the
+training from the GDF class.
+
+### Sampling
+
+The other important part is sampling, when you want to use this framework to sample you can just do the following:
+
+```python
+from gdf import DDPMSampler
+
+shift = 1
+sampling_configs = {
+ "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift,
+ "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999])
+}
+
+*_, (sampled, _, _) = gdf.sample(
+ diffusion_model, {"cond": extra_conditions}, latents.shape,
+ unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)},
+ device=device, **sampling_configs
+)
+```
+
+# Available modules
+
+TODO
diff --git a/gdf/samplers.py b/gdf/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6048c86a261d53d0440a3b2c1591a03d9978c4f
--- /dev/null
+++ b/gdf/samplers.py
@@ -0,0 +1,43 @@
+import torch
+
+class SimpleSampler():
+ def __init__(self, gdf):
+ self.gdf = gdf
+ self.current_step = -1
+
+ def __call__(self, *args, **kwargs):
+ self.current_step += 1
+ return self.step(*args, **kwargs)
+
+ def init_x(self, shape):
+ return torch.randn(*shape)
+
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev):
+ raise NotImplementedError("You should override the 'apply' function.")
+
+class DDIMSampler(SimpleSampler):
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0):
+ a, b = self.gdf.input_scaler(logSNR)
+ if len(a.shape) == 1:
+ a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1))
+
+ a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
+ if len(a_prev.shape) == 1:
+ a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1))
+
+ sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0
+ # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
+ x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
+ return x
+
+class DDPMSampler(DDIMSampler):
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1):
+ return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta)
+
+class LCMSampler(SimpleSampler):
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev):
+ a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
+ if len(a_prev.shape) == 1:
+ a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1))
+ return x0 * a_prev + torch.randn_like(epsilon) * b_prev
+
\ No newline at end of file
diff --git a/gdf/scalers.py b/gdf/scalers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1adb8b0269667f3d006c7d7d17cbf2b7ef56ca9
--- /dev/null
+++ b/gdf/scalers.py
@@ -0,0 +1,42 @@
+import torch
+
+class BaseScaler():
+ def __init__(self):
+ self.stretched_limits = None
+
+ def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1):
+ min_logSNR = schedule(torch.ones(1), shift=shift)
+ max_logSNR = schedule(torch.zeros(1), shift=shift)
+
+ min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1]
+ max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0]
+ self.stretched_limits = [min_a, max_a, min_b, max_b]
+ return self.stretched_limits
+
+ def stretch_limits(self, a, b):
+ min_a, max_a, min_b, max_b = self.stretched_limits
+ return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b)
+
+ def scalers(self, logSNR):
+ raise NotImplementedError("this method needs to be overridden")
+
+ def __call__(self, logSNR):
+ a, b = self.scalers(logSNR)
+ if self.stretched_limits is not None:
+ a, b = self.stretch_limits(a, b)
+ return a, b
+
+class VPScaler(BaseScaler):
+ def scalers(self, logSNR):
+ a_squared = logSNR.sigmoid()
+ a = a_squared.sqrt()
+ b = (1-a_squared).sqrt()
+ return a, b
+
+class LERPScaler(BaseScaler):
+ def scalers(self, logSNR):
+ _a = logSNR.exp() - 1
+ _a[_a == 0] = 1e-3 # Avoid division by zero
+ a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a)
+ b = 1-a
+ return a, b
diff --git a/gdf/schedulers.py b/gdf/schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..caa6e174da1d766ea5828616bb8113865106b628
--- /dev/null
+++ b/gdf/schedulers.py
@@ -0,0 +1,200 @@
+import torch
+import numpy as np
+
+class BaseSchedule():
+ def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs):
+ self.setup(*args, **kwargs)
+ self.limits = None
+ self.discrete_steps = discrete_steps
+ self.shift = shift
+ if force_limits:
+ self.reset_limits()
+
+ def reset_limits(self, shift=1, disable=False):
+ try:
+ self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max
+ return self.limits
+ except Exception:
+ print("WARNING: this schedule doesn't support t and will be unbounded")
+ return None
+
+ def setup(self, *args, **kwargs):
+ raise NotImplementedError("this method needs to be overriden")
+
+ def schedule(self, *args, **kwargs):
+ raise NotImplementedError("this method needs to be overriden")
+
+ def __call__(self, t, *args, shift=1, **kwargs):
+ if isinstance(t, torch.Tensor):
+ batch_size = None
+ if self.discrete_steps is not None:
+ if t.dtype != torch.long:
+ t = (t * (self.discrete_steps-1)).round().long()
+ t = t / (self.discrete_steps-1)
+ t = t.clamp(0, 1)
+ else:
+ batch_size = t
+ t = None
+ logSNR = self.schedule(t, batch_size, *args, **kwargs)
+ if shift*self.shift != 1:
+ logSNR += 2 * np.log(1/(shift*self.shift))
+ if self.limits is not None:
+ logSNR = logSNR.clamp(*self.limits)
+ return logSNR
+
+class CosineSchedule(BaseSchedule):
+ def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False):
+ self.s = torch.tensor([s])
+ self.clamp_range = clamp_range
+ self.norm_instead = norm_instead
+ self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0)
+ s, min_var = self.s.to(t.device), self.min_var.to(t.device)
+ var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var
+ if self.norm_instead:
+ var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0]
+ else:
+ var = var.clamp(*self.clamp_range)
+ logSNR = (var/(1-var)).log()
+ return logSNR
+
+class CosineSchedule2(BaseSchedule):
+ def setup(self, logsnr_range=[-15, 15]):
+ self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1]))
+ self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0]))
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = 1-torch.rand(batch_size)
+ return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log()
+
+class SqrtSchedule(BaseSchedule):
+ def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False):
+ self.s = s
+ self.clamp_range = clamp_range
+ self.norm_instead = norm_instead
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = 1-torch.rand(batch_size)
+ var = 1 - (t + self.s)**0.5
+ if self.norm_instead:
+ var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0]
+ else:
+ var = var.clamp(*self.clamp_range)
+ logSNR = (var/(1-var)).log()
+ return logSNR
+
+class RectifiedFlowsSchedule(BaseSchedule):
+ def setup(self, logsnr_range=[-15, 15]):
+ self.logsnr_range = logsnr_range
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = 1-torch.rand(batch_size)
+ logSNR = (((1-t)**2)/(t**2)).log()
+ logSNR = logSNR.clamp(*self.logsnr_range)
+ return logSNR
+
+class EDMSampleSchedule(BaseSchedule):
+ def setup(self, sigma_range=[0.002, 80], p=7):
+ self.sigma_range = sigma_range
+ self.p = p
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = 1-torch.rand(batch_size)
+ smin, smax, p = *self.sigma_range, self.p
+ sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p
+ logSNR = (1/sigma**2).log()
+ return logSNR
+
+class EDMTrainSchedule(BaseSchedule):
+ def setup(self, mu=-1.2, std=1.2):
+ self.mu = mu
+ self.std = std
+
+ def schedule(self, t, batch_size):
+ if t is not None:
+ raise Exception("EDMTrainSchedule doesn't support passing timesteps: t")
+ logSNR = -2*(torch.randn(batch_size) * self.std - self.mu)
+ return logSNR
+
+class LinearSchedule(BaseSchedule):
+ def setup(self, logsnr_range=[-10, 10]):
+ self.logsnr_range = logsnr_range
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = 1-torch.rand(batch_size)
+ logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1]
+ return logSNR
+
+# Any schedule that cannot be described easily as a continuous function of t
+# It needs to define self.x and self.y in the setup() method
+class PiecewiseLinearSchedule(BaseSchedule):
+ def setup(self):
+ self.x = None
+ self.y = None
+
+ def piecewise_linear(self, x, xs, ys):
+ indices = torch.searchsorted(xs[:-1], x) - 1
+ x_min, x_max = xs[indices], xs[indices+1]
+ y_min, y_max = ys[indices], ys[indices+1]
+ var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min)
+ return var
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = 1-torch.rand(batch_size)
+ var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device))
+ logSNR = (var/(1-var)).log()
+ return logSNR
+
+class StableDiffusionSchedule(PiecewiseLinearSchedule):
+ def setup(self, linear_range=[0.00085, 0.012], total_steps=1000):
+ linear_range_sqrt = [r**0.5 for r in linear_range]
+ self.x = torch.linspace(0, 1, total_steps+1)
+
+ alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2
+ self.y = alphas.cumprod(dim=-1)
+
+class AdaptiveTrainSchedule(BaseSchedule):
+ def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0):
+ th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1)
+ self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)])
+ self.bucket_probs = torch.ones(buckets)
+ self.min_probs = min_probs
+
+ def schedule(self, t, batch_size):
+ if t is not None:
+ raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t")
+ norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum())
+ buckets = torch.multinomial(norm_probs, batch_size, replacement=True)
+ ranges = self.bucket_ranges[buckets]
+ logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0]
+ return logSNR
+
+ def update_buckets(self, logSNR, loss, beta=0.99):
+ range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device)
+ range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float()
+ range_idx = range_mask.argmax(-1).cpu()
+ self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta)
+
+class InterpolatedSchedule(BaseSchedule):
+ def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]):
+ self.scheduler1 = scheduler1
+ self.scheduler2 = scheduler2
+ self.shifts = shifts
+
+ def schedule(self, t, batch_size):
+ if t is None:
+ t = 1-torch.rand(batch_size)
+ t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan
+ low_logSNR = self.scheduler1(t, shift=self.shifts[0])
+ high_logSNR = self.scheduler2(t, shift=self.shifts[1])
+ return low_logSNR * t + high_logSNR * (1-t)
+
diff --git a/gdf/targets.py b/gdf/targets.py
new file mode 100644
index 0000000000000000000000000000000000000000..115062b6001f93082fa836e1f3742723e5972efe
--- /dev/null
+++ b/gdf/targets.py
@@ -0,0 +1,46 @@
+class EpsilonTarget():
+ def __call__(self, x0, epsilon, logSNR, a, b):
+ return epsilon
+
+ def x0(self, noised, pred, logSNR, a, b):
+ return (noised - pred * b) / a
+
+ def epsilon(self, noised, pred, logSNR, a, b):
+ return pred
+ def noise_givenx0_noised(self, x0, noised , logSNR, a, b):
+ return (noised - a * x0) / b
+ def xt(self, x0, noise, logSNR, a, b):
+
+ return x0 * a + noise*b
+class X0Target():
+ def __call__(self, x0, epsilon, logSNR, a, b):
+ return x0
+
+ def x0(self, noised, pred, logSNR, a, b):
+ return pred
+
+ def epsilon(self, noised, pred, logSNR, a, b):
+ return (noised - pred * a) / b
+
+class VTarget():
+ def __call__(self, x0, epsilon, logSNR, a, b):
+ return a * epsilon - b * x0
+
+ def x0(self, noised, pred, logSNR, a, b):
+ squared_sum = a**2 + b**2
+ return a/squared_sum * noised - b/squared_sum * pred
+
+ def epsilon(self, noised, pred, logSNR, a, b):
+ squared_sum = a**2 + b**2
+ return b/squared_sum * noised + a/squared_sum * pred
+
+class RectifiedFlowsTarget():
+ def __call__(self, x0, epsilon, logSNR, a, b):
+ return epsilon - x0
+
+ def x0(self, noised, pred, logSNR, a, b):
+ return noised - pred * b
+
+ def epsilon(self, noised, pred, logSNR, a, b):
+ return noised + pred * a
+
\ No newline at end of file
diff --git a/inference/__init__.py b/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/inference/test_controlnet.py b/inference/test_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..250578262d2a118ece8b5a706aba1cd8115c62f5
--- /dev/null
+++ b/inference/test_controlnet.py
@@ -0,0 +1,166 @@
+import os
+import yaml
+import torch
+import torchvision
+from tqdm import tqdm
+import sys
+sys.path.append(os.path.abspath('./'))
+
+from inference.utils import *
+from core.utils import load_or_fail
+from train import WurstCore_control_lrguide, WurstCoreB
+from PIL import Image
+from core.utils import load_or_fail
+import math
+import argparse
+import time
+import random
+import numpy as np
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument( '--height', type=int, default=3840, help='image height')
+ parser.add_argument('--width', type=int, default=2160, help='image width')
+ parser.add_argument('--control_weight', type=float, default=0.70, help='[ 0.3, 0.8]')
+ parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ')
+ parser.add_argument('--seed', type=int, default=123, help='random seed')
+ parser.add_argument('--config_c', type=str,
+ default='configs/training/cfg_control_lr.yaml' ,help='config file for stage c, latent generation')
+ parser.add_argument('--config_b', type=str,
+ default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding')
+ parser.add_argument( '--prompt', type=str,
+ default='A peaceful lake surrounded by mountain, white cloud in the sky, high quality,', help='text prompt')
+ parser.add_argument( '--num_image', type=int, default=4, help='how many images generated')
+ parser.add_argument( '--output_dir', type=str, default='figures/controlnet_results/', help='output directory for generated image')
+ parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory')
+ parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel')
+ parser.add_argument( '--canny_source_url', type=str, default="figures/California_000490.jpg", help='image used to extract canny edge map')
+
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+
+ args = parse_args()
+ width = args.width
+ height = args.height
+ torch.manual_seed(args.seed)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float
+
+
+ # SETUP STAGE C
+ with open(args.config_c, "r", encoding="utf-8") as file:
+ loaded_config = yaml.safe_load(file)
+ core = WurstCore_control_lrguide(config_dict=loaded_config, device=device, training=False)
+
+ # SETUP STAGE B
+ with open(args.config_b, "r", encoding="utf-8") as file:
+ config_file_b = yaml.safe_load(file)
+
+ core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
+
+ extras = core.setup_extras_pre()
+ models = core.setup_models(extras)
+ models.generator.eval().requires_grad_(False)
+ print("CONTROLNET READY")
+
+ extras_b = core_b.setup_extras_pre()
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
+ models_b = WurstCoreB.Models(
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
+ )
+ models_b.generator.eval().requires_grad_(False)
+ print("STAGE B READY")
+
+ batch_size = 1
+ save_dir = args.output_dir
+ url = args.canny_source_url
+ images = resize_image(Image.open(url).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1)
+ batch = {'images': images}
+
+
+
+
+
+
+ cnet_multiplier = args.control_weight # 0.8 0.6 0.3 control strength
+ caption_list = [args.prompt] * args.num_image
+ height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
+ stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
+
+
+
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+
+ sdd = torch.load(args.pretrained_path, map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+ models.train_norm.load_state_dict(collect_sd, strict=True)
+
+
+
+
+ models.controlnet.load_state_dict(load_or_fail(core.config.controlnet_checkpoint_path), strict=True)
+ # Stage C Parameters
+ extras.sampling_configs['cfg'] = 1
+ extras.sampling_configs['shift'] = 2
+ extras.sampling_configs['timesteps'] = 20
+ extras.sampling_configs['t_start'] = 1.0
+
+ # Stage B Parameters
+ extras_b.sampling_configs['cfg'] = 1.1
+ extras_b.sampling_configs['shift'] = 1
+ extras_b.sampling_configs['timesteps'] = 10
+ extras_b.sampling_configs['t_start'] = 1.0
+
+ # PREPARE CONDITIONS
+
+
+
+
+ for out_cnt, caption in enumerate(caption_list):
+ with torch.no_grad():
+
+ batch['captions'] = [caption + ' high quality'] * batch_size
+ conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ cnet, cnet_input = core.get_cnet(batch, models, extras)
+ cnet_uncond = cnet
+ conditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet]
+ unconditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet_uncond]
+ edge_images = show_images(cnet_input)
+ models.generator.cuda()
+ for idx, img in enumerate(edge_images):
+ img.save(os.path.join(save_dir, f"edge_{url.split('/')[-1]}"))
+
+
+ print('STAGE C GENERATION***************************')
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions, unconditions)
+ models.generator.cpu()
+ torch.cuda.empty_cache()
+
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
+
+ conditions_b['effnet'] = sampled_c
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
+ print('STAGE B + A DECODING***************************')
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled)
+
+ torch.cuda.empty_cache()
+ imgs = show_images(sampled)
+
+ for idx, img in enumerate(imgs):
+ img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(out_cnt).zfill(5) + '.jpg'))
+ print('finished! Results at ', save_dir )
diff --git a/inference/test_personalized.py b/inference/test_personalized.py
new file mode 100644
index 0000000000000000000000000000000000000000..840d52d0ef3b026e73c34f715b7b18ec3537e62a
--- /dev/null
+++ b/inference/test_personalized.py
@@ -0,0 +1,180 @@
+
+import os
+import yaml
+import torch
+from tqdm import tqdm
+import sys
+sys.path.append(os.path.abspath('./'))
+from inference.utils import *
+from train import WurstCoreB
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from train import WurstCore_personalized as WurstCoreC
+import torch.nn.functional as F
+import numpy as np
+import random
+import math
+import argparse
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument( '--height', type=int, default=3072, help='image height')
+ parser.add_argument('--width', type=int, default=4096, help='image width')
+ parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ')
+ parser.add_argument('--seed', type=int, default=23, help='random seed')
+ parser.add_argument('--config_c', type=str,
+ default="configs/training/lora_personalization.yaml" ,help='config file for stage c, latent generation')
+ parser.add_argument('--config_b', type=str,
+ default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding')
+ parser.add_argument( '--prompt', type=str,
+ default='A photo of cat [roubaobao] with sunglasses, Time Square in the background, high quality, detail rich, 8k', help='text prompt')
+ parser.add_argument( '--num_image', type=int, default=4, help='how many images generated')
+ parser.add_argument( '--output_dir', type=str, default='figures/personalized/', help='output directory for generated image')
+ parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory')
+ parser.add_argument( '--pretrained_path_lora', type=str, default='models/lora_cat.safetensors',help='pretrained path of personalized lora parameter')
+ parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel')
+ args = parser.parse_args()
+ return args
+
+if __name__ == "__main__":
+ args = parse_args()
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ torch.manual_seed(args.seed)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float
+
+
+ # SETUP STAGE C
+ with open(args.config_c, "r", encoding="utf-8") as file:
+ loaded_config = yaml.safe_load(file)
+ core = WurstCoreC(config_dict=loaded_config, device=device, training=False)
+
+ # SETUP STAGE B
+ with open(args.config_b, "r", encoding="utf-8") as file:
+ config_file_b = yaml.safe_load(file)
+ core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
+
+ extras = core.setup_extras_pre()
+ models = core.setup_models(extras)
+ models.generator.eval().requires_grad_(False)
+ print("STAGE C READY")
+
+ extras_b = core_b.setup_extras_pre()
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
+ models_b = WurstCoreB.Models(
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
+ )
+ models_b.generator.bfloat16().eval().requires_grad_(False)
+ print("STAGE B READY")
+
+
+ batch_size = 1
+ captions = [args.prompt] * args.num_image
+ height, width = args.height, args.width
+ save_dir = args.output_dir
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+
+ pretrained_pth = args.pretrained_path
+ sdd = torch.load(pretrained_pth, map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+
+ models.train_norm.load_state_dict(collect_sd)
+
+
+ pretrained_pth_lora = args.pretrained_path_lora
+ sdd = torch.load(pretrained_pth_lora, map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+
+ models.train_lora.load_state_dict(collect_sd)
+
+
+ models.generator.eval()
+ models.train_norm.eval()
+
+
+ height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
+ stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
+
+ # Stage C Parameters
+
+ extras.sampling_configs['cfg'] = 4
+ extras.sampling_configs['shift'] = 1
+ extras.sampling_configs['timesteps'] = 20
+ extras.sampling_configs['t_start'] = 1.0
+ extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf)
+
+
+
+ # Stage B Parameters
+
+ extras_b.sampling_configs['cfg'] = 1.1
+ extras_b.sampling_configs['shift'] = 1
+ extras_b.sampling_configs['timesteps'] = 10
+ extras_b.sampling_configs['t_start'] = 1.0
+
+
+ for cnt, caption in enumerate(captions):
+
+ batch = {'captions': [caption] * batch_size}
+ conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
+
+
+
+
+ for cnt, caption in enumerate(captions):
+
+
+ batch = {'captions': [caption] * batch_size}
+ conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
+
+
+ with torch.no_grad():
+
+
+ models.generator.cuda()
+ print('STAGE C GENERATION***************************')
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
+
+
+
+ models.generator.cpu()
+ torch.cuda.empty_cache()
+
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
+ conditions_b['effnet'] = sampled_c
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
+ print('STAGE B + A DECODING***************************')
+
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled)
+
+ torch.cuda.empty_cache()
+ imgs = show_images(sampled)
+ for idx, img in enumerate(imgs):
+ print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx)
+ img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'))
+
+
+ print('finished! Results at ', save_dir )
+
+
+
diff --git a/inference/test_t2i.py b/inference/test_t2i.py
new file mode 100644
index 0000000000000000000000000000000000000000..3478f95e4c706d88a8c73688ed4e990adc9ea8d4
--- /dev/null
+++ b/inference/test_t2i.py
@@ -0,0 +1,170 @@
+
+import os
+import yaml
+import torch
+from tqdm import tqdm
+import sys
+sys.path.append(os.path.abspath('./'))
+from inference.utils import *
+from core.utils import load_or_fail
+from train import WurstCoreB
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from train import WurstCore_t2i as WurstCoreC
+import torch.nn.functional as F
+from core.utils import load_or_fail
+import numpy as np
+import random
+import math
+import argparse
+from einops import rearrange
+import math
+#inrfft_3b_strc_WurstCore
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument( '--height', type=int, default=2560, help='image height')
+ parser.add_argument('--width', type=int, default=5120, help='image width')
+ parser.add_argument('--seed', type=int, default=123, help='random seed')
+ parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ')
+ parser.add_argument('--config_c', type=str,
+ default='configs/training/t2i.yaml' ,help='config file for stage c, latent generation')
+ parser.add_argument('--config_b', type=str,
+ default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding')
+ parser.add_argument( '--prompt', type=str,
+ default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt')
+ parser.add_argument( '--num_image', type=int, default=10, help='how many images generated')
+ parser.add_argument( '--output_dir', type=str, default='figures/output_results/', help='output directory for generated image')
+ parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory')
+ parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel')
+ args = parser.parse_args()
+ return args
+
+
+
+if __name__ == "__main__":
+
+ args = parse_args()
+ print(args)
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(device)
+ torch.manual_seed(args.seed)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float
+ #gdf = gdf_refine(
+ # schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ # input_scaler=VPScaler(), target=EpsilonTarget(),
+ # noise_cond=CosineTNoiseCond(),
+ # loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
+ # )
+ # SETUP STAGE C
+ config_file = args.config_c
+ with open(config_file, "r", encoding="utf-8") as file:
+ loaded_config = yaml.safe_load(file)
+
+ core = WurstCoreC(config_dict=loaded_config, device=device, training=False)
+
+ # SETUP STAGE B
+ config_file_b = args.config_b
+ with open(config_file_b, "r", encoding="utf-8") as file:
+ config_file_b = yaml.safe_load(file)
+
+ core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
+
+ extras = core.setup_extras_pre()
+ models = core.setup_models(extras)
+ models.generator.eval().requires_grad_(False)
+ print("STAGE C READY")
+
+ extras_b = core_b.setup_extras_pre()
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
+ models_b = WurstCoreB.Models(
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
+ )
+ models_b.generator.bfloat16().eval().requires_grad_(False)
+ print("STAGE B READY")
+
+ captions = [args.prompt] * args.num_image
+
+
+ height, width = args.height, args.width
+ save_dir = args.output_dir
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ pretrained_path = args.pretrained_path
+ sdd = torch.load(pretrained_path, map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+
+ models.train_norm.load_state_dict(collect_sd)
+
+
+ models.generator.eval()
+ models.train_norm.eval()
+
+ batch_size=1
+ height_lr, width_lr = get_target_lr_size(height / width, std_size=32)
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
+ stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size)
+
+ # Stage C Parameters
+ extras.sampling_configs['cfg'] = 4
+ extras.sampling_configs['shift'] = 1
+ extras.sampling_configs['timesteps'] = 20
+ extras.sampling_configs['t_start'] = 1.0
+ extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf)
+
+
+
+ # Stage B Parameters
+ extras_b.sampling_configs['cfg'] = 1.1
+ extras_b.sampling_configs['shift'] = 1
+ extras_b.sampling_configs['timesteps'] = 10
+ extras_b.sampling_configs['t_start'] = 1.0
+
+
+
+
+ for cnt, caption in enumerate(captions):
+
+
+ batch = {'captions': [caption] * batch_size}
+ conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
+
+
+ with torch.no_grad():
+
+
+ models.generator.cuda()
+ print('STAGE C GENERATION***************************')
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device)
+
+
+
+ models.generator.cpu()
+ torch.cuda.empty_cache()
+
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
+ conditions_b['effnet'] = sampled_c
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
+ print('STAGE B + A DECODING***************************')
+
+ with torch.cuda.amp.autocast(dtype=dtype):
+ sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled)
+
+ torch.cuda.empty_cache()
+ imgs = show_images(sampled)
+ for idx, img in enumerate(imgs):
+ print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx)
+ img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'))
+
+
+ print('finished! Results at ', save_dir )
diff --git a/inference/utils.py b/inference/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab5af277069ec7803d53ff8f5fa29bed41fde29b
--- /dev/null
+++ b/inference/utils.py
@@ -0,0 +1,131 @@
+import PIL
+import torch
+import requests
+import torchvision
+from math import ceil
+from io import BytesIO
+import matplotlib.pyplot as plt
+import torchvision.transforms.functional as F
+import math
+from tqdm import tqdm
+def download_image(url):
+ return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB")
+
+
+def resize_image(image, size=768):
+ tensor_image = F.to_tensor(image)
+ resized_image = F.resize(tensor_image, size, antialias=True)
+ return resized_image
+
+
+def downscale_images(images, factor=3/4):
+ scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32)
+ scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
+ return scaled_image
+
+
+
+def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
+ resolution_multiple = 42.67
+ latent_height = ceil(height / compression_factor_b)
+ latent_width = ceil(width / compression_factor_b)
+ stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
+
+ latent_height = ceil(height / compression_factor_a)
+ latent_width = ceil(width / compression_factor_a)
+ stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
+
+ return stage_c_latent_shape, stage_b_latent_shape
+
+
+def get_views(H, W, window_size=64, stride=16):
+ '''
+ - H, W: height and width of the latent
+ '''
+ num_blocks_height = (H - window_size) // stride + 1
+ num_blocks_width = (W - window_size) // stride + 1
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
+ views = []
+ for i in range(total_num_blocks):
+ h_start = int((i // num_blocks_width) * stride)
+ h_end = h_start + window_size
+ w_start = int((i % num_blocks_width) * stride)
+ w_end = w_start + window_size
+ views.append((h_start, h_end, w_start, w_end))
+ return views
+
+
+
+def show_images(images, rows=None, cols=None, **kwargs):
+ if images.size(1) == 1:
+ images = images.repeat(1, 3, 1, 1)
+ elif images.size(1) > 3:
+ images = images[:, :3]
+
+ if rows is None:
+ rows = 1
+ if cols is None:
+ cols = images.size(0) // rows
+
+ _, _, h, w = images.shape
+
+ imgs = []
+ for i, img in enumerate(images):
+ imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)))
+
+ return imgs
+
+
+
+def decode_b(conditions_b, unconditions_b, models_b, bshape, extras_b, device, \
+ stage_a_tiled=False, num_instance=4, patch_size=256, stride=24):
+
+
+ sampling_b = extras_b.gdf.sample(
+ models_b.generator.half(), conditions_b, bshape,
+ unconditions_b, device=device,
+ **extras_b.sampling_configs,
+ )
+ models_b.generator.cuda()
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
+ sampled_b = sampled_b
+ models_b.generator.cpu()
+ torch.cuda.empty_cache()
+ if stage_a_tiled:
+ with torch.cuda.amp.autocast(dtype=torch.float16):
+ padding = (stride*2, stride*2, stride*2, stride*2)
+ sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect')
+ count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
+ sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
+ views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride)
+
+ for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))):
+
+ sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float()
+ count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1
+ sampled /= count
+ sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2]
+ else:
+
+ sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled)
+
+ return sampled.float()
+
+
+def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None):
+ if conditions is None:
+ conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ if unconditions is None:
+ unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+ sampling_c = extras.gdf.sample(
+ models.generator, conditions, stage_c_latent_shape, stage_c_latent_shape_lr,
+ unconditions, device=device, **extras.sampling_configs,
+ )
+ for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])):
+ sampled_c = sampled_c
+ return sampled_c
+
+def get_target_lr_size(ratio, std_size=24):
+ w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
+ return (h * 32 , w *32 )
+
diff --git a/models/models_checklist.txt b/models/models_checklist.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2fdec27a72db473c51893abc64826514b1d9d065
--- /dev/null
+++ b/models/models_checklist.txt
@@ -0,0 +1,7 @@
+https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors
+https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors
+https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors
+https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors
+https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors
+https://huggingface.co/roubaofeipi/UltraPixel/blob/main/ultrapixel_t2i.safetensors
+https://huggingface.co/roubaofeipi/UltraPixel/blob/main/lora_cat.safetensors (only required for personalization)
\ No newline at end of file
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6fcf5aa2a39061c3f4f82dde6ff063411223cb3
--- /dev/null
+++ b/modules/__init__.py
@@ -0,0 +1,6 @@
+from .effnet import EfficientNetEncoder
+from .stage_c import StageC
+from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
+from .previewer import Previewer
+from .controlnet import ControlNet, ControlNetDeliverer
+from . import controlnet as controlnet_filters
diff --git a/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c74bb92cb0db0876acda8aa3d102141526fd428
Binary files /dev/null and b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc differ
diff --git a/modules/cnet_modules/face_id/arcface.py b/modules/cnet_modules/face_id/arcface.py
new file mode 100644
index 0000000000000000000000000000000000000000..64e918bb90437f6f193a7ec384bea1fcd73c7abb
--- /dev/null
+++ b/modules/cnet_modules/face_id/arcface.py
@@ -0,0 +1,276 @@
+import numpy as np
+import onnx, onnx2torch, cv2
+import torch
+from insightface.utils import face_align
+
+
+class ArcFaceRecognizer:
+ def __init__(self, model_file=None, device='cpu', dtype=torch.float32):
+ assert model_file is not None
+ self.model_file = model_file
+
+ self.device = device
+ self.dtype = dtype
+ self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype)
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.model.eval()
+
+ self.input_mean = 127.5
+ self.input_std = 127.5
+ self.input_size = (112, 112)
+ self.input_shape = ['None', 3, 112, 112]
+
+ def get(self, img, face):
+ aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0])
+ face.embedding = self.get_feat(aimg).flatten()
+ return face.embedding
+
+ def compute_sim(self, feat1, feat2):
+ from numpy.linalg import norm
+ feat1 = feat1.ravel()
+ feat2 = feat2.ravel()
+ sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
+ return sim
+
+ def get_feat(self, imgs):
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ input_size = self.input_size
+
+ blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
+ (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+
+ blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype)
+ net_out = self.model(blob_torch)
+ return net_out[0].float().cpu()
+
+
+def distance2bbox(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ x1 = points[:, 0] - distance[:, 0]
+ y1 = points[:, 1] - distance[:, 1]
+ x2 = points[:, 0] + distance[:, 2]
+ y2 = points[:, 1] + distance[:, 3]
+ if max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1])
+ y1 = y1.clamp(min=0, max=max_shape[0])
+ x2 = x2.clamp(min=0, max=max_shape[1])
+ y2 = y2.clamp(min=0, max=max_shape[0])
+ return np.stack([x1, y1, x2, y2], axis=-1)
+
+
+def distance2kps(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom).
+ max_shape (tuple): Shape of the image.
+
+ Returns:
+ Tensor: Decoded bboxes.
+ """
+ preds = []
+ for i in range(0, distance.shape[1], 2):
+ px = points[:, i % 2] + distance[:, i]
+ py = points[:, i % 2 + 1] + distance[:, i + 1]
+ if max_shape is not None:
+ px = px.clamp(min=0, max=max_shape[1])
+ py = py.clamp(min=0, max=max_shape[0])
+ preds.append(px)
+ preds.append(py)
+ return np.stack(preds, axis=-1)
+
+
+class FaceDetector:
+ def __init__(self, model_file=None, dtype=torch.float32, device='cuda'):
+ self.model_file = model_file
+ self.taskname = 'detection'
+ self.center_cache = {}
+ self.nms_thresh = 0.4
+ self.det_thresh = 0.5
+
+ self.device = device
+ self.dtype = dtype
+ self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype)
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.model.eval()
+
+ input_shape = (320, 320)
+ self.input_size = input_shape
+ self.input_shape = input_shape
+
+ self.input_mean = 127.5
+ self.input_std = 128.0
+ self._anchor_ratio = 1.0
+ self._num_anchors = 1
+ self.fmc = 3
+ self._feat_stride_fpn = [8, 16, 32]
+ self._num_anchors = 2
+ self.use_kps = True
+
+ self.det_thresh = 0.5
+ self.nms_thresh = 0.4
+
+ def forward(self, img, threshold):
+ scores_list = []
+ bboxes_list = []
+ kpss_list = []
+ input_size = tuple(img.shape[0:2][::-1])
+ blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size,
+ (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype)
+ net_outs_torch = self.model(blob_torch)
+ # print(list(map(lambda x: x.shape, net_outs_torch)))
+ net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch))
+
+ input_height = blob.shape[2]
+ input_width = blob.shape[3]
+ fmc = self.fmc
+ for idx, stride in enumerate(self._feat_stride_fpn):
+ scores = net_outs[idx]
+ bbox_preds = net_outs[idx + fmc]
+ bbox_preds = bbox_preds * stride
+ if self.use_kps:
+ kps_preds = net_outs[idx + fmc * 2] * stride
+ height = input_height // stride
+ width = input_width // stride
+ K = height * width
+ key = (height, width, stride)
+ if key in self.center_cache:
+ anchor_centers = self.center_cache[key]
+ else:
+ # solution-1, c style:
+ # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
+ # for i in range(height):
+ # anchor_centers[i, :, 1] = i
+ # for i in range(width):
+ # anchor_centers[:, i, 0] = i
+
+ # solution-2:
+ # ax = np.arange(width, dtype=np.float32)
+ # ay = np.arange(height, dtype=np.float32)
+ # xv, yv = np.meshgrid(np.arange(width), np.arange(height))
+ # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
+
+ # solution-3:
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
+ # print(anchor_centers.shape)
+
+ anchor_centers = (anchor_centers * stride).reshape((-1, 2))
+ if self._num_anchors > 1:
+ anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2))
+ if len(self.center_cache) < 100:
+ self.center_cache[key] = anchor_centers
+
+ pos_inds = np.where(scores >= threshold)[0]
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
+ pos_scores = scores[pos_inds]
+ pos_bboxes = bboxes[pos_inds]
+ scores_list.append(pos_scores)
+ bboxes_list.append(pos_bboxes)
+ if self.use_kps:
+ kpss = distance2kps(anchor_centers, kps_preds)
+ # kpss = kps_preds
+ kpss = kpss.reshape((kpss.shape[0], -1, 2))
+ pos_kpss = kpss[pos_inds]
+ kpss_list.append(pos_kpss)
+ return scores_list, bboxes_list, kpss_list
+
+ def detect(self, img, input_size=None, max_num=0, metric='default'):
+ assert input_size is not None or self.input_size is not None
+ input_size = self.input_size if input_size is None else input_size
+
+ im_ratio = float(img.shape[0]) / img.shape[1]
+ model_ratio = float(input_size[1]) / input_size[0]
+ if im_ratio > model_ratio:
+ new_height = input_size[1]
+ new_width = int(new_height / im_ratio)
+ else:
+ new_width = input_size[0]
+ new_height = int(new_width * im_ratio)
+ det_scale = float(new_height) / img.shape[0]
+ resized_img = cv2.resize(img, (new_width, new_height))
+ det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
+ det_img[:new_height, :new_width, :] = resized_img
+
+ scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh)
+
+ scores = np.vstack(scores_list)
+ scores_ravel = scores.ravel()
+ order = scores_ravel.argsort()[::-1]
+ bboxes = np.vstack(bboxes_list) / det_scale
+ if self.use_kps:
+ kpss = np.vstack(kpss_list) / det_scale
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
+ pre_det = pre_det[order, :]
+ keep = self.nms(pre_det)
+ det = pre_det[keep, :]
+ if self.use_kps:
+ kpss = kpss[order, :, :]
+ kpss = kpss[keep, :, :]
+ else:
+ kpss = None
+ if max_num > 0 and det.shape[0] > max_num:
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
+ det[:, 1])
+ img_center = img.shape[0] // 2, img.shape[1] // 2
+ offsets = np.vstack([
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
+ ])
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
+ if metric == 'max':
+ values = area
+ else:
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
+ bindex = np.argsort(
+ values)[::-1] # some extra weight on the centering
+ bindex = bindex[0:max_num]
+ det = det[bindex, :]
+ if kpss is not None:
+ kpss = kpss[bindex, :]
+ return det, kpss
+
+ def nms(self, dets):
+ thresh = self.nms_thresh
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ x2 = dets[:, 2]
+ y2 = dets[:, 3]
+ scores = dets[:, 4]
+
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8200104d6d66a1084685c76373c38d752ed9c3d4
Binary files /dev/null and b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc differ
diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca432e5c5eed7ba17fc6cafb06a3ebe16002f67e
Binary files /dev/null and b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc differ
diff --git a/modules/cnet_modules/inpainting/saliency_model.pt b/modules/cnet_modules/inpainting/saliency_model.pt
new file mode 100644
index 0000000000000000000000000000000000000000..e1b02cc60b2999a8f9ff90557182e3dafab63db7
--- /dev/null
+++ b/modules/cnet_modules/inpainting/saliency_model.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:225a602e1f2a5d159424be011a63b27d83b56343a4379a90710eca9a26bab920
+size 451123
diff --git a/modules/cnet_modules/inpainting/saliency_model.py b/modules/cnet_modules/inpainting/saliency_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..82355a02baead47f50fe643e57b81f8caca78f79
--- /dev/null
+++ b/modules/cnet_modules/inpainting/saliency_model.py
@@ -0,0 +1,81 @@
+import torch
+import torchvision
+from torch import nn
+from PIL import Image
+import numpy as np
+import os
+
+
+# MICRO RESNET
+class ResBlock(nn.Module):
+ def __init__(self, channels):
+ super(ResBlock, self).__init__()
+
+ self.resblock = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(channels, channels, kernel_size=3),
+ nn.InstanceNorm2d(channels, affine=True),
+ nn.ReLU(),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(channels, channels, kernel_size=3),
+ nn.InstanceNorm2d(channels, affine=True),
+ )
+
+ def forward(self, x):
+ out = self.resblock(x)
+ return out + x
+
+
+class Upsample2d(nn.Module):
+ def __init__(self, scale_factor):
+ super(Upsample2d, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+
+ def forward(self, x):
+ x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')
+ return x
+
+
+class MicroResNet(nn.Module):
+ def __init__(self):
+ super(MicroResNet, self).__init__()
+
+ self.downsampler = nn.Sequential(
+ nn.ReflectionPad2d(4),
+ nn.Conv2d(3, 8, kernel_size=9, stride=4),
+ nn.InstanceNorm2d(8, affine=True),
+ nn.ReLU(),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(8, 16, kernel_size=3, stride=2),
+ nn.InstanceNorm2d(16, affine=True),
+ nn.ReLU(),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(16, 32, kernel_size=3, stride=2),
+ nn.InstanceNorm2d(32, affine=True),
+ nn.ReLU(),
+ )
+
+ self.residual = nn.Sequential(
+ ResBlock(32),
+ nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32),
+ ResBlock(64),
+ )
+
+ self.segmentator = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(64, 16, kernel_size=3),
+ nn.InstanceNorm2d(16, affine=True),
+ nn.ReLU(),
+ Upsample2d(scale_factor=2),
+ nn.ReflectionPad2d(4),
+ nn.Conv2d(16, 1, kernel_size=9),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ out = self.downsampler(x)
+ out = self.residual(out)
+ out = self.segmentator(out)
+ return out
diff --git a/modules/cnet_modules/pidinet/__init__.py b/modules/cnet_modules/pidinet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b4625bf915cc6c4053b7d7861a22ff371bc641
--- /dev/null
+++ b/modules/cnet_modules/pidinet/__init__.py
@@ -0,0 +1,37 @@
+# Pidinet
+# https://github.com/hellozhuo/pidinet
+
+import os
+import torch
+import numpy as np
+from einops import rearrange
+from .model import pidinet
+from .util import annotator_ckpts_path, safe_step
+
+
+class PidiNetDetector:
+ def __init__(self, device):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth"
+ modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+ self.netNetwork = pidinet()
+ self.netNetwork.load_state_dict(
+ {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()})
+ self.netNetwork.to(device).eval().requires_grad_(False)
+
+ def __call__(self, input_image): # , safe=False):
+ return self.netNetwork(input_image)[-1]
+ # assert input_image.ndim == 3
+ # input_image = input_image[:, :, ::-1].copy()
+ # with torch.no_grad():
+ # image_pidi = torch.from_numpy(input_image).float().cuda()
+ # image_pidi = image_pidi / 255.0
+ # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
+ # edge = self.netNetwork(image_pidi)[-1]
+
+ # if safe:
+ # edge = safe_step(edge)
+ # edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+ # return edge[0][0]
diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07fca0abb9c90b7b40746b4044c4000ae69e00c7
Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a060aa2baa87a3670aa0bf8276e2f34bafe9451
Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2243c853d18e2a404ced3eb4ac6a95a7a9ee6874
Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f70342fc64759bc7459abf0f7986ee3b7fd2126
Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2e7ab031924860f1262f4d44bf2eaf57ca78edd
Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc differ
diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4da8564d03f99caa7a45d9ccb1358cb282cd2711
Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc differ
diff --git a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1ceba1de87e7bb3c81961b80acbb3a106ca249c0
--- /dev/null
+++ b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:80860ac267258b5f27486e0ef152a211d0b08120f62aeb185a050acc30da486c
+size 2871148
diff --git a/modules/cnet_modules/pidinet/model.py b/modules/cnet_modules/pidinet/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..26644c6f6174c3b5407bd10c914045758cbadefe
--- /dev/null
+++ b/modules/cnet_modules/pidinet/model.py
@@ -0,0 +1,654 @@
+"""
+Author: Zhuo Su, Wenzhe Liu
+Date: Feb 18, 2021
+"""
+
+import math
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+nets = {
+ 'baseline': {
+ 'layer0': 'cv',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'c-v15': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'a-v15': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'r-v15': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cvvv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'avvv4': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'rvvv4': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cccv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cv',
+ },
+ 'aaav4': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'cv',
+ },
+ 'rrrv4': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ 'c16': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cd',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cd',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cd',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cd',
+ },
+ 'a16': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'ad',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'ad',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'ad',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'ad',
+ },
+ 'r16': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'rd',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'rd',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'rd',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'rd',
+ },
+ 'carv4': {
+ 'layer0': 'cd',
+ 'layer1': 'ad',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'ad',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'ad',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'ad',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+}
+
+
+def createConvFunc(op_type):
+ assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
+ if op_type == 'cv':
+ return F.conv2d
+
+ if op_type == 'cd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
+ assert padding == dilation, 'padding for cd_conv set wrong'
+
+ weights_c = weights.sum(dim=[2, 3], keepdim=True)
+ yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
+ y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y - yc
+
+ return func
+ elif op_type == 'ad':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
+ assert padding == dilation, 'padding for ad_conv set wrong'
+
+ shape = weights.shape
+ weights = weights.view(shape[0], shape[1], -1)
+ weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
+ y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+
+ return func
+ elif op_type == 'rd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
+ padding = 2 * dilation
+
+ shape = weights.shape
+ if weights.is_cuda:
+ buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
+ else:
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5)
+ weights = weights.view(shape[0], shape[1], -1)
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
+ buffer[:, :, 12] = 0
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
+ y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+
+ return func
+ else:
+ print('impossible to be here unless you force that')
+ return None
+
+
+class Conv2d(nn.Module):
+ def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
+ bias=False):
+ super(Conv2d, self).__init__()
+ if in_channels % groups != 0:
+ raise ValueError('in_channels must be divisible by groups')
+ if out_channels % groups != 0:
+ raise ValueError('out_channels must be divisible by groups')
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+ self.pdc = pdc
+
+ def reset_parameters(self):
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input):
+
+ return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+class CSAM(nn.Module):
+ """
+ Compact Spatial Attention Module
+ """
+
+ def __init__(self, channels):
+ super(CSAM, self).__init__()
+
+ mid_channels = 4
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
+ self.sigmoid = nn.Sigmoid()
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ y = self.relu1(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = self.sigmoid(y)
+
+ return x * y
+
+
+class CDCM(nn.Module):
+ """
+ Compact Dilation Convolution based Module
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super(CDCM, self).__init__()
+
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
+ self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
+ self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
+ self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
+ self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ x = self.relu1(x)
+ x = self.conv1(x)
+ x1 = self.conv2_1(x)
+ x2 = self.conv2_2(x)
+ x3 = self.conv2_3(x)
+ x4 = self.conv2_4(x)
+ return x1 + x2 + x3 + x4
+
+
+class MapReduce(nn.Module):
+ """
+ Reduce feature maps into a single edge map
+ """
+
+ def __init__(self, channels):
+ super(MapReduce, self).__init__()
+ self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class PDCBlock(nn.Module):
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock, self).__init__()
+ self.stride = stride
+
+ self.stride = stride
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+
+class PDCBlock_converted(nn.Module):
+ """
+ CPDC, APDC can be converted to vanilla 3x3 convolution
+ RPDC can be converted to vanilla 5x5 convolution
+ """
+
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock_converted, self).__init__()
+ self.stride = stride
+
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ if pdc == 'rd':
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
+ else:
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+
+class PiDiNet(nn.Module):
+ def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
+ super(PiDiNet, self).__init__()
+ self.sa = sa
+ if dil is not None:
+ assert isinstance(dil, int), 'dil should be an int'
+ self.dil = dil
+
+ self.fuseplanes = []
+
+ self.inplane = inplane
+ if convert:
+ if pdcs[0] == 'rd':
+ init_kernel_size = 5
+ init_padding = 2
+ else:
+ init_kernel_size = 3
+ init_padding = 1
+ self.init_block = nn.Conv2d(3, self.inplane,
+ kernel_size=init_kernel_size, padding=init_padding, bias=False)
+ block_class = PDCBlock_converted
+ else:
+ self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
+ block_class = PDCBlock
+
+ self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
+ self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
+ self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
+ self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
+ self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
+ self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 2C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
+ self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
+ self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
+ self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
+ self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
+ self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
+ self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.conv_reduces = nn.ModuleList()
+ if self.sa and self.dil is not None:
+ self.attentions = nn.ModuleList()
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.attentions.append(CSAM(self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ elif self.sa:
+ self.attentions = nn.ModuleList()
+ for i in range(4):
+ self.attentions.append(CSAM(self.fuseplanes[i]))
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+ elif self.dil is not None:
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ else:
+ for i in range(4):
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+
+ self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
+ nn.init.constant_(self.classifier.weight, 0.25)
+ nn.init.constant_(self.classifier.bias, 0)
+
+ # print('initialization done')
+
+ def get_weights(self):
+ conv_weights = []
+ bn_weights = []
+ relu_weights = []
+ for pname, p in self.named_parameters():
+ if 'bn' in pname:
+ bn_weights.append(p)
+ elif 'relu' in pname:
+ relu_weights.append(p)
+ else:
+ conv_weights.append(p)
+
+ return conv_weights, bn_weights, relu_weights
+
+ def forward(self, x):
+ H, W = x.size()[2:]
+
+ x = self.init_block(x)
+
+ x1 = self.block1_1(x)
+ x1 = self.block1_2(x1)
+ x1 = self.block1_3(x1)
+
+ x2 = self.block2_1(x1)
+ x2 = self.block2_2(x2)
+ x2 = self.block2_3(x2)
+ x2 = self.block2_4(x2)
+
+ x3 = self.block3_1(x2)
+ x3 = self.block3_2(x3)
+ x3 = self.block3_3(x3)
+ x3 = self.block3_4(x3)
+
+ x4 = self.block4_1(x3)
+ x4 = self.block4_2(x4)
+ x4 = self.block4_3(x4)
+ x4 = self.block4_4(x4)
+
+ x_fuses = []
+ if self.sa and self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](self.dilations[i](xi)))
+ elif self.sa:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](xi))
+ elif self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.dilations[i](xi))
+ else:
+ x_fuses = [x1, x2, x3, x4]
+
+ e1 = self.conv_reduces[0](x_fuses[0])
+ e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
+
+ e2 = self.conv_reduces[1](x_fuses[1])
+ e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
+
+ e3 = self.conv_reduces[2](x_fuses[2])
+ e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
+
+ e4 = self.conv_reduces[3](x_fuses[3])
+ e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
+
+ outputs = [e1, e2, e3, e4]
+
+ output = self.classifier(torch.cat(outputs, dim=1))
+ # if not self.training:
+ # return torch.sigmoid(output)
+
+ outputs.append(output)
+ outputs = [torch.sigmoid(r) for r in outputs]
+ return outputs
+
+
+def config_model(model):
+ model_options = list(nets.keys())
+ assert model in model_options, \
+ 'unrecognized model, please choose from %s' % str(model_options)
+
+ # print(str(nets[model]))
+
+ pdcs = []
+ for i in range(16):
+ layer_name = 'layer%d' % i
+ op = nets[model][layer_name]
+ pdcs.append(createConvFunc(op))
+
+ return pdcs
+
+
+def pidinet():
+ pdcs = config_model('carv4')
+ dil = 24 # if args.dil else None
+ return PiDiNet(60, pdcs, dil=dil, sa=True)
diff --git a/modules/cnet_modules/pidinet/util.py b/modules/cnet_modules/pidinet/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..aec00770c7706f95abf3a0b9b02dbe3232930596
--- /dev/null
+++ b/modules/cnet_modules/pidinet/util.py
@@ -0,0 +1,97 @@
+import random
+
+import numpy as np
+import cv2
+import os
+
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
+
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
+
+
+def make_noise_disk(H, W, C, F):
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
+ noise = noise[F: F + H, F: F + W]
+ noise -= np.min(noise)
+ noise /= np.max(noise)
+ if C == 1:
+ noise = noise[:, :, None]
+ return noise
+
+
+def min_max_norm(x):
+ x -= np.min(x)
+ x /= np.maximum(np.max(x), 1e-5)
+ return x
+
+
+def safe_step(x, step=2):
+ y = x.astype(np.float32) * float(step + 1)
+ y = y.astype(np.int32).astype(np.float32) / float(step)
+ return y
+
+
+def img2mask(img, H, W, low=10, high=90):
+ assert img.ndim == 3 or img.ndim == 2
+ assert img.dtype == np.uint8
+
+ if img.ndim == 3:
+ y = img[:, :, random.randrange(0, img.shape[2])]
+ else:
+ y = img
+
+ y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
+
+ if random.uniform(0, 1) < 0.5:
+ y = 255 - y
+
+ return y < np.percentile(y, random.randrange(low, high))
diff --git a/modules/common.py b/modules/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4ad71649f60f2dd38947c9ebc23bc51db2b544
--- /dev/null
+++ b/modules/common.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from einops import rearrange
+import torch.fft as fft
+class Linear(torch.nn.Linear):
+ def reset_parameters(self):
+ return None
+
+class Conv2d(torch.nn.Conv2d):
+ def reset_parameters(self):
+ return None
+
+
+
+class Attention2D(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0):
+ super().__init__()
+ self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
+
+ def forward(self, x, kv, self_attn=False):
+ orig_shape = x.shape
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
+ if self_attn:
+ #print('in line 23 algong self att ', kv.shape, x.shape)
+ kv = torch.cat([x, kv], dim=1)
+ #if x.shape[1] >= 72 * 72:
+ # x = x * math.sqrt(math.log(64*64, 24*24))
+
+ x = self.attn(x, kv, kv, need_weights=False)[0]
+ x = x.permute(0, 2, 1).view(*orig_shape)
+ return x
+
+
+class LayerNorm2d(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+class GlobalResponseNorm(nn.Module):
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2):
+ super().__init__()
+ self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ Linear(c + c_skip, c * 4),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4),
+ nn.Dropout(dropout),
+ Linear(c * 4, c)
+ )
+
+ def forward(self, x, x_skip=None):
+ x_res = x
+ x = self.norm(self.depthwise(x))
+ if x_skip is not None:
+ x = torch.cat([x, x_skip], dim=1)
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + x_res
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention2D(c, nhead, dropout)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ Linear(c_cond, c)
+ )
+
+ def forward(self, x, kv):
+ kv = self.kv_mapper(kv)
+ res = self.attention(self.norm(x), kv, self_attn=self.self_attn)
+
+ #print(torch.unique(res), torch.unique(x), self.self_attn)
+ #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24))
+ x = x + res
+
+ return x
+
+class FeedForwardBlock(nn.Module):
+ def __init__(self, c, dropout=0.0):
+ super().__init__()
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ Linear(c, c * 4),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4),
+ nn.Dropout(dropout),
+ Linear(c * 4, c)
+ )
+
+ def forward(self, x):
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x
+
+
+class TimestepBlock(nn.Module):
+ def __init__(self, c, c_timestep, conds=['sca']):
+ super().__init__()
+ self.mapper = Linear(c_timestep, c * 2)
+ self.conds = conds
+ for cname in conds:
+ setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
+
+ def forward(self, x, t):
+ t = t.chunk(len(self.conds) + 1, dim=1)
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
+ for i, c in enumerate(self.conds):
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
+ a, b = a + ac, b + bc
+ return x * (1 + a) + b
diff --git a/modules/common_ckpt.py b/modules/common_ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..f64cf11790bdd2a83ca0744629336d81464b3ed0
--- /dev/null
+++ b/modules/common_ckpt.py
@@ -0,0 +1,360 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from einops import rearrange
+from modules.speed_util import checkpoint
+class Linear(torch.nn.Linear):
+ def reset_parameters(self):
+ return None
+
+class Conv2d(torch.nn.Conv2d):
+ def reset_parameters(self):
+ return None
+
+class AttnBlock_lrfuse_backup(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention2D(c, nhead, dropout)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ Linear(c_cond, c)
+ )
+ self.fuse_mapper = nn.Sequential(
+ nn.SiLU(),
+ Linear(c_cond, c)
+ )
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, hr, lr):
+ return checkpoint(self._forward, (hr, lr), self.paramters(), self.use_checkpoint)
+ def _forward(self, hr, lr):
+ res = hr
+ hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c'))
+ lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr
+
+ lr_fuse = self.fuse_mapper(rearrange(lr_fuse, 'b c h w -> b (h w ) c'))
+ hr = self.attention(self.norm(res), lr_fuse, self_attn=False) + res
+ return hr
+
+
+class AttnBlock_lrfuse(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, kernel_size=3, use_checkpoint=True):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention2D(c, nhead, dropout)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ Linear(c_cond, c)
+ )
+
+
+ self.depthwise = Conv2d(c, c , kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
+
+ self.channelwise = nn.Sequential(
+ Linear(c + c, c ),
+ nn.GELU(),
+ GlobalResponseNorm(c ),
+ nn.Dropout(dropout),
+ Linear(c , c)
+ )
+ self.use_checkpoint = use_checkpoint
+
+
+ def forward(self, hr, lr):
+ return checkpoint(self._forward, (hr, lr), self.parameters(), self.use_checkpoint)
+
+ def _forward(self, hr, lr):
+ res = hr
+ hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c'))
+ lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr
+
+ lr_fuse = torch.nn.functional.interpolate(lr_fuse.float(), res.shape[2:])
+ #print('in line 65', lr_fuse.shape, res.shape)
+ media = torch.cat((self.depthwise(lr_fuse), res), dim=1)
+ out = self.channelwise(media.permute(0,2,3,1)).permute(0,3,1,2) + res
+
+ return out
+
+
+
+
+class Attention2D(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0):
+ super().__init__()
+ self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
+
+ def forward(self, x, kv, self_attn=False):
+ orig_shape = x.shape
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
+ if self_attn:
+ #print('in line 23 algong self att ', kv.shape, x.shape)
+
+ kv = torch.cat([x, kv], dim=1)
+ #if x.shape[1] > 48 * 48 and not self.training:
+ # x = x * math.sqrt(math.log(x.shape[1] , 24*24))
+
+ x = self.attn(x, kv, kv, need_weights=False)[0]
+ x = x.permute(0, 2, 1).view(*orig_shape)
+ return x
+class Attention2D_splitpatch(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0):
+ super().__init__()
+ self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
+
+ def forward(self, x, kv, self_attn=False):
+ orig_shape = x.shape
+
+ #x = rearrange(x, 'b c h w -> b c (nh wh) (nw ww)', wh=24, ww=24, nh=orig_shape[-2] // 24, nh=orig_shape[-1] // 24,)
+ x = rearrange(x, 'b c (nh wh) (nw ww) -> (b nh nw) (wh ww) c', wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24,)
+ #print('in line 168', x.shape)
+ #x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
+ if self_attn:
+ #print('in line 23 algong self att ', kv.shape, x.shape)
+ num = (orig_shape[-2] // 24) * (orig_shape[-1] // 24)
+ kv = torch.cat([x, kv.repeat(num, 1, 1)], dim=1)
+ #if x.shape[1] > 48 * 48 and not self.training:
+ # x = x * math.sqrt(math.log(x.shape[1] / math.sqrt(16), 24*24))
+
+ x = self.attn(x, kv, kv, need_weights=False)[0]
+ x = rearrange(x, ' (b nh nw) (wh ww) c -> b c (nh wh) (nw ww)', b=orig_shape[0], wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24)
+ #x = x.permute(0, 2, 1).view(*orig_shape)
+
+ return x
+class Attention2D_extra(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0):
+ super().__init__()
+ self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
+
+ def forward(self, x, kv, extra_emb=None, self_attn=False):
+ orig_shape = x.shape
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
+ num_x = x.shape[1]
+
+
+ if extra_emb is not None:
+ ori_extra_shape = extra_emb.shape
+ extra_emb = extra_emb.view(extra_emb.size(0), extra_emb.size(1), -1).permute(0, 2, 1)
+ x = torch.cat((x, extra_emb), dim=1)
+ if self_attn:
+ #print('in line 23 algong self att ', kv.shape, x.shape)
+ kv = torch.cat([x, kv], dim=1)
+ x = self.attn(x, kv, kv, need_weights=False)[0]
+ img = x[:, :num_x, :].permute(0, 2, 1).view(*orig_shape)
+ if extra_emb is not None:
+ fix = x[:, num_x:, :].permute(0, 2, 1).view(*ori_extra_shape)
+ return img, fix
+ else:
+ return img
+class AttnBlock_extraq(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ #self.norm2 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention2D_extra(c, nhead, dropout)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ Linear(c_cond, c)
+ )
+ # norm2 initialization in generator in init extra parameter
+ def forward(self, x, kv, extra_emb=None):
+ #print('in line 84', x.shape, kv.shape, self.self_attn, extra_emb if extra_emb is None else extra_emb.shape)
+ #in line 84 torch.Size([1, 1536, 32, 32]) torch.Size([1, 85, 1536]) True None
+ #if extra_emb is not None:
+
+ kv = self.kv_mapper(kv)
+ if extra_emb is not None:
+ res_x, res_extra = self.attention(self.norm(x), kv, extra_emb=self.norm2(extra_emb), self_attn=self.self_attn)
+ x = x + res_x
+ extra_emb = extra_emb + res_extra
+ return x, extra_emb
+ else:
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ return x
+class AttnBlock_latent2ex(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention2D(c, nhead, dropout)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ Linear(c_cond, c)
+ )
+
+ def forward(self, x, kv):
+ #print('in line 84', x.shape, kv.shape, self.self_attn)
+ kv = F.interpolate(kv.float(), x.shape[2:])
+ kv = kv.view(kv.size(0), kv.size(1), -1).permute(0, 2, 1)
+ kv = self.kv_mapper(kv)
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ return x
+
+class LayerNorm2d(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+class AttnBlock_crossbranch(nn.Module):
+ def __init__(self, attnmodule, c, c_cond, nhead, self_attn=True, dropout=0.0):
+ super().__init__()
+ self.attn = AttnBlock(c, c_cond, nhead, self_attn, dropout)
+ #print('in line 108', attnmodule.device)
+ self.attn.load_state_dict(attnmodule.state_dict())
+ self.norm1 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+
+ self.channelwise1 = nn.Sequential(
+ Linear(c *2, c ),
+ nn.GELU(),
+ GlobalResponseNorm(c ),
+ nn.Dropout(dropout),
+ Linear(c, c)
+ )
+ self.channelwise2 = nn.Sequential(
+ Linear(c *2, c ),
+ nn.GELU(),
+ GlobalResponseNorm(c ),
+ nn.Dropout(dropout),
+ Linear(c, c)
+ )
+ self.c = c
+ def forward(self, x, kv, main_x):
+ #print('in line 84', x.shape, kv.shape, main_x.shape, self.c)
+
+ x = self.channelwise1(torch.cat((x, F.interpolate(main_x.float(), x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x
+ x = self.attn(x, kv)
+ main_x = self.channelwise2(torch.cat((main_x, F.interpolate(x.float(), main_x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + main_x
+ return main_x, x
+
+class GlobalResponseNorm(nn.Module):
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, use_checkpoint =True): # , num_heads=4, expansion=2):
+ super().__init__()
+ self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ Linear(c + c_skip, c * 4),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4),
+ nn.Dropout(dropout),
+ Linear(c * 4, c)
+ )
+ self.use_checkpoint = use_checkpoint
+ def forward(self, x, x_skip=None):
+
+ if x_skip is not None:
+ return checkpoint(self._forward_skip, (x, x_skip), self.parameters(), self.use_checkpoint)
+ else:
+ #print('in line 298', x.shape)
+ return checkpoint(self._forward_woskip, (x, ), self.parameters(), self.use_checkpoint)
+
+
+
+ def _forward_skip(self, x, x_skip):
+ x_res = x
+ x = self.norm(self.depthwise(x))
+ if x_skip is not None:
+ x = torch.cat([x, x_skip], dim=1)
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + x_res
+ def _forward_woskip(self, x):
+ x_res = x
+ x = self.norm(self.depthwise(x))
+
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + x_res
+
+class AttnBlock(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention2D(c, nhead, dropout)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ Linear(c_cond, c)
+ )
+ self.use_checkpoint = use_checkpoint
+ def forward(self, x, kv):
+ return checkpoint(self._forward, (x, kv), self.parameters(), self.use_checkpoint)
+ def _forward(self, x, kv):
+ kv = self.kv_mapper(kv)
+ res = self.attention(self.norm(x), kv, self_attn=self.self_attn)
+
+ #print(torch.unique(res), torch.unique(x), self.self_attn)
+ #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24))
+ x = x + res
+
+ return x
+class AttnBlock_mytest(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention2D(c, nhead, dropout)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(c_cond, c)
+ )
+
+ def forward(self, x, kv):
+ kv = self.kv_mapper(kv)
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ return x
+
+class FeedForwardBlock(nn.Module):
+ def __init__(self, c, dropout=0.0):
+ super().__init__()
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ Linear(c, c * 4),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4),
+ nn.Dropout(dropout),
+ Linear(c * 4, c)
+ )
+
+ def forward(self, x):
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x
+
+
+class TimestepBlock(nn.Module):
+ def __init__(self, c, c_timestep, conds=['sca'], use_checkpoint=True):
+ super().__init__()
+ self.mapper = Linear(c_timestep, c * 2)
+ self.conds = conds
+ for cname in conds:
+ setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
+
+ self.use_checkpoint = use_checkpoint
+ def forward(self, x, t):
+ return checkpoint(self._forward, (x, t), self.parameters(), self.use_checkpoint)
+
+ def _forward(self, x, t):
+ #print('in line 284', x.shape, t.shape, self.conds)
+ #in line 284 torch.Size([4, 2048, 19, 29]) torch.Size([4, 192]) ['sca', 'crp']
+ t = t.chunk(len(self.conds) + 1, dim=1)
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
+ for i, c in enumerate(self.conds):
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
+ a, b = a + ac, b + bc
+ return x * (1 + a) + b
diff --git a/modules/controlnet.py b/modules/controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c187aecb725e00e19924ae308e3aac401acfdf06
--- /dev/null
+++ b/modules/controlnet.py
@@ -0,0 +1,349 @@
+import torchvision
+import torch
+from torch import nn
+import numpy as np
+import kornia
+import cv2
+from core.utils import load_or_fail
+#from insightface.app.common import Face
+from .effnet import EfficientNetEncoder
+from .cnet_modules.pidinet import PidiNetDetector
+from .cnet_modules.inpainting.saliency_model import MicroResNet
+#from .cnet_modules.face_id.arcface import FaceDetector, ArcFaceRecognizer
+from .common import LayerNorm2d
+
+
+class CNetResBlock(nn.Module):
+ def __init__(self, c):
+ super().__init__()
+ self.blocks = nn.Sequential(
+ LayerNorm2d(c),
+ nn.GELU(),
+ nn.Conv2d(c, c, kernel_size=3, padding=1),
+ LayerNorm2d(c),
+ nn.GELU(),
+ nn.Conv2d(c, c, kernel_size=3, padding=1),
+ )
+
+ def forward(self, x):
+ return x + self.blocks(x)
+
+
+class ControlNet(nn.Module):
+ def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None):
+ super().__init__()
+ if bottleneck_mode is None:
+ bottleneck_mode = 'effnet'
+ self.proj_blocks = proj_blocks
+ if bottleneck_mode == 'effnet':
+ embd_channels = 1280
+ #self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval()
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
+ if c_in != 3:
+ in_weights = self.backbone[0][0].weight.data
+ self.backbone[0][0] = nn.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False)
+ if c_in > 3:
+ nn.init.constant_(self.backbone[0][0].weight, 0)
+ self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
+ else:
+ self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
+ elif bottleneck_mode == 'simple':
+ embd_channels = c_in
+ self.backbone = nn.Sequential(
+ nn.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1),
+ )
+ elif bottleneck_mode == 'large':
+ self.backbone = nn.Sequential(
+ nn.Conv2d(c_in, 4096 * 4, kernel_size=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(4096 * 4, 1024, kernel_size=1),
+ *[CNetResBlock(1024) for _ in range(8)],
+ nn.Conv2d(1024, 1280, kernel_size=1),
+ )
+ embd_channels = 1280
+ else:
+ raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
+ self.projections = nn.ModuleList()
+ for _ in range(len(proj_blocks)):
+ self.projections.append(nn.Sequential(
+ nn.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False),
+ ))
+ nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
+
+ def forward(self, x):
+ x = self.backbone(x)
+ proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
+ for i, idx in enumerate(self.proj_blocks):
+ proj_outputs[idx] = self.projections[i](x)
+ return proj_outputs
+
+
+class ControlNetDeliverer():
+ def __init__(self, controlnet_projections):
+ self.controlnet_projections = controlnet_projections
+ self.restart()
+
+ def restart(self):
+ self.idx = 0
+ return self
+
+ def __call__(self):
+ if self.idx < len(self.controlnet_projections):
+ output = self.controlnet_projections[self.idx]
+ else:
+ output = None
+ self.idx += 1
+ return output
+
+
+# CONTROLNET FILTERS ----------------------------------------------------
+
+class BaseFilter():
+ def __init__(self, device):
+ self.device = device
+
+ def num_channels(self):
+ return 3
+
+ def __call__(self, x):
+ return x
+
+
+class CannyFilter(BaseFilter):
+ def __init__(self, device, resize=224):
+ super().__init__(device)
+ self.resize = resize
+
+ def num_channels(self):
+ return 1
+
+ def __call__(self, x):
+ orig_size = x.shape[-2:]
+ if self.resize is not None:
+ x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear')
+ edges = [cv2.Canny(x[i].mul(255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), 100, 200) for i in range(len(x))]
+ edges = torch.stack([torch.tensor(e).div(255).unsqueeze(0) for e in edges], dim=0)
+ if self.resize is not None:
+ edges = nn.functional.interpolate(edges, size=orig_size, mode='bilinear')
+ return edges
+
+
+class QRFilter(BaseFilter):
+ def __init__(self, device, resize=224, blobify=True, dilation_kernels=[3, 5, 7], blur_kernels=[15]):
+ super().__init__(device)
+ self.resize = resize
+ self.blobify = blobify
+ self.dilation_kernels = dilation_kernels
+ self.blur_kernels = blur_kernels
+
+ def num_channels(self):
+ return 1
+
+ def __call__(self, x):
+ x = x.to(self.device)
+ orig_size = x.shape[-2:]
+ if self.resize is not None:
+ x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear')
+
+ x = kornia.color.rgb_to_hsv(x)[:, -1:]
+ # blobify
+ if self.blobify:
+ d_kernel = np.random.choice(self.dilation_kernels)
+ d_blur = np.random.choice(self.blur_kernels)
+ if d_blur > 0:
+ x = torchvision.transforms.GaussianBlur(d_blur)(x)
+ if d_kernel > 0:
+ blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5,
+ d_kernel).pow(2)[:,
+ None]) < 0.3).float().to(self.device)
+ x = kornia.morphology.dilation(x, blob_mask)
+ x = kornia.morphology.erosion(x, blob_mask)
+ # mask
+ vmax, vmin = x.amax(dim=[2, 3], keepdim=True)[0], x.amin(dim=[2, 3], keepdim=True)[0]
+ th = (vmax - vmin) * 0.33
+ high_brightness, low_brightness = (x > (vmax - th)).float(), (x < (vmin + th)).float()
+ mask = (torch.ones_like(x) - low_brightness + high_brightness) * 0.5
+
+ if self.resize is not None:
+ mask = nn.functional.interpolate(mask, size=orig_size, mode='bilinear')
+ return mask.cpu()
+
+
+class PidiFilter(BaseFilter):
+ def __init__(self, device, resize=224, dilation_kernels=[0, 3, 5, 7, 9], binarize=True):
+ super().__init__(device)
+ self.resize = resize
+ self.model = PidiNetDetector(device)
+ self.dilation_kernels = dilation_kernels
+ self.binarize = binarize
+
+ def num_channels(self):
+ return 1
+
+ def __call__(self, x):
+ x = x.to(self.device)
+ orig_size = x.shape[-2:]
+ if self.resize is not None:
+ x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear')
+
+ x = self.model(x)
+ d_kernel = np.random.choice(self.dilation_kernels)
+ if d_kernel > 0:
+ blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, d_kernel).pow(2)[
+ :, None]) < 0.3).float().to(self.device)
+ x = kornia.morphology.dilation(x, blob_mask)
+ if self.binarize:
+ th = np.random.uniform(0.05, 0.7)
+ x = (x > th).float()
+
+ if self.resize is not None:
+ x = nn.functional.interpolate(x, size=orig_size, mode='bilinear')
+ return x.cpu()
+
+
+class SRFilter(BaseFilter):
+ def __init__(self, device, scale_factor=1 / 4):
+ super().__init__(device)
+ self.scale_factor = scale_factor
+
+ def num_channels(self):
+ return 3
+
+ def __call__(self, x):
+ x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest")
+ return torch.nn.functional.interpolate(x, scale_factor=1 / self.scale_factor, mode="nearest")
+
+
+class SREffnetFilter(BaseFilter):
+ def __init__(self, device, scale_factor=1/2):
+ super().__init__(device)
+ self.scale_factor = scale_factor
+
+ self.effnet_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+ ])
+
+ self.effnet = EfficientNetEncoder().to(self.device)
+ effnet_checkpoint = load_or_fail("models/effnet_encoder.safetensors")
+ self.effnet.load_state_dict(effnet_checkpoint)
+ self.effnet.eval().requires_grad_(False)
+
+ def num_channels(self):
+ return 16
+
+ def __call__(self, x):
+ x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest")
+ with torch.no_grad():
+ effnet_embedding = self.effnet(self.effnet_preprocess(x.to(self.device))).cpu()
+ effnet_embedding = torch.nn.functional.interpolate(effnet_embedding, scale_factor=1/self.scale_factor, mode="nearest")
+ upscaled_image = torch.nn.functional.interpolate(x, scale_factor=1/self.scale_factor, mode="nearest")
+ return effnet_embedding, upscaled_image
+
+
+class InpaintFilter(BaseFilter):
+ def __init__(self, device, thresold=[0.04, 0.4], p_outpaint=0.4):
+ super().__init__(device)
+ self.saliency_model = MicroResNet().eval().requires_grad_(False).to(device)
+ self.saliency_model.load_state_dict(load_or_fail("modules/cnet_modules/inpainting/saliency_model.pt"))
+ self.thresold = thresold
+ self.p_outpaint = p_outpaint
+
+ def num_channels(self):
+ return 4
+
+ def __call__(self, x, mask=None, threshold=None, outpaint=None):
+ x = x.to(self.device)
+ resized_x = torchvision.transforms.functional.resize(x, 240, antialias=True)
+ if threshold is None:
+ threshold = np.random.uniform(self.thresold[0], self.thresold[1])
+ if mask is None:
+ saliency_map = self.saliency_model(resized_x) > threshold
+ if outpaint is None:
+ if np.random.rand() < self.p_outpaint:
+ saliency_map = ~saliency_map
+ else:
+ if outpaint:
+ saliency_map = ~saliency_map
+ interpolated_saliency_map = torch.nn.functional.interpolate(saliency_map.float(), size=x.shape[2:], mode="nearest")
+ saliency_map = torchvision.transforms.functional.gaussian_blur(interpolated_saliency_map, 141) > 0.5
+ inpainted_images = torch.where(saliency_map, torch.ones_like(x), x)
+ mask = torch.nn.functional.interpolate(saliency_map.float(), size=inpainted_images.shape[2:], mode="nearest")
+ else:
+ mask = mask.to(self.device)
+ inpainted_images = torch.where(mask, torch.ones_like(x), x)
+ c_inpaint = torch.cat([inpainted_images, mask], dim=1)
+ return c_inpaint.cpu()
+
+
+# IDENTITY
+'''
+class IdentityFilter(BaseFilter):
+ def __init__(self, device, max_faces=4, p_drop=0.05, p_full=0.3):
+ detector_path = 'modules/cnet_modules/face_id/models/buffalo_l/det_10g.onnx'
+ recognizer_path = 'modules/cnet_modules/face_id/models/buffalo_l/w600k_r50.onnx'
+
+ super().__init__(device)
+ self.max_faces = max_faces
+ self.p_drop = p_drop
+ self.p_full = p_full
+
+ self.detector = FaceDetector(detector_path, device=device)
+ self.recognizer = ArcFaceRecognizer(recognizer_path, device=device)
+
+ self.id_colors = torch.tensor([
+ [1.0, 0.0, 0.0], # RED
+ [0.0, 1.0, 0.0], # GREEN
+ [0.0, 0.0, 1.0], # BLUE
+ [1.0, 0.0, 1.0], # PURPLE
+ [0.0, 1.0, 1.0], # CYAN
+ [1.0, 1.0, 0.0], # YELLOW
+ [0.5, 0.0, 0.0], # DARK RED
+ [0.0, 0.5, 0.0], # DARK GREEN
+ [0.0, 0.0, 0.5], # DARK BLUE
+ [0.5, 0.0, 0.5], # DARK PURPLE
+ [0.0, 0.5, 0.5], # DARK CYAN
+ [0.5, 0.5, 0.0], # DARK YELLOW
+ ])
+
+ def num_channels(self):
+ return 512
+
+ def get_faces(self, image):
+ npimg = image.permute(1, 2, 0).mul(255).to(device="cpu", dtype=torch.uint8).cpu().numpy()
+ bgr = cv2.cvtColor(npimg, cv2.COLOR_RGB2BGR)
+ bboxes, kpss = self.detector.detect(bgr, max_num=self.max_faces)
+ N = len(bboxes)
+ ids = torch.zeros((N, 512), dtype=torch.float32)
+ for i in range(N):
+ face = Face(bbox=bboxes[i, :4], kps=kpss[i], det_score=bboxes[i, 4])
+ ids[i, :] = self.recognizer.get(bgr, face)
+ tbboxes = torch.tensor(bboxes[:, :4], dtype=torch.int)
+
+ ids = ids / torch.linalg.norm(ids, dim=1, keepdim=True)
+ return tbboxes, ids # returns bounding boxes (N x 4) and ID vectors (N x 512)
+
+ def __call__(self, x):
+ visual_aid = x.clone().cpu()
+ face_mtx = torch.zeros(x.size(0), 512, x.size(-2) // 32, x.size(-1) // 32)
+
+ for i in range(x.size(0)):
+ bounding_boxes, ids = self.get_faces(x[i])
+ for j in range(bounding_boxes.size(0)):
+ if np.random.rand() > self.p_drop:
+ sx, sy, ex, ey = (bounding_boxes[j] / 32).clamp(min=0).round().int().tolist()
+ ex, ey = max(ex, sx + 1), max(ey, sy + 1)
+ if bounding_boxes.size(0) == 1 and np.random.rand() < self.p_full:
+ sx, sy, ex, ey = 0, 0, x.size(-1) // 32, x.size(-2) // 32
+ face_mtx[i, :, sy:ey, sx:ex] = ids[j:j + 1, :, None, None]
+ visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] += self.id_colors[j % 13, :,
+ None, None]
+ visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] *= 0.5
+
+ return face_mtx.to(x.device), visual_aid.to(x.device)
+'''
diff --git a/modules/effnet.py b/modules/effnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eb2690c2547c8c7553aec8a9f9e838241f8f61c
--- /dev/null
+++ b/modules/effnet.py
@@ -0,0 +1,17 @@
+import torchvision
+from torch import nn
+
+
+# EfficientNet
+class EfficientNetEncoder(nn.Module):
+ def __init__(self, c_latent=16):
+ super().__init__()
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
+ self.mapper = nn.Sequential(
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
+ )
+
+ def forward(self, x):
+ return self.mapper(self.backbone(x))
+
diff --git a/modules/inr_fea_res_lite.py b/modules/inr_fea_res_lite.py
new file mode 100644
index 0000000000000000000000000000000000000000..41ddfb09937f26e2c7d0193b4a65607efabde5e5
--- /dev/null
+++ b/modules/inr_fea_res_lite.py
@@ -0,0 +1,435 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import einops
+import numpy as np
+import models
+from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d
+#from modules.common_ckpt import AttnBlock,
+from einops import rearrange
+import torch.fft as fft
+from modules.speed_util import checkpoint
+def batched_linear_mm(x, wb):
+ # x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2)
+ one = torch.ones(*x.shape[:-1], 1, device=x.device)
+ return torch.matmul(torch.cat([x, one], dim=-1), wb)
+def make_coord_grid(shape, range, device=None):
+ """
+ Args:
+ shape: tuple
+ range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim
+ Returns:
+ grid: shape (*shape, )
+ """
+ l_lst = []
+ for i, s in enumerate(shape):
+ l = (0.5 + torch.arange(s, device=device)) / s
+ if isinstance(range[0], list) or isinstance(range[0], tuple):
+ minv, maxv = range[i]
+ else:
+ minv, maxv = range
+ l = minv + (maxv - minv) * l
+ l_lst.append(l)
+ grid = torch.meshgrid(*l_lst, indexing='ij')
+ grid = torch.stack(grid, dim=-1)
+ return grid
+def init_wb(shape):
+ weight = torch.empty(shape[1], shape[0] - 1)
+ nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
+
+ bias = torch.empty(shape[1], 1)
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(bias, -bound, bound)
+
+ return torch.cat([weight, bias], dim=1).t().detach()
+
+def init_wb_rewrite(shape):
+ weight = torch.empty(shape[1], shape[0] - 1)
+
+ torch.nn.init.xavier_uniform_(weight)
+
+ bias = torch.empty(shape[1], 1)
+ torch.nn.init.xavier_uniform_(bias)
+
+
+ return torch.cat([weight, bias], dim=1).t().detach()
+class HypoMlp(nn.Module):
+
+ def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024):
+ super().__init__()
+ self.use_pe = use_pe
+ self.pe_dim = pe_dim
+ self.pe_sigma = pe_sigma
+ self.depth = depth
+ self.param_shapes = dict()
+ if use_pe:
+ last_dim = in_dim * pe_dim
+ else:
+ last_dim = in_dim
+ for i in range(depth): # for each layer the weight
+ cur_dim = hidden_dim if i < depth - 1 else out_dim
+ self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim)
+ last_dim = cur_dim
+ self.relu = nn.ReLU()
+ self.params = None
+ self.out_bias = out_bias
+
+ def set_params(self, params):
+ self.params = params
+
+ def convert_posenc(self, x):
+ w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device))
+ x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1)
+ x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1)
+ return x
+
+ def forward(self, x):
+ B, query_shape = x.shape[0], x.shape[1: -1]
+ x = x.view(B, -1, x.shape[-1])
+ if self.use_pe:
+ x = self.convert_posenc(x)
+ #print('in line 79 after pos embedding', x.shape)
+ for i in range(self.depth):
+ x = batched_linear_mm(x, self.params[f'wb{i}'])
+ if i < self.depth - 1:
+ x = self.relu(x)
+ else:
+ x = x + self.out_bias
+ x = x.view(B, *query_shape, -1)
+ return x
+
+
+
+class Attention(nn.Module):
+
+ def __init__(self, dim, n_head, head_dim, dropout=0.):
+ super().__init__()
+ self.n_head = n_head
+ inner_dim = n_head * head_dim
+ self.to_q = nn.Sequential(
+ nn.SiLU(),
+ Linear(dim, inner_dim ))
+ self.to_kv = nn.Sequential(
+ nn.SiLU(),
+ Linear(dim, inner_dim * 2))
+ self.scale = head_dim ** -0.5
+ # self.to_out = nn.Sequential(
+ # Linear(inner_dim, dim),
+ # nn.Dropout(dropout),
+ # )
+
+ def forward(self, fr, to=None):
+ if to is None:
+ to = fr
+ q = self.to_q(fr)
+ k, v = self.to_kv(to).chunk(2, dim=-1)
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v])
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+ attn = F.softmax(dots, dim=-1) # b h n n
+ out = torch.matmul(attn, v)
+ out = einops.rearrange(out, 'b h n d -> b n (h d)')
+ return out
+
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim, ff_dim, dropout=0.):
+ super().__init__()
+
+ self.net = nn.Sequential(
+ Linear(dim, ff_dim),
+ nn.GELU(),
+ #GlobalResponseNorm(ff_dim),
+ nn.Dropout(dropout),
+ Linear(ff_dim, dim)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class PreNorm(nn.Module):
+
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x):
+ return self.fn(self.norm(x))
+
+
+#TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = [])
+class TransformerEncoder(nn.Module):
+
+ def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.):
+ super().__init__()
+ self.layers = nn.ModuleList()
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)),
+ PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)),
+ ]))
+
+ def forward(self, x):
+ for norm_attn, norm_ff in self.layers:
+ x = x + norm_attn(x)
+ x = x + norm_ff(x)
+ return x
+class ImgrecTokenizer(nn.Module):
+
+ def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16):
+ super().__init__()
+
+ if isinstance(patch_size, int):
+ patch_size = (patch_size, patch_size)
+ if isinstance(padding, int):
+ padding = (padding, padding)
+ self.patch_size = patch_size
+ self.padding = padding
+ self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim)
+
+ self.posemb = nn.Parameter(torch.randn(input_size, dim))
+
+ def forward(self, x):
+ #print(x.shape)
+ p = self.patch_size
+ x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L)
+ #print('in line 185 after unfoding', x.shape)
+ x = x.permute(0, 2, 1).contiguous()
+ ttt = self.prefc(x)
+
+ x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0)
+ return x
+
+class SpatialAttention(nn.Module):
+ def __init__(self, kernel_size=7):
+ super(SpatialAttention, self).__init__()
+
+ self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = torch.mean(x, dim=1, keepdim=True)
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
+ x = torch.cat([avg_out, max_out], dim=1)
+ x = self.conv1(x)
+ return self.sigmoid(x)
+
+class TimestepBlock_res(nn.Module):
+ def __init__(self, c, c_timestep, conds=['sca']):
+ super().__init__()
+
+ self.mapper = Linear(c_timestep, c * 2)
+ self.conds = conds
+ for cname in conds:
+ setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
+
+
+
+
+ def forward(self, x, t):
+ #print(x.shape, t.shape, self.conds, 'in line 269')
+ t = t.chunk(len(self.conds) + 1, dim=1)
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
+
+ for i, c in enumerate(self.conds):
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
+ a, b = a + ac, b + bc
+ return x * (1 + a) + b
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+
+class ScaleNormalize_res(nn.Module):
+ def __init__(self, c, scale_c, conds=['sca']):
+ super().__init__()
+ self.c_r = scale_c
+ self.mapping = TimestepBlock_res(c, scale_c, conds=conds)
+ self.t_conds = conds
+ self.alpha = nn.Conv2d(c, c, kernel_size=1)
+ self.gamma = nn.Conv2d(c, c, kernel_size=1)
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
+
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+ def forward(self, x, std_size=24*24):
+ scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size))
+ scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val
+ scale_val_f = self.gen_r_embedding(scale_val)
+ for c in self.t_conds:
+ t_cond = torch.zeros_like(scale_val)
+ scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1)
+
+ f = self.mapping(x, scale_val_f)
+
+ return f + x
+
+
+class TransInr_withnorm(nn.Module):
+
+ def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]):
+ super().__init__()
+ self.input_layer= nn.Conv2d(ind, ch, 1)
+ self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch)
+ #self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128)
+ #self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, )
+
+ self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128)
+ self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim)
+ #self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = [])
+ self.base_params = nn.ParameterDict()
+ n_wtokens = 0
+ self.wtoken_postfc = nn.ModuleDict()
+ self.wtoken_rng = dict()
+ for name, shape in self.hyponet.param_shapes.items():
+ self.base_params[name] = nn.Parameter(init_wb(shape))
+ g = min(n_groups, shape[1])
+ assert shape[1] % g == 0
+ self.wtoken_postfc[name] = nn.Sequential(
+ nn.LayerNorm(f_dim),
+ nn.Linear(f_dim, shape[0] - 1),
+ )
+ self.wtoken_rng[name] = (n_wtokens, n_wtokens + g)
+ n_wtokens += g
+ self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim))
+ self.output_layer= nn.Conv2d(ch, ind, 1)
+
+
+ self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds)
+
+
+ self.hr_norm = ScaleNormalize_res(ind, 64, conds=[])
+
+ self.normalize_final = nn.Sequential(
+ LayerNorm2d(ind, elementwise_affine=False, eps=1e-6),
+ )
+
+ self.toout = nn.Sequential(
+ Linear( ind*2, ind // 4),
+ nn.GELU(),
+ Linear( ind // 4, ind)
+ )
+ self.apply(self._init_weights)
+
+ mask = torch.zeros((1, 1, 32, 32))
+ h, w = 32, 32
+ center_h, center_w = h // 2, w // 2
+ low_freq_h, low_freq_w = h // 4, w // 4
+ mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1
+ self.mask = mask
+
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ #nn.init.constant_(self.last.weight, 0)
+ def adain(self, feature_a, feature_b):
+ norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True)
+ norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True)
+ #feature_a = F.interpolate(feature_a, feature_b.shape[2:])
+ feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean
+ return feature_b
+ def forward(self, target_shape, target, dtokens, t_emb):
+ #print(target.shape, dtokens.shape, 'in line 290')
+ hlr, wlr = dtokens.shape[2:]
+ original = dtokens
+
+ dtokens = self.input_layer(dtokens)
+ dtokens = self.tokenizer(dtokens)
+ B = dtokens.shape[0]
+ wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B)
+ #print(wtokens.shape, dtokens.shape)
+ trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1))
+ trans_out = trans_out[:, -len(self.wtokens):, :]
+
+ params = dict()
+ for name, shape in self.hyponet.param_shapes.items():
+ wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B)
+ w, b = wb[:, :-1, :], wb[:, -1:, :]
+
+ l, r = self.wtoken_rng[name]
+ x = self.wtoken_postfc[name](trans_out[:, l: r, :])
+ x = x.transpose(-1, -2) # (B, shape[0] - 1, g)
+ w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1)
+
+ wb = torch.cat([w, b], dim=1)
+ params[name] = wb
+ coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device)
+ coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0])
+ self.hyponet.set_params(params)
+ ori_up = F.interpolate(original.float(), target_shape[2:])
+ hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up
+ #print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537')
+
+ output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ #print(output.shape, 'in line 540')
+ #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3
+ output = self.mapp_t(output, t_emb)
+ output = self.normalize_final(output)
+ output = self.hr_norm(output)
+ #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ #output = self.mapp_t(output, t_emb)
+ #output = self.weight(output) * output
+
+ return output
+
+
+
+
+
+
+class LayerNorm2d(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+class GlobalResponseNorm(nn.Module):
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+
+if __name__ == '__main__':
+ #ef __init__(self, ch, n_head, head_dim, n_groups):
+ trans_inr = TransInr(16, 24, 32, 64).cuda()
+ input = torch.randn((1, 16, 24, 24)).cuda()
+ source = torch.randn((1, 16, 16, 16)).cuda()
+ t = torch.randn((1, 128)).cuda()
+ output, hr = trans_inr(input, t, source)
+
+ total_up = sum([ param.nelement() for param in trans_inr.parameters()])
+ print(output.shape, hr.shape, total_up /1e6 )
+
diff --git a/modules/lora.py b/modules/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc0a2bd797f3669a465f6c2c4255b52fe1bda7a7
--- /dev/null
+++ b/modules/lora.py
@@ -0,0 +1,71 @@
+import torch
+from torch import nn
+
+
+class LoRA(nn.Module):
+ def __init__(self, layer, name='weight', rank=16, alpha=1):
+ super().__init__()
+ weight = getattr(layer, name)
+ self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1))))
+ self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank)))
+ nn.init.normal_(self.lora_up, mean=0, std=1)
+
+ self.scale = alpha / rank
+ self.enabled = True
+
+ def forward(self, original_weights):
+ if self.enabled:
+ lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2)
+ lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale
+ return original_weights + lora_weights
+ else:
+ return original_weights
+
+
+def apply_lora(model, filters=None, rank=16):
+ def check_parameter(module, name):
+ return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(
+ getattr(module, name), nn.Parameter)
+
+ for name, module in model.named_modules():
+ if filters is None or any([f in name for f in filters]):
+ if check_parameter(module, "weight"):
+ device, dtype = module.weight.device, module.weight.dtype
+ torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device))
+ elif check_parameter(module, "in_proj_weight"):
+ device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype
+ torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device))
+
+
+class ReToken(nn.Module):
+ def __init__(self, indices=None):
+ super().__init__()
+ assert indices is not None
+ self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280))
+ self.register_buffer('indices', torch.tensor(indices))
+ self.enabled = True
+
+ def forward(self, embeddings):
+ if self.enabled:
+ embeddings = embeddings.clone()
+ for i, idx in enumerate(self.indices):
+ embeddings[idx] += self.embeddings[i]
+ return embeddings
+
+
+def apply_retoken(module, indices=None):
+ def check_parameter(module, name):
+ return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(
+ getattr(module, name), nn.Parameter)
+
+ if check_parameter(module, "weight"):
+ device, dtype = module.weight.device, module.weight.dtype
+ torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device))
+
+
+def remove_lora(model, leave_parametrized=True):
+ for module in model.modules():
+ if torch.nn.utils.parametrize.is_parametrized(module, "weight"):
+ nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized)
+ elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"):
+ nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized)
diff --git a/modules/model_4stage_lite.py b/modules/model_4stage_lite.py
new file mode 100644
index 0000000000000000000000000000000000000000..e77cc5d73ccda882774f447f5a8bb86fe71fe755
--- /dev/null
+++ b/modules/model_4stage_lite.py
@@ -0,0 +1,458 @@
+import torch
+from torch import nn
+import numpy as np
+import math
+from modules.common_ckpt import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock
+from .controlnet import ControlNetDeliverer
+import torch.nn.functional as F
+from modules.inr_fea_res_lite import TransInr_withnorm as TransInr
+from modules.inr_fea_res_lite import ScaleNormalize_res
+from einops import rearrange
+import torch.fft as fft
+import random
+class UpDownBlock2d(nn.Module):
+ def __init__(self, c_in, c_out, mode, enabled=True):
+ super().__init__()
+ assert mode in ['up', 'down']
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
+ align_corners=True) if enabled else nn.Identity()
+ mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x.float())
+ return x
+def ada_in(a, b):
+ mean_a = torch.mean(a, dim=(2, 3), keepdim=True)
+ std_a = torch.std(a, dim=(2, 3), keepdim=True)
+
+ mean_b = torch.mean(b, dim=(2, 3), keepdim=True)
+ std_b = torch.std(b, dim=(2, 3), keepdim=True)
+
+ return (b - mean_b) / (1e-8 + std_b) * std_a + mean_a
+def feature_dist_loss(x1, x2):
+ mu1 = torch.mean(x1, dim=(2, 3))
+ mu2 = torch.mean(x2, dim=(2, 3))
+
+ std1 = torch.std(x1, dim=(2, 3))
+ std2 = torch.std(x2, dim=(2, 3))
+ std_loss = torch.mean(torch.abs(torch.log(std1+ 1e-8) - torch.log(std2+ 1e-8)))
+ mean_loss = torch.mean(torch.abs(mu1 - mu2))
+ #print('in line 36', std_loss, mean_loss)
+ return std_loss + mean_loss*0.1
+class StageC(nn.Module):
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
+ dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False],
+ lr_h=24, lr_w=24):
+ super().__init__()
+
+ self.lr_h, self.lr_w = lr_h, lr_w
+ self.block_repeat = block_repeat
+ self.c_in = c_in
+ self.c_cond = c_cond
+ self.patch_size = patch_size
+ self.c_hidden = c_hidden
+ self.nhead = nhead
+ self.blocks = blocks
+ self.level_config = level_config
+ self.kernel_size = kernel_size
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+ self.self_attn = self_attn
+ self.dropout = dropout
+ self.switch_level = switch_level
+ # CONDITIONING
+ self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond)
+ self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq)
+ self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq)
+ self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1])
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+
+
+ #extra down blocks
+
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1])
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
+ nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ self.apply(self._init_weights) # General init
+ nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
+ nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
+ nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
+
+ # blocks
+ for level_block in self.down_blocks + self.up_blocks:
+ for block in level_block:
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ elif isinstance(block, TimestepBlock):
+ for layer in block.modules():
+ if isinstance(layer, nn.Linear):
+ nn.init.constant_(layer.weight, 0)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+
+ def _init_extra_parameter(self):
+
+
+
+ self.agg_net = nn.ModuleList()
+ for _ in range(2):
+
+ self.agg_net.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) #
+
+ self.agg_net_up = nn.ModuleList()
+ for _ in range(2):
+
+ self.agg_net_up.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) #
+
+
+
+
+
+ self.norm_down_blocks = nn.ModuleList()
+ for i in range(len(self.c_hidden)):
+
+ up_blocks = nn.ModuleList()
+ for j in range(self.blocks[0][i]):
+ if j % 4 == 0:
+ up_blocks.append(
+ ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[]))
+ self.norm_down_blocks.append(up_blocks)
+
+
+ self.norm_up_blocks = nn.ModuleList()
+ for i in reversed(range(len(self.c_hidden))):
+
+ up_block = nn.ModuleList()
+ for j in range(self.blocks[1][::-1][i]):
+ if j % 4 == 0:
+ up_block.append(ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[]))
+ self.norm_up_blocks.append(up_block)
+
+
+
+
+ self.agg_net.apply(self._init_weights)
+ self.agg_net_up.apply(self._init_weights)
+ self.norm_up_blocks.apply(self._init_weights)
+ self.norm_down_blocks.apply(self._init_weights)
+ for block in self.agg_net + self.agg_net_up:
+ #for block in level_block:
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ elif isinstance(block, TimestepBlock):
+ for layer in block.modules():
+ if isinstance(layer, nn.Linear):
+ nn.init.constant_(layer.weight, 0)
+
+
+
+
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
+ clip_txt = self.clip_txt_mapper(clip_txt)
+ if len(clip_txt_pooled.shape) == 2:
+ clip_txt_pool = clip_txt_pooled.unsqueeze(1)
+ if len(clip_img.shape) == 2:
+ clip_img = clip_img.unsqueeze(1)
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip, cnet=None, require_q=False, lr_guide=None, r_emb_lite=None, guide_weight=1):
+ level_outputs = []
+ if require_q:
+ qs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for stage_cnt, (down_block, downscaler, repmap) in enumerate(block_group):
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for inner_cnt, block in enumerate(down_block):
+
+
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ if cnet is not None and lr_guide is None:
+ #if cnet is not None :
+ next_cnet = cnet()
+ if next_cnet is not None:
+
+ x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+
+ x = block(x, clip)
+ if require_q and (inner_cnt == 2 ):
+ qs.append(x.clone())
+ if lr_guide is not None and (inner_cnt == 2 ) :
+
+ guide = self.agg_net[stage_cnt](x.shape, x, lr_guide[stage_cnt], r_emb_lite)
+ x = x + guide
+
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x) # 0 indicate last output
+ if require_q:
+ return level_outputs, qs
+ return level_outputs
+
+
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None, require_ff=False, agg_f=None, r_emb_lite=None, guide_weight=1):
+ if require_ff:
+ agg_feas = []
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+
+
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+
+ if cnet is not None and agg_f is None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+
+ x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+
+
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+
+
+ x = block(x, clip)
+ if require_ff and (k == 2 ):
+ agg_feas.append(x.clone())
+ if agg_f is not None and (k == 2 ) :
+
+ guide = self.agg_net_up[i](x.shape, x, agg_f[i], r_emb_lite) # training 1 test 4k 0.8 2k 0.7
+ if not self.training:
+ hw = x.shape[-2] * x.shape[-1]
+ if hw >= 96*96:
+ guide = 0.7*guide
+
+ else:
+
+ if hw >= 72*72:
+ guide = 0.5* guide
+ else:
+
+ guide = 0.3* guide
+
+ x = x + guide
+
+
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ #if require_ff:
+ # agg_feas.append(x.clone())
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+
+
+ if require_ff:
+ return x, agg_feas
+
+ return x
+
+
+
+
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, lr_guide=None, reuire_f=False, cnet=None, require_t=False, guide_weight=0.5, **kwargs):
+
+ r_embed = self.gen_r_embedding(r)
+
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
+
+ # Model Blocks
+
+ x = self.embedding(x)
+
+
+
+ if cnet is not None:
+ cnet = ControlNetDeliverer(cnet)
+
+ if not reuire_f:
+ level_outputs = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, \
+ require_q=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight)
+ x = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, \
+ require_ff=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight)
+ else:
+ level_outputs, lr_enc = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, require_q=True)
+ x, lr_dec = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, require_ff=True)
+
+ if reuire_f and require_t:
+ return self.clf(x), r_embed, lr_enc, lr_dec
+ if reuire_f:
+ return self.clf(x), lr_enc, lr_dec
+ if require_t:
+ return self.clf(x), r_embed
+ return self.clf(x)
+
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
+
+
+
+if __name__ == '__main__':
+ generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+ total_ori = sum([ param.nelement() for param in generator.parameters()])
+ generator._init_extra_parameter()
+ generator = generator.cuda()
+ total = sum([ param.nelement() for param in generator.parameters()])
+ total_down = sum([ param.nelement() for param in generator.down_blocks.parameters()])
+
+ total_up = sum([ param.nelement() for param in generator.up_blocks.parameters()])
+ total_pro = sum([ param.nelement() for param in generator.project.parameters()])
+
+
+ print(total_ori / 1e6, total / 1e6, total_up / 1e6, total_down / 1e6, total_pro / 1e6)
+
+ # for name, module in generator.down_blocks.named_modules():
+ # print(name, module)
+ output, out_lr = generator(
+ x=torch.randn(1, 16, 24, 24).cuda(),
+ x_lr=torch.randn(1, 16, 16, 16).cuda(),
+ r=torch.tensor([0.7056]).cuda(),
+ clip_text=torch.randn(1, 77, 1280).cuda(),
+ clip_text_pooled = torch.randn(1, 1, 1280).cuda(),
+ clip_img = torch.randn(1, 1, 768).cuda()
+ )
+ print(output.shape, out_lr.shape)
+ # cnt
diff --git a/modules/previewer.py b/modules/previewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..51ab24292d8ac0da8d24b17d8fc0ac9e1419a3d7
--- /dev/null
+++ b/modules/previewer.py
@@ -0,0 +1,45 @@
+from torch import nn
+
+
+# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
+class Previewer(nn.Module):
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
+ super().__init__()
+ self.blocks = nn.Sequential(
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden),
+
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden),
+
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 2),
+
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 2),
+
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ nn.GELU(),
+ nn.BatchNorm2d(c_hidden // 4),
+
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
+ )
+
+ def forward(self, x):
+ return self.blocks(x)
diff --git a/modules/resnet.py b/modules/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..460a808942be147d76b8b1f3baf29fec1e2a7b8d
--- /dev/null
+++ b/modules/resnet.py
@@ -0,0 +1,415 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+#import fvcore.nn.weight_init as weight_init
+
+"""
+Functions for building the BottleneckBlock from Detectron2.
+# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
+"""
+
+def get_norm(norm, out_channels, num_norm_groups=32):
+ """
+ Args:
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
+ or a callable that takes a channel number and returns
+ the normalization layer as a nn.Module.
+ Returns:
+ nn.Module or None: the normalization layer
+ """
+ if norm is None:
+ return None
+ if isinstance(norm, str):
+ if len(norm) == 0:
+ return None
+ norm = {
+ "GN": lambda channels: nn.GroupNorm(num_norm_groups, channels),
+ }[norm]
+ return norm(out_channels)
+
+class Conv2d(nn.Conv2d):
+ """
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
+ Args:
+ norm (nn.Module, optional): a normalization layer
+ activation (callable(Tensor) -> Tensor): a callable activation function
+ It assumes that norm layer is used before activation.
+ """
+ norm = kwargs.pop("norm", None)
+ activation = kwargs.pop("activation", None)
+ super().__init__(*args, **kwargs)
+
+ self.norm = norm
+ self.activation = activation
+
+ def forward(self, x):
+ x = F.conv2d(
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
+ )
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
+class CNNBlockBase(nn.Module):
+ """
+ A CNN block is assumed to have input channels, output channels and a stride.
+ The input and output of `forward()` method must be NCHW tensors.
+ The method can perform arbitrary computation but must match the given
+ channels and stride specification.
+ Attribute:
+ in_channels (int):
+ out_channels (int):
+ stride (int):
+ """
+
+ def __init__(self, in_channels, out_channels, stride):
+ """
+ The `__init__` method of any subclass should also contain these arguments.
+ Args:
+ in_channels (int):
+ out_channels (int):
+ stride (int):
+ """
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.stride = stride
+
+class BottleneckBlock(CNNBlockBase):
+ """
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ *,
+ bottleneck_channels,
+ stride=1,
+ num_groups=1,
+ norm="GN",
+ stride_in_1x1=False,
+ dilation=1,
+ num_norm_groups=32
+ ):
+ """
+ Args:
+ bottleneck_channels (int): number of output channels for the 3x3
+ "bottleneck" conv layers.
+ num_groups (int): number of groups for the 3x3 conv layer.
+ norm (str or callable): normalization for all conv layers.
+ See :func:`layers.get_norm` for supported format.
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
+ first 1x1 convolution or the bottleneck 3x3 convolution.
+ dilation (int): the dilation rate of the 3x3 conv layer.
+ """
+ super().__init__(in_channels, out_channels, stride)
+
+ if in_channels != out_channels:
+ self.shortcut = Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ norm=get_norm(norm, out_channels, num_norm_groups),
+ )
+ else:
+ self.shortcut = None
+
+ # The original MSRA ResNet models have stride in the first 1x1 conv
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
+ # stride in the 3x3 conv
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
+
+ self.conv1 = Conv2d(
+ in_channels,
+ bottleneck_channels,
+ kernel_size=1,
+ stride=stride_1x1,
+ bias=False,
+ norm=get_norm(norm, bottleneck_channels, num_norm_groups),
+ )
+
+ self.conv2 = Conv2d(
+ bottleneck_channels,
+ bottleneck_channels,
+ kernel_size=3,
+ stride=stride_3x3,
+ padding=1 * dilation,
+ bias=False,
+ groups=num_groups,
+ dilation=dilation,
+ norm=get_norm(norm, bottleneck_channels, num_norm_groups),
+ )
+
+ self.conv3 = Conv2d(
+ bottleneck_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False,
+ norm=get_norm(norm, out_channels, num_norm_groups),
+ )
+
+ #for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
+ # if layer is not None: # shortcut can be None
+ # weight_init.c2_msra_fill(layer)
+
+ # Zero-initialize the last normalization in each residual branch,
+ # so that at the beginning, the residual branch starts with zeros,
+ # and each residual block behaves like an identity.
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
+ # "For BN layers, the learnable scaling coefficient ¦Ã is initialized
+ # to be 1, except for each residual block's last BN
+ # where ¦Ã is initialized to be 0."
+
+ # nn.init.constant_(self.conv3.norm.weight, 0)
+ # TODO this somehow hurts performance when training GN models from scratch.
+ # Add it as an option when we need to use this code to train a backbone.
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = F.relu_(out)
+
+ out = self.conv2(out)
+ out = F.relu_(out)
+
+ out = self.conv3(out)
+
+ if self.shortcut is not None:
+ shortcut = self.shortcut(x)
+ else:
+ shortcut = x
+
+ out += shortcut
+ out = F.relu_(out)
+ return out
+
+class ResNet(nn.Module):
+ """
+ Implement :paper:`ResNet`.
+ """
+
+ def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
+ """
+ Args:
+ stem (nn.Module): a stem module
+ stages (list[list[CNNBlockBase]]): several (typically 4) stages,
+ each contains multiple :class:`CNNBlockBase`.
+ num_classes (None or int): if None, will not perform classification.
+ Otherwise, will create a linear layer.
+ out_features (list[str]): name of the layers whose outputs should
+ be returned in forward. Can be anything in "stem", "linear", or "res2" ...
+ If None, will return the output of the last layer.
+ freeze_at (int): The number of stages at the beginning to freeze.
+ see :meth:`freeze` for detailed explanation.
+ """
+ super().__init__()
+ self.stem = stem
+ self.num_classes = num_classes
+
+ current_stride = self.stem.stride
+ self._out_feature_strides = {"stem": current_stride}
+ self._out_feature_channels = {"stem": self.stem.out_channels}
+
+ self.stage_names, self.stages = [], []
+
+ if out_features is not None:
+ # Avoid keeping unused layers in this module. They consume extra memory
+ # and may cause allreduce to fail
+ num_stages = max(
+ [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
+ )
+ stages = stages[:num_stages]
+ for i, blocks in enumerate(stages):
+ assert len(blocks) > 0, len(blocks)
+ for block in blocks:
+ assert isinstance(block, CNNBlockBase), block
+
+ name = "res" + str(i + 2)
+ stage = nn.Sequential(*blocks)
+
+ self.add_module(name, stage)
+ self.stage_names.append(name)
+ self.stages.append(stage)
+
+ self._out_feature_strides[name] = current_stride = int(
+ current_stride * np.prod([k.stride for k in blocks])
+ )
+ self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
+ self.stage_names = tuple(self.stage_names) # Make it static for scripting
+
+ if num_classes is not None:
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.linear = nn.Linear(curr_channels, num_classes)
+
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
+ # "The 1000-way fully-connected layer is initialized by
+ # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
+ nn.init.normal_(self.linear.weight, std=0.01)
+ name = "linear"
+
+ if out_features is None:
+ out_features = [name]
+ self._out_features = out_features
+ assert len(self._out_features)
+ children = [x[0] for x in self.named_children()]
+ for out_feature in self._out_features:
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
+ self.freeze(freeze_at)
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+ outputs = {}
+ x = self.stem(x)
+ if "stem" in self._out_features:
+ outputs["stem"] = x
+ for name, stage in zip(self.stage_names, self.stages):
+ x = stage(x)
+ if name in self._out_features:
+ outputs[name] = x
+ if self.num_classes is not None:
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.linear(x)
+ if "linear" in self._out_features:
+ outputs["linear"] = x
+ return outputs
+
+ def freeze(self, freeze_at=0):
+ """
+ Freeze the first several stages of the ResNet. Commonly used in
+ fine-tuning.
+ Layers that produce the same feature map spatial size are defined as one
+ "stage" by :paper:`FPN`.
+ Args:
+ freeze_at (int): number of stages to freeze.
+ `1` means freezing the stem. `2` means freezing the stem and
+ one residual stage, etc.
+ Returns:
+ nn.Module: this ResNet itself
+ """
+ if freeze_at >= 1:
+ self.stem.freeze()
+ for idx, stage in enumerate(self.stages, start=2):
+ if freeze_at >= idx:
+ for block in stage.children():
+ block.freeze()
+ return self
+
+ @staticmethod
+ def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
+ """
+ Create a list of blocks of the same type that forms one ResNet stage.
+ Args:
+ block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
+ stage. A module of this type must not change spatial resolution of inputs unless its
+ stride != 1.
+ num_blocks (int): number of blocks in this stage
+ in_channels (int): input channels of the entire stage.
+ out_channels (int): output channels of **every block** in the stage.
+ kwargs: other arguments passed to the constructor of
+ `block_class`. If the argument name is "xx_per_block", the
+ argument is a list of values to be passed to each block in the
+ stage. Otherwise, the same argument is passed to every block
+ in the stage.
+ Returns:
+ list[CNNBlockBase]: a list of block module.
+ Examples:
+ ::
+ stage = ResNet.make_stage(
+ BottleneckBlock, 3, in_channels=16, out_channels=64,
+ bottleneck_channels=16, num_groups=1,
+ stride_per_block=[2, 1, 1],
+ dilations_per_block=[1, 1, 2]
+ )
+ Usually, layers that produce the same feature map spatial size are defined as one
+ "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
+ all be 1.
+ """
+ blocks = []
+ for i in range(num_blocks):
+ curr_kwargs = {}
+ for k, v in kwargs.items():
+ if k.endswith("_per_block"):
+ assert len(v) == num_blocks, (
+ f"Argument '{k}' of make_stage should have the "
+ f"same length as num_blocks={num_blocks}."
+ )
+ newk = k[: -len("_per_block")]
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
+ curr_kwargs[newk] = v[i]
+ else:
+ curr_kwargs[k] = v
+
+ blocks.append(
+ block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
+ )
+ in_channels = out_channels
+ return blocks
+
+ @staticmethod
+ def make_default_stages(depth, block_class=None, **kwargs):
+ """
+ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
+ If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
+ instead for fine-grained customization.
+ Args:
+ depth (int): depth of ResNet
+ block_class (type): the CNN block class. Has to accept
+ `bottleneck_channels` argument for depth > 50.
+ By default it is BasicBlock or BottleneckBlock, based on the
+ depth.
+ kwargs:
+ other arguments to pass to `make_stage`. Should not contain
+ stride and channels, as they are predefined for each depth.
+ Returns:
+ list[list[CNNBlockBase]]: modules in all stages; see arguments of
+ :class:`ResNet.__init__`.
+ """
+ num_blocks_per_stage = {
+ 18: [2, 2, 2, 2],
+ 34: [3, 4, 6, 3],
+ 50: [3, 4, 6, 3],
+ 101: [3, 4, 23, 3],
+ 152: [3, 8, 36, 3],
+ }[depth]
+ if block_class is None:
+ block_class = BasicBlock if depth < 50 else BottleneckBlock
+ if depth < 50:
+ in_channels = [64, 64, 128, 256]
+ out_channels = [64, 128, 256, 512]
+ else:
+ in_channels = [64, 256, 512, 1024]
+ out_channels = [256, 512, 1024, 2048]
+ ret = []
+ for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
+ if depth >= 50:
+ kwargs["bottleneck_channels"] = o // 4
+ ret.append(
+ ResNet.make_stage(
+ block_class=block_class,
+ num_blocks=n,
+ stride_per_block=[s] + [1] * (n - 1),
+ in_channels=i,
+ out_channels=o,
+ **kwargs,
+ )
+ )
+ return ret
\ No newline at end of file
diff --git a/modules/speed_util.py b/modules/speed_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b9507c74833bec270b00bd252a3c76fcc09fab3
--- /dev/null
+++ b/modules/speed_util.py
@@ -0,0 +1,55 @@
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
\ No newline at end of file
diff --git a/modules/stage_a.py b/modules/stage_a.py
new file mode 100644
index 0000000000000000000000000000000000000000..2840ef71d30e3da74954ab4a05e724fd7fef86cf
--- /dev/null
+++ b/modules/stage_a.py
@@ -0,0 +1,183 @@
+import torch
+from torch import nn
+from torchtools.nn import VectorQuantize
+from einops import rearrange
+import torch.nn.functional as F
+import math
+class ResBlock(nn.Module):
+ def __init__(self, c, c_hidden):
+ super().__init__()
+ # depthwise/attention
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.depthwise = nn.Sequential(
+ nn.ReplicationPad2d(1),
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
+ )
+
+ # channelwise
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ nn.Linear(c, c_hidden),
+ nn.GELU(),
+ nn.Linear(c_hidden, c),
+ )
+
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
+
+ # Init weights
+ def _basic_init(module):
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def _norm(self, x, norm):
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ def forward(self, x):
+
+ mods = self.gammas
+
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
+
+ #x = x.to(torch.float64)
+ x = x + self.depthwise(x_temp) * mods[2]
+
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
+
+ return x
+
+
+def extract_patches(tensor, patch_size, stride):
+ b, c, H, W = tensor.shape
+ pad_h = (patch_size - (H - patch_size) % stride) % stride
+ pad_w = (patch_size - (W - patch_size) % stride) % stride
+ tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect')
+
+
+ patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
+ patches = patches.contiguous().view(b, c, -1, patch_size, patch_size)
+ patches = patches.permute(0, 2, 1, 3, 4)
+ return patches, (H, W)
+
+def fuse_patches(patches, patch_size, stride, H, W):
+
+ b, num_patches, c, _, _ = patches.shape
+ patches = patches.permute(0, 2, 1, 3, 4)
+
+
+
+ pad_h = (patch_size - (H - patch_size) % stride) % stride
+ pad_w = (patch_size - (W - patch_size) % stride) % stride
+ out_h = H + pad_h
+ out_w = W + pad_w
+ patches = patches.contiguous().view(b, c , -1, patch_size*patch_size ).permute(0, 1, 3, 2)
+ patches = patches.contiguous().view(b, c*patch_size*patch_size, -1)
+
+ tensor = F.fold(patches, output_size=(out_h, out_w), kernel_size=patch_size, stride=stride)
+ overlap_cnt = F.fold(torch.ones_like(patches), output_size=(out_h, out_w), kernel_size=patch_size, stride=stride)
+ tensor = tensor / overlap_cnt
+ print('end fuse patch', tensor.shape, (tensor.dtype))
+ return tensor[:, :, :H, :W]
+
+
+
+class StageA(nn.Module):
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
+ scale_factor=0.43): # 0.3764
+ super().__init__()
+ self.c_latent = c_latent
+ self.scale_factor = scale_factor
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
+
+ # Encoder blocks
+ self.in_block = nn.Sequential(
+ nn.PixelUnshuffle(2),
+ nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
+ )
+ down_blocks = []
+ for i in range(levels):
+ if i > 0:
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
+ down_blocks.append(block)
+ down_blocks.append(nn.Sequential(
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
+ ))
+ self.down_blocks = nn.Sequential(*down_blocks)
+ self.down_blocks[0]
+
+ self.codebook_size = codebook_size
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
+
+ # Decoder blocks
+ up_blocks = [nn.Sequential(
+ nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
+ )]
+ for i in range(levels):
+ for j in range(bottleneck_blocks if i == 0 else 1):
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
+ up_blocks.append(block)
+ if i < levels - 1:
+ up_blocks.append(
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
+ padding=1))
+ self.up_blocks = nn.Sequential(*up_blocks)
+ self.out_block = nn.Sequential(
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
+ nn.PixelShuffle(2),
+ )
+
+ def encode(self, x, quantize=False):
+ x = self.in_block(x)
+ x = self.down_blocks(x)
+ if quantize:
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
+ return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
+ else:
+ return x / self.scale_factor, None, None, None
+
+
+
+ def decode(self, x, tiled_decoding=False):
+ x = x * self.scale_factor
+ x = self.up_blocks(x)
+ x = self.out_block(x)
+ return x
+
+ def forward(self, x, quantize=False):
+ qe, x, _, vq_loss = self.encode(x, quantize)
+ x = self.decode(qe)
+ return x, vq_loss
+
+
+class Discriminator(nn.Module):
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
+ super().__init__()
+ d = max(depth - 3, 3)
+ layers = [
+ nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
+ nn.LeakyReLU(0.2),
+ ]
+ for i in range(depth - 1):
+ c_in = c_hidden // (2 ** max((d - i), 0))
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
+ layers.append(nn.InstanceNorm2d(c_out))
+ layers.append(nn.LeakyReLU(0.2))
+ self.encoder = nn.Sequential(*layers)
+ self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
+ self.logits = nn.Sigmoid()
+
+ def forward(self, x, cond=None):
+ x = self.encoder(x)
+ if cond is not None:
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
+ x = torch.cat([x, cond], dim=1)
+ x = self.shuffle(x)
+ x = self.logits(x)
+ return x
diff --git a/modules/stage_b.py b/modules/stage_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..f89b42d61327278820e164b1c093cbf8d1048ee1
--- /dev/null
+++ b/modules/stage_b.py
@@ -0,0 +1,239 @@
+import math
+import numpy as np
+import torch
+from torch import nn
+from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock
+
+
+class StageB(nn.Module):
+ def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
+ nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
+ c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True,
+ t_conds=['sca']):
+ super().__init__()
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+
+ # CONDITIONING
+ self.effnet_mapper = nn.Sequential(
+ nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1),
+ nn.GELU(),
+ nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+ self.pixels_mapper = nn.Sequential(
+ nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1),
+ nn.GELU(),
+ nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+ self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq)
+ self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
+ nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
+ nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
+ nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ self.apply(self._init_weights) # General init
+ nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
+ nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
+ nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
+ nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
+ nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
+
+ # blocks
+ for level_block in self.down_blocks + self.up_blocks:
+ for block in level_block:
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ elif isinstance(block, TimestepBlock):
+ for layer in block.modules():
+ if isinstance(layer, nn.Linear):
+ nn.init.constant_(layer.weight, 0)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip):
+ if len(clip.shape) == 2:
+ clip = clip.unsqueeze(1)
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip):
+ level_outputs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for down_block, downscaler, repmap in block_group:
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for block in down_block:
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, clip):
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+ return x
+
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
+ if pixels is None:
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
+
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r)
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
+ clip = self.gen_c_embeddings(clip)
+
+ # Model Blocks
+ x = self.embedding(x)
+ x = x + self.effnet_mapper(
+ nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode='bilinear', align_corners=True))
+ x = x + nn.functional.interpolate(self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ level_outputs = self._down_encode(x, r_embed, clip)
+ x = self._up_decode(level_outputs, r_embed, clip)
+ return self.clf(x)
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/modules/stage_c.py b/modules/stage_c.py
new file mode 100644
index 0000000000000000000000000000000000000000..53b73d0197712b981ec1a154428c21af2149646a
--- /dev/null
+++ b/modules/stage_c.py
@@ -0,0 +1,252 @@
+import torch
+from torch import nn
+import numpy as np
+import math
+from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock
+#from .controlnet import ControlNetDeliverer
+
+
+class UpDownBlock2d(nn.Module):
+ def __init__(self, c_in, c_out, mode, enabled=True):
+ super().__init__()
+ assert mode in ['up', 'down']
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
+ align_corners=True) if enabled else nn.Identity()
+ mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x.float())
+ return x
+
+
+class StageC(nn.Module):
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
+ dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False]):
+ super().__init__()
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+
+ # CONDITIONING
+ self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond)
+ self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq)
+ self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq)
+ self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1),
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1])
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1])
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
+ nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ self.apply(self._init_weights) # General init
+ nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
+ nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
+ nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
+
+ # blocks
+ for level_block in self.down_blocks + self.up_blocks:
+ for block in level_block:
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ elif isinstance(block, TimestepBlock):
+ for layer in block.modules():
+ if isinstance(layer, nn.Linear):
+ nn.init.constant_(layer.weight, 0)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
+ clip_txt = self.clip_txt_mapper(clip_txt)
+ if len(clip_txt_pooled.shape) == 2:
+ clip_txt_pool = clip_txt_pooled.unsqueeze(1)
+ if len(clip_img.shape) == 2:
+ clip_img = clip_img.unsqueeze(1)
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip, cnet=None):
+ level_outputs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for down_block, downscaler, repmap in block_group:
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for block in down_block:
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+ return x
+
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r)
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
+
+ # Model Blocks
+ x = self.embedding(x)
+ if cnet is not None:
+ cnet = ControlNetDeliverer(cnet)
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
+ return self.clf(x)
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/prompt_list.txt b/prompt_list.txt
new file mode 100644
index 0000000000000000000000000000000000000000..27cd31b4750d2f15fdb6f2a3f4bdd117a7377267
--- /dev/null
+++ b/prompt_list.txt
@@ -0,0 +1,32 @@
+A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening
+in the early morning light.
+
+A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a
+clear blue sky.
+
+A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing
+a vintage floral dress and standing in front of a blooming garden.
+
+The image features a snow-covered mountain range with a large, snow-covered mountain in the background.
+The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the
+winter season, with snow covering the ground and the trees.
+
+Crocodile in a sweater.
+
+A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school
+uniform, standing under a cherry blossom tree with petals falling around her. The background shows a
+traditional Japanese school with cherry blossoms in full bloom.
+
+A playful Labrador retriever puppy with a shiny, golden coat, chasing a red ball in a spacious backyard, with
+green grass and a wooden fence.
+
+A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm
+lights glowing from the windows, and a path of footprints leading to the front door.
+
+A highly detailed, high-quality image of the Banff National Park in Canada. The turquoise waters of Lake
+Louise are surrounded by snow-capped mountains and dense pine forests. A wooden canoe is docked at the
+edge of the lake. The sky is a clear, bright blue, and the air is crisp and fresh.
+
+A highly detailed, high-quality image of a Shih Tzu receiving a bath in a home bathroom. The dog is standing
+in a tub, covered in suds, with a slightly wet and adorable look. The background includes bathroom fixtures,
+towels, and a clean, tiled floor.
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 7647fc562eb544209951a1163bdd619dbbed29b1..0a2397e70942601b5d37bd0b370cb842e1feb7bb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,19 @@
-huggingface_hub==0.22.2
-scikit-build-core
-https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.81-cu124/llama_cpp_python-0.2.81-cp310-cp310-linux_x86_64.whl
\ No newline at end of file
+--find-links https://download.pytorch.org/whl/torch_stable.html
+accelerate>=0.25.0
+torch==2.1.2+cu118
+torchvision==0.16.2+cu118
+transformers>=4.30.0
+numpy>=1.23.5
+kornia>=0.7.0
+insightface>=0.7.3
+opencv-python>=4.8.1.78
+tqdm>=4.66.1
+matplotlib>=3.7.4
+webdataset>=0.2.79
+wandb>=0.16.2
+munch>=4.0.0
+onnxruntime>=1.16.3
+einops>=0.7.0
+onnx2torch>=1.5.13
+warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
+torchtools @ git+https://github.com/pabloppp/pytorch-tools
diff --git a/train/__init__.py b/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1331f6b933f63c99a6bdf074201fdb4b8f78c2
--- /dev/null
+++ b/train/__init__.py
@@ -0,0 +1,5 @@
+from .train_b import WurstCore as WurstCoreB
+from .train_c import WurstCore as WurstCoreC
+from .train_t2i import WurstCore as WurstCore_t2i
+from .train_ultrapixel_control import WurstCore as WurstCore_control_lrguide
+from .train_personalized import WurstCore as WurstCore_personalized
\ No newline at end of file
diff --git a/train/base.py b/train/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e8a6ef306e40da8c9d8db33ceba2f8b2982a9a9
--- /dev/null
+++ b/train/base.py
@@ -0,0 +1,402 @@
+import yaml
+import json
+import torch
+import wandb
+import torchvision
+import numpy as np
+from torch import nn
+from tqdm import tqdm
+from abc import abstractmethod
+from fractions import Fraction
+import matplotlib.pyplot as plt
+from dataclasses import dataclass
+from torch.distributed import barrier
+from torch.utils.data import DataLoader
+
+from gdf import GDF
+from gdf import AdaptiveLossWeight
+
+from core import WarpCore
+from core.data import setup_webdataset_path, MultiGetter, MultiFilter, Bucketeer
+from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
+
+import webdataset as wds
+from webdataset.handlers import warn_and_continue
+
+import transformers
+transformers.utils.logging.set_verbosity_error()
+
+
+class DataCore(WarpCore):
+ @dataclass(frozen=True)
+ class Config(WarpCore.Config):
+ image_size: int = EXPECTED_TRAIN
+ webdataset_path: str = EXPECTED_TRAIN
+ grad_accum_steps: int = EXPECTED_TRAIN
+ batch_size: int = EXPECTED_TRAIN
+ multi_aspect_ratio: list = None
+
+ captions_getter: list = None
+ dataset_filters: list = None
+
+ bucketeer_random_ratio: float = 0.05
+
+ @dataclass(frozen=True)
+ class Extras(WarpCore.Extras):
+ transforms: torchvision.transforms.Compose = EXPECTED
+ clip_preprocess: torchvision.transforms.Compose = EXPECTED
+
+ @dataclass(frozen=True)
+ class Models(WarpCore.Models):
+ tokenizer: nn.Module = EXPECTED
+ text_model: nn.Module = EXPECTED
+ image_model: nn.Module = None
+
+ config: Config
+
+ def webdataset_path(self):
+ if isinstance(self.config.webdataset_path, str) and (self.config.webdataset_path.strip().startswith(
+ 'pipe:') or self.config.webdataset_path.strip().startswith('file:')):
+ return self.config.webdataset_path
+ else:
+ dataset_path = self.config.webdataset_path
+ if isinstance(self.config.webdataset_path, str) and self.config.webdataset_path.strip().endswith('.yml'):
+ with open(self.config.webdataset_path, 'r', encoding='utf-8') as file:
+ dataset_path = yaml.safe_load(file)
+ return setup_webdataset_path(dataset_path, cache_path=f"{self.config.experiment_id}_webdataset_cache.yml")
+
+ def webdataset_preprocessors(self, extras: Extras):
+ def identity(x):
+ if isinstance(x, bytes):
+ x = x.decode('utf-8')
+ return x
+
+ # CUSTOM CAPTIONS GETTER -----
+ def get_caption(oc, c, p_og=0.05): # cog_contexual, cog_caption
+ if p_og > 0 and np.random.rand() < p_og and len(oc) > 0:
+ return identity(oc)
+ else:
+ return identity(c)
+
+ captions_getter = MultiGetter(rules={
+ ('old_caption', 'caption'): lambda oc, c: get_caption(json.loads(oc)['og_caption'], c, p_og=0.05)
+ })
+
+ return [
+ ('jpg;png',
+ torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None else extras.transforms,
+ 'images'),
+ ('txt', identity, 'captions') if self.config.captions_getter is None else (
+ self.config.captions_getter[0], eval(self.config.captions_getter[1]), 'captions'),
+ ]
+
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
+ # SETUP DATASET
+ dataset_path = self.webdataset_path()
+ preprocessors = self.webdataset_preprocessors(extras)
+
+ handler = warn_and_continue
+ dataset = wds.WebDataset(
+ dataset_path, resampled=True, handler=handler
+ ).select(
+ MultiFilter(rules={
+ f[0]: eval(f[1]) for f in self.config.dataset_filters
+ }) if self.config.dataset_filters is not None else lambda _: True
+ ).shuffle(690, handler=handler).decode(
+ "pilrgb", handler=handler
+ ).to_tuple(
+ *[p[0] for p in preprocessors], handler=handler
+ ).map_tuple(
+ *[p[1] for p in preprocessors], handler=handler
+ ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)})
+
+ def identity(x):
+ return x
+
+ # SETUP DATALOADER
+ real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
+ dataloader = DataLoader(
+ dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True,
+ collate_fn=identity if self.config.multi_aspect_ratio is not None else None
+ )
+ if self.is_main_node:
+ print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
+
+ if self.config.multi_aspect_ratio is not None:
+ aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
+ dataloader_iterator = Bucketeer(dataloader, density=self.config.image_size ** 2, factor=32,
+ ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
+ interpolate_nearest=False) # , use_smartcrop=True)
+ else:
+ dataloader_iterator = iter(dataloader)
+
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator)
+
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
+ eval_image_embeds=False, return_fields=None):
+ if return_fields is None:
+ return_fields = ['clip_text', 'clip_text_pooled', 'clip_img']
+
+ captions = batch.get('captions', None)
+ images = batch.get('images', None)
+ batch_size = len(captions)
+
+ text_embeddings = None
+ text_pooled_embeddings = None
+ if 'clip_text' in return_fields or 'clip_text_pooled' in return_fields:
+ if is_eval:
+ if is_unconditional:
+ captions_unpooled = ["" for _ in range(batch_size)]
+ else:
+ captions_unpooled = captions
+ else:
+ rand_idx = np.random.rand(batch_size) > 0.05
+ captions_unpooled = [str(c) if keep else "" for c, keep in zip(captions, rand_idx)]
+ clip_tokens_unpooled = models.tokenizer(captions_unpooled, truncation=True, padding="max_length",
+ max_length=models.tokenizer.model_max_length,
+ return_tensors="pt").to(self.device)
+ text_encoder_output = models.text_model(**clip_tokens_unpooled, output_hidden_states=True)
+ if 'clip_text' in return_fields:
+ text_embeddings = text_encoder_output.hidden_states[-1]
+ if 'clip_text_pooled' in return_fields:
+ text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1)
+
+ image_embeddings = None
+ if 'clip_img' in return_fields:
+ image_embeddings = torch.zeros(batch_size, 768, device=self.device)
+ if images is not None:
+ images = images.to(self.device)
+ if is_eval:
+ if not is_unconditional and eval_image_embeds:
+ image_embeddings = models.image_model(extras.clip_preprocess(images)).image_embeds
+ else:
+ rand_idx = np.random.rand(batch_size) > 0.9
+ if any(rand_idx):
+ image_embeddings[rand_idx] = models.image_model(extras.clip_preprocess(images[rand_idx])).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+ return {
+ 'clip_text': text_embeddings,
+ 'clip_text_pooled': text_pooled_embeddings,
+ 'clip_img': image_embeddings
+ }
+
+
+class TrainingCore(DataCore, WarpCore):
+ @dataclass(frozen=True)
+ class Config(DataCore.Config, WarpCore.Config):
+ updates: int = EXPECTED_TRAIN
+ backup_every: int = EXPECTED_TRAIN
+ save_every: int = EXPECTED_TRAIN
+
+ # EMA UPDATE
+ ema_start_iters: int = None
+ ema_iters: int = None
+ ema_beta: float = None
+
+ use_fsdp: bool = None
+
+ @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
+ class Info(WarpCore.Info):
+ ema_loss: float = None
+ adaptive_loss: dict = None
+
+ @dataclass(frozen=True)
+ class Models(WarpCore.Models):
+ generator: nn.Module = EXPECTED
+ generator_ema: nn.Module = None # optional
+
+ @dataclass(frozen=True)
+ class Optimizers(WarpCore.Optimizers):
+ generator: any = EXPECTED
+
+ @dataclass(frozen=True)
+ class Extras(WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+
+ info: Info
+ config: Config
+
+ @abstractmethod
+ def forward_pass(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models):
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers,
+ schedulers: WarpCore.Schedulers):
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def models_to_save(self) -> list:
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ raise NotImplementedError("This method needs to be overriden")
+
+ @abstractmethod
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ raise NotImplementedError("This method needs to be overriden")
+
+ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers,
+ schedulers: WarpCore.Schedulers):
+ start_iter = self.info.iter + 1
+ max_iters = self.config.updates * self.config.grad_accum_steps
+ if self.is_main_node:
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
+
+ pbar = tqdm(range(start_iter, max_iters + 1)) if self.is_main_node else range(start_iter,
+ max_iters + 1) # <--- DDP
+ if 'generator' in self.models_to_save():
+ models.generator.train()
+ for i in pbar:
+ # FORWARD PASS
+ loss, loss_adjusted = self.forward_pass(data, extras, models)
+
+ # # BACKWARD PASS
+ grad_norm = self.backward_pass(
+ i % self.config.grad_accum_steps == 0 or i == max_iters, loss, loss_adjusted,
+ models, optimizers, schedulers
+ )
+ self.info.iter = i
+
+ # UPDATE EMA
+ if models.generator_ema is not None and i % self.config.ema_iters == 0:
+ update_weights_ema(
+ models.generator_ema, models.generator,
+ beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
+ )
+
+ # UPDATE LOSS METRICS
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
+
+ if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(
+ grad_norm.item()):
+ wandb.alert(
+ title=f"NaN value encountered in training run {self.info.wandb_run_id}",
+ text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
+ wait_duration=60 * 30
+ )
+
+ if self.is_main_node:
+ logs = {
+ 'loss': self.info.ema_loss,
+ 'raw_loss': loss.mean().item(),
+ 'grad_norm': grad_norm.item(),
+ 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
+ 'total_steps': self.info.total_steps,
+ }
+
+ pbar.set_postfix(logs)
+ if self.config.wandb_project is not None:
+ wandb.log(logs)
+
+ if i == 1 or i % (self.config.save_every * self.config.grad_accum_steps) == 0 or i == max_iters:
+ # SAVE AND CHECKPOINT STUFF
+ if np.isnan(loss.mean().item()):
+ if self.is_main_node and self.config.wandb_project is not None:
+ tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
+ wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.wandb_run_id}",
+ text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
+ else:
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ self.info.adaptive_loss = {
+ 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
+ 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
+ }
+ self.save_checkpoints(models, optimizers)
+ if self.is_main_node:
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
+ self.sample(models, data, extras)
+
+ def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
+ barrier()
+ suffix = '' if suffix is None else suffix
+ self.save_info(self.info, suffix=suffix)
+ models_dict = models.to_dict()
+ optimizers_dict = optimizers.to_dict()
+ for key in self.models_to_save():
+ model = models_dict[key]
+ if model is not None:
+ self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
+ for key in optimizers_dict:
+ optimizer = optimizers_dict[key]
+ if optimizer is not None:
+ self.save_optimizer(optimizer, f'{key}_optim{suffix}',
+ fsdp_model=models_dict[key] if self.config.use_fsdp else None)
+ if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
+ self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k")
+ torch.cuda.empty_cache()
+
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
+ if 'generator' in self.models_to_save():
+ models.generator.eval()
+ with torch.no_grad():
+ batch = next(data.iterator)
+
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ latents = self.encode_latents(batch, models, extras)
+ noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ pred = models.generator(noised, noise_cond, **conditions)
+ pred = extras.gdf.undiffuse(noised, logSNR, pred)[0]
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ *_, (sampled, _, _) = extras.gdf.sample(
+ models.generator, conditions,
+ latents.shape, unconditions, device=self.device, **extras.sampling_configs
+ )
+
+ if models.generator_ema is not None:
+ *_, (sampled_ema, _, _) = extras.gdf.sample(
+ models.generator_ema, conditions,
+ latents.shape, unconditions, device=self.device, **extras.sampling_configs
+ )
+ else:
+ sampled_ema = sampled
+
+ if self.is_main_node:
+ noised_images = torch.cat(
+ [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0)
+ pred_images = torch.cat(
+ [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0)
+ sampled_images = torch.cat(
+ [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0)
+ sampled_images_ema = torch.cat(
+ [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))],
+ dim=0)
+
+ images = batch['images']
+ if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
+ images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
+
+ collage_img = torch.cat([
+ torch.cat([i for i in images.cpu()], dim=-1),
+ torch.cat([i for i in noised_images.cpu()], dim=-1),
+ torch.cat([i for i in pred_images.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_ema.cpu()], dim=-1),
+ ], dim=-2)
+
+ torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
+ torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg')
+
+ captions = batch['captions']
+ if self.config.wandb_project is not None:
+ log_data = [
+ [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [
+ wandb.Image(images[i])] for i in range(len(images))]
+ log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"])
+ wandb.log({"Log": log_table})
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1])
+ plt.ylabel('Raw Loss')
+ plt.ylabel('LogSNR')
+ wandb.log({"Loss/LogSRN": plt})
+
+ if 'generator' in self.models_to_save():
+ models.generator.train()
diff --git a/train/dist_core.py b/train/dist_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e4e9e670a3b853fac345618d3557d648d813902
--- /dev/null
+++ b/train/dist_core.py
@@ -0,0 +1,47 @@
+import os
+import torch
+
+
+def get_world_size():
+ """Find OMPI world size without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('PMI_SIZE') is not None:
+ return int(os.environ.get('PMI_SIZE') or 1)
+ elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
+ else:
+ return torch.cuda.device_count()
+
+
+def get_global_rank():
+ """Find OMPI world rank without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('PMI_RANK') is not None:
+ return int(os.environ.get('PMI_RANK') or 0)
+ elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
+ else:
+ return 0
+
+
+def get_local_rank():
+ """Find OMPI local rank without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('MPI_LOCALRANKID') is not None:
+ return int(os.environ.get('MPI_LOCALRANKID') or 0)
+ elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
+ else:
+ return 0
+
+
+def get_master_ip():
+ if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
+ return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
+ elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
+ return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
+ else:
+ return "127.0.0.1"
diff --git a/train/train_b.py b/train/train_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3441a5841750a7c33b49756d2d60064a68d82d8
--- /dev/null
+++ b/train/train_b.py
@@ -0,0 +1,305 @@
+import torch
+import torchvision
+from torch import nn, optim
+from transformers import AutoTokenizer, CLIPTextModelWithProjection
+from warmup_scheduler import GradualWarmupScheduler
+import numpy as np
+
+import sys
+import os
+from dataclasses import dataclass
+
+from gdf import GDF, EpsilonTarget, CosineSchedule
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from torchtools.transforms import SmartCrop
+
+from modules.effnet import EfficientNetEncoder
+from modules.stage_a import StageA
+
+from modules.stage_b import StageB
+from modules.stage_b import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
+
+from train.base import DataCore, TrainingCore
+
+from core import WarpCore
+from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
+
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from contextlib import contextmanager
+
+class WurstCore(TrainingCore, DataCore, WarpCore):
+ @dataclass(frozen=True)
+ class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
+ # TRAINING PARAMS
+ lr: float = EXPECTED_TRAIN
+ warmup_updates: int = EXPECTED_TRAIN
+ shift: float = EXPECTED_TRAIN
+ dtype: str = None
+
+ # MODEL VERSION
+ model_version: str = EXPECTED # 3BB or 700M
+ clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
+
+ # CHECKPOINT PATHS
+ stage_a_checkpoint_path: str = EXPECTED
+ effnet_checkpoint_path: str = EXPECTED
+ generator_checkpoint_path: str = None
+
+ # gdf customization
+ adaptive_loss_weight: str = None
+
+ @dataclass(frozen=True)
+ class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
+ effnet: nn.Module = EXPECTED
+ stage_a: nn.Module = EXPECTED
+
+ @dataclass(frozen=True)
+ class Schedulers(WarpCore.Schedulers):
+ generator: any = None
+
+ @dataclass(frozen=True)
+ class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+ effnet_preprocess: torchvision.transforms.Compose = EXPECTED
+
+ info: TrainingCore.Info
+ config: Config
+
+ def setup_extras_pre(self) -> Extras:
+ gdf = GDF(
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ input_scaler=VPScaler(), target=EpsilonTarget(),
+ noise_cond=CosineTNoiseCond(),
+ loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
+ )
+ sampling_configs = {"cfg": 1.5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 10}
+
+ if self.info.adaptive_loss is not None:
+ gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
+ gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
+
+ effnet_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+ ])
+
+ transforms = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Resize(self.config.image_size,
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
+ antialias=True),
+ SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) if self.config.training else torchvision.transforms.CenterCrop(self.config.image_size)
+ ])
+
+ return self.Extras(
+ gdf=gdf,
+ sampling_configs=sampling_configs,
+ transforms=transforms,
+ effnet_preprocess=effnet_preprocess,
+ clip_preprocess=None
+ )
+
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None):
+ images = batch.get('images', None)
+
+ if images is not None:
+ images = images.to(self.device)
+ if is_eval and not is_unconditional:
+ effnet_embeddings = models.effnet(extras.effnet_preprocess(images))
+ else:
+ if is_eval:
+ effnet_factor = 1
+ else:
+ effnet_factor = np.random.uniform(0.5, 1) # f64 to f32
+ effnet_height, effnet_width = int(((images.size(-2)*effnet_factor)//32)*32), int(((images.size(-1)*effnet_factor)//32)*32)
+
+ effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height//32, effnet_width//32, device=self.device)
+ if not is_eval:
+ effnet_images = torchvision.transforms.functional.resize(images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
+ rand_idx = np.random.rand(len(images)) <= 0.9
+ if any(rand_idx):
+ effnet_embeddings[rand_idx] = models.effnet(extras.effnet_preprocess(effnet_images[rand_idx]))
+ else:
+ effnet_embeddings = None
+
+ conditions = super().get_conditions(
+ batch, models, extras, is_eval, is_unconditional,
+ eval_image_embeds, return_fields=return_fields or ['clip_text_pooled']
+ )
+
+ return {'effnet': effnet_embeddings, 'clip': conditions['clip_text_pooled']}
+
+ def setup_models(self, extras: Extras, skip_clip: bool = False) -> Models:
+ dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32
+
+ # EfficientNet encoder
+ effnet = EfficientNetEncoder().to(self.device)
+ effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
+
+ effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
+ effnet.eval().requires_grad_(False)
+ del effnet_checkpoint
+
+ # vqGAN
+ stage_a = StageA().to(self.device)
+ stage_a_checkpoint = load_or_fail(self.config.stage_a_checkpoint_path)
+ stage_a.load_state_dict(stage_a_checkpoint if 'state_dict' not in stage_a_checkpoint else stage_a_checkpoint['state_dict'])
+ stage_a.eval().requires_grad_(False)
+ del stage_a_checkpoint
+
+ @contextmanager
+ def dummy_context():
+ yield None
+
+ loading_context = dummy_context if self.config.training else init_empty_weights
+
+ # Diffusion models
+ with loading_context():
+ generator_ema = None
+ if self.config.model_version == '3B':
+ generator = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]])
+ if self.config.ema_start_iters is not None:
+ generator_ema = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]])
+ elif self.config.model_version == '700M':
+ generator = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]])
+ if self.config.ema_start_iters is not None:
+ generator_ema = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]])
+ else:
+ raise ValueError(f"Unknown model version {self.config.model_version}")
+
+ if self.config.generator_checkpoint_path is not None:
+ if loading_context is dummy_context:
+ generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
+ else:
+ for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
+ set_module_tensor_to_device(generator, param_name, "cpu", value=param)
+ generator = generator.to(dtype).to(self.device)
+ generator = self.load_model(generator, 'generator')
+
+ if generator_ema is not None:
+ if loading_context is dummy_context:
+ generator_ema.load_state_dict(generator.state_dict())
+ else:
+ for param_name, param in generator.state_dict().items():
+ set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param)
+ generator_ema = self.load_model(generator_ema, 'generator_ema')
+ generator_ema.to(dtype).to(self.device).eval().requires_grad_(False)
+
+ if self.config.use_fsdp:
+ fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock])
+ generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
+ if generator_ema is not None:
+ generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
+
+ if skip_clip:
+ tokenizer = None
+ text_model = None
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
+ text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
+
+ return self.Models(
+ effnet=effnet, stage_a=stage_a,
+ generator=generator, generator_ema=generator_ema,
+ tokenizer=tokenizer, text_model=text_model
+ )
+
+ def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+ optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
+ optimizer = self.load_optimizer(optimizer, 'generator_optim',
+ fsdp_model=models.generator if self.config.use_fsdp else None)
+ return self.Optimizers(generator=optimizer)
+
+ def setup_schedulers(self, extras: Extras, models: Models,
+ optimizers: TrainingCore.Optimizers) -> Schedulers:
+ scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
+ scheduler.last_epoch = self.info.total_steps
+ return self.Schedulers(generator=scheduler)
+
+ def _pyramid_noise(self, epsilon, size_range=None, levels=10, scale_mode='nearest'):
+ epsilon = epsilon.clone()
+ multipliers = [1]
+ for i in range(1, levels):
+ m = 0.75 ** i
+ h, w = epsilon.size(-2) // (2 ** i), epsilon.size(-2) // (2 ** i)
+ if size_range is None or (size_range[0] <= h <= size_range[1] or size_range[0] <= w <= size_range[1]):
+ offset = torch.randn(epsilon.size(0), epsilon.size(1), h, w, device=self.device)
+ epsilon = epsilon + torch.nn.functional.interpolate(offset, size=epsilon.shape[-2:],
+ mode=scale_mode) * m
+ multipliers.append(m)
+ if h <= 1 or w <= 1:
+ break
+ epsilon = epsilon / sum([m ** 2 for m in multipliers]) ** 0.5
+ # epsilon = epsilon / epsilon.std()
+ return epsilon
+
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
+ batch = next(data.iterator)
+
+ with torch.no_grad():
+ conditions = self.get_conditions(batch, models, extras)
+ latents = self.encode_latents(batch, models, extras)
+ epsilon = torch.randn_like(latents)
+ epsilon = self._pyramid_noise(epsilon, size_range=[1, 16])
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1,
+ epsilon=epsilon)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ pred = models.generator(noised, noise_cond, **conditions)
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+ loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ extras.gdf.loss_weight.update_buckets(logSNR, loss)
+
+ return loss, loss_adjusted
+
+ def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers,
+ schedulers: Schedulers):
+ if update:
+ loss_adjusted.backward()
+ grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
+ optimizers_dict = optimizers.to_dict()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].step()
+ schedulers_dict = schedulers.to_dict()
+ for k in schedulers_dict:
+ if k != 'training':
+ schedulers_dict[k].step()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].zero_grad(set_to_none=True)
+ self.info.total_steps += 1
+ else:
+ loss_adjusted.backward()
+ grad_norm = torch.tensor(0.0).to(self.device)
+
+ return grad_norm
+
+ def models_to_save(self):
+ return ['generator', 'generator_ema']
+
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ images = batch['images'].to(self.device)
+ return models.stage_a.encode(images)[0]
+
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ return models.stage_a.decode(latents.float()).clamp(0, 1)
+
+
+if __name__ == '__main__':
+ print("Launching Script")
+ warpcore = WurstCore(
+ config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
+ device=torch.device(int(os.environ.get("SLURM_LOCALID")))
+ )
+ # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
+
+ # RUN TRAINING
+ warpcore()
diff --git a/train/train_c.py b/train/train_c.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4490c6eebc3e1c5126dd13c53603872f1459a3e
--- /dev/null
+++ b/train/train_c.py
@@ -0,0 +1,266 @@
+import torch
+import torchvision
+from torch import nn, optim
+from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
+from warmup_scheduler import GradualWarmupScheduler
+
+import sys
+import os
+from dataclasses import dataclass
+
+from gdf import GDF, EpsilonTarget, CosineSchedule
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from torchtools.transforms import SmartCrop
+
+from modules.effnet import EfficientNetEncoder
+from modules.stage_c import StageC
+from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
+from modules.previewer import Previewer
+
+from train.base import DataCore, TrainingCore
+
+from core import WarpCore
+from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
+
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from contextlib import contextmanager
+
+class WurstCore(TrainingCore, DataCore, WarpCore):
+ @dataclass(frozen=True)
+ class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
+ # TRAINING PARAMS
+ lr: float = EXPECTED_TRAIN
+ warmup_updates: int = EXPECTED_TRAIN
+ dtype: str = None
+
+ # MODEL VERSION
+ model_version: str = EXPECTED # 3.6B or 1B
+ clip_image_model_name: str = 'openai/clip-vit-large-patch14'
+ clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
+
+ # CHECKPOINT PATHS
+ effnet_checkpoint_path: str = EXPECTED
+ previewer_checkpoint_path: str = EXPECTED
+ generator_checkpoint_path: str = None
+
+ # gdf customization
+ adaptive_loss_weight: str = None
+
+ @dataclass(frozen=True)
+ class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
+ effnet: nn.Module = EXPECTED
+ previewer: nn.Module = EXPECTED
+
+ @dataclass(frozen=True)
+ class Schedulers(WarpCore.Schedulers):
+ generator: any = None
+
+ @dataclass(frozen=True)
+ class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+ effnet_preprocess: torchvision.transforms.Compose = EXPECTED
+
+ info: TrainingCore.Info
+ config: Config
+
+ def setup_extras_pre(self) -> Extras:
+ gdf = GDF(
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ input_scaler=VPScaler(), target=EpsilonTarget(),
+ noise_cond=CosineTNoiseCond(),
+ loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
+ )
+ sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
+
+ if self.info.adaptive_loss is not None:
+ gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
+ gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
+
+ effnet_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+ ])
+
+ clip_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
+ torchvision.transforms.CenterCrop(224),
+ torchvision.transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
+ )
+ ])
+
+ if self.config.training:
+ transforms = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
+ SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
+ ])
+ else:
+ transforms = None
+
+ return self.Extras(
+ gdf=gdf,
+ sampling_configs=sampling_configs,
+ transforms=transforms,
+ effnet_preprocess=effnet_preprocess,
+ clip_preprocess=clip_preprocess
+ )
+
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
+ eval_image_embeds=False, return_fields=None):
+ conditions = super().get_conditions(
+ batch, models, extras, is_eval, is_unconditional,
+ eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
+ )
+ return conditions
+
+ def setup_models(self, extras: Extras) -> Models:
+ dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32
+
+ # EfficientNet encoder
+ effnet = EfficientNetEncoder()
+ effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
+ effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
+ effnet.eval().requires_grad_(False).to(self.device)
+ del effnet_checkpoint
+
+ # Previewer
+ previewer = Previewer()
+ previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
+ previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
+ previewer.eval().requires_grad_(False).to(self.device)
+ del previewer_checkpoint
+
+ @contextmanager
+ def dummy_context():
+ yield None
+
+ loading_context = dummy_context if self.config.training else init_empty_weights
+
+ # Diffusion models
+ with loading_context():
+ generator_ema = None
+ if self.config.model_version == '3.6B':
+ generator = StageC()
+ if self.config.ema_start_iters is not None:
+ generator_ema = StageC()
+ elif self.config.model_version == '1B':
+ generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+ if self.config.ema_start_iters is not None:
+ generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+ else:
+ raise ValueError(f"Unknown model version {self.config.model_version}")
+
+ if self.config.generator_checkpoint_path is not None:
+ if loading_context is dummy_context:
+ generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
+ else:
+
+ for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
+ set_module_tensor_to_device(generator, param_name, "cpu", value=param)
+ generator = generator.to(dtype).to(self.device)
+ generator = self.load_model(generator, 'generator')
+
+ if generator_ema is not None:
+ if loading_context is dummy_context:
+ generator_ema.load_state_dict(generator.state_dict())
+ else:
+ for param_name, param in generator.state_dict().items():
+ set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param)
+ generator_ema = self.load_model(generator_ema, 'generator_ema')
+ generator_ema.to(dtype).to(self.device).eval().requires_grad_(False)
+
+ if self.config.use_fsdp:
+ fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock])
+ generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
+ if generator_ema is not None:
+ generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
+
+ tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
+ text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
+ image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
+
+ return self.Models(
+ effnet=effnet, previewer=previewer,
+ generator=generator, generator_ema=generator_ema,
+ tokenizer=tokenizer, text_model=text_model, image_model=image_model
+ )
+
+ def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+ optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
+ optimizer = self.load_optimizer(optimizer, 'generator_optim',
+ fsdp_model=models.generator if self.config.use_fsdp else None)
+ return self.Optimizers(generator=optimizer)
+
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
+ scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
+ scheduler.last_epoch = self.info.total_steps
+ return self.Schedulers(generator=scheduler)
+
+ # Training loop --------------------------------
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
+ batch = next(data.iterator)
+
+ with torch.no_grad():
+ conditions = self.get_conditions(batch, models, extras)
+ latents = self.encode_latents(batch, models, extras)
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ pred = models.generator(noised, noise_cond, **conditions)
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+ loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ extras.gdf.loss_weight.update_buckets(logSNR, loss)
+
+ return loss, loss_adjusted
+
+ def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
+ if update:
+ loss_adjusted.backward()
+ grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
+ optimizers_dict = optimizers.to_dict()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].step()
+ schedulers_dict = schedulers.to_dict()
+ for k in schedulers_dict:
+ if k != 'training':
+ schedulers_dict[k].step()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].zero_grad(set_to_none=True)
+ self.info.total_steps += 1
+ else:
+ loss_adjusted.backward()
+ grad_norm = torch.tensor(0.0).to(self.device)
+
+ return grad_norm
+
+ def models_to_save(self):
+ return ['generator', 'generator_ema']
+
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ images = batch['images'].to(self.device)
+ return models.effnet(extras.effnet_preprocess(images))
+
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ return models.previewer(latents)
+
+
+if __name__ == '__main__':
+ print("Launching Script")
+ warpcore = WurstCore(
+ config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
+ device=torch.device(int(os.environ.get("SLURM_LOCALID")))
+ )
+ # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
+
+ # RUN TRAINING
+ warpcore()
diff --git a/train/train_c_lora.py b/train/train_c_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b83eee0f250e5359901d39b8d4052254cfff4fa
--- /dev/null
+++ b/train/train_c_lora.py
@@ -0,0 +1,330 @@
+import torch
+import torchvision
+from torch import nn, optim
+from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
+from warmup_scheduler import GradualWarmupScheduler
+
+import sys
+import os
+import re
+from dataclasses import dataclass
+
+from gdf import GDF, EpsilonTarget, CosineSchedule
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from torchtools.transforms import SmartCrop
+
+from modules.effnet import EfficientNetEncoder
+from modules.stage_c import StageC
+from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
+from modules.previewer import Previewer
+from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
+
+from train.base import DataCore, TrainingCore
+
+from core import WarpCore
+from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
+
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy
+from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
+import functools
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from contextlib import contextmanager
+
+
+class WurstCore(TrainingCore, DataCore, WarpCore):
+ @dataclass(frozen=True)
+ class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
+ # TRAINING PARAMS
+ lr: float = EXPECTED_TRAIN
+ warmup_updates: int = EXPECTED_TRAIN
+ dtype: str = None
+
+ # MODEL VERSION
+ model_version: str = EXPECTED # 3.6B or 1B
+ clip_image_model_name: str = 'openai/clip-vit-large-patch14'
+ clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
+
+ # CHECKPOINT PATHS
+ effnet_checkpoint_path: str = EXPECTED
+ previewer_checkpoint_path: str = EXPECTED
+ generator_checkpoint_path: str = None
+ lora_checkpoint_path: str = None
+
+ # LoRA STUFF
+ module_filters: list = EXPECTED
+ rank: int = EXPECTED
+ train_tokens: list = EXPECTED
+
+ # gdf customization
+ adaptive_loss_weight: str = None
+
+ @dataclass(frozen=True)
+ class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
+ effnet: nn.Module = EXPECTED
+ previewer: nn.Module = EXPECTED
+ lora: nn.Module = EXPECTED
+
+ @dataclass(frozen=True)
+ class Schedulers(WarpCore.Schedulers):
+ lora: any = None
+
+ @dataclass(frozen=True)
+ class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+ effnet_preprocess: torchvision.transforms.Compose = EXPECTED
+
+ @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
+ class Info(TrainingCore.Info):
+ train_tokens: list = None
+
+ @dataclass(frozen=True)
+ class Optimizers(TrainingCore.Optimizers, WarpCore.Optimizers):
+ generator: any = None
+ lora: any = EXPECTED
+
+ # --------------------------------------------
+ info: Info
+ config: Config
+
+ # Extras: gdf, transforms and preprocessors --------------------------------
+ def setup_extras_pre(self) -> Extras:
+ gdf = GDF(
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ input_scaler=VPScaler(), target=EpsilonTarget(),
+ noise_cond=CosineTNoiseCond(),
+ loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
+ )
+ sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
+
+ if self.info.adaptive_loss is not None:
+ gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
+ gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
+
+ effnet_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+ ])
+
+ clip_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
+ torchvision.transforms.CenterCrop(224),
+ torchvision.transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
+ )
+ ])
+
+ if self.config.training:
+ transforms = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
+ SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
+ ])
+ else:
+ transforms = None
+
+ return self.Extras(
+ gdf=gdf,
+ sampling_configs=sampling_configs,
+ transforms=transforms,
+ effnet_preprocess=effnet_preprocess,
+ clip_preprocess=clip_preprocess
+ )
+
+ # Data --------------------------------
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
+ eval_image_embeds=False, return_fields=None):
+ conditions = super().get_conditions(
+ batch, models, extras, is_eval, is_unconditional,
+ eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
+ )
+ return conditions
+
+ # Models, Optimizers & Schedulers setup --------------------------------
+ def setup_models(self, extras: Extras) -> Models:
+ dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32
+
+ # EfficientNet encoder
+ effnet = EfficientNetEncoder().to(self.device)
+ effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
+ effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
+ effnet.eval().requires_grad_(False)
+ del effnet_checkpoint
+
+ # Previewer
+ previewer = Previewer().to(self.device)
+ previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
+ previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
+ previewer.eval().requires_grad_(False)
+ del previewer_checkpoint
+
+ @contextmanager
+ def dummy_context():
+ yield None
+
+ loading_context = dummy_context if self.config.training else init_empty_weights
+
+ with loading_context():
+ # Diffusion models
+ if self.config.model_version == '3.6B':
+ generator = StageC()
+ elif self.config.model_version == '1B':
+ generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+ else:
+ raise ValueError(f"Unknown model version {self.config.model_version}")
+
+ if self.config.generator_checkpoint_path is not None:
+ if loading_context is dummy_context:
+ generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
+ else:
+ for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
+ set_module_tensor_to_device(generator, param_name, "cpu", value=param)
+ generator = generator.to(dtype).to(self.device)
+ generator = self.load_model(generator, 'generator')
+
+ # if self.config.use_fsdp:
+ # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
+ # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
+
+ # CLIP encoders
+ tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
+ text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
+ image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
+
+ # PREPARE LORA
+ update_tokens = []
+ for tkn_regex, aggr_regex in self.config.train_tokens:
+ if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')):
+ # Insert new token
+ tokenizer.add_tokens([tkn_regex])
+ # add new zeros embedding
+ new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1]
+ if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline
+ aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None]
+ if len(aggr_tokens) > 0:
+ new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True)
+ elif self.is_main_node:
+ print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.")
+ text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([
+ text_model.text_model.embeddings.token_embedding.weight.data, new_embedding
+ ], dim=0)
+ selected_tokens = [len(tokenizer.vocab) - 1]
+ else:
+ selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None]
+ update_tokens += selected_tokens
+ update_tokens = list(set(update_tokens)) # remove duplicates
+
+ apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens)
+ apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank)
+ text_model.text_model.to(self.device)
+ generator.to(self.device)
+ lora = nn.ModuleDict()
+ lora['embeddings'] = text_model.text_model.embeddings.token_embedding.parametrizations.weight[0]
+ lora['weights'] = nn.ModuleList()
+ for module in generator.modules():
+ if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)):
+ lora['weights'].append(module)
+
+ self.info.train_tokens = [(i, tokenizer.decode(i)) for i in update_tokens]
+ if self.is_main_node:
+ print("Updating tokens:", self.info.train_tokens)
+ print(f"LoRA training {len(lora['weights'])} layers")
+
+ if self.config.lora_checkpoint_path is not None:
+ lora_checkpoint = load_or_fail(self.config.lora_checkpoint_path)
+ lora.load_state_dict(lora_checkpoint if 'state_dict' not in lora_checkpoint else lora_checkpoint['state_dict'])
+
+ lora = self.load_model(lora, 'lora')
+ lora.to(self.device).train().requires_grad_(True)
+ if self.config.use_fsdp:
+ # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
+ fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken])
+ lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
+
+ return self.Models(
+ effnet=effnet, previewer=previewer,
+ generator=generator, generator_ema=None,
+ lora=lora,
+ tokenizer=tokenizer, text_model=text_model, image_model=image_model
+ )
+
+ def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
+ optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
+ optimizer = self.load_optimizer(optimizer, 'lora_optim',
+ fsdp_model=models.lora if self.config.use_fsdp else None)
+ return self.Optimizers(generator=None, lora=optimizer)
+
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
+ scheduler = GradualWarmupScheduler(optimizers.lora, multiplier=1, total_epoch=self.config.warmup_updates)
+ scheduler.last_epoch = self.info.total_steps
+ return self.Schedulers(lora=scheduler)
+
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
+ batch = next(data.iterator)
+
+ conditions = self.get_conditions(batch, models, extras)
+ with torch.no_grad():
+ latents = self.encode_latents(batch, models, extras)
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ pred = models.generator(noised, noise_cond, **conditions)
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+ loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ extras.gdf.loss_weight.update_buckets(logSNR, loss)
+
+ return loss, loss_adjusted
+
+ def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
+ if update:
+ loss_adjusted.backward()
+ grad_norm = nn.utils.clip_grad_norm_(models.lora.parameters(), 1.0)
+ optimizers_dict = optimizers.to_dict()
+ for k in optimizers_dict:
+ if optimizers_dict[k] is not None and k != 'training':
+ optimizers_dict[k].step()
+ schedulers_dict = schedulers.to_dict()
+ for k in schedulers_dict:
+ if k != 'training':
+ schedulers_dict[k].step()
+ for k in optimizers_dict:
+ if optimizers_dict[k] is not None and k != 'training':
+ optimizers_dict[k].zero_grad(set_to_none=True)
+ self.info.total_steps += 1
+ else:
+ loss_adjusted.backward()
+ grad_norm = torch.tensor(0.0).to(self.device)
+
+ return grad_norm
+
+ def models_to_save(self):
+ return ['lora']
+
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
+ models.lora.eval()
+ super().sample(models, data, extras)
+ models.lora.train(), models.generator.eval()
+
+ def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ images = batch['images'].to(self.device)
+ return models.effnet(extras.effnet_preprocess(images))
+
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ return models.previewer(latents)
+
+
+if __name__ == '__main__':
+ print("Launching Script")
+ warpcore = WurstCore(
+ config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
+ device=torch.device(int(os.environ.get("SLURM_LOCALID")))
+ )
+ warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
+
+ # RUN TRAINING
+ warpcore()
diff --git a/train/train_personalized.py b/train/train_personalized.py
new file mode 100644
index 0000000000000000000000000000000000000000..5161b7c621a0eb9daf9d0f0566322bbeed646284
--- /dev/null
+++ b/train/train_personalized.py
@@ -0,0 +1,899 @@
+import torch
+import json
+import yaml
+import torchvision
+from torch import nn, optim
+from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
+from warmup_scheduler import GradualWarmupScheduler
+import torch.multiprocessing as mp
+import os
+import numpy as np
+import re
+import sys
+sys.path.append(os.path.abspath('./'))
+
+from dataclasses import dataclass
+from torch.distributed import init_process_group, destroy_process_group, barrier
+from gdf import GDF_dual_fixlrt as GDF
+from gdf import EpsilonTarget, CosineSchedule
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from torchtools.transforms import SmartCrop
+from fractions import Fraction
+from modules.effnet import EfficientNetEncoder
+from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
+from modules.common_ckpt import GlobalResponseNorm
+from modules.previewer import Previewer
+from core.data import Bucketeer
+from train.base import DataCore, TrainingCore
+from tqdm import tqdm
+from core import WarpCore
+from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
+
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from contextlib import contextmanager
+from train.dist_core import *
+import glob
+from torch.utils.data import DataLoader, Dataset
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+from PIL import Image
+from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
+from core.utils import Base
+import torch.nn.functional as F
+import functools
+import math
+import copy
+import random
+from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
+
+Image.MAX_IMAGE_PIXELS = None
+torch.manual_seed(23)
+random.seed(23)
+np.random.seed(23)
+#7978026
+
+class Null_Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ def forward(self, x):
+ pass
+
+
+
+
+def identity(x):
+ if isinstance(x, bytes):
+ x = x.decode('utf-8')
+ return x
+def check_nan_inmodel(model, meta=''):
+ for name, param in model.named_parameters():
+ if torch.isnan(param).any():
+ print(f"nan detected in {name}", meta)
+ return True
+ print('no nan', meta)
+ return False
+class mydist_dataset(Dataset):
+ def __init__(self, rootpath, tmp_prompt, img_processor=None):
+
+ self.img_pathlist = glob.glob(os.path.join(rootpath, '*.jpg'))
+ self.img_pathlist = self.img_pathlist * 100000
+ self.img_processor = img_processor
+ self.length = len( self.img_pathlist)
+ self.caption = tmp_prompt
+
+
+ def __getitem__(self, idx):
+
+ imgpath = self.img_pathlist[idx]
+ txt = self.caption
+
+
+
+
+ try:
+ img = Image.open(imgpath).convert('RGB')
+ w, h = img.size
+ if self.img_processor is not None:
+ img = self.img_processor(img)
+
+ except:
+ print('exception', imgpath)
+ return self.__getitem__(random.randint(0, self.length -1 ) )
+ return dict(captions=txt, images=img)
+ def __len__(self):
+ return self.length
+class WurstCore(TrainingCore, DataCore, WarpCore):
+ @dataclass(frozen=True)
+ class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
+ # TRAINING PARAMS
+ lr: float = EXPECTED_TRAIN
+ warmup_updates: int = EXPECTED_TRAIN
+ dtype: str = None
+
+ # MODEL VERSION
+ model_version: str = EXPECTED # 3.6B or 1B
+ clip_image_model_name: str = 'openai/clip-vit-large-patch14'
+ clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
+
+ # CHECKPOINT PATHS
+ effnet_checkpoint_path: str = EXPECTED
+ previewer_checkpoint_path: str = EXPECTED
+ generator_checkpoint_path: str = None
+ ultrapixel_path: str = EXPECTED
+
+ # gdf customization
+ adaptive_loss_weight: str = None
+
+ # LoRA STUFF
+ module_filters: list = EXPECTED
+ rank: int = EXPECTED
+ train_tokens: list = EXPECTED
+ use_ddp: bool=EXPECTED
+ tmp_prompt: str=EXPECTED
+ @dataclass(frozen=True)
+ class Data(Base):
+ dataset: Dataset = EXPECTED
+ dataloader: DataLoader = EXPECTED
+ iterator: any = EXPECTED
+ sampler: DistributedSampler = EXPECTED
+
+ @dataclass(frozen=True)
+ class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
+ effnet: nn.Module = EXPECTED
+ previewer: nn.Module = EXPECTED
+ train_norm: nn.Module = EXPECTED
+ train_lora: nn.Module = EXPECTED
+
+ @dataclass(frozen=True)
+ class Schedulers(WarpCore.Schedulers):
+ generator: any = None
+
+ @dataclass(frozen=True)
+ class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+ effnet_preprocess: torchvision.transforms.Compose = EXPECTED
+
+ info: TrainingCore.Info
+ config: Config
+
+ def setup_extras_pre(self) -> Extras:
+ gdf = GDF(
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ input_scaler=VPScaler(), target=EpsilonTarget(),
+ noise_cond=CosineTNoiseCond(),
+ loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
+ )
+ sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
+
+ if self.info.adaptive_loss is not None:
+ gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
+ gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
+
+ effnet_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+ ])
+
+ clip_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
+ torchvision.transforms.CenterCrop(224),
+ torchvision.transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
+ )
+ ])
+
+ if self.config.training:
+ transforms = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
+ SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
+ ])
+ else:
+ transforms = None
+
+ return self.Extras(
+ gdf=gdf,
+ sampling_configs=sampling_configs,
+ transforms=transforms,
+ effnet_preprocess=effnet_preprocess,
+ clip_preprocess=clip_preprocess
+ )
+
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
+ eval_image_embeds=False, return_fields=None):
+ conditions = super().get_conditions(
+ batch, models, extras, is_eval, is_unconditional,
+ eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
+ )
+ return conditions
+
+ def setup_models(self, extras: Extras) -> Models: # configure model
+
+
+ dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
+
+ # EfficientNet encoderin
+ effnet = EfficientNetEncoder()
+ effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
+ effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
+ effnet.eval().requires_grad_(False).to(self.device)
+ del effnet_checkpoint
+
+ # Previewer
+ previewer = Previewer()
+ previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
+ previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
+ previewer.eval().requires_grad_(False).to(self.device)
+ del previewer_checkpoint
+
+ @contextmanager
+ def dummy_context():
+ yield None
+
+ loading_context = dummy_context if self.config.training else init_empty_weights
+
+ # Diffusion models
+ with loading_context():
+ generator_ema = None
+ if self.config.model_version == '3.6B':
+ generator = StageC()
+ if self.config.ema_start_iters is not None: # default setting
+ generator_ema = StageC()
+ elif self.config.model_version == '1B':
+ print('in line 155 1b light model', self.config.model_version )
+ generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+
+ if self.config.ema_start_iters is not None and self.config.training:
+ generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+ else:
+ raise ValueError(f"Unknown model version {self.config.model_version}")
+
+
+
+ if loading_context is dummy_context:
+ generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
+ else:
+ for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
+ set_module_tensor_to_device(generator, param_name, "cpu", value=param)
+
+ generator._init_extra_parameter()
+ generator = generator.to(torch.bfloat16).to(self.device)
+
+ train_norm = nn.ModuleList()
+
+
+ cnt_norm = 0
+ for mm in generator.modules():
+ if isinstance(mm, GlobalResponseNorm):
+
+ train_norm.append(Null_Model())
+ cnt_norm += 1
+
+
+
+
+ train_norm.append(generator.agg_net)
+ train_norm.append(generator.agg_net_up)
+ sdd = torch.load(self.config.ultrapixel_path, map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+ train_norm.load_state_dict(collect_sd)
+
+
+
+ # CLIP encoders
+ tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
+ text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
+ image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
+
+ # PREPARE LORA
+ train_lora = nn.ModuleList()
+ update_tokens = []
+ for tkn_regex, aggr_regex in self.config.train_tokens:
+ if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')):
+ # Insert new token
+ tokenizer.add_tokens([tkn_regex])
+ # add new zeros embedding
+ new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1]
+ if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline
+ aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None]
+ if len(aggr_tokens) > 0:
+ new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True)
+ elif self.is_main_node:
+ print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.")
+ text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([
+ text_model.text_model.embeddings.token_embedding.weight.data, new_embedding
+ ], dim=0)
+ selected_tokens = [len(tokenizer.vocab) - 1]
+ else:
+ selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None]
+ update_tokens += selected_tokens
+ update_tokens = list(set(update_tokens)) # remove duplicates
+
+ apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens)
+
+ apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank)
+ for module in generator.modules():
+ if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)):
+ train_lora.append(module)
+
+
+ train_lora.append(text_model.text_model.embeddings.token_embedding.parametrizations.weight[0])
+
+ if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors')):
+ sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors'), map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+ train_lora.load_state_dict(collect_sd, strict=True)
+
+
+ train_norm.to(self.device).train().requires_grad_(True)
+
+ if generator_ema is not None:
+
+ generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
+ generator_ema._init_extra_parameter()
+ pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
+ if os.path.exists(pretrained_pth):
+ generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
+
+ generator_ema.eval().requires_grad_(False)
+
+ check_nan_inmodel(generator, 'generator')
+
+
+
+ if self.config.use_ddp and self.config.training:
+
+ train_lora = DDP(train_lora, device_ids=[self.device], find_unused_parameters=True)
+
+
+
+ return self.Models(
+ effnet=effnet, previewer=previewer, train_norm = train_norm,
+ generator=generator, generator_ema=generator_ema,
+ tokenizer=tokenizer, text_model=text_model, image_model=image_model,
+ train_lora=train_lora
+ )
+
+ def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+
+
+ params = []
+ params += list(models.train_lora.module.parameters())
+ optimizer = optim.AdamW(params, lr=self.config.lr)
+
+ return self.Optimizers(generator=optimizer)
+
+ def ema_update(self, ema_model, source_model, beta):
+ for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
+ param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
+
+ def sync_ema(self, ema_model):
+ print('sync ema', torch.distributed.get_world_size())
+ for param in ema_model.parameters():
+ torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
+ param.data /= torch.distributed.get_world_size()
+ def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+
+
+ optimizer = optim.AdamW(
+ models.generator.up_blocks.parameters() ,
+ lr=self.config.lr)
+ optimizer = self.load_optimizer(optimizer, 'generator_optim',
+ fsdp_model=models.generator if self.config.use_fsdp else None)
+ return self.Optimizers(generator=optimizer)
+
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
+ scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
+ scheduler.last_epoch = self.info.total_steps
+ return self.Schedulers(generator=scheduler)
+
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
+ # SETUP DATASET
+ dataset_path = self.config.webdataset_path
+
+
+ dataset = mydist_dataset(dataset_path, self.config.tmp_prompt, \
+ torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
+ else extras.transforms)
+
+ # SETUP DATALOADER
+ real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
+
+ sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
+ dataloader = DataLoader(
+ dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True,
+ collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
+ sampler = sampler
+ )
+ if self.is_main_node:
+ print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
+
+ if self.config.multi_aspect_ratio is not None:
+ aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
+ dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
+ ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
+ interpolate_nearest=False) # , use_smartcrop=True)
+ else:
+
+ dataloader_iterator = iter(dataloader)
+
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
+
+
+
+
+
+ def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
+
+ if not single_gpu:
+ local_rank = rank
+ process_id = rank
+ world_size = get_world_size()
+
+ self.process_id = process_id
+ self.is_main_node = process_id == 0
+ self.device = torch.device(local_rank)
+ self.world_size = world_size
+
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '14443'
+ torch.cuda.set_device(local_rank)
+ init_process_group(
+ backend="nccl",
+ rank=local_rank,
+ world_size=world_size,
+ # init_method=init_method,
+ )
+ print(f"[GPU {process_id}] READY")
+ else:
+ self.is_main_node = rank == 0
+ self.process_id = rank
+ self.device = torch.device('cuda:0')
+ self.world_size = 1
+ print("Running in single thread, DDP not enabled.")
+ # Training loop --------------------------------
+ def get_target_lr_size(self, ratio, std_size=24):
+ w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
+ return (h * 32 , w * 32)
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
+
+ batch = data
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
+ shape_lr = self.get_target_lr_size(ratio)
+ with torch.no_grad():
+ conditions = self.get_conditions(batch, models, extras)
+
+ latents = self.encode_latents(batch, models, extras)
+ latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
+
+
+
+ flag_lr = random.random() < 0.5 or self.info.iter <5000
+
+ if flag_lr:
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1)
+ else:
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
+ if not flag_lr:
+ noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = \
+ extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+
+
+ if not flag_lr:
+ with torch.no_grad():
+ _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
+
+
+ pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if not flag_lr else None , **conditions)
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+
+ loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
+
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ extras.gdf.loss_weight.update_buckets(logSNR, loss)
+ return loss, loss_adjusted
+
+ def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
+
+ if update:
+
+ torch.distributed.barrier()
+ loss_adjusted.backward()
+
+ grad_norm = nn.utils.clip_grad_norm_(models.train_lora.module.parameters(), 1.0)
+ optimizers_dict = optimizers.to_dict()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].step()
+ schedulers_dict = schedulers.to_dict()
+ for k in schedulers_dict:
+ if k != 'training':
+ schedulers_dict[k].step()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].zero_grad(set_to_none=True)
+ self.info.total_steps += 1
+ else:
+
+ loss_adjusted.backward()
+ grad_norm = torch.tensor(0.0).to(self.device)
+
+ return grad_norm
+
+ def models_to_save(self):
+ return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema']
+
+ def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
+
+ images = batch['images'].to(self.device)
+ if target_size is not None:
+ images = F.interpolate(images, target_size)
+
+ return models.effnet(extras.effnet_preprocess(images))
+
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ return models.previewer(latents)
+
+ def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
+
+ self.is_main_node = (rank == 0)
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
+ self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
+ self.info: self.Info = self.setup_info()
+ print('in line 292', self.config.experiment_id, rank, world_size <= 1)
+ p = [i for i in range( 2 * 768 // 32)]
+ p = [num / sum(p) for num in p]
+ self.rand_pro = p
+ self.res_list = [o for o in range(800, 2336, 32)]
+
+
+
+ def __call__(self, single_gpu=False):
+
+ if self.config.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ if self.is_main_node:
+ print()
+ print("**STARTIG JOB WITH CONFIG:**")
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
+ print("------------------------------------")
+ print()
+ print("**INFO:**")
+ print(yaml.dump(vars(self.info), default_flow_style=False))
+ print("------------------------------------")
+ print()
+ print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device )
+ # SETUP STUFF
+ extras = self.setup_extras_pre()
+ assert extras is not None, "setup_extras_pre() must return a DTO"
+
+
+
+ data = self.setup_data(extras)
+ assert data is not None, "setup_data() must return a DTO"
+ if self.is_main_node:
+ print("**DATA:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ models = self.setup_models(extras)
+ assert models is not None, "setup_models() must return a DTO"
+ if self.is_main_node:
+ print("**MODELS:**")
+ print(yaml.dump({
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
+ }, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+
+
+ optimizers = self.setup_optimizers(extras, models)
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
+ if self.is_main_node:
+ print("**OPTIMIZERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ schedulers = self.setup_schedulers(extras, models, optimizers)
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
+ if self.is_main_node:
+ print("**SCHEDULERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
+ if self.is_main_node:
+ print("**EXTRAS:**")
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+ # -------
+
+ # TRAIN
+ if self.is_main_node:
+ print("**TRAINING STARTING...**")
+ self.train(data, extras, models, optimizers, schedulers)
+
+ if single_gpu is False:
+ barrier()
+ destroy_process_group()
+ if self.is_main_node:
+ print()
+ print("------------------------------------")
+ print()
+ print("**TRAINING COMPLETE**")
+ if self.config.wandb_project is not None:
+ wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
+
+
+ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
+ schedulers: WarpCore.Schedulers):
+ start_iter = self.info.iter + 1
+ max_iters = self.config.updates * self.config.grad_accum_steps
+ if self.is_main_node:
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
+
+
+ if self.is_main_node:
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
+ if 'generator' in self.models_to_save():
+ models.generator.train()
+
+ iter_cnt = 0
+ epoch_cnt = 0
+ models.train_norm.train()
+ while True:
+ epoch_cnt += 1
+ if self.world_size > 1:
+
+ data.sampler.set_epoch(epoch_cnt)
+ for ggg in range(len(data.dataloader)):
+ iter_cnt += 1
+ # FORWARD PASS
+
+ loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
+
+
+ # # BACKWARD PASS
+
+ grad_norm = self.backward_pass(
+ iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
+ models, optimizers, schedulers
+ )
+
+
+
+ self.info.iter = iter_cnt
+
+
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
+
+
+ if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
+ print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
+
+ if self.is_main_node:
+ logs = {
+ 'loss': self.info.ema_loss,
+ 'backward_loss': loss_adjusted.mean().item(),
+
+ 'ema_loss': self.info.ema_loss,
+ 'raw_ori_loss': loss.mean().item(),
+
+ 'grad_norm': grad_norm.item(),
+ 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
+ 'total_steps': self.info.total_steps,
+ }
+
+
+ print(iter_cnt, max_iters, logs, epoch_cnt, )
+
+
+
+
+
+
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
+
+ if np.isnan(loss.mean().item()):
+ if self.is_main_node and self.config.wandb_project is not None:
+ print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
+
+ else:
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ self.info.adaptive_loss = {
+ 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
+ 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
+ }
+
+
+ if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
+ print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
+ torch.save(models.train_lora.state_dict(), \
+ f'{self.config.output_path}/{self.config.experiment_id}/train_lora.safetensors')
+
+
+ torch.save(models.train_lora.state_dict(), \
+ f'{self.config.output_path}/{self.config.experiment_id}/train_lora_{iter_cnt}.safetensors')
+
+
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
+
+ if self.is_main_node:
+
+ self.sample(models, data, extras)
+ if False:
+ param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()}
+ threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10%
+ important_params = [name for name, change in param_changes.items() if change > threshold]
+ print(important_params, threshold, len(param_changes), self.process_id)
+ json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4)
+
+
+ if self.info.iter >= max_iters:
+ break
+
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
+
+
+ models.generator.eval()
+ models.train_norm.eval()
+ with torch.no_grad():
+ batch = next(data.iterator)
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
+
+ shape_lr = self.get_target_lr_size(ratio)
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ latents = self.encode_latents(batch, models, extras)
+ latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
+
+ if self.is_main_node:
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
+ models.generator, conditions,
+ latents.shape, latents_lr.shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+
+ sampled_ema = sampled
+ sampled_ema_lr = sampled_lr
+
+
+ if self.is_main_node:
+ print('sampling results hr latent shape ', latents.shape, 'lr latent shape', latents_lr.shape, )
+ noised_images = torch.cat(
+ [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
+
+ sampled_images = torch.cat(
+ [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
+ sampled_images_ema = torch.cat(
+ [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))],
+ dim=0)
+
+ noised_images_lr = torch.cat(
+ [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
+
+ sampled_images_lr = torch.cat(
+ [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
+ sampled_images_ema_lr = torch.cat(
+ [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))],
+ dim=0)
+
+ images = batch['images']
+ if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
+ images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
+ images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
+
+ collage_img = torch.cat([
+ torch.cat([i for i in images.cpu()], dim=-1),
+ torch.cat([i for i in noised_images.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_ema.cpu()], dim=-1),
+ ], dim=-2)
+
+ collage_img_lr = torch.cat([
+ torch.cat([i for i in images_lr.cpu()], dim=-1),
+ torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1),
+ ], dim=-2)
+
+ torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
+ torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
+
+ captions = batch['captions']
+ if self.config.wandb_project is not None:
+ log_data = [
+ [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [
+ wandb.Image(images[i])] for i in range(len(images))]
+ log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"])
+ wandb.log({"Log": log_table})
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1])
+ plt.ylabel('Raw Loss')
+ plt.ylabel('LogSNR')
+ wandb.log({"Loss/LogSRN": plt})
+
+
+ models.generator.train()
+ models.train_norm.train()
+ print('finish sampling')
+
+
+
+ def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
+
+
+ models.generator.eval()
+ models.trans_inr.eval()
+ with torch.no_grad():
+
+ if self.is_main_node:
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
+ models.generator, conditions,
+ hr_shape, lr_shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+ if models.generator_ema is not None:
+
+ *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
+ models.generator_ema, conditions,
+ latents.shape, latents_lr.shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+ else:
+ sampled_ema = sampled
+ sampled_ema_lr = sampled_lr
+
+
+ return sampled, sampled_lr
+def main_worker(rank, cfg):
+ print("Launching Script in main worker")
+ warpcore = WurstCore(
+ config_file_path=cfg, rank=rank, world_size = get_world_size()
+ )
+ # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
+
+ # RUN TRAINING
+ warpcore(get_world_size()==1)
+
+if __name__ == '__main__':
+
+ if get_master_ip() == "127.0.0.1":
+
+ mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
+ else:
+ main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )
diff --git a/train/train_t2i.py b/train/train_t2i.py
new file mode 100644
index 0000000000000000000000000000000000000000..456ca4b0dd1fe8e1fc18e3e5c940797439071d1f
--- /dev/null
+++ b/train/train_t2i.py
@@ -0,0 +1,807 @@
+import torch
+import json
+import yaml
+import torchvision
+from torch import nn, optim
+from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
+from warmup_scheduler import GradualWarmupScheduler
+import torch.multiprocessing as mp
+import numpy as np
+import os
+import sys
+sys.path.append(os.path.abspath('./'))
+from dataclasses import dataclass
+from torch.distributed import init_process_group, destroy_process_group, barrier
+from gdf import GDF_dual_fixlrt as GDF
+from gdf import EpsilonTarget, CosineSchedule
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from torchtools.transforms import SmartCrop
+from fractions import Fraction
+from modules.effnet import EfficientNetEncoder
+
+from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
+from modules.previewer import Previewer
+from core.data import Bucketeer
+from train.base import DataCore, TrainingCore
+from tqdm import tqdm
+from core import WarpCore
+from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
+
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from contextlib import contextmanager
+from train.dist_core import *
+import glob
+from torch.utils.data import DataLoader, Dataset
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+from PIL import Image
+from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
+from core.utils import Base
+from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm
+import torch.nn.functional as F
+import functools
+import math
+import copy
+import random
+from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
+Image.MAX_IMAGE_PIXELS = None
+torch.manual_seed(23)
+random.seed(23)
+np.random.seed(23)
+#7978026
+
+class Null_Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ def forward(self, x):
+ pass
+
+
+
+
+def identity(x):
+ if isinstance(x, bytes):
+ x = x.decode('utf-8')
+ return x
+def check_nan_inmodel(model, meta=''):
+ for name, param in model.named_parameters():
+ if torch.isnan(param).any():
+ print(f"nan detected in {name}", meta)
+ return True
+ print('no nan', meta)
+ return False
+class mydist_dataset(Dataset):
+ def __init__(self, rootpath, img_processor=None):
+
+ self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg'))
+ self.img_processor = img_processor
+ self.length = len( self.img_pathlist)
+
+
+
+ def __getitem__(self, idx):
+
+ imgpath = self.img_pathlist[idx]
+ json_file = imgpath.replace('.jpg', '.json')
+
+ with open(json_file, 'r') as file:
+ info = json.load(file)
+ txt = info['caption']
+ if txt is None:
+ txt = ' '
+ try:
+ img = Image.open(imgpath).convert('RGB')
+ w, h = img.size
+ if self.img_processor is not None:
+ img = self.img_processor(img)
+
+ except:
+ print('exception', imgpath)
+ return self.__getitem__(random.randint(0, self.length -1 ) )
+ return dict(captions=txt, images=img)
+ def __len__(self):
+ return self.length
+
+class WurstCore(TrainingCore, DataCore, WarpCore):
+ @dataclass(frozen=True)
+ class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
+ # TRAINING PARAMS
+ lr: float = EXPECTED_TRAIN
+ warmup_updates: int = EXPECTED_TRAIN
+ dtype: str = None
+
+ # MODEL VERSION
+ model_version: str = EXPECTED # 3.6B or 1B
+ clip_image_model_name: str = 'openai/clip-vit-large-patch14'
+ clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
+
+ # CHECKPOINT PATHS
+ effnet_checkpoint_path: str = EXPECTED
+ previewer_checkpoint_path: str = EXPECTED
+
+ generator_checkpoint_path: str = None
+
+ # gdf customization
+ adaptive_loss_weight: str = None
+ use_ddp: bool=EXPECTED
+
+
+ @dataclass(frozen=True)
+ class Data(Base):
+ dataset: Dataset = EXPECTED
+ dataloader: DataLoader = EXPECTED
+ iterator: any = EXPECTED
+ sampler: DistributedSampler = EXPECTED
+
+ @dataclass(frozen=True)
+ class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
+ effnet: nn.Module = EXPECTED
+ previewer: nn.Module = EXPECTED
+ train_norm: nn.Module = EXPECTED
+
+
+ @dataclass(frozen=True)
+ class Schedulers(WarpCore.Schedulers):
+ generator: any = None
+
+ @dataclass(frozen=True)
+ class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+ effnet_preprocess: torchvision.transforms.Compose = EXPECTED
+
+ info: TrainingCore.Info
+ config: Config
+
+ def setup_extras_pre(self) -> Extras:
+ gdf = GDF(
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ input_scaler=VPScaler(), target=EpsilonTarget(),
+ noise_cond=CosineTNoiseCond(),
+ loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
+ )
+ sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
+
+ if self.info.adaptive_loss is not None:
+ gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
+ gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
+
+ effnet_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+ ])
+
+ clip_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
+ torchvision.transforms.CenterCrop(224),
+ torchvision.transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
+ )
+ ])
+
+ if self.config.training:
+ transforms = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
+ SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
+ ])
+ else:
+ transforms = None
+
+ return self.Extras(
+ gdf=gdf,
+ sampling_configs=sampling_configs,
+ transforms=transforms,
+ effnet_preprocess=effnet_preprocess,
+ clip_preprocess=clip_preprocess
+ )
+
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
+ eval_image_embeds=False, return_fields=None):
+ conditions = super().get_conditions(
+ batch, models, extras, is_eval, is_unconditional,
+ eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
+ )
+ return conditions
+
+ def setup_models(self, extras: Extras) -> Models: # configure model
+
+ dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
+
+ # EfficientNet encoderin
+ effnet = EfficientNetEncoder()
+ effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
+ effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
+ effnet.eval().requires_grad_(False).to(self.device)
+ del effnet_checkpoint
+
+ # Previewer
+ previewer = Previewer()
+ previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
+ previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
+ previewer.eval().requires_grad_(False).to(self.device)
+ del previewer_checkpoint
+
+ @contextmanager
+ def dummy_context():
+ yield None
+
+ loading_context = dummy_context if self.config.training else init_empty_weights
+
+ # Diffusion models
+ with loading_context():
+ generator_ema = None
+ if self.config.model_version == '3.6B':
+ generator = StageC()
+ if self.config.ema_start_iters is not None: # default setting
+ generator_ema = StageC()
+ elif self.config.model_version == '1B':
+ print('in line 155 1b light model', self.config.model_version )
+ generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+
+ if self.config.ema_start_iters is not None and self.config.training:
+ generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+ else:
+ raise ValueError(f"Unknown model version {self.config.model_version}")
+
+
+
+ if loading_context is dummy_context:
+ generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
+ else:
+ for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
+ set_module_tensor_to_device(generator, param_name, "cpu", value=param)
+
+ generator._init_extra_parameter()
+ generator = generator.to(torch.bfloat16).to(self.device)
+
+
+ train_norm = nn.ModuleList()
+ cnt_norm = 0
+ for mm in generator.modules():
+ if isinstance(mm, GlobalResponseNorm):
+
+ train_norm.append(Null_Model())
+ cnt_norm += 1
+
+ train_norm.append(generator.agg_net)
+ train_norm.append(generator.agg_net_up)
+ total = sum([ param.nelement() for param in train_norm.parameters()])
+ print('Trainable parameter', total / 1048576)
+
+ if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
+ sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+ train_norm.load_state_dict(collect_sd, strict=True)
+
+
+ train_norm.to(self.device).train().requires_grad_(True)
+ train_norm_ema = copy.deepcopy(train_norm)
+ train_norm_ema.to(self.device).eval().requires_grad_(False)
+ if generator_ema is not None:
+
+ generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
+ generator_ema._init_extra_parameter()
+
+
+ pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
+ if os.path.exists(pretrained_pth):
+ print(pretrained_pth, 'exists')
+ generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
+
+
+ generator_ema.eval().requires_grad_(False)
+
+
+
+
+ check_nan_inmodel(generator, 'generator')
+
+
+
+ if self.config.use_ddp and self.config.training:
+
+ train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True)
+
+ # CLIP encoders
+ tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
+ text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
+ image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
+
+ return self.Models(
+ effnet=effnet, previewer=previewer, train_norm = train_norm,
+ generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model,
+ )
+
+ def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+
+
+ params = []
+ params += list(models.train_norm.module.parameters())
+
+ optimizer = optim.AdamW(params, lr=self.config.lr)
+
+ return self.Optimizers(generator=optimizer)
+
+ def ema_update(self, ema_model, source_model, beta):
+ for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
+ param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
+
+ def sync_ema(self, ema_model):
+ for param in ema_model.parameters():
+ torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
+ param.data /= torch.distributed.get_world_size()
+ def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+
+
+ optimizer = optim.AdamW(
+ models.generator.up_blocks.parameters() ,
+ lr=self.config.lr)
+ optimizer = self.load_optimizer(optimizer, 'generator_optim',
+ fsdp_model=models.generator if self.config.use_fsdp else None)
+ return self.Optimizers(generator=optimizer)
+
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
+ scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
+ scheduler.last_epoch = self.info.total_steps
+ return self.Schedulers(generator=scheduler)
+
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
+ # SETUP DATASET
+ dataset_path = self.config.webdataset_path
+ dataset = mydist_dataset(dataset_path, \
+ torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
+ else extras.transforms)
+
+ # SETUP DATALOADER
+ real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
+
+ sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
+ dataloader = DataLoader(
+ dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True,
+ collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
+ sampler = sampler
+ )
+ if self.is_main_node:
+ print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
+
+ if self.config.multi_aspect_ratio is not None:
+ aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
+ dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
+ ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
+ interpolate_nearest=False) # , use_smartcrop=True)
+ else:
+
+ dataloader_iterator = iter(dataloader)
+
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
+
+
+ def models_to_save(self):
+ pass
+ def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
+
+ if not single_gpu:
+ local_rank = rank
+ process_id = rank
+ world_size = get_world_size()
+
+ self.process_id = process_id
+ self.is_main_node = process_id == 0
+ self.device = torch.device(local_rank)
+ self.world_size = world_size
+
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '41443'
+ torch.cuda.set_device(local_rank)
+ init_process_group(
+ backend="nccl",
+ rank=local_rank,
+ world_size=world_size,
+ )
+ print(f"[GPU {process_id}] READY")
+ else:
+ self.is_main_node = rank == 0
+ self.process_id = rank
+ self.device = torch.device('cuda:0')
+ self.world_size = 1
+ print("Running in single thread, DDP not enabled.")
+ # Training loop --------------------------------
+ def get_target_lr_size(self, ratio, std_size=24):
+ w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
+ return (h * 32 , w * 32)
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
+ #batch = next(data.iterator)
+ batch = data
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
+ shape_lr = self.get_target_lr_size(ratio)
+ #print('in line 485', shape_lr, ratio, batch['images'].shape)
+ with torch.no_grad():
+ conditions = self.get_conditions(batch, models, extras)
+
+ latents = self.encode_latents(batch, models, extras)
+ latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
+
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
+ noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ # 768 1536
+ require_cond = True
+
+ with torch.no_grad():
+ _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
+
+
+ pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions)
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+
+ loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
+
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ extras.gdf.loss_weight.update_buckets(logSNR, loss)
+
+ return loss, loss_adjusted
+
+ def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
+
+
+ if update:
+
+ torch.distributed.barrier()
+ loss_adjusted.backward()
+
+ grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0)
+
+ optimizers_dict = optimizers.to_dict()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].step()
+ schedulers_dict = schedulers.to_dict()
+ for k in schedulers_dict:
+ if k != 'training':
+ schedulers_dict[k].step()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].zero_grad(set_to_none=True)
+ self.info.total_steps += 1
+ else:
+
+ loss_adjusted.backward()
+
+ grad_norm = torch.tensor(0.0).to(self.device)
+
+ return grad_norm
+
+
+ def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
+
+ images = batch['images'].to(self.device)
+ if target_size is not None:
+ images = F.interpolate(images, target_size)
+
+ return models.effnet(extras.effnet_preprocess(images))
+
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ return models.previewer(latents)
+
+ def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
+
+ self.is_main_node = (rank == 0)
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
+ self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
+ self.info: self.Info = self.setup_info()
+
+
+
+ def __call__(self, single_gpu=False):
+
+ if self.config.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ if self.is_main_node:
+ print()
+ print("**STARTIG JOB WITH CONFIG:**")
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
+ print("------------------------------------")
+ print()
+ print("**INFO:**")
+ print(yaml.dump(vars(self.info), default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ # SETUP STUFF
+ extras = self.setup_extras_pre()
+ assert extras is not None, "setup_extras_pre() must return a DTO"
+
+
+
+ data = self.setup_data(extras)
+ assert data is not None, "setup_data() must return a DTO"
+ if self.is_main_node:
+ print("**DATA:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ models = self.setup_models(extras)
+ assert models is not None, "setup_models() must return a DTO"
+ if self.is_main_node:
+ print("**MODELS:**")
+ print(yaml.dump({
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
+ }, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+
+
+ optimizers = self.setup_optimizers(extras, models)
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
+ if self.is_main_node:
+ print("**OPTIMIZERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ schedulers = self.setup_schedulers(extras, models, optimizers)
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
+ if self.is_main_node:
+ print("**SCHEDULERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
+ if self.is_main_node:
+ print("**EXTRAS:**")
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+ # -------
+
+ # TRAIN
+ if self.is_main_node:
+ print("**TRAINING STARTING...**")
+ self.train(data, extras, models, optimizers, schedulers)
+
+ if single_gpu is False:
+ barrier()
+ destroy_process_group()
+ if self.is_main_node:
+ print()
+ print("------------------------------------")
+ print()
+ print("**TRAINING COMPLETE**")
+
+
+
+ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
+ schedulers: WarpCore.Schedulers):
+ start_iter = self.info.iter + 1
+ max_iters = self.config.updates * self.config.grad_accum_steps
+ if self.is_main_node:
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
+
+
+ if self.is_main_node:
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
+
+ models.generator.train()
+
+ iter_cnt = 0
+ epoch_cnt = 0
+ models.train_norm.train()
+ while True:
+ epoch_cnt += 1
+ if self.world_size > 1:
+
+ data.sampler.set_epoch(epoch_cnt)
+ for ggg in range(len(data.dataloader)):
+ iter_cnt += 1
+ loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
+ grad_norm = self.backward_pass(
+ iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
+ models, optimizers, schedulers
+ )
+
+ self.info.iter = iter_cnt
+
+
+ # UPDATE LOSS METRICS
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
+
+ #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss)
+ if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
+ print(f" NaN value encountered in training run {self.info.wandb_run_id}", \
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
+
+ if self.is_main_node:
+ logs = {
+ 'loss': self.info.ema_loss,
+ 'backward_loss': loss_adjusted.mean().item(),
+ 'ema_loss': self.info.ema_loss,
+ 'raw_ori_loss': loss.mean().item(),
+ 'grad_norm': grad_norm.item(),
+ 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
+ 'total_steps': self.info.total_steps,
+ }
+ if iter_cnt % (self.config.save_every) == 0:
+
+ print(iter_cnt, max_iters, logs, epoch_cnt, )
+
+
+
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
+
+ # SAVE AND CHECKPOINT STUFF
+ if np.isnan(loss.mean().item()):
+ if self.is_main_node and self.config.wandb_project is not None:
+ print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
+
+ else:
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ self.info.adaptive_loss = {
+ 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
+ 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
+ }
+
+
+
+ if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
+ print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
+ torch.save(models.train_norm.state_dict(), \
+ f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors')
+
+ torch.save(models.train_norm.state_dict(), \
+ f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors')
+
+
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
+
+ if self.is_main_node:
+
+ self.sample(models, data, extras)
+
+
+ if self.info.iter >= max_iters:
+ break
+
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
+
+
+ models.generator.eval()
+ models.train_norm.eval()
+ with torch.no_grad():
+ batch = next(data.iterator)
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
+
+ shape_lr = self.get_target_lr_size(ratio)
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ latents = self.encode_latents(batch, models, extras)
+ latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
+
+
+ if self.is_main_node:
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
+ models.generator, conditions,
+ latents.shape, latents_lr.shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+
+
+
+ if self.is_main_node:
+ print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, )
+ noised_images = torch.cat(
+ [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
+
+ sampled_images = torch.cat(
+ [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
+
+
+ noised_images_lr = torch.cat(
+ [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
+
+ sampled_images_lr = torch.cat(
+ [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
+
+ images = batch['images']
+ if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
+ images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
+ images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
+
+ collage_img = torch.cat([
+ torch.cat([i for i in images.cpu()], dim=-1),
+ torch.cat([i for i in noised_images.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images.cpu()], dim=-1),
+ ], dim=-2)
+
+ collage_img_lr = torch.cat([
+ torch.cat([i for i in images_lr.cpu()], dim=-1),
+ torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
+ ], dim=-2)
+
+ torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
+ torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
+
+
+ models.generator.train()
+ models.train_norm.train()
+ print('finish sampling')
+
+
+
+ def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
+
+
+ models.generator.eval()
+
+ with torch.no_grad():
+
+ if self.is_main_node:
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
+ models.generator, conditions,
+ hr_shape, lr_shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+ if models.generator_ema is not None:
+
+ *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
+ models.generator_ema, conditions,
+ latents.shape, latents_lr.shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+ else:
+ sampled_ema = sampled
+ sampled_ema_lr = sampled_lr
+
+ return sampled, sampled_lr
+def main_worker(rank, cfg):
+ print("Launching Script in main worker")
+
+ warpcore = WurstCore(
+ config_file_path=cfg, rank=rank, world_size = get_world_size()
+ )
+ # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
+
+ # RUN TRAINING
+ warpcore(get_world_size()==1)
+
+if __name__ == '__main__':
+ print('launch multi process')
+ # os.environ["OMP_NUM_THREADS"] = "1"
+ # os.environ["MKL_NUM_THREADS"] = "1"
+ #dist.init_process_group(backend="nccl")
+ #torch.backends.cudnn.benchmark = True
+#train/train_c_my.py
+ #mp.set_sharing_strategy('file_system')
+
+ if get_master_ip() == "127.0.0.1":
+ # manually launch distributed processes
+ mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
+ else:
+ main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )
diff --git a/train/train_ultrapixel_control.py b/train/train_ultrapixel_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd67965973a85ed1d72c164dd0e8970f8b5ce277
--- /dev/null
+++ b/train/train_ultrapixel_control.py
@@ -0,0 +1,928 @@
+import torch
+import json
+import yaml
+import torchvision
+from torch import nn, optim
+from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
+from warmup_scheduler import GradualWarmupScheduler
+import torch.multiprocessing as mp
+import numpy as np
+import sys
+
+import os
+from dataclasses import dataclass
+from torch.distributed import init_process_group, destroy_process_group, barrier
+from gdf import GDF_dual_fixlrt as GDF
+from gdf import EpsilonTarget, CosineSchedule
+from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
+from torchtools.transforms import SmartCrop
+from fractions import Fraction
+from modules.effnet import EfficientNetEncoder
+
+from modules.model_4stage_lite import StageC
+
+from modules.model_4stage_lite import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
+from modules.common_ckpt import GlobalResponseNorm
+from modules.previewer import Previewer
+from core.data import Bucketeer
+from train.base import DataCore, TrainingCore
+from tqdm import tqdm
+from core import WarpCore
+from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from contextlib import contextmanager
+from train.dist_core import *
+import glob
+from torch.utils.data import DataLoader, Dataset
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+from PIL import Image
+from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
+from core.utils import Base
+from modules.common import LayerNorm2d
+import torch.nn.functional as F
+import functools
+import math
+import copy
+import random
+from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
+from modules import ControlNet, ControlNetDeliverer
+from modules import controlnet_filters
+
+Image.MAX_IMAGE_PIXELS = None
+torch.manual_seed(8432)
+random.seed(8432)
+np.random.seed(8432)
+#7978026
+
+class Null_Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ def forward(self, x):
+ pass
+
+
+def identity(x):
+ if isinstance(x, bytes):
+ x = x.decode('utf-8')
+ return x
+def check_nan_inmodel(model, meta=''):
+ for name, param in model.named_parameters():
+ if torch.isnan(param).any():
+ print(f"nan detected in {name}", meta)
+ return True
+ print('no nan', meta)
+ return False
+
+
+class WurstCore(TrainingCore, DataCore, WarpCore):
+ @dataclass(frozen=True)
+ class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
+ # TRAINING PARAMS
+ lr: float = EXPECTED_TRAIN
+ warmup_updates: int = EXPECTED_TRAIN
+ dtype: str = None
+
+ # MODEL VERSION
+ model_version: str = EXPECTED # 3.6B or 1B
+ clip_image_model_name: str = 'openai/clip-vit-large-patch14'
+ clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
+
+ # CHECKPOINT PATHS
+ effnet_checkpoint_path: str = EXPECTED
+ previewer_checkpoint_path: str = EXPECTED
+ #trans_inr_ckpt: str = EXPECTED
+ generator_checkpoint_path: str = None
+ controlnet_checkpoint_path: str = EXPECTED
+
+ # controlnet settings
+ controlnet_blocks: list = EXPECTED
+ controlnet_filter: str = EXPECTED
+ controlnet_filter_params: dict = None
+ controlnet_bottleneck_mode: str = None
+
+
+ # gdf customization
+ adaptive_loss_weight: str = None
+
+ #module_filters: list = EXPECTED
+ #rank: int = EXPECTED
+ @dataclass(frozen=True)
+ class Data(Base):
+ dataset: Dataset = EXPECTED
+ dataloader: DataLoader = EXPECTED
+ iterator: any = EXPECTED
+ sampler: DistributedSampler = EXPECTED
+
+ @dataclass(frozen=True)
+ class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
+ effnet: nn.Module = EXPECTED
+ previewer: nn.Module = EXPECTED
+ train_norm: nn.Module = EXPECTED
+ train_norm_ema: nn.Module = EXPECTED
+ controlnet: nn.Module = EXPECTED
+
+ @dataclass(frozen=True)
+ class Schedulers(WarpCore.Schedulers):
+ generator: any = None
+
+ @dataclass(frozen=True)
+ class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
+ gdf: GDF = EXPECTED
+ sampling_configs: dict = EXPECTED
+ effnet_preprocess: torchvision.transforms.Compose = EXPECTED
+ controlnet_filter: controlnet_filters.BaseFilter = EXPECTED
+
+ info: TrainingCore.Info
+ config: Config
+
+ def setup_extras_pre(self) -> Extras:
+ gdf = GDF(
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
+ input_scaler=VPScaler(), target=EpsilonTarget(),
+ noise_cond=CosineTNoiseCond(),
+ loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
+ )
+ sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
+
+ if self.info.adaptive_loss is not None:
+ gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
+ gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
+
+ effnet_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+ ])
+
+ clip_preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
+ torchvision.transforms.CenterCrop(224),
+ torchvision.transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
+ )
+ ])
+
+ if self.config.training:
+ transforms = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
+ SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
+ ])
+ else:
+ transforms = None
+ controlnet_filter = getattr(controlnet_filters, self.config.controlnet_filter)(
+ self.device,
+ **(self.config.controlnet_filter_params if self.config.controlnet_filter_params is not None else {})
+ )
+
+ return self.Extras(
+ gdf=gdf,
+ sampling_configs=sampling_configs,
+ transforms=transforms,
+ effnet_preprocess=effnet_preprocess,
+ clip_preprocess=clip_preprocess,
+ controlnet_filter=controlnet_filter
+ )
+ def get_cnet(self, batch: dict, models: Models, extras: Extras, cnet_input=None, target_size=None, **kwargs):
+ images = batch['images']
+ if target_size is not None:
+ images = Image.resize(images, target_size)
+ with torch.no_grad():
+ if cnet_input is None:
+ cnet_input = extras.controlnet_filter(images, **kwargs)
+ if isinstance(cnet_input, tuple):
+ cnet_input, cnet_input_preview = cnet_input
+ else:
+ cnet_input_preview = cnet_input
+ cnet_input, cnet_input_preview = cnet_input.to(self.device), cnet_input_preview.to(self.device)
+ cnet = models.controlnet(cnet_input)
+ return cnet, cnet_input_preview
+
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
+ eval_image_embeds=False, return_fields=None):
+ conditions = super().get_conditions(
+ batch, models, extras, is_eval, is_unconditional,
+ eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
+ )
+ return conditions
+
+ def setup_models(self, extras: Extras) -> Models: # configure model
+
+
+ dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
+
+ # EfficientNet encoderin
+ effnet = EfficientNetEncoder()
+ effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
+ effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
+ effnet.eval().requires_grad_(False).to(self.device)
+ del effnet_checkpoint
+
+ # Previewer
+ previewer = Previewer()
+ previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
+ previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
+ previewer.eval().requires_grad_(False).to(self.device)
+ del previewer_checkpoint
+
+ @contextmanager
+ def dummy_context():
+ yield None
+
+ loading_context = dummy_context if self.config.training else init_empty_weights
+
+ # Diffusion models
+ with loading_context():
+ generator_ema = None
+ if self.config.model_version == '3.6B':
+ generator = StageC()
+ if self.config.ema_start_iters is not None: # default setting
+ generator_ema = StageC()
+ elif self.config.model_version == '1B':
+
+ generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+
+ if self.config.ema_start_iters is not None and self.config.training:
+ generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
+ else:
+ raise ValueError(f"Unknown model version {self.config.model_version}")
+
+
+
+ if loading_context is dummy_context:
+ generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
+ else:
+ for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
+ set_module_tensor_to_device(generator, param_name, "cpu", value=param)
+
+ generator._init_extra_parameter()
+
+
+
+
+ generator = generator.to(torch.bfloat16).to(self.device)
+
+ train_norm = nn.ModuleList()
+
+
+ cnt_norm = 0
+ for mm in generator.modules():
+ if isinstance(mm, GlobalResponseNorm):
+
+ train_norm.append(Null_Model())
+ cnt_norm += 1
+
+
+
+
+ train_norm.append(generator.agg_net)
+ train_norm.append(generator.agg_net_up)
+
+
+
+
+ if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
+ sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
+ collect_sd = {}
+ for k, v in sdd.items():
+ collect_sd[k[7:]] = v
+ train_norm.load_state_dict(collect_sd, strict=True)
+
+
+ train_norm.to(self.device).train().requires_grad_(True)
+ train_norm_ema = copy.deepcopy(train_norm)
+ train_norm_ema.to(self.device).eval().requires_grad_(False)
+ if generator_ema is not None:
+
+ generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
+ generator_ema._init_extra_parameter()
+
+ pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
+ if os.path.exists(pretrained_pth):
+ print(pretrained_pth, 'exists')
+ generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
+
+ generator_ema.eval().requires_grad_(False)
+
+ check_nan_inmodel(generator, 'generator')
+
+
+
+ if self.config.use_fsdp and self.config.training:
+ train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True)
+
+
+ # CLIP encoders
+ tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
+ text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
+ image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
+
+ controlnet = ControlNet(
+ c_in=extras.controlnet_filter.num_channels(),
+ proj_blocks=self.config.controlnet_blocks,
+ bottleneck_mode=self.config.controlnet_bottleneck_mode
+ )
+ controlnet = controlnet.to(dtype).to(self.device)
+ controlnet = self.load_model(controlnet, 'controlnet')
+ controlnet.backbone.eval().requires_grad_(True)
+
+
+ return self.Models(
+ effnet=effnet, previewer=previewer, train_norm = train_norm,
+ generator=generator, generator_ema=generator_ema,
+ tokenizer=tokenizer, text_model=text_model, image_model=image_model,
+ train_norm_ema=train_norm_ema, controlnet =controlnet
+ )
+
+ def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+
+#
+
+ params = []
+ params += list(models.train_norm.module.parameters())
+
+ optimizer = optim.AdamW(params, lr=self.config.lr)
+
+ return self.Optimizers(generator=optimizer)
+
+ def ema_update(self, ema_model, source_model, beta):
+ for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
+ param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
+
+ def sync_ema(self, ema_model):
+ print('sync ema', torch.distributed.get_world_size())
+ for param in ema_model.parameters():
+ torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
+ param.data /= torch.distributed.get_world_size()
+ def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
+
+
+ optimizer = optim.AdamW(
+ models.generator.up_blocks.parameters() ,
+ lr=self.config.lr)
+ optimizer = self.load_optimizer(optimizer, 'generator_optim',
+ fsdp_model=models.generator if self.config.use_fsdp else None)
+ return self.Optimizers(generator=optimizer)
+
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
+ scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
+ scheduler.last_epoch = self.info.total_steps
+ return self.Schedulers(generator=scheduler)
+
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
+ # SETUP DATASET
+ dataset_path = self.config.webdataset_path
+ print('in line 96', dataset_path, type(dataset_path))
+
+ dataset = mydist_dataset(dataset_path, \
+ torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
+ else extras.transforms)
+
+ # SETUP DATALOADER
+ real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
+ print('in line 119', self.process_id, real_batch_size)
+ sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
+ dataloader = DataLoader(
+ dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True,
+ collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
+ sampler = sampler
+ )
+ if self.is_main_node:
+ print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
+
+ if self.config.multi_aspect_ratio is not None:
+ aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
+ dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
+ ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
+ interpolate_nearest=False) # , use_smartcrop=True)
+ else:
+
+ dataloader_iterator = iter(dataloader)
+
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
+
+
+
+
+
+ def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
+
+ if not single_gpu:
+ local_rank = rank
+ process_id = rank
+ world_size = get_world_size()
+
+ self.process_id = process_id
+ self.is_main_node = process_id == 0
+ self.device = torch.device(local_rank)
+ self.world_size = world_size
+
+
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '41443'
+ torch.cuda.set_device(local_rank)
+ init_process_group(
+ backend="nccl",
+ rank=local_rank,
+ world_size=world_size,
+ # init_method=init_method,
+ )
+ print(f"[GPU {process_id}] READY")
+ else:
+ self.is_main_node = rank == 0
+ self.process_id = rank
+ self.device = torch.device('cuda:0')
+ self.world_size = 1
+ print("Running in single thread, DDP not enabled.")
+ # Training loop --------------------------------
+ def get_target_lr_size(self, ratio, std_size=24):
+ w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
+ return (h * 32 , w * 32)
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
+ #batch = next(data.iterator)
+ batch = data
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
+ shape_lr = self.get_target_lr_size(ratio)
+
+ with torch.no_grad():
+ conditions = self.get_conditions(batch, models, extras)
+
+ latents = self.encode_latents(batch, models, extras)
+ latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
+
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
+ noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+
+ require_cond = True
+
+ with torch.no_grad():
+ _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
+
+
+ pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions)
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
+
+ loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
+ #
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ extras.gdf.loss_weight.update_buckets(logSNR, loss)
+
+ return loss, loss_adjusted
+
+ def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
+
+ if update:
+
+ torch.distributed.barrier()
+ loss_adjusted.backward()
+
+
+ grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0)
+
+ optimizers_dict = optimizers.to_dict()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].step()
+ schedulers_dict = schedulers.to_dict()
+ for k in schedulers_dict:
+ if k != 'training':
+ schedulers_dict[k].step()
+ for k in optimizers_dict:
+ if k != 'training':
+ optimizers_dict[k].zero_grad(set_to_none=True)
+ self.info.total_steps += 1
+ else:
+ #print('in line 457', loss_adjusted)
+ loss_adjusted.backward()
+ #torch.distributed.barrier()
+ grad_norm = torch.tensor(0.0).to(self.device)
+
+ return grad_norm
+
+ def models_to_save(self):
+ return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema']
+
+ def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
+
+ images = batch['images'].to(self.device)
+ if target_size is not None:
+ images = F.interpolate(images, target_size)
+ #images = apply_degradations(images)
+ return models.effnet(extras.effnet_preprocess(images))
+
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
+ return models.previewer(latents)
+
+ def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
+ # Temporary setup, will be overriden by setup_ddp if required
+ # self.device = device
+ # self.process_id = 0
+ # self.is_main_node = True
+ # self.world_size = 1
+ # ----
+ # self.world_size = world_size
+ # self.process_id = rank
+ # self.device=device
+ self.is_main_node = (rank == 0)
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
+ self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
+ self.info: self.Info = self.setup_info()
+ print('in line 292', self.config.experiment_id, rank, world_size <= 1)
+ p = [i for i in range( 2 * 768 // 32)]
+ p = [num / sum(p) for num in p]
+ self.rand_pro = p
+ self.res_list = [o for o in range(800, 2336, 32)]
+
+ #[32, 40, 48]
+ #in line 292 stage_c_3b_finetuning False
+
+ def __call__(self, single_gpu=False):
+ # this will change the device to the CUDA rank
+ #self.setup_wandb()
+ if self.config.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+ if self.is_main_node:
+ print()
+ print("**STARTIG JOB WITH CONFIG:**")
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
+ print("------------------------------------")
+ print()
+ print("**INFO:**")
+ print(yaml.dump(vars(self.info), default_flow_style=False))
+ print("------------------------------------")
+ print()
+ print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device )
+ # SETUP STUFF
+ extras = self.setup_extras_pre()
+ assert extras is not None, "setup_extras_pre() must return a DTO"
+
+
+
+ data = self.setup_data(extras)
+ assert data is not None, "setup_data() must return a DTO"
+ if self.is_main_node:
+ print("**DATA:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ models = self.setup_models(extras)
+ assert models is not None, "setup_models() must return a DTO"
+ if self.is_main_node:
+ print("**MODELS:**")
+ print(yaml.dump({
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
+ }, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+
+
+ optimizers = self.setup_optimizers(extras, models)
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
+ if self.is_main_node:
+ print("**OPTIMIZERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ schedulers = self.setup_schedulers(extras, models, optimizers)
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
+ if self.is_main_node:
+ print("**SCHEDULERS:**")
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
+ if self.is_main_node:
+ print("**EXTRAS:**")
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
+ print("------------------------------------")
+ print()
+ # -------
+
+ # TRAIN
+ if self.is_main_node:
+ print("**TRAINING STARTING...**")
+ self.train(data, extras, models, optimizers, schedulers)
+
+ if single_gpu is False:
+ barrier()
+ destroy_process_group()
+ if self.is_main_node:
+ print()
+ print("------------------------------------")
+ print()
+ print("**TRAINING COMPLETE**")
+ if self.config.wandb_project is not None:
+ wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
+
+
+ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
+ schedulers: WarpCore.Schedulers):
+ start_iter = self.info.iter + 1
+ max_iters = self.config.updates * self.config.grad_accum_steps
+ if self.is_main_node:
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
+
+
+ if self.is_main_node:
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
+ if 'generator' in self.models_to_save():
+ models.generator.train()
+ #initial_params = {name: param.clone() for name, param in models.train_norm.named_parameters()}
+ iter_cnt = 0
+ epoch_cnt = 0
+ models.train_norm.train()
+ while True:
+ epoch_cnt += 1
+ if self.world_size > 1:
+ print('sampler set epoch', epoch_cnt)
+ data.sampler.set_epoch(epoch_cnt)
+ for ggg in range(len(data.dataloader)):
+ iter_cnt += 1
+ # FORWARD PASS
+ #print('in line 414 before forward', iter_cnt, batch['captions'][0], self.process_id)
+ #loss, loss_adjusted, loss_extra = self.forward_pass(batch, extras, models)
+ loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
+
+ #print('in line 416', loss, iter_cnt)
+ # # BACKWARD PASS
+
+ grad_norm = self.backward_pass(
+ iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
+ models, optimizers, schedulers
+ )
+
+
+
+ self.info.iter = iter_cnt
+
+ # UPDATE EMA
+ if iter_cnt % self.config.ema_iters == 0:
+
+ with torch.no_grad():
+ print('in line 890 ema update', self.config.ema_iters, iter_cnt)
+ self.ema_update(models.train_norm_ema, models.train_norm, self.config.ema_beta)
+ #generator.module.agg_net.
+ #self.ema_update(models.generator_ema.agg_net, models.generator.module.agg_net, self.config.ema_beta)
+ #self.ema_update(models.generator_ema.agg_net_up, models.generator.module.agg_net_up, self.config.ema_beta)
+
+ # UPDATE LOSS METRICS
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
+
+ #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss)
+ if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
+ print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
+
+ if self.is_main_node:
+ logs = {
+ 'loss': self.info.ema_loss,
+ 'backward_loss': loss_adjusted.mean().item(),
+ #'raw_extra_loss': loss_extra.mean().item(),
+ 'ema_loss': self.info.ema_loss,
+ 'raw_ori_loss': loss.mean().item(),
+ #'raw_rec_loss': loss_rec.mean().item(),
+ #'raw_lr_loss': loss_lr.mean().item(),
+ #'reg_loss':loss_reg.item(),
+ 'grad_norm': grad_norm.item(),
+ 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
+ 'total_steps': self.info.total_steps,
+ }
+ if iter_cnt % (self.config.save_every) == 0:
+
+ print(iter_cnt, max_iters, logs, epoch_cnt, )
+ #pbar.set_postfix(logs)
+
+
+ #if iter_cnt % 10 == 0:
+
+
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
+ #if True:
+ # SAVE AND CHECKPOINT STUFF
+ if np.isnan(loss.mean().item()):
+ if self.is_main_node and self.config.wandb_project is not None:
+ print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
+
+ else:
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ self.info.adaptive_loss = {
+ 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
+ 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
+ }
+ #self.save_checkpoints(models, optimizers)
+
+ #torch.save(models.trans_inr.module.state_dict(), \
+ #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr.safetensors')
+ #torch.save(models.trans_inr_ema.state_dict(), \
+ #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr_ema.safetensors')
+
+
+ if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
+ print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
+ torch.save(models.train_norm.state_dict(), \
+ f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors')
+
+ #self.sync_ema(models.train_norm_ema)
+ torch.save(models.train_norm_ema.state_dict(), \
+ f'{self.config.output_path}/{self.config.experiment_id}/train_norm_ema.safetensors')
+ #if self.is_main_node and iter_cnt % (4 * self.config.save_every * self.config.grad_accum_steps) == 0:
+ torch.save(models.train_norm.state_dict(), \
+ f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors')
+
+
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
+
+ if self.is_main_node:
+ #check_nan_inmodel(models.generator, 'generator')
+ #check_nan_inmodel(models.generator_ema, 'generator_ema')
+ self.sample(models, data, extras)
+ if False:
+ param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()}
+ threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10%
+ important_params = [name for name, change in param_changes.items() if change > threshold]
+ print(important_params, threshold, len(param_changes), self.process_id)
+ json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4)
+
+
+ if self.info.iter >= max_iters:
+ break
+
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
+
+ #if 'generator' in self.models_to_save():
+ models.generator.eval()
+ models.train_norm.eval()
+ with torch.no_grad():
+ batch = next(data.iterator)
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
+ #batch['images'] = batch['images'].to(torch.float16)
+ shape_lr = self.get_target_lr_size(ratio)
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+ cnet, cnet_input = self.get_cnet(batch, models, extras)
+ conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet}
+
+ latents = self.encode_latents(batch, models, extras)
+ latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
+
+ if self.is_main_node:
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ #print('in line 366 on v100 switch to tf16')
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
+ models.generator, models.trans_inr, conditions,
+ latents.shape, latents_lr.shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+
+
+ #else:
+ sampled_ema = sampled
+ sampled_ema_lr = sampled_lr
+
+
+ if self.is_main_node:
+ print('sampling results', latents.shape, latents_lr.shape, )
+ noised_images = torch.cat(
+ [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
+
+ sampled_images = torch.cat(
+ [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
+ sampled_images_ema = torch.cat(
+ [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))],
+ dim=0)
+
+ noised_images_lr = torch.cat(
+ [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
+
+ sampled_images_lr = torch.cat(
+ [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
+ sampled_images_ema_lr = torch.cat(
+ [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))],
+ dim=0)
+
+ images = batch['images']
+ if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
+ images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
+ images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
+
+ collage_img = torch.cat([
+ torch.cat([i for i in images.cpu()], dim=-1),
+ torch.cat([i for i in noised_images.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_ema.cpu()], dim=-1),
+ ], dim=-2)
+
+ collage_img_lr = torch.cat([
+ torch.cat([i for i in images_lr.cpu()], dim=-1),
+ torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
+ torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1),
+ ], dim=-2)
+
+ torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
+ torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
+ #torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg')
+
+ captions = batch['captions']
+ if self.config.wandb_project is not None:
+ log_data = [
+ [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [
+ wandb.Image(images[i])] for i in range(len(images))]
+ log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"])
+ wandb.log({"Log": log_table})
+
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
+ plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1])
+ plt.ylabel('Raw Loss')
+ plt.ylabel('LogSNR')
+ wandb.log({"Loss/LogSRN": plt})
+
+ #if 'generator' in self.models_to_save():
+ models.generator.train()
+ models.train_norm.train()
+ print('finishe sampling in line 901')
+
+
+
+ def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
+
+ #if 'generator' in self.models_to_save():
+ models.generator.eval()
+ models.trans_inr.eval()
+ models.controlnet.eval()
+ with torch.no_grad():
+
+ if self.is_main_node:
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
+ cnet, cnet_input = self.get_cnet(batch, models, extras, target_size = lr_shape)
+ conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet}
+
+ #print('in line 885', self.is_main_node)
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ #print('in line 366 on v100 switch to tf16')
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
+ models.generator, models.trans_inr, conditions,
+ hr_shape, lr_shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+ if models.generator_ema is not None:
+
+ *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
+ models.generator_ema, models.trans_inr_ema, conditions,
+ latents.shape, latents_lr.shape,
+ unconditions, device=self.device, **extras.sampling_configs
+ )
+
+ else:
+ sampled_ema = sampled
+ sampled_ema_lr = sampled_lr
+ #x0, x, epsilon, x0_lr, x_lr, pred_lr)
+ #sampled, _ = models.trans_inr(sampled, None, sampled)
+ #sampled_lr, _ = models.trans_inr(sampled, None, sampled_lr)
+
+ return sampled, sampled_lr
+def main_worker(rank, cfg):
+ print("Launching Script in main worker")
+ print('in line 467', rank)
+ warpcore = WurstCore(
+ config_file_path=cfg, rank=rank, world_size = get_world_size()
+ )
+ # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
+
+ # RUN TRAINING
+ warpcore(get_world_size()==1)
+
+if __name__ == '__main__':
+ print('launch multi process')
+ # os.environ["OMP_NUM_THREADS"] = "1"
+ # os.environ["MKL_NUM_THREADS"] = "1"
+ #dist.init_process_group(backend="nccl")
+ #torch.backends.cudnn.benchmark = True
+#train/train_c_my.py
+ #mp.set_sharing_strategy('file_system')
+ print('in line 481', sys.argv[1] if len(sys.argv) > 1 else None)
+ print('in line 481',get_master_ip(), get_world_size() )
+ print('in line 484', get_world_size())
+ if get_master_ip() == "127.0.0.1":
+ # manually launch distributed processes
+ mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
+ else:
+ main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )