diff --git a/app.py b/app.py index bf08cc43b90980c3b7c915b94c841af31e36963e..3c4ac8cf1a8d4a201305effc5566332bf0acf82d 100644 --- a/app.py +++ b/app.py @@ -1,213 +1,158 @@ import spaces -import json -import subprocess import os -import sys - -def run_command(command): - process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) - output, error = process.communicate() - if process.returncode != 0: - print(f"Error executing command: {command}") - print(error.decode('utf-8')) - exit(1) - return output.decode('utf-8') - -# Download CUDA installer -download_command = "wget https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" -result = run_command(download_command) -if result is None: - print("Failed to download CUDA installer.") - exit(1) - -# Run CUDA installer in silent mode -install_command = "sh cuda_12.2.0_535.54.03_linux.run --silent --toolkit --samples --override" -result = run_command(install_command) -if result is None: - print("Failed to run CUDA installer.") - exit(1) - -print("CUDA installation process completed.") - -def install_packages(): - - # Clone the repository with submodules - run_command("git clone --recurse-submodules https://github.com/abetlen/llama-cpp-python.git") - - # Change to the cloned directory - os.chdir("llama-cpp-python") - - # Checkout the specific commit in the llama.cpp submodule - os.chdir("vendor/llama.cpp") - run_command("git checkout 50e0535") - os.chdir("../..") - - # Upgrade pip - run_command("pip install --upgrade pip") - - - - # Install all optional dependencies with CUDA support - run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .') - - run_command("make clean && GGML_OPENBLAS=1 make -j") - - # Reinstall the package with CUDA support - run_command('CMAKE_ARGS="-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DCUDA_PATH=/usr/local/cuda-12.2 -DCUDAToolkit_ROOT=/usr/local/cuda-12.2 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.2/lib64" FORCE_CMAKE=1 pip install -e .') - - # Install llama-cpp-agent - run_command("pip install llama-cpp-agent") - - run_command("export PYTHONPATH=$PYTHONPATH:$(pwd)") - - print("Installation complete!") - -try: - install_packages() - - # Add a delay to allow for package registration - import time - time.sleep(5) - - # Force Python to reload the site packages - import site - import importlib - importlib.reload(site) - - # Now try to import the libraries - from llama_cpp import Llama - from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType - from llama_cpp_agent.providers import LlamaCppPythonProvider - from llama_cpp_agent.chat_history import BasicChatHistory - from llama_cpp_agent.chat_history.messages import Roles - - print("Libraries imported successfully!") -except Exception as e: - print(f"Installation failed or libraries couldn't be imported: {str(e)}") - sys.exit(1) - +import requests +import yaml +import torch import gradio as gr +from PIL import Image +import sys +sys.path.append(os.path.abspath('./')) +from inference.utils import * +from core.utils import load_or_fail +from train import WurstCoreB +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from train import WurstCore_t2i as WurstCoreC +import torch.nn.functional as F +from core.utils import load_or_fail +import numpy as np +import random +import math +from einops import rearrange from huggingface_hub import hf_hub_download -hf_hub_download( - repo_id="MaziyarPanahi/Mistral-Nemo-Instruct-2407-GGUF", - filename="Mistral-Nemo-Instruct-2407.Q5_K_M.gguf", - local_dir="./models" -) -# Initialize LLM outside the respond function -llm = Llama( - model_path="models/Mistral-Nemo-Instruct-2407.Q5_K_M.gguf", - flash_attn=True, - n_gpu_layers=81, - n_batch=1024, - n_ctx=32768, +def download_file(url, folder_path, filename): + if not os.path.exists(folder_path): + os.makedirs(folder_path) + file_path = os.path.join(folder_path, filename) + + if os.path.isfile(file_path): + print(f"File already exists: {file_path}") + else: + response = requests.get(url, stream=True) + if response.status_code == 200: + with open(file_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=1024): + file.write(chunk) + print(f"File successfully downloaded and saved: {file_path}") + else: + print(f"Error downloading the file. Status code: {response.status_code}") + +def download_models(): + models = { + "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/StableWurst", "stage_a.safetensors"), + "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/StableWurst", "previewer.safetensors"), + "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/StableWurst", "effnet_encoder.safetensors"), + "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/StableWurst", "stage_b_lite_bf16.safetensors"), + "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/StableWurst", "stage_c_bf16.safetensors"), + "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/UltraPixel", "ultrapixel_t2i.safetensors"), + "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/UltraPixel", "lora_cat.safetensors"), + } + + for model, (url, folder, filename) in models.items(): + download_file(url, folder, filename) + +download_models() + +# Global variables +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dtype = torch.bfloat16 + +# Load configs and setup models +with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file: + config_c = yaml.safe_load(file) + +with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file: + config_b = yaml.safe_load(file) + +core = WurstCoreC(config_dict=config_c, device=device, training=False) +core_b = WurstCoreB(config_dict=config_b, device=device, training=False) + +extras = core.setup_extras_pre() +models = core.setup_models(extras) +models.generator.eval().requires_grad_(False) + +extras_b = core_b.setup_extras_pre() +models_b = core_b.setup_models(extras_b, skip_clip=True) +models_b = WurstCoreB.Models( + **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} ) - -provider = LlamaCppPythonProvider(llm) - -@spaces.GPU(duration=120) -def respond( - message, - history: list[tuple[str, str]], - system_message, - max_tokens, - temperature, - top_p, - top_k, - repeat_penalty, -): - chat_template = MessagesFormatterType.MISTRAL - - agent = LlamaCppAgent( - provider, - system_prompt=f"{system_message}", - predefined_messages_formatter_type=chat_template, - debug_output=True - ) +models_b.generator.bfloat16().eval().requires_grad_(False) + +# Load pretrained model +pretrained_path = "models/ultrapixel_t2i.safetensors" +sdd = torch.load(pretrained_path, map_location='cpu') +collect_sd = {k[7:]: v for k, v in sdd.items()} +models.train_norm.load_state_dict(collect_sd) +models.generator.eval() +models.train_norm.eval() + +# Set up sampling configurations +extras.sampling_configs.update({ + 'cfg': 4, + 'shift': 1, + 'timesteps': 20, + 't_start': 1.0, + 'sampler': DDPMSampler(extras.gdf) +}) + +extras_b.sampling_configs.update({ + 'cfg': 1.1, + 'shift': 1, + 'timesteps': 10, + 't_start': 1.0 +}) + +@spaces.GPU +def generate_image(prompt, height, width, seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + batch_size = 1 + height_lr, width_lr = get_target_lr_size(height / width, std_size=32) + stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) + stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) + + batch = {'captions': [prompt] * batch_size} + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - settings = provider.get_provider_default_settings() - settings.temperature = temperature - settings.top_k = top_k - settings.top_p = top_p - settings.max_tokens = max_tokens - settings.repeat_penalty = repeat_penalty - settings.stream = True - - messages = BasicChatHistory() - - for msn in history: - user = { - 'role': Roles.user, - 'content': msn[0] - } - assistant = { - 'role': Roles.assistant, - 'content': msn[1] - } - messages.add_message(user) - messages.add_message(assistant) + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - stream = agent.get_chat_response( - message, - llm_sampling_settings=settings, - chat_history=messages, - returns_streaming_generator=True, - print_output=False - ) - - outputs = "" - for output in stream: - outputs += output - yield outputs - -description = """

-[Instruct Model] -[Base Model] -[GGUF Version] -

-""" - -demo = gr.ChatInterface( - respond, - additional_inputs=[ - gr.Textbox(value="You are a helpful assistant.", label="System message"), - gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"), - gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), - gr.Slider( - minimum=0.1, - maximum=1.0, - value=0.95, - step=0.05, - label="Top-p", - ), - gr.Slider( - minimum=0, - maximum=100, - value=40, - step=1, - label="Top-k", - ), - gr.Slider( - minimum=0.0, - maximum=2.0, - value=1.1, - step=0.1, - label="Repetition penalty", - ), + with torch.no_grad(): + models.generator.cuda() + with torch.cuda.amp.autocast(dtype=dtype): + sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) + + models.generator.cpu() + torch.cuda.empty_cache() + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + conditions_b['effnet'] = sampled_c + unconditions_b['effnet'] = torch.zeros_like(sampled_c) + + with torch.cuda.amp.autocast(dtype=dtype): + sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True) + + torch.cuda.empty_cache() + imgs = show_images(sampled) + return imgs[0] + +iface = gr.Interface( + fn=generate_image, + inputs=[ + gr.Textbox(label="Prompt"), + gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024), + gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024), + gr.Number(label="Seed", value=42) ], - retry_btn="Retry", - undo_btn="Undo", - clear_btn="Clear", - submit_btn="Send", - title="Chat with Mistral-NeMo using llama.cpp", - description=description, - chatbot=gr.Chatbot( - scale=1, - likeable=False, - show_copy_button=True - ) + outputs=gr.Image(type="pil"), + title="UltraPixel Image Generation", + description="Generate high-resolution images using UltraPixel model.", + theme='bethecloud/storj_theme' ) -if __name__ == "__main__": - demo.launch(debug=True) \ No newline at end of file +iface.launch() \ No newline at end of file diff --git a/configs/inference/controlnet_c_3b_canny.yaml b/configs/inference/controlnet_c_3b_canny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..286d7a6c8017e922a020d6ae5633cc3e27f9b702 --- /dev/null +++ b/configs/inference/controlnet_c_3b_canny.yaml @@ -0,0 +1,14 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: CannyFilter +controlnet_filter_params: + resize: 224 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: models/canny.safetensors diff --git a/configs/inference/controlnet_c_3b_identity.yaml b/configs/inference/controlnet_c_3b_identity.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a20fa860fed5f6eea1d33113535c2633205e327 --- /dev/null +++ b/configs/inference/controlnet_c_3b_identity.yaml @@ -0,0 +1,17 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_bottleneck_mode: 'simple' +controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] +controlnet_filter: IdentityFilter +controlnet_filter_params: + max_faces: 4 + p_drop: 0.00 + p_full: 0.0 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: diff --git a/configs/inference/controlnet_c_3b_inpainting.yaml b/configs/inference/controlnet_c_3b_inpainting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a94bd7953dfa407184d9094b481a56cdbbb73549 --- /dev/null +++ b/configs/inference/controlnet_c_3b_inpainting.yaml @@ -0,0 +1,15 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: InpaintFilter +controlnet_filter_params: + thresold: [0.04, 0.4] + p_outpaint: 0.4 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: models/inpainting.safetensors diff --git a/configs/inference/controlnet_c_3b_sr.yaml b/configs/inference/controlnet_c_3b_sr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13c4a2cd2dcd2a3cf87fb32bd6e34269e796a747 --- /dev/null +++ b/configs/inference/controlnet_c_3b_sr.yaml @@ -0,0 +1,15 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# ControlNet specific +controlnet_bottleneck_mode: 'large' +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: SREffnetFilter +controlnet_filter_params: + scale_factor: 0.5 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: models/super_resolution.safetensors diff --git a/configs/inference/lora_c_3b.yaml b/configs/inference/lora_c_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7468078c657c1f569c6c052a14b265d69082ab25 --- /dev/null +++ b/configs/inference/lora_c_3b.yaml @@ -0,0 +1,15 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +# LoRA specific +module_filters: ['.attn'] +rank: 4 +train_tokens: + # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized + - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +lora_checkpoint_path: models/lora_fernando_10k.safetensors diff --git a/configs/inference/stage_b_1b.yaml b/configs/inference/stage_b_1b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0811cae75622614e91de6532262acb2c062bf344 --- /dev/null +++ b/configs/inference/stage_b_1b.yaml @@ -0,0 +1,13 @@ +# GLOBAL STUFF +model_version: 700M +dtype: bfloat16 + +# For demonstration purposes in reconstruct_images.ipynb +webdataset_path: path to your dataset +batch_size: 1 +image_size: 2048 +grad_accum_steps: 1 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +stage_a_checkpoint_path: models/stage_a.safetensors +generator_checkpoint_path: models/stage_b_lite_bf16.safetensors \ No newline at end of file diff --git a/configs/inference/stage_b_3b.yaml b/configs/inference/stage_b_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..840268980103e0c629599b966705043d6a616578 --- /dev/null +++ b/configs/inference/stage_b_3b.yaml @@ -0,0 +1,13 @@ +# GLOBAL STUFF +model_version: 3B +dtype: bfloat16 + +# For demonstration purposes in reconstruct_images.ipynb +webdataset_path: path to your dataset +batch_size: 4 +image_size: 1024 +grad_accum_steps: 1 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +stage_a_checkpoint_path: models/stage_a.safetensors +generator_checkpoint_path: models/stage_b_lite_bf16.safetensors \ No newline at end of file diff --git a/configs/inference/stage_c_1b.yaml b/configs/inference/stage_c_1b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..781886e515d80e7870abb89bf8fd0ce7c7c8d4b6 --- /dev/null +++ b/configs/inference/stage_c_1b.yaml @@ -0,0 +1,7 @@ +# GLOBAL STUFF +model_version: 1B +dtype: bfloat16 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_lite_bf16.safetensors \ No newline at end of file diff --git a/configs/inference/stage_c_3b.yaml b/configs/inference/stage_c_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b22897e71996ad78f3832af78f5bc44ca06d206d --- /dev/null +++ b/configs/inference/stage_c_3b.yaml @@ -0,0 +1,7 @@ +# GLOBAL STUFF +model_version: 3.6B +dtype: bfloat16 + +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/cfg_control_lr.yaml b/configs/training/cfg_control_lr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2955b6a925504525b981e7004b65a33573c08aef --- /dev/null +++ b/configs/training/cfg_control_lr.yaml @@ -0,0 +1,47 @@ +# GLOBAL STUFF +experiment_id: Ultrapixel_controlnet + +checkpoint_path: checkpoint output path +output_path: visual results output path +model_version: 3.6B +dtype: float32 +# # WandB +# wandb_project: StableCascade +# wandb_entity: wandb_username +#module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ] +#rank: 32 +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 12 +#image_size: [1536, 2048, 2560, 3072, 4096] +image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] +#image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608] +#image_size: [ 1024, 1280] +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 2 +updates: 40000 +backup_every: 5000 +save_every: 256 +warmup_updates: 1 +use_fsdp: True + +# ControlNet specific +controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] +controlnet_filter: CannyFilter +controlnet_filter_params: + resize: 224 +# offset_noise: 0.1 + +# GDF +adaptive_loss_weight: True + +ema_start_iters: 10 +ema_iters: 50 +ema_beta: 0.9 + +webdataset_path: path to your training dataset +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +controlnet_checkpoint_path: pretrained controlnet path + diff --git a/configs/training/lora_personalization.yaml b/configs/training/lora_personalization.yaml new file mode 100644 index 0000000000000000000000000000000000000000..857795e6d37e9cb61bd76aa588f432978ed90ad2 --- /dev/null +++ b/configs/training/lora_personalization.yaml @@ -0,0 +1,37 @@ +# GLOBAL STUFF +experiment_id: roubao_cat_personalized + +checkpoint_path: checkpoint output path +output_path: visual results output path +model_version: 3.6B +dtype: float32 + +module_filters: [ '.attn'] +rank: 4 +train_tokens: + # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized + - ['[roubaobao]', '^cat'] # custom token [snail], initialize as avg of snail & snails +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 4 + +image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 2 +updates: 40000 +backup_every: 5000 +save_every: 512 +warmup_updates: 1 +use_ddp: True + +# GDF +adaptive_loss_weight: True + + +tmp_prompt: a photo of a cat [roubaobao] +webdataset_path: path to your personalized training dataset +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors +ultrapixel_path: models/ultrapixel_t2i.safetensors + diff --git a/configs/training/t2i.yaml b/configs/training/t2i.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a0ceaca0ad8813e3c9b998661ac3e9b3c0937fd --- /dev/null +++ b/configs/training/t2i.yaml @@ -0,0 +1,29 @@ +# GLOBAL STUFF +experiment_id: ultrapixel_t2i +#strc_fixlrt_norm3_lite_1024_hrft_newdata +checkpoint_path: checkpoint output path #output model directory +output_path: visual results output path #experiment output directory +model_version: 3.6B # finetune large stage c model of stablecascade +dtype: float32 + + +# TRAINING PARAMS +lr: 1.0e-4 +batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps +image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution +multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] +grad_accum_steps: 2 +updates: 40000 +backup_every: 5000 +save_every: 256 +warmup_updates: 1 +use_ddp: True + +# GDF +adaptive_loss_weight: True + + +webdataset_path: path to your personalized training dataset +effnet_checkpoint_path: models/effnet_encoder.safetensors +previewer_checkpoint_path: models/previewer.safetensors +generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed382f1907ddc86c7e9a9618c21441755a6221a9 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,372 @@ +import os +import yaml +import torch +from torch import nn +import wandb +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass +from torch.utils.data import Dataset, DataLoader + +from torch.distributed import init_process_group, destroy_process_group, barrier +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictType +) + +from .utils import Base, EXPECTED, EXPECTED_TRAIN +from .utils import create_folder_if_necessary, safe_save, load_or_fail + +# pylint: disable=unused-argument +class WarpCore(ABC): + @dataclass(frozen=True) + class Config(Base): + experiment_id: str = EXPECTED_TRAIN + checkpoint_path: str = EXPECTED_TRAIN + output_path: str = EXPECTED_TRAIN + checkpoint_extension: str = "safetensors" + dist_file_subfolder: str = "" + allow_tf32: bool = True + + wandb_project: str = None + wandb_entity: str = None + + @dataclass() # not frozen, means that fields are mutable + class Info(): # not inheriting from Base, because we don't want to enforce the default fields + wandb_run_id: str = None + total_steps: int = 0 + iter: int = 0 + + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + + @dataclass(frozen=True) + class Models(Base): + pass + + @dataclass(frozen=True) + class Optimizers(Base): + pass + + @dataclass(frozen=True) + class Schedulers(Base): + pass + + @dataclass(frozen=True) + class Extras(Base): + pass + # --------------------------------------- + info: Info + config: Config + + # FSDP stuff + fsdp_defaults = { + "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, + "cpu_offload": None, + "mixed_precision": MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ), + "limit_all_gathers": True, + } + fsdp_fullstate_save_policy = FullStateDictConfig( + offload_to_cpu=True, rank0_only=True + ) + # ------------ + + # OVERRIDEABLE METHODS + + # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup + def setup_extras_pre(self) -> Extras: + return self.Extras() + + # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator + @abstractmethod + def setup_data(self, extras: Extras) -> Data: + raise NotImplementedError("This method needs to be overriden") + + # return a dict with all models that are going to be used in the training + @abstractmethod + def setup_models(self, extras: Extras) -> Models: + raise NotImplementedError("This method needs to be overriden") + + # return a dict with all optimizers that are going to be used in the training + @abstractmethod + def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: + raise NotImplementedError("This method needs to be overriden") + + # [optionally] return a dict with all schedulers that are going to be used in the training + def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: + return self.Schedulers() + + # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup + def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras: + return self.Extras.from_dict(extras.to_dict()) + + # perform the training here + @abstractmethod + def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + raise NotImplementedError("This method needs to be overriden") + # ------------ + + def setup_info(self, full_path=None) -> Info: + if full_path is None: + full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json") + info_dict = load_or_fail(full_path, wandb_run_id=None) or {} + info_dto = self.Info(**info_dict) + if info_dto.total_steps > 0 and self.is_main_node: + print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps) + return info_dto + + def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config: + if config_file_path is not None: + if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"): + with open(config_file_path, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + elif config_file_path.endswith(".json"): + with open(config_file_path, "r", encoding="utf-8") as file: + loaded_config = json.load(file) + else: + raise ValueError("Config file must be either a .yml|.yaml or .json file") + return self.Config.from_dict({**loaded_config, 'training': training}) + if config_dict is not None: + return self.Config.from_dict({**config_dict, 'training': training}) + return self.Config(training=training) + + def setup_ddp(self, experiment_id, single_gpu=False): + if not single_gpu: + local_rank = int(os.environ.get("SLURM_LOCALID")) + process_id = int(os.environ.get("SLURM_PROCID")) + world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}" + # if os.path.exists(dist_file_path) and self.is_main_node: + # os.remove(dist_file_path) + + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=process_id, + world_size=world_size, + init_method=f"file://{dist_file_path}", + ) + print(f"[GPU {process_id}] READY") + else: + print("Running in single thread, DDP not enabled.") + + def setup_wandb(self): + if self.is_main_node and self.config.wandb_project is not None: + self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id() + wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict()) + + if self.info.total_steps > 0: + wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}") + else: + wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started") + + # LOAD UTILITIES ---------- + def load_model(self, model, model_id=None, full_path=None, strict=True): + print('in line 181 load model', type(model), model_id, full_path, strict) + if model_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" + elif full_path is None and model_id is None: + raise ValueError( + "This method expects either 'model_id' or 'full_path' to be defined" + ) + + checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) + if checkpoint is not None: + model.load_state_dict(checkpoint, strict=strict) + del checkpoint + + return model + + def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): + if optim_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" + elif full_path is None and optim_id is None: + raise ValueError( + "This method expects either 'optim_id' or 'full_path' to be defined" + ) + + checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) + if checkpoint is not None: + try: + if fsdp_model is not None: + sharded_optimizer_state_dict = ( + FSDP.scatter_full_optim_state_dict( # <---- FSDP + checkpoint + if ( + self.is_main_node + or self.fsdp_defaults["sharding_strategy"] + == ShardingStrategy.NO_SHARD + ) + else None, + fsdp_model, + ) + ) + optim.load_state_dict(sharded_optimizer_state_dict) + del checkpoint, sharded_optimizer_state_dict + else: + optim.load_state_dict(checkpoint) + # pylint: disable=broad-except + except Exception as e: + print("!!! Failed loading optimizer, skipping... Exception:", e) + + return optim + + # SAVE UTILITIES ---------- + def save_info(self, info, suffix=""): + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json" + create_folder_if_necessary(full_path) + if self.is_main_node: + safe_save(vars(self.info), full_path) + + def save_model(self, model, model_id=None, full_path=None, is_fsdp=False): + if model_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" + elif full_path is None and model_id is None: + raise ValueError( + "This method expects either 'model_id' or 'full_path' to be defined" + ) + create_folder_if_necessary(full_path) + if is_fsdp: + with FSDP.summon_full_params(model): + pass + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy + ): + checkpoint = model.state_dict() + if self.is_main_node: + safe_save(checkpoint, full_path) + del checkpoint + else: + if self.is_main_node: + checkpoint = model.state_dict() + safe_save(checkpoint, full_path) + del checkpoint + + def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): + if optim_id is not None and full_path is None: + full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" + elif full_path is None and optim_id is None: + raise ValueError( + "This method expects either 'optim_id' or 'full_path' to be defined" + ) + create_folder_if_necessary(full_path) + if fsdp_model is not None: + optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) + if self.is_main_node: + safe_save(optim_statedict, full_path) + del optim_statedict + else: + if self.is_main_node: + checkpoint = optim.state_dict() + safe_save(checkpoint, full_path) + del checkpoint + # ----- + + def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True): + # Temporary setup, will be overriden by setup_ddp if required + self.device = device + self.process_id = 0 + self.is_main_node = True + self.world_size = 1 + # ---- + + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.info: self.Info = self.setup_info() + + def __call__(self, single_gpu=False): + self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank + self.setup_wandb() + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + if self.config.wandb_project is not None: + wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") diff --git a/core/data/__init__.py b/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b687719914b2e303909f7c280347e4bdee607d13 --- /dev/null +++ b/core/data/__init__.py @@ -0,0 +1,69 @@ +import json +import subprocess +import yaml +import os +from .bucketeer import Bucketeer + +class MultiFilter(): + def __init__(self, rules, default=False): + self.rules = rules + self.default = default + + def __call__(self, x): + try: + x_json = x['json'] + if isinstance(x_json, bytes): + x_json = json.loads(x_json) + validations = [] + for k, r in self.rules.items(): + if isinstance(k, tuple): + v = r(*[x_json[kv] for kv in k]) + else: + v = r(x_json[k]) + validations.append(v) + return all(validations) + except Exception: + return False + +class MultiGetter(): + def __init__(self, rules): + self.rules = rules + + def __call__(self, x_json): + if isinstance(x_json, bytes): + x_json = json.loads(x_json) + outputs = [] + for k, r in self.rules.items(): + if isinstance(k, tuple): + v = r(*[x_json[kv] for kv in k]) + else: + v = r(x_json[k]) + outputs.append(v) + if len(outputs) == 1: + outputs = outputs[0] + return outputs + +def setup_webdataset_path(paths, cache_path=None): + if cache_path is None or not os.path.exists(cache_path): + tar_paths = [] + if isinstance(paths, str): + paths = [paths] + for path in paths: + if path.strip().endswith(".tar"): + # Avoid looking up s3 if we already have a tar file + tar_paths.append(path) + continue + bucket = "/".join(path.split("/")[:3]) + result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) + files = result.stdout.decode('utf-8').split() + files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] + tar_paths += files + + with open(cache_path, 'w', encoding='utf-8') as outfile: + yaml.dump(tar_paths, outfile, default_flow_style=False) + else: + with open(cache_path, 'r', encoding='utf-8') as file: + tar_paths = yaml.safe_load(file) + + tar_paths_str = ",".join([f"{p}" for p in tar_paths]) + return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" diff --git a/core/data/bucketeer.py b/core/data/bucketeer.py new file mode 100644 index 0000000000000000000000000000000000000000..131e6ba4293bd7c00399f08609aba184b712d5e8 --- /dev/null +++ b/core/data/bucketeer.py @@ -0,0 +1,88 @@ +import torch +import torchvision +import numpy as np +from torchtools.transforms import SmartCrop +import math + +class Bucketeer(): + def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): + assert crop_mode in ['center', 'random', 'smart'] + self.crop_mode = crop_mode + self.ratios = ratios + if reverse_list: + for r in list(ratios): + if 1/r not in self.ratios: + self.ratios.append(1/r) + self.sizes = {} + for dd in density: + self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] + + self.batch_size = dataloader.batch_size + self.iterator = iter(dataloader) + all_sizes = [] + for k, vs in self.sizes.items(): + all_sizes += vs + self.buckets = {s: [] for s in all_sizes} + self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None + self.p_random_ratio = p_random_ratio + self.interpolate_nearest = interpolate_nearest + + def get_available_batch(self): + for b in self.buckets: + if len(self.buckets[b]) >= self.batch_size: + batch = self.buckets[b][:self.batch_size] + self.buckets[b] = self.buckets[b][self.batch_size:] + return batch + return None + + def get_closest_size(self, x): + w, h = x.size(-1), x.size(-2) + + + best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) + find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} + min_ = find_dict[list(find_dict.keys())[0]] + find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] + for dd, val in find_dict.items(): + if val < min_: + min_ = val + find_size = self.sizes[dd][best_size_idx] + + return find_size + + def get_resize_size(self, orig_size, tgt_size): + if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: + alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) + resize_size = max(alt_min, min(tgt_size)) + else: + alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) + resize_size = max(alt_max, max(tgt_size)) + + return resize_size + + def __next__(self): + batch = self.get_available_batch() + while batch is None: + elements = next(self.iterator) + for dct in elements: + img = dct['images'] + size = self.get_closest_size(img) + resize_size = self.get_resize_size(img.shape[-2:], size) + + if self.interpolate_nearest: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) + else: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) + if self.crop_mode == 'center': + img = torchvision.transforms.functional.center_crop(img, size) + elif self.crop_mode == 'random': + img = torchvision.transforms.RandomCrop(size)(img) + elif self.crop_mode == 'smart': + self.smartcrop.output_size = size + img = self.smartcrop(img) + + self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) + batch = self.get_available_batch() + + out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} + return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} diff --git a/core/data/bucketeer_deg.py b/core/data/bucketeer_deg.py new file mode 100644 index 0000000000000000000000000000000000000000..7206ccf08932f617abb811221cc7bbe1d126f184 --- /dev/null +++ b/core/data/bucketeer_deg.py @@ -0,0 +1,91 @@ +import torch +import torchvision +import numpy as np +from torchtools.transforms import SmartCrop +import math + +class Bucketeer(): + def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): + assert crop_mode in ['center', 'random', 'smart'] + self.crop_mode = crop_mode + self.ratios = ratios + if reverse_list: + for r in list(ratios): + if 1/r not in self.ratios: + self.ratios.append(1/r) + self.sizes = {} + for dd in density: + self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] + print('in line 17 buckteer', self.sizes) + self.batch_size = dataloader.batch_size + self.iterator = iter(dataloader) + all_sizes = [] + for k, vs in self.sizes.items(): + all_sizes += vs + self.buckets = {s: [] for s in all_sizes} + self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None + self.p_random_ratio = p_random_ratio + self.interpolate_nearest = interpolate_nearest + + def get_available_batch(self): + for b in self.buckets: + if len(self.buckets[b]) >= self.batch_size: + batch = self.buckets[b][:self.batch_size] + self.buckets[b] = self.buckets[b][self.batch_size:] + return batch + return None + + def get_closest_size(self, x): + w, h = x.size(-1), x.size(-2) + #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: + # best_size_idx = np.random.randint(len(self.ratios)) + #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio) + #else: + + best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) + find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} + min_ = find_dict[list(find_dict.keys())[0]] + find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] + for dd, val in find_dict.items(): + if val < min_: + min_ = val + find_size = self.sizes[dd][best_size_idx] + + return find_size + + def get_resize_size(self, orig_size, tgt_size): + if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: + alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) + resize_size = max(alt_min, min(tgt_size)) + else: + alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) + resize_size = max(alt_max, max(tgt_size)) + #print('in line 50', orig_size, tgt_size, resize_size) + return resize_size + + def __next__(self): + batch = self.get_available_batch() + while batch is None: + elements = next(self.iterator) + for dct in elements: + img = dct['images'] + size = self.get_closest_size(img) + resize_size = self.get_resize_size(img.shape[-2:], size) + #print('in line 74', img.size(), resize_size) + if self.interpolate_nearest: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) + else: + img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) + if self.crop_mode == 'center': + img = torchvision.transforms.functional.center_crop(img, size) + elif self.crop_mode == 'random': + img = torchvision.transforms.RandomCrop(size)(img) + elif self.crop_mode == 'smart': + self.smartcrop.output_size = size + img = self.smartcrop(img) + print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img)) + self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) + batch = self.get_available_batch() + + out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} + return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} diff --git a/core/data/deg_kair_utils/utils_alignfaces.py b/core/data/deg_kair_utils/utils_alignfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..fa74e8a2e8984f5075d0cbd06afd494c9661a015 --- /dev/null +++ b/core/data/deg_kair_utils/utils_alignfaces.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Apr 24 15:43:29 2017 +@author: zhaoy +""" +import cv2 +import numpy as np +from skimage import transform as trans + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [ + [30.29459953, 51.69630051], + [65.53179932, 51.50139999], + [48.02519989, 71.73660278], + [33.54930115, 92.3655014], + [62.72990036, 92.20410156] +] + +DEFAULT_CROP_SIZE = (96, 112) + + +def _umeyama(src, dst, estimate_scale=True, scale=1.0): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = dst_demean.T @ src_demean / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = U @ V + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = U @ np.diag(d) @ V + d[dim - 1] = s + else: + T[:dim, :dim] = U @ np.diag(d) @ V + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) + else: + scale = scale + + T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) + T[:dim, :dim] *= scale + + return T, scale + + +class FaceWarpException(Exception): + def __str__(self): + return 'In File {}:{}'.format( + __file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, + inner_padding_factor=0.0, + outer_padding=(0, 0), + default_square=False): + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if (output_size and + output_size[0] == tmp_crop_size[0] and + output_size[1] == tmp_crop_size[1]): + print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) + return tmp_5pts + + if (inner_padding_factor == 0 and + outer_padding == (0, 0)): + if output_size is None: + print('No paddings to do: return default reference points') + return tmp_5pts + else: + raise FaceWarpException( + 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) + and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + print(' deduced from paddings, output_size = ', output_size) + + if not (outer_padding[0] < output_size[0] + and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0]' + 'and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + # print('---> STEP1: pad the inner region according inner_padding_factor') + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 2) resize the padded inner region + # print('---> STEP2: resize the padded inner region') + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + # print(' crop_size = ', tmp_crop_size) + # print(' size_bf_outer_pad = ', size_bf_outer_pad) + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + # print(' resize scale_factor = ', scale_factor) + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + # print('---> STEP3: add outer_padding to make output_size') + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + # + # print('===> end get_reference_facial_points\n') + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([ + [A[0, 0], A[1, 0], A[2, 0]], + [A[0, 1], A[1, 1], A[2, 1]] + ]) + elif rank == 2: + tfm = np.float32([ + [A[0, 0], A[1, 0], 0], + [A[0, 1], A[1, 1], 0] + ]) + + return tfm + + +def warp_and_crop_face(src_img, + facial_pts, + reference_pts=None, + crop_size=(96, 112), + align_type='smilarity'): #smilarity cv2_affine affine + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, + inner_padding_factor, + outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException( + 'reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException( + 'facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException( + 'facial_pts and reference_pts must have the same shape') + + if align_type is 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) + elif align_type is 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) + else: + params, scale = _umeyama(src_pts, ref_pts) + tfm = params[:2, :] + + params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale) + tfm_inv = params[:2, :] + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3) + + return face_img, tfm_inv diff --git a/core/data/deg_kair_utils/utils_blindsr.py b/core/data/deg_kair_utils/utils_blindsr.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1a7baf99473043e216c16f464f4e168cbd94ab --- /dev/null +++ b/core/data/deg_kair_utils/utils_blindsr.py @@ -0,0 +1,631 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from core.data.deg_kair_utils import utils_image as util + +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth + + + + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf-1)*0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w-1) + y1 = np.clip(y1, 0, h-1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1,c,1,1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z-MU + ZZ_t = ZZ.transpose(0,1,3,2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) + arg = -(x*x + y*y)/(2*std*std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h/sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha,1])]) + h1 = alpha/(alpha+1) + h2 = (1-alpha)/(alpha+1) + h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1/sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + + ''' bicubic downsampling + blur + + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2*sf + if random.random() < 0.5: + l1 = wd2*random.random() + l2 = wd2*random.random() + k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5/sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2/255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3,3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2/255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3,3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10**(2*random.random()+2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h-lq_patchsize) + rnd_w = random.randint(0, w-lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize*sf or w < lq_patchsize*sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3])) + else: + img = util.imresize_np(img, 1/2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1,2*sf) + img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + + + +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize*sf or w < lq_patchsize*sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + + +if __name__ == '__main__': + img = util.imread_uint('utils/test.png', 3) + img = util.uint2single(img) + sf = 4 + + for i in range(20): + img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72) + print(i) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) + img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i)+'.png') + +# for i in range(10): +# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64) +# print(i) +# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) +# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) +# util.imsave(img_concat, str(i)+'.png') + +# run utils/utils_blindsr.py diff --git a/core/data/deg_kair_utils/utils_bnorm.py b/core/data/deg_kair_utils/utils_bnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd346e05b66efd074f81f1961068e2de45ac5da --- /dev/null +++ b/core/data/deg_kair_utils/utils_bnorm.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + + +""" +# -------------------------------------------- +# Batch Normalization +# -------------------------------------------- + +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# 01/Jan/2019 +# -------------------------------------------- +""" + + +# -------------------------------------------- +# remove/delete specified layer +# -------------------------------------------- +def deleteLayer(model, layer_type=nn.BatchNorm2d): + ''' Kai Zhang, 11/Jan/2019. + ''' + for k, m in list(model.named_children()): + if isinstance(m, layer_type): + del model._modules[k] + deleteLayer(m, layer_type) + + +# -------------------------------------------- +# merge bn, "conv+bn" --> "conv" +# -------------------------------------------- +def merge_bn(model): + ''' Kai Zhang, 11/Jan/2019. + merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv') + based on https://github.com/pytorch/pytorch/pull/901 + ''' + prev_m = None + for k, m in list(model.named_children()): + if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)): + + w = prev_m.weight.data + + if prev_m.bias is None: + zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type()) + prev_m.bias = nn.Parameter(zeros) + b = prev_m.bias.data + + invstd = m.running_var.clone().add_(m.eps).pow_(-0.5) + if isinstance(prev_m, nn.ConvTranspose2d): + w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w)) + else: + w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) + b.add_(-m.running_mean).mul_(invstd) + if m.affine: + if isinstance(prev_m, nn.ConvTranspose2d): + w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w)) + else: + w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) + b.mul_(m.weight.data).add_(m.bias.data) + + del model._modules[k] + prev_m = m + merge_bn(m) + + +# -------------------------------------------- +# add bn, "conv" --> "conv+bn" +# -------------------------------------------- +def add_bn(model): + ''' Kai Zhang, 11/Jan/2019. + ''' + for k, m in list(model.named_children()): + if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)): + b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True) + b.weight.data.fill_(1) + new_m = nn.Sequential(model._modules[k], b) + model._modules[k] = new_m + add_bn(m) + + +# -------------------------------------------- +# tidy model after removing bn +# -------------------------------------------- +def tidy_sequential(model): + ''' Kai Zhang, 11/Jan/2019. + ''' + for k, m in list(model.named_children()): + if isinstance(m, nn.Sequential): + if m.__len__() == 1: + model._modules[k] = m.__getitem__(0) + tidy_sequential(m) diff --git a/core/data/deg_kair_utils/utils_deblur.py b/core/data/deg_kair_utils/utils_deblur.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab5852d0cb334627abcd9476409d632740be389 --- /dev/null +++ b/core/data/deg_kair_utils/utils_deblur.py @@ -0,0 +1,655 @@ +# -*- coding: utf-8 -*- +import numpy as np +import scipy +from scipy import fftpack +import torch + +from math import cos, sin +from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round +from numpy.random import randn, rand +from scipy.signal import convolve2d +import cv2 +import random +# import utils_image as util + +''' +modified by Kai Zhang (github: https://github.com/cszn) +03/03/2019 +''' + + +def get_uperleft_denominator(img, kernel): + ''' + img: HxWxC + kernel: hxw + denominator: HxWx1 + upperleft: HxWxC + ''' + V = psf2otf(kernel, img.shape[:2]) + denominator = np.expand_dims(np.abs(V)**2, axis=2) + upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1]) + return upperleft, denominator + + +def get_uperleft_denominator_pytorch(img, kernel): + ''' + img: NxCxHxW + kernel: Nx1xhxw + denominator: Nx1xHxW + upperleft: NxCxHxWx2 + ''' + V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2 + denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW + upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2 + return upperleft, denominator + + +def c2c(x): + return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) + + +def r2c(x): + return torch.stack([x, torch.zeros_like(x)], -1) + + +def cdiv(x, y): + a, b = x[..., 0], x[..., 1] + c, d = y[..., 0], y[..., 1] + cd2 = c**2 + d**2 + return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) + + +def cabs(x): + return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) + + +def cmul(t1, t2): + ''' + complex multiplication + t1: NxCxHxWx2 + output: NxCxHxWx2 + ''' + real1, imag1 = t1[..., 0], t1[..., 1] + real2, imag2 = t2[..., 0], t2[..., 1] + return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) + + +def cconj(t, inplace=False): + ''' + # complex's conjugation + t: NxCxHxWx2 + output: NxCxHxWx2 + ''' + c = t.clone() if not inplace else t + c[..., 1] *= -1 + return c + + +def rfft(t): + return torch.rfft(t, 2, onesided=False) + + +def irfft(t): + return torch.irfft(t, 2, onesided=False) + + +def fft(t): + return torch.fft(t, 2) + + +def ifft(t): + return torch.ifft(t, 2) + + +def p2o(psf, shape): + ''' + # psf: NxCxhxw + # shape: [H,W] + # otf: NxCxHxWx2 + ''' + otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) + otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) + for axis, axis_size in enumerate(psf.shape[2:]): + otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) + otf = torch.rfft(otf, 2, onesided=False) + n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) + otf[...,1][torch.abs(otf[...,1])= abs(y)] = abs(x)[abs(x) >= abs(y)] + maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)] + minxy = np.zeros(x.shape) + minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)] + minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)] + m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\ + (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\ + np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2) + m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\ + (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\ + np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2) + h = None + return h + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) + arg = -(x*x + y*y)/(2*std*std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h/sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha,1])]) + h1 = alpha/(alpha+1) + h2 = (1-alpha)/(alpha+1) + h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial_log(hsize, sigma): + raise(NotImplemented) + + +def fspecial_motion(motion_len, theta): + raise(NotImplemented) + + +def fspecial_prewitt(): + return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]]) + + +def fspecial_sobel(): + return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'average': + return fspecial_average(*args, **kwargs) + if filter_type == 'disk': + return fspecial_disk(*args, **kwargs) + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + if filter_type == 'log': + return fspecial_log(*args, **kwargs) + if filter_type == 'motion': + return fspecial_motion(*args, **kwargs) + if filter_type == 'prewitt': + return fspecial_prewitt(*args, **kwargs) + if filter_type == 'sobel': + return fspecial_sobel(*args, **kwargs) + + +def fspecial_gauss(size, sigma): + x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1] + g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) + return g / g.sum() + + +def blurkernel_synthesis(h=37, w=None): + # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py + w = h if w is None else w + kdims = [h, w] + x = randomTrajectory(250) + k = None + while k is None: + k = kernelFromTrajectory(x) + + # center pad to kdims + pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2) + pad_width = [(pad_width[0],), (pad_width[1],)] + + if pad_width[0][0]<0 or pad_width[1][0]<0: + k = k[0:h, 0:h] + else: + k = pad(k, pad_width, "constant") + x1,x2 = k.shape + if np.random.randint(0, 4) == 1: + k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR) + y1, y2 = k.shape + k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2] + + if sum(k)<0.1: + k = fspecial_gaussian(h, 0.1+6*np.random.rand(1)) + k = k / sum(k) + # import matplotlib.pyplot as plt + # plt.imshow(k, interpolation="nearest", cmap="gray") + # plt.show() + return k + + +def kernelFromTrajectory(x): + h = 5 - log(rand()) / 0.15 + h = round(min([h, 27])).astype(int) + h = h + 1 - h % 2 + w = h + k = zeros((h, w)) + + xmin = min(x[0]) + xmax = max(x[0]) + ymin = min(x[1]) + ymax = max(x[1]) + xthr = arange(xmin, xmax, (xmax - xmin) / w) + ythr = arange(ymin, ymax, (ymax - ymin) / h) + + for i in range(1, xthr.size): + for j in range(1, ythr.size): + idx = ( + (x[0, :] >= xthr[i - 1]) + & (x[0, :] < xthr[i]) + & (x[1, :] >= ythr[j - 1]) + & (x[1, :] < ythr[j]) + ) + k[i - 1, j - 1] = sum(idx) + if sum(k) == 0: + return + k = k / sum(k) + k = convolve2d(k, fspecial_gauss(3, 1), "same") + k = k / sum(k) + return k + + +def randomTrajectory(T): + x = zeros((3, T)) + v = randn(3, T) + r = zeros((3, T)) + trv = 1 / 1 + trr = 2 * pi / T + for t in range(1, T): + F_rot = randn(3) / (t + 1) + r[:, t - 1] + F_trans = randn(3) / (t + 1) + r[:, t] = r[:, t - 1] + trr * F_rot + v[:, t] = v[:, t - 1] + trv * F_trans + st = v[:, t] + st = rot3D(st, r[:, t]) + x[:, t] = x[:, t - 1] + st + return x + + +def rot3D(x, r): + Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]]) + Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]]) + Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]]) + R = Rz @ Ry @ Rx + x = R @ x + return x + + +if __name__ == '__main__': + a = opt_fft_size([111]) + print(a) + + print(fspecial('gaussian', 5, 1)) + + print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape) + + k = blurkernel_synthesis(11) + import matplotlib.pyplot as plt + plt.imshow(k, interpolation="nearest", cmap="gray") + plt.show() diff --git a/core/data/deg_kair_utils/utils_dist.py b/core/data/deg_kair_utils/utils_dist.py new file mode 100644 index 0000000000000000000000000000000000000000..88811737a8fc7cb6e12d9226a9242dbf8391d86b --- /dev/null +++ b/core/data/deg_kair_utils/utils_dist.py @@ -0,0 +1,201 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +# ---------------------------------- +# init +# ---------------------------------- +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + + +# ---------------------------------- +# get rank and world_size +# ---------------------------------- +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def get_rank(): + if not dist.is_available(): + return 0 + + if not dist.is_initialized(): + return 0 + + return dist.get_rank() + + +def get_world_size(): + if not dist.is_available(): + return 1 + + if not dist.is_initialized(): + return 1 + + return dist.get_world_size() + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + + + + + +# ---------------------------------- +# operation across ranks +# ---------------------------------- +def reduce_sum(tensor): + if not dist.is_available(): + return tensor + + if not dist.is_initialized(): + return tensor + + tensor = tensor.clone() + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + return tensor + + +def gather_grad(params): + world_size = get_world_size() + + if world_size == 1: + return + + for param in params: + if param.grad is not None: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data.div_(world_size) + + +def all_gather(data): + world_size = get_world_size() + + if world_size == 1: + return [data] + + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to('cuda') + + local_size = torch.IntTensor([tensor.numel()]).to('cuda') + size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) + + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') + tensor = torch.cat((tensor, padding), 0) + + dist.all_gather(tensor_list, tensor) + + data_list = [] + + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_loss_dict(loss_dict): + world_size = get_world_size() + + if world_size < 2: + return loss_dict + + with torch.no_grad(): + keys = [] + losses = [] + + for k in sorted(loss_dict.keys()): + keys.append(k) + losses.append(loss_dict[k]) + + losses = torch.stack(losses, 0) + dist.reduce(losses, dst=0) + + if dist.get_rank() == 0: + losses /= world_size + + reduced_losses = {k: v for k, v in zip(keys, losses)} + + return reduced_losses + diff --git a/core/data/deg_kair_utils/utils_googledownload.py b/core/data/deg_kair_utils/utils_googledownload.py new file mode 100644 index 0000000000000000000000000000000000000000..25533d4e0d90bac7519874a654ffd833d16ae289 --- /dev/null +++ b/core/data/deg_kair_utils/utils_googledownload.py @@ -0,0 +1,93 @@ +import math +import requests +from tqdm import tqdm + + +''' +borrowed from +https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py +''' + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get( + URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int( + response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, + destination, + file_size=None, + chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' + f'/ {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +if __name__ == "__main__": + file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv' + save_path = 'BSRGAN.pth' + download_file_from_google_drive(file_id, save_path) diff --git a/core/data/deg_kair_utils/utils_image.py b/core/data/deg_kair_utils/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0e513a8bc1594c9ce2ba47ce3fe3b497269b7f16 --- /dev/null +++ b/core/data/deg_kair_utils/utils_image.py @@ -0,0 +1,1016 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +# import torchvision.transforms as transforms +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if isinstance(dataroot, str): + paths = sorted(_get_paths_from_images(dataroot)) + elif isinstance(dataroot, list): + paths = [] + for i in dataroot: + paths += sorted(_get_paths_from_images(i)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) + # print(w1) + # print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(uint) +# numpy(single) <---> tensor +# numpy(uint) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(uint) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(uint) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.uint8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + rlt = np.clip(rlt, 0, 255) + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR, SSIM and PSNRB +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def _blocking_effect_factor(im): + block_size = 8 + + block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8) + block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8) + + horizontal_block_difference = ( + (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum( + 3).sum(2).sum(1) + vertical_block_difference = ( + (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum( + 2).sum(1) + + nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions) + nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions) + + horizontal_nonblock_difference = ( + (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum( + 3).sum(2).sum(1) + vertical_nonblock_difference = ( + (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum( + 3).sum(2).sum(1) + + n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1) + n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1) + boundary_difference = (horizontal_block_difference + vertical_block_difference) / ( + n_boundary_horiz + n_boundary_vert) + + n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz + n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert + nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / ( + n_nonboundary_horiz + n_nonboundary_vert) + + scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]])) + bef = scaler * (boundary_difference - nonboundary_difference) + + bef[boundary_difference <= nonboundary_difference] = 0 + return bef + + +def calculate_psnrb(img1, img2, border=0): + """Calculate PSNR-B (Peak Signal-to-Noise Ratio). + Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation + # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + Returns: + float: psnr result. + """ + + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + + if img1.ndim == 2: + img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2) + + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py + img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255. + img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255. + + total = 0 + for c in range(img1.shape[1]): + mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none') + bef = _blocking_effect_factor(img1[:, c:c + 1, :, :]) + + mse = mse.view(mse.shape[0], -1).mean(1) + total += 10 * torch.log10(1 / (mse + bef)) + + return float(total) / img1.shape[1] + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) +# imshow(single2uint(img_bicubic)) +# +# img_tensor = single2tensor4(img) +# for i in range(8): +# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1)) + +# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200) +# imssave(patches,'a.png') + + + + + + + diff --git a/core/data/deg_kair_utils/utils_lmdb.py b/core/data/deg_kair_utils/utils_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..75192c346bb9c0b96f8b09635ed548bd6e797d89 --- /dev/null +++ b/core/data/deg_kair_utils/utils_lmdb.py @@ -0,0 +1,205 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + # deal with `libpng error: Read Error` + if img is None: + print(f'To deal with `libpng error: Read Error`, use PIL to load {path}') + from PIL import Image + import numpy as np + img = Image.open(path) + img = np.asanyarray(img) + img = img[:, :, [2, 1, 0]] + + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/core/data/deg_kair_utils/utils_logger.py b/core/data/deg_kair_utils/utils_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..3067190e1b09b244814e0ccc4496b18f06e22b54 --- /dev/null +++ b/core/data/deg_kair_utils/utils_logger.py @@ -0,0 +1,66 @@ +import sys +import datetime +import logging + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +def log(*args, **kwargs): + print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) + + +''' +# -------------------------------------------- +# logger +# -------------------------------------------- +''' + + +def logger_info(logger_name, log_path='default_logger.log'): + ''' set up logger + modified by Kai Zhang (github: https://github.com/cszn) + ''' + log = logging.getLogger(logger_name) + if log.hasHandlers(): + print('LogHandlers exist!') + else: + print('LogHandlers setup!') + level = logging.INFO + formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') + fh = logging.FileHandler(log_path, mode='a') + fh.setFormatter(formatter) + log.setLevel(level) + log.addHandler(fh) + # print(len(log.handlers)) + + sh = logging.StreamHandler() + sh.setFormatter(formatter) + log.addHandler(sh) + + +''' +# -------------------------------------------- +# print to file and std_out simultaneously +# -------------------------------------------- +''' + + +class logger_print(object): + def __init__(self, log_path="default.log"): + self.terminal = sys.stdout + self.log = open(log_path, 'a') + + def write(self, message): + self.terminal.write(message) + self.log.write(message) # write the message + + def flush(self): + pass diff --git a/core/data/deg_kair_utils/utils_mat.py b/core/data/deg_kair_utils/utils_mat.py new file mode 100644 index 0000000000000000000000000000000000000000..cd25d500c0eae77a3b815b8e956205b737ee43d4 --- /dev/null +++ b/core/data/deg_kair_utils/utils_mat.py @@ -0,0 +1,88 @@ +import os +import json +import scipy.io as spio +import pandas as pd + + +def loadmat(filename): + ''' + this function should be called instead of direct spio.loadmat + as it cures the problem of not properly recovering python dictionaries + from mat files. It calls the function check keys to cure all entries + which are still mat-objects + ''' + data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True) + return dict_to_nonedict(_check_keys(data)) + +def _check_keys(dict): + ''' + checks if entries in dictionary are mat-objects. If yes + todict is called to change them to nested dictionaries + ''' + for key in dict: + if isinstance(dict[key], spio.matlab.mio5_params.mat_struct): + dict[key] = _todict(dict[key]) + return dict + +def _todict(matobj): + ''' + A recursive function which constructs from matobjects nested dictionaries + ''' + dict = {} + for strg in matobj._fieldnames: + elem = matobj.__dict__[strg] + if isinstance(elem, spio.matlab.mio5_params.mat_struct): + dict[strg] = _todict(elem) + else: + dict[strg] = elem + return dict + + +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +def mat2json(mat_path=None, filepath = None): + """ + Converts .mat file to .json and writes new file + Parameters + ---------- + mat_path: Str + path/filename .mat存放路径 + filepath: Str + 如果需è¦ä¿å­˜æˆjson, 添加这一路径. å¦åˆ™ä¸ä¿å­˜ + Returns + 返回转化的字典 + ------- + None + Examples + -------- + >>> mat2json(blah blah) + """ + + matlabFile = loadmat(mat_path) + #pop all those dumb fields that don't let you jsonize file + matlabFile.pop('__header__') + matlabFile.pop('__version__') + matlabFile.pop('__globals__') + #jsonize the file - orientation is 'index' + matlabFile = pd.Series(matlabFile).to_json() + + if filepath: + json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json' + with open(json_path, 'w') as f: + f.write(matlabFile) + return matlabFile \ No newline at end of file diff --git a/core/data/deg_kair_utils/utils_matconvnet.py b/core/data/deg_kair_utils/utils_matconvnet.py new file mode 100644 index 0000000000000000000000000000000000000000..506dc47805ae07976022b236ca64c98e9a6f78b3 --- /dev/null +++ b/core/data/deg_kair_utils/utils_matconvnet.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +from collections import OrderedDict + +# import scipy.io as io +import hdf5storage + +""" +# -------------------------------------------- +# Convert matconvnet SimpleNN model into pytorch model +# -------------------------------------------- +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# 28/Nov/2019 +# -------------------------------------------- +""" + + +def weights2tensor(x, squeeze=False, in_features=None, out_features=None): + """Modified version of https://github.com/albanie/pytorch-mcn + Adjust memory layout and load weights as torch tensor + Args: + x (ndaray): a numpy array, corresponding to a set of network weights + stored in column major order + squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove + singletons from the trailing dimensions. So after converting to + pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1) + it will be reshaped to a matrix with shape (A,B). + in_features (int :: None): used to reshape weights for a linear block. + out_features (int :: None): used to reshape weights for a linear block. + Returns: + torch.tensor: a permuted sets of weights, matching the pytorch layout + convention + """ + if x.ndim == 4: + x = x.transpose((3, 2, 0, 1)) +# for FFDNet, pixel-shuffle layer +# if x.shape[1]==13: +# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:] +# if x.shape[0]==12: +# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] +# if x.shape[1]==5: +# x=x[:,[0,2,1,3, 4],:,:] +# if x.shape[0]==4: +# x=x[[0,2,1,3],:,:,:] +## for SRMD, pixel-shuffle layer +# if x.shape[0]==12: +# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] +# if x.shape[0]==27: +# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:] +# if x.shape[0]==48: +# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:] + + elif x.ndim == 3: # add by Kai + x = x[:,:,:,None] + x = x.transpose((3, 2, 0, 1)) + elif x.ndim == 2: + if x.shape[1] == 1: + x = x.flatten() + if squeeze: + if in_features and out_features: + x = x.reshape((out_features, in_features)) + x = np.squeeze(x) + return torch.from_numpy(np.ascontiguousarray(x)) + + +def save_model(network, save_path): + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + +if __name__ == '__main__': + + +# from utils import utils_logger +# import logging +# utils_logger.logger_info('a', 'a.log') +# logger = logging.getLogger('a') +# + # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat') + mcn = hdf5storage.loadmat('models/modelcolor.mat') + + + #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0]) + + mat_net = OrderedDict() + for idx in range(25): + mat_net[str(idx)] = OrderedDict() + count = -1 + + print(idx) + for i in range(13): + + if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv': + + count += 1 + w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0] + # print(w.shape) + w = weights2tensor(w) + # print(w.shape) + + b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1] + b = weights2tensor(b) + print(b.shape) + + mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w + mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b + + torch.save(mat_net, 'model_zoo/modelcolor.pth') + + + +# from models.network_dncnn import IRCNN as net +# network = net(in_nc=3, out_nc=3, nc=64) +# state_dict = network.state_dict() +# +# #show_kv(state_dict) +# +# for i in range(len(mcn['net'][0][0][0])): +# print(mcn['net'][0][0][0][i][0][0][0][0]) +# +# count = -1 +# mat_net = OrderedDict() +# for i in range(len(mcn['net'][0][0][0])): +# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv': +# +# count += 1 +# w = mcn['net'][0][0][0][i][0][1][0][0] +# print(w.shape) +# w = weights2tensor(w) +# print(w.shape) +# +# b = mcn['net'][0][0][0][i][0][1][0][1] +# b = weights2tensor(b) +# print(b.shape) +# +# mat_net['model.{:d}.weight'.format(count*2)] = w +# mat_net['model.{:d}.bias'.format(count*2)] = b +# +# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth') +# +# +# +# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth') +# def show_kv(net): +# for k, v in net.items(): +# print(k) +# +# show_kv(crt_net) + + +# from models.network_dncnn import DnCNN as net +# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') + +# from models.network_srmd import SRMD as net +# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') +# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle') +# +# from models.network_rrdb import RRDB as net +# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') +# +# state_dict = network.state_dict() +# for key, param in state_dict.items(): +# print(key) +# from models.network_imdn import IMDN as net +# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') +# state_dict = network.state_dict() +# mat_net = OrderedDict() +# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()): +# mat_net[key] = param2 +# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') +# + +# net_old = torch.load('net_old.pth') +# def show_kv(net): +# for k, v in net.items(): +# print(k) +# +# show_kv(net_old) +# from models.network_dpsr import MSRResNet_prior as net +# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle') +# state_dict = network.state_dict() +# net_new = OrderedDict() +# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()): +# net_new[key] = param_old +# torch.save(net_new, 'net_new.pth') + + + # print(key) + # print(param.size()) + + + + # run utils/utils_matconvnet.py diff --git a/core/data/deg_kair_utils/utils_model.py b/core/data/deg_kair_utils/utils_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d9e6ac651784c7ed36e623c3a6175883123c2b --- /dev/null +++ b/core/data/deg_kair_utils/utils_model.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +from utils import utils_image as util +import re +import glob +import os + + +''' +# -------------------------------------------- +# Model +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +''' + + +def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): + """ + # --------------------------------------- + # Kai Zhang (github: https://github.com/cszn) + # 03/Mar/2019 + # --------------------------------------- + Args: + save_dir: model folder + net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' + pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path + + Return: + init_iter: iteration number + init_path: model path + # --------------------------------------- + """ + + file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) + if file_list: + iter_exist = [] + for file_ in file_list: + iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) + iter_exist.append(int(iter_current[0])) + init_iter = max(iter_exist) + init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) + else: + init_iter = 0 + init_path = pretrained_path + return init_iter, init_path + + +def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): + ''' + # --------------------------------------- + # Kai Zhang (github: https://github.com/cszn) + # 03/Mar/2019 + # --------------------------------------- + Args: + model: trained model + L: input Low-quality image + mode: + (0) normal: test(model, L) + (1) pad: test_pad(model, L, modulo=16) + (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) + (3) x8: test_x8(model, L, modulo=1) ^_^ + (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) + refield: effective receptive filed of the network, 32 is enough + useful when split, i.e., mode=2, 4 + min_size: min_sizeXmin_size image, e.g., 256X256 image + useful when split, i.e., mode=2, 4 + sf: scale factor for super-resolution, otherwise 1 + modulo: 1 if split + useful when pad, i.e., mode=1 + + Returns: + E: estimated image + # --------------------------------------- + ''' + if mode == 0: + E = test(model, L) + elif mode == 1: + E = test_pad(model, L, modulo, sf) + elif mode == 2: + E = test_split(model, L, refield, min_size, sf, modulo) + elif mode == 3: + E = test_x8(model, L, modulo, sf) + elif mode == 4: + E = test_split_x8(model, L, refield, min_size, sf, modulo) + return E + + +''' +# -------------------------------------------- +# normal (0) +# -------------------------------------------- +''' + + +def test(model, L): + E = model(L) + return E + + +''' +# -------------------------------------------- +# pad (1) +# -------------------------------------------- +''' + + +def test_pad(model, L, modulo=16, sf=1): + h, w = L.size()[-2:] + paddingBottom = int(np.ceil(h/modulo)*modulo-h) + paddingRight = int(np.ceil(w/modulo)*modulo-w) + L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) + E = model(L) + E = E[..., :h*sf, :w*sf] + return E + + +''' +# -------------------------------------------- +# split (function) +# -------------------------------------------- +''' + + +def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): + """ + Args: + model: trained model + L: input Low-quality image + refield: effective receptive filed of the network, 32 is enough + min_size: min_sizeXmin_size image, e.g., 256X256 image + sf: scale factor for super-resolution, otherwise 1 + modulo: 1 if split + + Returns: + E: estimated result + """ + h, w = L.size()[-2:] + if h*w <= min_size**2: + L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) + E = model(L) + E = E[..., :h*sf, :w*sf] + else: + top = slice(0, (h//2//refield+1)*refield) + bottom = slice(h - (h//2//refield+1)*refield, h) + left = slice(0, (w//2//refield+1)*refield) + right = slice(w - (w//2//refield+1)*refield, w) + Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] + + if h * w <= 4*(min_size**2): + Es = [model(Ls[i]) for i in range(4)] + else: + Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] + + b, c = Es[0].size()[:2] + E = torch.zeros(b, c, sf * h, sf * w).type_as(L) + + E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] + E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] + E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] + E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] + return E + + +''' +# -------------------------------------------- +# split (2) +# -------------------------------------------- +''' + + +def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): + E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) + return E + + +''' +# -------------------------------------------- +# x8 (3) +# -------------------------------------------- +''' + + +def test_x8(model, L, modulo=1, sf=1): + E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)] + for i in range(len(E_list)): + if i == 3 or i == 5: + E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i) + else: + E_list[i] = util.augment_img_tensor4(E_list[i], mode=i) + output_cat = torch.stack(E_list, dim=0) + E = output_cat.mean(dim=0, keepdim=False) + return E + + +''' +# -------------------------------------------- +# split and x8 (4) +# -------------------------------------------- +''' + + +def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): + E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] + for k, i in enumerate(range(len(E_list))): + if i==3 or i==5: + E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i) + else: + E_list[k] = util.augment_img_tensor4(E_list[k], mode=i) + output_cat = torch.stack(E_list, dim=0) + E = output_cat.mean(dim=0, keepdim=False) + return E + + +''' +# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- +# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ +# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- +''' + + +''' +# -------------------------------------------- +# print +# -------------------------------------------- +''' + + +# -------------------------------------------- +# print model +# -------------------------------------------- +def print_model(model): + msg = describe_model(model) + print(msg) + + +# -------------------------------------------- +# print params +# -------------------------------------------- +def print_params(model): + msg = describe_params(model) + print(msg) + + +''' +# -------------------------------------------- +# information +# -------------------------------------------- +''' + + +# -------------------------------------------- +# model inforation +# -------------------------------------------- +def info_model(model): + msg = describe_model(model) + return msg + + +# -------------------------------------------- +# params inforation +# -------------------------------------------- +def info_params(model): + msg = describe_params(model) + return msg + + +''' +# -------------------------------------------- +# description +# -------------------------------------------- +''' + + +# -------------------------------------------- +# model name and total number of parameters +# -------------------------------------------- +def describe_model(model): + if isinstance(model, torch.nn.DataParallel): + model = model.module + msg = '\n' + msg += 'models name: {}'.format(model.__class__.__name__) + '\n' + msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n' + msg += 'Net structure:\n{}'.format(str(model)) + '\n' + return msg + + +# -------------------------------------------- +# parameters description +# -------------------------------------------- +def describe_params(model): + if isinstance(model, torch.nn.DataParallel): + model = model.module + msg = '\n' + msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' + for name, param in model.state_dict().items(): + if not 'num_batches_tracked' in name: + v = param.data.clone().float() + msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' + return msg + + +if __name__ == '__main__': + + class Net(torch.nn.Module): + def __init__(self, in_channels=3, out_channels=3): + super(Net, self).__init__() + self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) + + def forward(self, x): + x = self.conv(x) + return x + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + model = Net() + model = model.eval() + print_model(model) + print_params(model) + x = torch.randn((2,3,401,401)) + torch.cuda.empty_cache() + with torch.no_grad(): + for mode in range(5): + y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1) + print(y.shape) + + # run utils/utils_model.py diff --git a/core/data/deg_kair_utils/utils_modelsummary.py b/core/data/deg_kair_utils/utils_modelsummary.py new file mode 100644 index 0000000000000000000000000000000000000000..5e040e31d8ddffbb8b7b2e2dc4ddf0b9cdca6a23 --- /dev/null +++ b/core/data/deg_kair_utils/utils_modelsummary.py @@ -0,0 +1,485 @@ +import torch.nn as nn +import torch +import numpy as np + +''' +---- 1) FLOPs: floating point operations +---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs +---- 3) #Conv2d: the number of ‘Conv2d’ layers +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 21/July/2020 +# -------------------------------------------- +# Reference +https://github.com/sovrasov/flops-counter.pytorch.git + +# If you use this code, please consider the following citation: + +@inproceedings{zhang2020aim, % + title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results}, + author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others}, + booktitle={European Conference on Computer Vision Workshops}, + year={2020} +} +# -------------------------------------------- +''' + +def get_model_flops(model, input_res, print_per_layer_stat=True, + input_constructor=None): + assert type(input_res) is tuple, 'Please provide the size of the input image.' + assert len(input_res) >= 3, 'Input image should have 3 dimensions.' + flops_model = add_flops_counting_methods(model) + flops_model.eval().start_flops_count() + if input_constructor: + input = input_constructor(input_res) + _ = flops_model(**input) + else: + device = list(flops_model.parameters())[-1].device + batch = torch.FloatTensor(1, *input_res).to(device) + _ = flops_model(batch) + + if print_per_layer_stat: + print_model_with_flops(flops_model) + flops_count = flops_model.compute_average_flops_cost() + flops_model.stop_flops_count() + + return flops_count + +def get_model_activation(model, input_res, input_constructor=None): + assert type(input_res) is tuple, 'Please provide the size of the input image.' + assert len(input_res) >= 3, 'Input image should have 3 dimensions.' + activation_model = add_activation_counting_methods(model) + activation_model.eval().start_activation_count() + if input_constructor: + input = input_constructor(input_res) + _ = activation_model(**input) + else: + device = list(activation_model.parameters())[-1].device + batch = torch.FloatTensor(1, *input_res).to(device) + _ = activation_model(batch) + + activation_count, num_conv = activation_model.compute_average_activation_cost() + activation_model.stop_activation_count() + + return activation_count, num_conv + + +def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True, + input_constructor=None): + assert type(input_res) is tuple + assert len(input_res) >= 3 + flops_model = add_flops_counting_methods(model) + flops_model.eval().start_flops_count() + if input_constructor: + input = input_constructor(input_res) + _ = flops_model(**input) + else: + batch = torch.FloatTensor(1, *input_res) + _ = flops_model(batch) + + if print_per_layer_stat: + print_model_with_flops(flops_model) + flops_count = flops_model.compute_average_flops_cost() + params_count = get_model_parameters_number(flops_model) + flops_model.stop_flops_count() + + if as_strings: + return flops_to_string(flops_count), params_to_string(params_count) + + return flops_count, params_count + + +def flops_to_string(flops, units='GMac', precision=2): + if units is None: + if flops // 10**9 > 0: + return str(round(flops / 10.**9, precision)) + ' GMac' + elif flops // 10**6 > 0: + return str(round(flops / 10.**6, precision)) + ' MMac' + elif flops // 10**3 > 0: + return str(round(flops / 10.**3, precision)) + ' KMac' + else: + return str(flops) + ' Mac' + else: + if units == 'GMac': + return str(round(flops / 10.**9, precision)) + ' ' + units + elif units == 'MMac': + return str(round(flops / 10.**6, precision)) + ' ' + units + elif units == 'KMac': + return str(round(flops / 10.**3, precision)) + ' ' + units + else: + return str(flops) + ' Mac' + + +def params_to_string(params_num): + if params_num // 10 ** 6 > 0: + return str(round(params_num / 10 ** 6, 2)) + ' M' + elif params_num // 10 ** 3: + return str(round(params_num / 10 ** 3, 2)) + ' k' + else: + return str(params_num) + + +def print_model_with_flops(model, units='GMac', precision=3): + total_flops = model.compute_average_flops_cost() + + def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + def flops_repr(self): + accumulated_flops_cost = self.accumulate_flops() + return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), + '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), + self.original_extra_repr()]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(model) + model.apply(del_extra_repr) + + +def get_model_parameters_number(model): + params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return params_num + + +def add_flops_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + # embed() + net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) + net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) + net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) + net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) + + net_main_module.reset_flops_count() + return net_main_module + + +def compute_average_flops_cost(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Returns current mean flops consumption per image. + + """ + + flops_sum = 0 + for module in self.modules(): + if is_supported_instance(module): + flops_sum += module.__flops__ + + return flops_sum + + +def start_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Activates the computation of mean flops consumption per image. + Call it before you run the network. + + """ + self.apply(add_flops_counter_hook_function) + + +def stop_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Stops computing the mean flops consumption per image. + Call whenever you want to pause the computation. + + """ + self.apply(remove_flops_counter_hook_function) + + +def reset_flops_count(self): + """ + A method that will be available after add_flops_counting_methods() is called + on a desired net object. + + Resets statistics computed so far. + + """ + self.apply(add_flops_counter_variable_or_reset) + + +def add_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + return + + if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): + handle = module.register_forward_hook(conv_flops_counter_hook) + elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)): + handle = module.register_forward_hook(relu_flops_counter_hook) + elif isinstance(module, nn.Linear): + handle = module.register_forward_hook(linear_flops_counter_hook) + elif isinstance(module, (nn.BatchNorm2d)): + handle = module.register_forward_hook(bn_flops_counter_hook) + else: + handle = module.register_forward_hook(empty_flops_counter_hook) + module.__flops_handle__ = handle + + +def remove_flops_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_handle__'): + module.__flops_handle__.remove() + del module.__flops_handle__ + + +def add_flops_counter_variable_or_reset(module): + if is_supported_instance(module): + module.__flops__ = 0 + + +# ---- Internal functions +def is_supported_instance(module): + if isinstance(module, + ( + nn.Conv2d, nn.ConvTranspose2d, + nn.BatchNorm2d, + nn.Linear, + nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, + )): + return True + + return False + + +def conv_flops_counter_hook(conv_module, input, output): + # Can have multiple inputs, getting the first one + # input = input[0] + + batch_size = output.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(conv_module.kernel_size) + in_channels = conv_module.in_channels + out_channels = conv_module.out_channels + groups = conv_module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel + + active_elements_count = batch_size * np.prod(output_dims) + overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count) + + # overall_flops = overall_conv_flops + + conv_module.__flops__ += int(overall_conv_flops) + # conv_module.__output_dims__ = output_dims + + +def relu_flops_counter_hook(module, input, output): + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + # print(module.__flops__, id(module)) + # print(module) + + +def linear_flops_counter_hook(module, input, output): + input = input[0] + if len(input.shape) == 1: + batch_size = 1 + module.__flops__ += int(batch_size * input.shape[0] * output.shape[0]) + else: + batch_size = input.shape[0] + module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) + + +def bn_flops_counter_hook(module, input, output): + # input = input[0] + # TODO: need to check here + # batch_flops = np.prod(input.shape) + # if module.affine: + # batch_flops *= 2 + # module.__flops__ += int(batch_flops) + batch = output.shape[0] + output_dims = output.shape[2:] + channels = module.num_features + batch_flops = batch * channels * np.prod(output_dims) + if module.affine: + batch_flops *= 2 + module.__flops__ += int(batch_flops) + + +# ---- Count the number of convolutional layers and the activation +def add_activation_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + # embed() + net_main_module.start_activation_count = start_activation_count.__get__(net_main_module) + net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module) + net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module) + net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module) + + net_main_module.reset_activation_count() + return net_main_module + + +def compute_average_activation_cost(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Returns current mean activation consumption per image. + + """ + + activation_sum = 0 + num_conv = 0 + for module in self.modules(): + if is_supported_instance_for_activation(module): + activation_sum += module.__activation__ + num_conv += module.__num_conv__ + return activation_sum, num_conv + + +def start_activation_count(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Activates the computation of mean activation consumption per image. + Call it before you run the network. + + """ + self.apply(add_activation_counter_hook_function) + + +def stop_activation_count(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Stops computing the mean activation consumption per image. + Call whenever you want to pause the computation. + + """ + self.apply(remove_activation_counter_hook_function) + + +def reset_activation_count(self): + """ + A method that will be available after add_activation_counting_methods() is called + on a desired net object. + + Resets statistics computed so far. + + """ + self.apply(add_activation_counter_variable_or_reset) + + +def add_activation_counter_hook_function(module): + if is_supported_instance_for_activation(module): + if hasattr(module, '__activation_handle__'): + return + + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): + handle = module.register_forward_hook(conv_activation_counter_hook) + module.__activation_handle__ = handle + + +def remove_activation_counter_hook_function(module): + if is_supported_instance_for_activation(module): + if hasattr(module, '__activation_handle__'): + module.__activation_handle__.remove() + del module.__activation_handle__ + + +def add_activation_counter_variable_or_reset(module): + if is_supported_instance_for_activation(module): + module.__activation__ = 0 + module.__num_conv__ = 0 + + +def is_supported_instance_for_activation(module): + if isinstance(module, + ( + nn.Conv2d, nn.ConvTranspose2d, + )): + return True + + return False + +def conv_activation_counter_hook(module, input, output): + """ + Calculate the activations in the convolutional operation. + Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces. + :param module: + :param input: + :param output: + :return: + """ + module.__activation__ += output.numel() + module.__num_conv__ += 1 + + +def empty_flops_counter_hook(module, input, output): + module.__flops__ += 0 + + +def upsample_flops_counter_hook(module, input, output): + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + + +def pool_flops_counter_hook(module, input, output): + input = input[0] + module.__flops__ += int(np.prod(input.shape)) + + +def dconv_flops_counter_hook(dconv_module, input, output): + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape + out_channels, _, kernel_dim2, _, = dconv_module.projection.shape + # groups = dconv_module.groups + + # filters_per_channel = out_channels // groups + conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels + conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels + active_elements_count = batch_size * np.prod(output_dims) + + overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count + overall_flops = overall_conv_flops + + dconv_module.__flops__ += int(overall_flops) + # dconv_module.__output_dims__ = output_dims + + + + + diff --git a/core/data/deg_kair_utils/utils_option.py b/core/data/deg_kair_utils/utils_option.py new file mode 100644 index 0000000000000000000000000000000000000000..cf096210e2d8ea553b06a91ac5cdaa21127d837c --- /dev/null +++ b/core/data/deg_kair_utils/utils_option.py @@ -0,0 +1,255 @@ +import os +from collections import OrderedDict +from datetime import datetime +import json +import re +import glob + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +def get_timestamp(): + return datetime.now().strftime('_%y%m%d_%H%M%S') + + +def parse(opt_path, is_train=True): + + # ---------------------------------------- + # remove comments starting with '//' + # ---------------------------------------- + json_str = '' + with open(opt_path, 'r') as f: + for line in f: + line = line.split('//')[0] + '\n' + json_str += line + + # ---------------------------------------- + # initialize opt + # ---------------------------------------- + opt = json.loads(json_str, object_pairs_hook=OrderedDict) + + opt['opt_path'] = opt_path + opt['is_train'] = is_train + + # ---------------------------------------- + # set default + # ---------------------------------------- + if 'merge_bn' not in opt: + opt['merge_bn'] = False + opt['merge_bn_startpoint'] = -1 + + if 'scale' not in opt: + opt['scale'] = 1 + + # ---------------------------------------- + # datasets + # ---------------------------------------- + for phase, dataset in opt['datasets'].items(): + phase = phase.split('_')[0] + dataset['phase'] = phase + dataset['scale'] = opt['scale'] # broadcast + dataset['n_channels'] = opt['n_channels'] # broadcast + if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None: + dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H']) + if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None: + dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L']) + + # ---------------------------------------- + # path + # ---------------------------------------- + for key, path in opt['path'].items(): + if path and key in opt['path']: + opt['path'][key] = os.path.expanduser(path) + + path_task = os.path.join(opt['path']['root'], opt['task']) + opt['path']['task'] = path_task + opt['path']['log'] = path_task + opt['path']['options'] = os.path.join(path_task, 'options') + + if is_train: + opt['path']['models'] = os.path.join(path_task, 'models') + opt['path']['images'] = os.path.join(path_task, 'images') + else: # test + opt['path']['images'] = os.path.join(path_task, 'test_images') + + # ---------------------------------------- + # network + # ---------------------------------------- + opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1 + + # ---------------------------------------- + # GPU devices + # ---------------------------------------- + gpu_list = ','.join(str(x) for x in opt['gpu_ids']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list + print('export CUDA_VISIBLE_DEVICES=' + gpu_list) + + # ---------------------------------------- + # default setting for distributeddataparallel + # ---------------------------------------- + if 'find_unused_parameters' not in opt: + opt['find_unused_parameters'] = True + if 'use_static_graph' not in opt: + opt['use_static_graph'] = False + if 'dist' not in opt: + opt['dist'] = False + opt['num_gpu'] = len(opt['gpu_ids']) + print('number of GPUs is: ' + str(opt['num_gpu'])) + + # ---------------------------------------- + # default setting for perceptual loss + # ---------------------------------------- + if 'F_feature_layer' not in opt['train']: + opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34] + if 'F_weights' not in opt['train']: + opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0] + if 'F_lossfn_type' not in opt['train']: + opt['train']['F_lossfn_type'] = 'l1' + if 'F_use_input_norm' not in opt['train']: + opt['train']['F_use_input_norm'] = True + if 'F_use_range_norm' not in opt['train']: + opt['train']['F_use_range_norm'] = False + + # ---------------------------------------- + # default setting for optimizer + # ---------------------------------------- + if 'G_optimizer_type' not in opt['train']: + opt['train']['G_optimizer_type'] = "adam" + if 'G_optimizer_betas' not in opt['train']: + opt['train']['G_optimizer_betas'] = [0.9,0.999] + if 'G_scheduler_restart_weights' not in opt['train']: + opt['train']['G_scheduler_restart_weights'] = 1 + if 'G_optimizer_wd' not in opt['train']: + opt['train']['G_optimizer_wd'] = 0 + if 'G_optimizer_reuse' not in opt['train']: + opt['train']['G_optimizer_reuse'] = False + if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']: + opt['train']['D_optimizer_reuse'] = False + + # ---------------------------------------- + # default setting of strict for model loading + # ---------------------------------------- + if 'G_param_strict' not in opt['train']: + opt['train']['G_param_strict'] = True + if 'netD' in opt and 'D_param_strict' not in opt['path']: + opt['train']['D_param_strict'] = True + if 'E_param_strict' not in opt['path']: + opt['train']['E_param_strict'] = True + + # ---------------------------------------- + # Exponential Moving Average + # ---------------------------------------- + if 'E_decay' not in opt['train']: + opt['train']['E_decay'] = 0 + + # ---------------------------------------- + # default setting for discriminator + # ---------------------------------------- + if 'netD' in opt: + if 'net_type' not in opt['netD']: + opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet + if 'in_nc' not in opt['netD']: + opt['netD']['in_nc'] = 3 + if 'base_nc' not in opt['netD']: + opt['netD']['base_nc'] = 64 + if 'n_layers' not in opt['netD']: + opt['netD']['n_layers'] = 3 + if 'norm_type' not in opt['netD']: + opt['netD']['norm_type'] = 'spectral' + + + return opt + + +def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): + """ + Args: + save_dir: model folder + net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' + pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path + + Return: + init_iter: iteration number + init_path: model path + """ + file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) + if file_list: + iter_exist = [] + for file_ in file_list: + iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) + iter_exist.append(int(iter_current[0])) + init_iter = max(iter_exist) + init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) + else: + init_iter = 0 + init_path = pretrained_path + return init_iter, init_path + + +''' +# -------------------------------------------- +# convert the opt into json file +# -------------------------------------------- +''' + + +def save(opt): + opt_path = opt['opt_path'] + opt_path_copy = opt['path']['options'] + dirname, filename_ext = os.path.split(opt_path) + filename, ext = os.path.splitext(filename_ext) + dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext) + with open(dump_path, 'w') as dump_file: + json.dump(opt, dump_file, indent=2) + + +''' +# -------------------------------------------- +# dict to string for logger +# -------------------------------------------- +''' + + +def dict2str(opt, indent_l=1): + msg = '' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_l * 2) + k + ':[\n' + msg += dict2str(v, indent_l + 1) + msg += ' ' * (indent_l * 2) + ']\n' + else: + msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' + return msg + + +''' +# -------------------------------------------- +# convert OrderedDict to NoneDict, +# return None for missing key +# -------------------------------------------- +''' + + +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +class NoneDict(dict): + def __missing__(self, key): + return None diff --git a/core/data/deg_kair_utils/utils_params.py b/core/data/deg_kair_utils/utils_params.py new file mode 100644 index 0000000000000000000000000000000000000000..def1cb79e11472b9b8ebbaae4bd83e7216af2ccb --- /dev/null +++ b/core/data/deg_kair_utils/utils_params.py @@ -0,0 +1,135 @@ +import torch + +import torchvision + +from models import basicblock as B + +def show_kv(net): + for k, v in net.items(): + print(k) + +# should run train debug mode first to get an initial model +#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth') +# +#for k, v in crt_net.items(): +# print(k) +#for k, v in crt_net.items(): +# if k in pretrained_net: +# crt_net[k] = pretrained_net[k] +# print('replace ... ', k) + +# x2 -> x4 +#crt_net['model.5.weight'] = pretrained_net['model.2.weight'] +#crt_net['model.5.bias'] = pretrained_net['model.2.bias'] +#crt_net['model.8.weight'] = pretrained_net['model.5.weight'] +#crt_net['model.8.bias'] = pretrained_net['model.5.bias'] +#crt_net['model.10.weight'] = pretrained_net['model.7.weight'] +#crt_net['model.10.bias'] = pretrained_net['model.7.bias'] +#torch.save(crt_net, '../pretrained_tmp.pth') + +# x2 -> x3 +''' +in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3 +new_filter = torch.Tensor(576, 64, 3, 3) +new_filter[0:256, :, :, :] = in_filter +new_filter[256:512, :, :, :] = in_filter +new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :] +crt_net['model.2.weight'] = new_filter + +in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3 +new_bias = torch.Tensor(576) +new_bias[0:256] = in_bias +new_bias[256:512] = in_bias +new_bias[512:] = in_bias[0:576 - 512] +crt_net['model.2.bias'] = new_bias + +torch.save(crt_net, '../pretrained_tmp.pth') +''' + +# x2 -> x8 +''' +crt_net['model.5.weight'] = pretrained_net['model.2.weight'] +crt_net['model.5.bias'] = pretrained_net['model.2.bias'] +crt_net['model.8.weight'] = pretrained_net['model.2.weight'] +crt_net['model.8.bias'] = pretrained_net['model.2.bias'] +crt_net['model.11.weight'] = pretrained_net['model.5.weight'] +crt_net['model.11.bias'] = pretrained_net['model.5.bias'] +crt_net['model.13.weight'] = pretrained_net['model.7.weight'] +crt_net['model.13.bias'] = pretrained_net['model.7.bias'] +torch.save(crt_net, '../pretrained_tmp.pth') +''' + +# x3/4/8 RGB -> Y + +def rgb2gray_net(net, only_input=True): + + if only_input: + in_filter = net['0.weight'] + in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114 + in_new_filter.unsqueeze_(1) + net['0.weight'] = in_new_filter + +# out_filter = pretrained_net['model.13.weight'] +# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \ +# out_filter[2, :, :, :] * 0.114 +# out_new_filter.unsqueeze_(0) +# crt_net['model.13.weight'] = out_new_filter +# out_bias = pretrained_net['model.13.bias'] +# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114 +# out_new_bias = torch.Tensor(1).fill_(out_new_bias) +# crt_net['model.13.bias'] = out_new_bias + +# torch.save(crt_net, '../pretrained_tmp.pth') + + return net + + + +if __name__ == '__main__': + + net = torchvision.models.vgg19(pretrained=True) + for k,v in net.features.named_parameters(): + if k=='0.weight': + in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114 + in_new_filter.unsqueeze_(1) + v = in_new_filter + print(v.shape) + print(v[0,0,0,0]) + if k=='0.bias': + in_new_bias = v + print(v[0]) + + print(net.features[0]) + + net.features[0] = B.conv(1, 64, mode='C') + + print(net.features[0]) + net.features[0].weight.data=in_new_filter + net.features[0].bias.data=in_new_bias + + for k,v in net.features.named_parameters(): + if k=='0.weight': + print(v[0,0,0,0]) + if k=='0.bias': + print(v[0]) + + # transfer parameters of old model to new one + model_old = torch.load(model_path) + state_dict = model.state_dict() + for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()): + state_dict[key2] = param + print([key, key2]) + # print([param.size(), param2.size()]) + torch.save(state_dict, 'model_new.pth') + + + # rgb2gray_net(net) + + + + + + + + + diff --git a/core/data/deg_kair_utils/utils_receptivefield.py b/core/data/deg_kair_utils/utils_receptivefield.py new file mode 100644 index 0000000000000000000000000000000000000000..82ad613b9e744189e13b721a558dbc0f42c57b30 --- /dev/null +++ b/core/data/deg_kair_utils/utils_receptivefield.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# online calculation: https://fomoro.com/research/article/receptive-field-calculator# + +# [filter size, stride, padding] +#Assume the two dimensions are the same +#Each kernel requires the following parameters: +# - k_i: kernel size +# - s_i: stride +# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow) +# +#Each layer i requires the following parameters to be fully represented: +# - n_i: number of feature (data layer has n_1 = imagesize ) +# - j_i: distance (projected to image pixel distance) between center of two adjacent features +# - r_i: receptive field of a feature in layer i +# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding) + +import math + +def outFromIn(conv, layerIn): + n_in = layerIn[0] + j_in = layerIn[1] + r_in = layerIn[2] + start_in = layerIn[3] + k = conv[0] + s = conv[1] + p = conv[2] + + n_out = math.floor((n_in - k + 2*p)/s) + 1 + actualP = (n_out-1)*s - n_in + k + pR = math.ceil(actualP/2) + pL = math.floor(actualP/2) + + j_out = j_in * s + r_out = r_in + (k - 1)*j_in + start_out = start_in + ((k-1)/2 - pL)*j_in + return n_out, j_out, r_out, start_out + +def printLayer(layer, layer_name): + print(layer_name + ":") + print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3])) + + + +layerInfos = [] +if __name__ == '__main__': + + convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]] + layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12'] + imsize = 128 + + print ("-------Net summary------") + currentLayer = [imsize, 1, 1, 0.5] + printLayer(currentLayer, "input image") + for i in range(len(convnet)): + currentLayer = outFromIn(convnet[i], currentLayer) + layerInfos.append(currentLayer) + printLayer(currentLayer, layer_names[i]) + + +# run utils/utils_receptivefield.py + \ No newline at end of file diff --git a/core/data/deg_kair_utils/utils_regularizers.py b/core/data/deg_kair_utils/utils_regularizers.py new file mode 100644 index 0000000000000000000000000000000000000000..17e7c8524b716f36e10b41d72fee2e375af69454 --- /dev/null +++ b/core/data/deg_kair_utils/utils_regularizers.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +''' + + +# -------------------------------------------- +# SVD Orthogonal Regularization +# -------------------------------------------- +def regularizer_orth(m): + """ + # ---------------------------------------- + # SVD Orthogonal Regularization + # ---------------------------------------- + # Applies regularization to the training by performing the + # orthogonalization technique described in the paper + # This function is to be called by the torch.nn.Module.apply() method, + # which applies svd_orthogonalization() to every layer of the model. + # usage: net.apply(regularizer_orth) + # ---------------------------------------- + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + w = m.weight.data.clone() + c_out, c_in, f1, f2 = w.size() + # dtype = m.weight.data.type() + w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) + # self.netG.apply(svd_orthogonalization) + u, s, v = torch.svd(w) + s[s > 1.5] = s[s > 1.5] - 1e-4 + s[s < 0.5] = s[s < 0.5] + 1e-4 + w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) + m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) + else: + pass + + +# -------------------------------------------- +# SVD Orthogonal Regularization +# -------------------------------------------- +def regularizer_orth2(m): + """ + # ---------------------------------------- + # Applies regularization to the training by performing the + # orthogonalization technique described in the paper + # This function is to be called by the torch.nn.Module.apply() method, + # which applies svd_orthogonalization() to every layer of the model. + # usage: net.apply(regularizer_orth2) + # ---------------------------------------- + """ + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + w = m.weight.data.clone() + c_out, c_in, f1, f2 = w.size() + # dtype = m.weight.data.type() + w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) + u, s, v = torch.svd(w) + s_mean = s.mean() + s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4 + s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4 + w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) + m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) + else: + pass + + + +def regularizer_clip(m): + """ + # ---------------------------------------- + # usage: net.apply(regularizer_clip) + # ---------------------------------------- + """ + eps = 1e-4 + c_min = -1.5 + c_max = 1.5 + + classname = m.__class__.__name__ + if classname.find('Conv') != -1 or classname.find('Linear') != -1: + w = m.weight.data.clone() + w[w > c_max] -= eps + w[w < c_min] += eps + m.weight.data = w + + if m.bias is not None: + b = m.bias.data.clone() + b[b > c_max] -= eps + b[b < c_min] += eps + m.bias.data = b + +# elif classname.find('BatchNorm2d') != -1: +# +# rv = m.running_var.data.clone() +# rm = m.running_mean.data.clone() +# +# if m.affine: +# m.weight.data +# m.bias.data diff --git a/core/data/deg_kair_utils/utils_sisr.py b/core/data/deg_kair_utils/utils_sisr.py new file mode 100644 index 0000000000000000000000000000000000000000..e9edbd72ce53351d9e306c9774073a0e2eb0bdb3 --- /dev/null +++ b/core/data/deg_kair_utils/utils_sisr.py @@ -0,0 +1,848 @@ +# -*- coding: utf-8 -*- +from utils import utils_image as util +import random + +import scipy +import scipy.stats as ss +import scipy.io as io +from scipy import ndimage +from scipy.interpolate import interp2d + +import numpy as np +import torch + + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# modified by Kai Zhang (github: https://github.com/cszn) +# 03/03/2020 +# -------------------------------------------- +""" + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +""" +# -------------------------------------------- +# calculate PCA projection matrix +# -------------------------------------------- +""" + + +def get_pca_matrix(x, dim_pca=15): + """ + Args: + x: 225x10000 matrix + dim_pca: 15 + Returns: + pca_matrix: 15x225 + """ + C = np.dot(x, x.T) + w, v = scipy.linalg.eigh(C) + pca_matrix = v[:, -dim_pca:].T + + return pca_matrix + + +def show_pca(x): + """ + x: PCA projection matrix, e.g., 15x225 + """ + for i in range(x.shape[0]): + xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F") + util.surf(xc) + + +def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500): + kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32) + for i in range(num_samples): + + theta = np.pi*np.random.rand(1) + l1 = 0.1+l_max*np.random.rand(1) + l2 = 0.1+(l1-0.1)*np.random.rand(1) + + k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0]) + + # util.imshow(k) + + kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F') + + # io.savemat('k.mat', {'k': kernels}) + + pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca) + + io.savemat(path, {'p': pca_matrix}) + + return pca_matrix + + +""" +# -------------------------------------------- +# shifted anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z-MU + ZZ_t = ZZ.transpose(0,1,3,2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + sf = random.choice([1, 2, 3, 4]) + scale_factor = np.array([sf, sf]) + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z-MU + ZZ_t = ZZ.transpose(0,1,3,2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1/sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +''' +# ================= +# Numpy +# ================= +''' + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH, image or kernel + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf-1)*0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w-1) + y1 = np.clip(y1, 0, h-1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +''' +# ================= +# pytorch +# ================= +''' + + +def splits(a, sf): + ''' + a: tensor NxCxWxHx2 + sf: scale factor + out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2) + ''' + b = torch.stack(torch.chunk(a, sf, dim=2), dim=5) + b = torch.cat(torch.chunk(b, sf, dim=3), dim=5) + return b + + +def c2c(x): + return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) + + +def r2c(x): + return torch.stack([x, torch.zeros_like(x)], -1) + + +def cdiv(x, y): + a, b = x[..., 0], x[..., 1] + c, d = y[..., 0], y[..., 1] + cd2 = c**2 + d**2 + return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) + + +def csum(x, y): + return torch.stack([x[..., 0] + y, x[..., 1]], -1) + + +def cabs(x): + return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) + + +def cmul(t1, t2): + ''' + complex multiplication + t1: NxCxHxWx2 + output: NxCxHxWx2 + ''' + real1, imag1 = t1[..., 0], t1[..., 1] + real2, imag2 = t2[..., 0], t2[..., 1] + return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) + + +def cconj(t, inplace=False): + ''' + # complex's conjugation + t: NxCxHxWx2 + output: NxCxHxWx2 + ''' + c = t.clone() if not inplace else t + c[..., 1] *= -1 + return c + + +def rfft(t): + return torch.rfft(t, 2, onesided=False) + + +def irfft(t): + return torch.irfft(t, 2, onesided=False) + + +def fft(t): + return torch.fft(t, 2) + + +def ifft(t): + return torch.ifft(t, 2) + + +def p2o(psf, shape): + ''' + Args: + psf: NxCxhxw + shape: [H,W] + + Returns: + otf: NxCxHxWx2 + ''' + otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) + otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) + for axis, axis_size in enumerate(psf.shape[2:]): + otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) + otf = torch.rfft(otf, 2, onesided=False) + n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) + otf[...,1][torch.abs(otf[...,1]) x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding) + ''' + x = torch.cat([x, x[:, :, 0:pad, :]], dim=2) + x = torch.cat([x, x[:, :, :, 0:pad]], dim=3) + x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2) + x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3) + return x + + +def pad_circular(input, padding): + # type: (Tensor, List[int]) -> Tensor + """ + Arguments + :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))` + :param padding: (tuple): m-elem tuple where m is the degree of convolution + Returns + :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0], + H + 2 * padding[1]], W + 2 * padding[2]))` + """ + offset = 3 + for dimension in range(input.dim() - offset + 1): + input = dim_pad_circular(input, padding[dimension], dimension + offset) + return input + + +def dim_pad_circular(input, padding, dimension): + # type: (Tensor, int, int) -> Tensor + input = torch.cat([input, input[[slice(None)] * (dimension - 1) + + [slice(0, padding)]]], dim=dimension - 1) + input = torch.cat([input[[slice(None)] * (dimension - 1) + + [slice(-2 * padding, -padding)]], input], dim=dimension - 1) + return input + + +def imfilter(x, k): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + ''' + x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2)) + x = torch.nn.functional.conv2d(x, k, groups=x.shape[1]) + return x + + +def G(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + sf: scale factor + center: the first one or the moddle one + + Matlab function: + tmp = imfilter(x,h,'circular'); + y = downsample2(tmp,K); + ''' + x = downsample(imfilter(x, k), sf=sf, center=center) + return x + + +def Gt(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + sf: scale factor + center: the first one or the moddle one + + Matlab function: + tmp = upsample2(x,K); + y = imfilter(tmp,h,'circular'); + ''' + x = imfilter(upsample(x, sf=sf, center=center), k) + return x + + +def interpolation_down(x, sf, center=False): + mask = torch.zeros_like(x) + if center: + start = torch.tensor((sf-1)//2) + mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x) + LR = x[..., start::sf, start::sf] + else: + mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x) + LR = x[..., ::sf, ::sf] + y = x.mul(mask) + + return LR, y, mask + + +''' +# ================= +Numpy +# ================= +''' + + +def blockproc(im, blocksize, fun): + xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0) + xblocks_proc = [] + for xb in xblocks: + yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1) + yblocks_proc = [] + for yb in yblocks: + yb_proc = fun(yb) + yblocks_proc.append(yb_proc) + xblocks_proc.append(np.concatenate(yblocks_proc, axis=1)) + + proc = np.concatenate(xblocks_proc, axis=0) + + return proc + + +def fun_reshape(a): + return np.reshape(a, (-1,1,a.shape[-1]), order='F') + + +def fun_mul(a, b): + return a*b + + +def BlockMM(nr, nc, Nb, m, x1): + ''' + myfun = @(block_struct) reshape(block_struct.data,m,1); + x1 = blockproc(x1,[nr nc],myfun); + x1 = reshape(x1,m,Nb); + x1 = sum(x1,2); + x = reshape(x1,nr,nc); + ''' + fun = fun_reshape + x1 = blockproc(x1, blocksize=(nr, nc), fun=fun) + x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F') + x1 = np.sum(x1, 1) + x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F') + return x + + +def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m): + ''' + x1 = FB.*FR; + FBR = BlockMM(nr,nc,Nb,m,x1); + invW = BlockMM(nr,nc,Nb,m,F2B); + invWBR = FBR./(invW + tau*Nb); + fun = @(block_struct) block_struct.data.*invWBR; + FCBinvWBR = blockproc(FBC,[nr,nc],fun); + FX = (FR-FCBinvWBR)/tau; + Xest = real(ifft2(FX)); + ''' + x1 = FB*FR + FBR = BlockMM(nr, nc, Nb, m, x1) + invW = BlockMM(nr, nc, Nb, m, F2B) + invWBR = FBR/(invW + tau*Nb) + FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR)) + FX = (FR-FCBinvWBR)/tau + Xest = np.real(np.fft.ifft2(FX, axes=(0, 1))) + return Xest + + +def psf2otf(psf, shape=None): + """ + Convert point-spread function to optical transfer function. + Compute the Fast Fourier Transform (FFT) of the point-spread + function (PSF) array and creates the optical transfer function (OTF) + array that is not influenced by the PSF off-centering. + By default, the OTF array is the same size as the PSF array. + To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF + post-pads the PSF array (down or to the right) with zeros to match + dimensions specified in OUTSIZE, then circularly shifts the values of + the PSF array up (or to the left) until the central pixel reaches (1,1) + position. + Parameters + ---------- + psf : `numpy.ndarray` + PSF array + shape : int + Output shape of the OTF array + Returns + ------- + otf : `numpy.ndarray` + OTF array + Notes + ----- + Adapted from MATLAB psf2otf function + """ + if type(shape) == type(None): + shape = psf.shape + shape = np.array(shape) + if np.all(psf == 0): + # return np.zeros_like(psf) + return np.zeros(shape) + if len(psf.shape) == 1: + psf = psf.reshape((1, psf.shape[0])) + inshape = psf.shape + psf = zero_pad(psf, shape, position='corner') + for axis, axis_size in enumerate(inshape): + psf = np.roll(psf, -int(axis_size / 2), axis=axis) + # Compute the OTF + otf = np.fft.fft2(psf, axes=(0, 1)) + # Estimate the rough number of operations involved in the FFT + # and discard the PSF imaginary part if within roundoff error + # roundoff error = machine epsilon = sys.float_info.epsilon + # or np.finfo().eps + n_ops = np.sum(psf.size * np.log2(psf.shape)) + otf = np.real_if_close(otf, tol=n_ops) + return otf + + +def zero_pad(image, shape, position='corner'): + """ + Extends image to a certain size with zeros + Parameters + ---------- + image: real 2d `numpy.ndarray` + Input image + shape: tuple of int + Desired output shape of the image + position : str, optional + The position of the input image in the output one: + * 'corner' + top-left corner (default) + * 'center' + centered + Returns + ------- + padded_img: real `numpy.ndarray` + The zero-padded image + """ + shape = np.asarray(shape, dtype=int) + imshape = np.asarray(image.shape, dtype=int) + if np.alltrue(imshape == shape): + return image + if np.any(shape <= 0): + raise ValueError("ZERO_PAD: null or negative shape given") + dshape = shape - imshape + if np.any(dshape < 0): + raise ValueError("ZERO_PAD: target size smaller than source one") + pad_img = np.zeros(shape, dtype=image.dtype) + idx, idy = np.indices(imshape) + if position == 'center': + if np.any(dshape % 2 != 0): + raise ValueError("ZERO_PAD: source and target shapes " + "have different parity.") + offx, offy = dshape // 2 + else: + offx, offy = (0, 0) + pad_img[idx + offx, idy + offy] = image + return pad_img + + +def upsample_np(x, sf=3, center=False): + st = (sf-1)//2 if center else 0 + z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2])) + z[st::sf, st::sf, ...] = x + return z + + +def downsample_np(x, sf=3, center=False): + st = (sf-1)//2 if center else 0 + return x[st::sf, st::sf, ...] + + +def imfilter_np(x, k): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def G_np(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + + Matlab function: + tmp = imfilter(x,h,'circular'); + y = downsample2(tmp,K); + ''' + x = downsample_np(imfilter_np(x, k), sf=sf, center=center) + return x + + +def Gt_np(x, k, sf=3, center=False): + ''' + x: image, NxcxHxW + k: kernel, cx1xhxw + + Matlab function: + tmp = upsample2(x,K); + y = imfilter(tmp,h,'circular'); + ''' + x = imfilter_np(upsample_np(x, sf=sf, center=center), k) + return x + + +if __name__ == '__main__': + img = util.imread_uint('test.bmp', 3) + + img = util.uint2single(img) + k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6) + util.imshow(k*10) + + + for sf in [2, 3, 4]: + + # modcrop + img = modcrop_np(img, sf=sf) + + # 1) bicubic degradation + img_b = bicubic_degradation(img, sf=sf) + print(img_b.shape) + + # 2) srmd degradation + img_s = srmd_degradation(img, k, sf=sf) + print(img_s.shape) + + # 3) dpsr degradation + img_d = dpsr_degradation(img, k, sf=sf) + print(img_d.shape) + + # 4) classical degradation + img_d = classical_degradation(img, k, sf=sf) + print(img_d.shape) + + k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01) + #print(k) +# util.imshow(k*10) + + k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0) +# util.imshow(k*10) + + + # PCA +# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500) +# print(pca_matrix.shape) +# show_pca(pca_matrix) + # run utils/utils_sisr.py + # run utils_sisr.py + + + + + + + diff --git a/core/data/deg_kair_utils/utils_video.py b/core/data/deg_kair_utils/utils_video.py new file mode 100644 index 0000000000000000000000000000000000000000..596dd4203098cf7b36f3d8499ccbf299623381ae --- /dev/null +++ b/core/data/deg_kair_utils/utils_video.py @@ -0,0 +1,493 @@ +import os +import cv2 +import numpy as np +import torch +import random +from os import path as osp +from torch.nn import functional as F +from abc import ABCMeta, abstractmethod + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + return_imgname(bool): Whether return image names. Default False. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + list[str]: Returned image name list. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + + if return_imgname: + imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths] + return imgs, imgnames + else: + return imgs + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): + """Paired random crop. Support Numpy array and Tensor inputs. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. Default: None. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + # determine input type: Numpy array or Tensor + input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' + + if input_type == 'Tensor': + h_lq, w_lq = img_lqs[0].size()[-2:] + h_gt, w_gt = img_gts[0].size()[-2:] + else: + h_lq, w_lq = img_lqs[0].shape[0:2] + h_gt, w_gt = img_gts[0].shape[0:2] + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + if input_type == 'Tensor': + img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] + else: + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + if input_type == 'Tensor': + img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] + else: + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing different lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + diff --git a/core/data/deg_kair_utils/utils_videoio.py b/core/data/deg_kair_utils/utils_videoio.py new file mode 100644 index 0000000000000000000000000000000000000000..5be8c7f06802d5aaa7155a1cdcb27d2838a0882c --- /dev/null +++ b/core/data/deg_kair_utils/utils_videoio.py @@ -0,0 +1,555 @@ +import os +import cv2 +import numpy as np +import torch +import random +from os import path as osp +from torchvision.utils import make_grid +import sys +from pathlib import Path +import six +from collections import OrderedDict +import math +import glob +import av +import io +from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT, + CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH, + CAP_PROP_POS_FRAMES, VideoWriter_fourcc) + +if sys.version_info <= (3, 3): + FileNotFoundError = IOError +else: + FileNotFoundError = FileNotFoundError + + +def is_str(x): + """Whether the input is an string instance.""" + return isinstance(x, six.string_types) + + +def is_filepath(x): + return is_str(x) or isinstance(x, Path) + + +def fopen(filepath, *args, **kwargs): + if is_str(filepath): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError('`filepath` should be a string or a Path') + + +def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == '': + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): + """Scan a directory to find the interested files. + Args: + dir_path (str | :obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = suffix.lower() if isinstance(suffix, str) else tuple( + item.lower() for item in suffix) + + root = dir_path + + def _scandir(dir_path, suffix, recursive, case_sensitive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, + case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +class Cache: + + def __init__(self, capacity): + self._cache = OrderedDict() + self._capacity = int(capacity) + if capacity <= 0: + raise ValueError('capacity must be a positive integer') + + @property + def capacity(self): + return self._capacity + + @property + def size(self): + return len(self._cache) + + def put(self, key, val): + if key in self._cache: + return + if len(self._cache) >= self.capacity: + self._cache.popitem(last=False) + self._cache[key] = val + + def get(self, key, default=None): + val = self._cache[key] if key in self._cache else default + return val + + +class VideoReader: + """Video class with similar usage to a list object. + + This video warpper class provides convenient apis to access frames. + There exists an issue of OpenCV's VideoCapture class that jumping to a + certain frame may be inaccurate. It is fixed in this class by checking + the position after jumping each time. + Cache is used when decoding videos. So if the same frame is visited for + the second time, there is no need to decode again if it is stored in the + cache. + + """ + + def __init__(self, filename, cache_capacity=10): + # Check whether the video path is a url + if not filename.startswith(('https://', 'http://')): + check_file_exist(filename, 'Video file not found: ' + filename) + self._vcap = cv2.VideoCapture(filename) + assert cache_capacity > 0 + self._cache = Cache(cache_capacity) + self._position = 0 + # get basic info + self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH)) + self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT)) + self._fps = self._vcap.get(CAP_PROP_FPS) + self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT)) + self._fourcc = self._vcap.get(CAP_PROP_FOURCC) + + @property + def vcap(self): + """:obj:`cv2.VideoCapture`: The raw VideoCapture object.""" + return self._vcap + + @property + def opened(self): + """bool: Indicate whether the video is opened.""" + return self._vcap.isOpened() + + @property + def width(self): + """int: Width of video frames.""" + return self._width + + @property + def height(self): + """int: Height of video frames.""" + return self._height + + @property + def resolution(self): + """tuple: Video resolution (width, height).""" + return (self._width, self._height) + + @property + def fps(self): + """float: FPS of the video.""" + return self._fps + + @property + def frame_cnt(self): + """int: Total frames of the video.""" + return self._frame_cnt + + @property + def fourcc(self): + """str: "Four character code" of the video.""" + return self._fourcc + + @property + def position(self): + """int: Current cursor position, indicating frame decoded.""" + return self._position + + def _get_real_position(self): + return int(round(self._vcap.get(CAP_PROP_POS_FRAMES))) + + def _set_real_position(self, frame_id): + self._vcap.set(CAP_PROP_POS_FRAMES, frame_id) + pos = self._get_real_position() + for _ in range(frame_id - pos): + self._vcap.read() + self._position = frame_id + + def read(self): + """Read the next frame. + + If the next frame have been decoded before and in the cache, then + return it directly, otherwise decode, cache and return it. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + # pos = self._position + if self._cache: + img = self._cache.get(self._position) + if img is not None: + ret = True + else: + if self._position != self._get_real_position(): + self._set_real_position(self._position) + ret, img = self._vcap.read() + if ret: + self._cache.put(self._position, img) + else: + ret, img = self._vcap.read() + if ret: + self._position += 1 + return img + + def get_frame(self, frame_id): + """Get frame by index. + + Args: + frame_id (int): Index of the expected frame, 0-based. + + Returns: + ndarray or None: Return the frame if successful, otherwise None. + """ + if frame_id < 0 or frame_id >= self._frame_cnt: + raise IndexError( + f'"frame_id" must be between 0 and {self._frame_cnt - 1}') + if frame_id == self._position: + return self.read() + if self._cache: + img = self._cache.get(frame_id) + if img is not None: + self._position = frame_id + 1 + return img + self._set_real_position(frame_id) + ret, img = self._vcap.read() + if ret: + if self._cache: + self._cache.put(self._position, img) + self._position += 1 + return img + + def current_frame(self): + """Get the current frame (frame that is just visited). + + Returns: + ndarray or None: If the video is fresh, return None, otherwise + return the frame. + """ + if self._position == 0: + return None + return self._cache.get(self._position - 1) + + def cvt2frames(self, + frame_dir, + file_start=0, + filename_tmpl='{:06d}.jpg', + start=0, + max_num=0, + show_progress=False): + """Convert a video to frame images. + + Args: + frame_dir (str): Output directory to store all the frame images. + file_start (int): Filenames will start from the specified number. + filename_tmpl (str): Filename template with the index as the + placeholder. + start (int): The starting frame index. + max_num (int): Maximum number of frames to be written. + show_progress (bool): Whether to show a progress bar. + """ + mkdir_or_exist(frame_dir) + if max_num == 0: + task_num = self.frame_cnt - start + else: + task_num = min(self.frame_cnt - start, max_num) + if task_num <= 0: + raise ValueError('start must be less than total frame number') + if start > 0: + self._set_real_position(start) + + def write_frame(file_idx): + img = self.read() + if img is None: + return + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + cv2.imwrite(filename, img) + + if show_progress: + pass + #track_progress(write_frame, range(file_start,file_start + task_num)) + else: + for i in range(task_num): + write_frame(file_start + i) + + def __len__(self): + return self.frame_cnt + + def __getitem__(self, index): + if isinstance(index, slice): + return [ + self.get_frame(i) + for i in range(*index.indices(self.frame_cnt)) + ] + # support negative indexing + if index < 0: + index += self.frame_cnt + if index < 0: + raise IndexError('index out of range') + return self.get_frame(index) + + def __iter__(self): + self._set_real_position(0) + return self + + def __next__(self): + img = self.read() + if img is not None: + return img + else: + raise StopIteration + + next = __next__ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._vcap.release() + + +def frames2video(frame_dir, + video_file, + fps=30, + fourcc='XVID', + filename_tmpl='{:06d}.jpg', + start=0, + end=0, + show_progress=False): + """Read the frame images from a directory and join them as a video. + + Args: + frame_dir (str): The directory containing video frames. + video_file (str): Output filename. + fps (float): FPS of the output video. + fourcc (str): Fourcc of the output video, this should be compatible + with the output file type. + filename_tmpl (str): Filename template with the index as the variable. + start (int): Starting frame index. + end (int): Ending frame index. + show_progress (bool): Whether to show a progress bar. + """ + if end == 0: + ext = filename_tmpl.split('.')[-1] + end = len([name for name in scandir(frame_dir, ext)]) + first_file = osp.join(frame_dir, filename_tmpl.format(start)) + check_file_exist(first_file, 'The start frame not found: ' + first_file) + img = cv2.imread(first_file) + height, width = img.shape[:2] + resolution = (width, height) + vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps, + resolution) + + def write_frame(file_idx): + filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) + img = cv2.imread(filename) + vwriter.write(img) + + if show_progress: + pass + # track_progress(write_frame, range(start, end)) + else: + for i in range(start, end): + write_frame(i) + vwriter.release() + + +def video2images(video_path, output_dir): + vidcap = cv2.VideoCapture(video_path) + in_fps = vidcap.get(cv2.CAP_PROP_FPS) + print('video fps:', in_fps) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + loaded, frame = vidcap.read() + total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f'number of total frames is: {total_frames:06}') + for i_frame in range(total_frames): + if i_frame % 100 == 0: + print(f'{i_frame:06} / {total_frames:06}') + frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png') + cv2.imwrite(frame_name, frame) + loaded, frame = vidcap.read() + + +def images2video(image_dir, video_path, fps=24, image_ext='png'): + ''' + #codec = cv2.VideoWriter_fourcc(*'XVID') + #codec = cv2.VideoWriter_fourcc('A','V','C','1') + #codec = cv2.VideoWriter_fourcc('Y','U','V','1') + #codec = cv2.VideoWriter_fourcc('P','I','M','1') + #codec = cv2.VideoWriter_fourcc('M','J','P','G') + codec = cv2.VideoWriter_fourcc('M','P','4','2') + #codec = cv2.VideoWriter_fourcc('D','I','V','3') + #codec = cv2.VideoWriter_fourcc('D','I','V','X') + #codec = cv2.VideoWriter_fourcc('U','2','6','3') + #codec = cv2.VideoWriter_fourcc('I','2','6','3') + #codec = cv2.VideoWriter_fourcc('F','L','V','1') + #codec = cv2.VideoWriter_fourcc('H','2','6','4') + #codec = cv2.VideoWriter_fourcc('A','Y','U','V') + #codec = cv2.VideoWriter_fourcc('I','U','Y','V') + ç¼–ç å™¨å¸¸ç”¨çš„几ç§ï¼š + cv2.VideoWriter_fourcc("I", "4", "2", "0") + 压缩的yuv颜色编ç å™¨ï¼Œ4:2:0色彩度å­é‡‡æ · 兼容性好,产生很大的视频 avi + cv2.VideoWriter_fourcc("P", I", "M", "1") + 采用mpeg-1ç¼–ç ï¼Œæ–‡ä»¶ä¸ºavi + cv2.VideoWriter_fourcc("X", "V", "T", "D") + 采用mpeg-4ç¼–ç ï¼Œå¾—到视频大å°å¹³å‡ 拓展åavi + cv2.VideoWriter_fourcc("T", "H", "E", "O") + Ogg Vorbis, 拓展å为ogv + cv2.VideoWriter_fourcc("F", "L", "V", "1") + FLASH视频,拓展å为.flv + ''' + image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext)))) + print(len(image_files)) + height, width, _ = cv2.imread(image_files[0]).shape + out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V') + out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height)) + + for image_file in image_files: + img = cv2.imread(image_file) + img = cv2.resize(img, (width, height), interpolation=3) + out_video.write(img) + out_video.release() + + +def add_video_compression(imgs): + codec_type = ['libx264', 'h264', 'mpeg4'] + codec_prob = [1 / 3., 1 / 3., 1 / 3.] + codec = random.choices(codec_type, codec_prob)[0] + # codec = 'mpeg4' + bitrate = [1e4, 1e5] + bitrate = np.random.randint(bitrate[0], bitrate[1] + 1) + + buf = io.BytesIO() + with av.open(buf, 'w', 'mp4') as container: + stream = container.add_stream(codec, rate=1) + stream.height = imgs[0].shape[0] + stream.width = imgs[0].shape[1] + stream.pix_fmt = 'yuv420p' + stream.bit_rate = bitrate + + for img in imgs: + img = np.uint8((img.clip(0, 1)*255.).round()) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame.pict_type = 'NONE' + # pdb.set_trace() + for packet in stream.encode(frame): + container.mux(packet) + + # Flush stream + for packet in stream.encode(): + container.mux(packet) + + outputs = [] + with av.open(buf, 'r', 'mp4') as container: + if container.streams.video: + for frame in container.decode(**{'video': 0}): + outputs.append( + frame.to_rgb().to_ndarray().astype(np.float32) / 255.) + + #outputs = np.stack(outputs, axis=0) + return outputs + + +if __name__ == '__main__': + + # ----------------------------------- + # test VideoReader(filename, cache_capacity=10) + # ----------------------------------- +# video_reader = VideoReader('utils/test.mp4') +# from utils import utils_image as util +# inputs = [] +# for frame in video_reader: +# print(frame.dtype) +# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) +# #util.imshow(np.flip(frame, axis=2)) + + # ----------------------------------- + # test video2images(video_path, output_dir) + # ----------------------------------- +# video2images('utils/test.mp4', 'frames') + + # ----------------------------------- + # test images2video(image_dir, video_path, fps=24, image_ext='png') + # ----------------------------------- +# images2video('frames', 'video_02.mp4', fps=30, image_ext='png') + + + # ----------------------------------- + # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png') + # ----------------------------------- +# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png') + + + # ----------------------------------- + # test add_video_compression(imgs) + # ----------------------------------- +# imgs = [] +# image_ext = 'png' +# frames = 'frames' +# from utils import utils_image as util +# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext)))) +# for i, image_file in enumerate(image_files): +# if i < 7: +# img = util.imread_uint(image_file, 3) +# img = util.uint2single(img) +# imgs.append(img) +# +# results = add_video_compression(imgs) +# for i, img in enumerate(results): +# util.imshow(util.single2uint(img)) +# util.imsave(util.single2uint(img),f'{i:05}.png') + + # run utils/utils_video.py + + + + + + + diff --git a/core/scripts/__init__.py b/core/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/core/scripts/cli.py b/core/scripts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe3ecc330ecf9f0b3af1e7dc6b3758673712cc7 --- /dev/null +++ b/core/scripts/cli.py @@ -0,0 +1,41 @@ +import sys +import argparse +from .. import WarpCore +from .. import templates + + +def template_init(args): + return '''' + + + '''.strip() + + +def init_template(args): + parser = argparse.ArgumentParser(description='WarpCore template init tool') + parser.add_argument('-t', '--template', type=str, default='WarpCore') + args = parser.parse_args(args) + + if args.template == 'WarpCore': + template_cls = WarpCore + else: + try: + template_cls = __import__(args.template) + except ModuleNotFoundError: + template_cls = getattr(templates, args.template) + print(template_cls) + + +def main(): + if len(sys.argv) < 2: + print('Usage: core ') + sys.exit(1) + if sys.argv[1] == 'init': + init_template(sys.argv[2:]) + else: + print('Unknown command') + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/core/templates/__init__.py b/core/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..570f16de78bcce68aa49ff0a5d0fad63284f6948 --- /dev/null +++ b/core/templates/__init__.py @@ -0,0 +1 @@ +from .diffusion import DiffusionCore \ No newline at end of file diff --git a/core/templates/diffusion.py b/core/templates/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f36dc3f5efa14669cc36cc3c0cffcc8def037289 --- /dev/null +++ b/core/templates/diffusion.py @@ -0,0 +1,236 @@ +from .. import WarpCore +from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from abc import abstractmethod +from dataclasses import dataclass +import torch +from torch import nn +from torch.utils.data import DataLoader +from gdf import GDF +import numpy as np +from tqdm import tqdm +import wandb + +import webdataset as wds +from webdataset.handlers import warn_and_continue +from torch.distributed import barrier +from enum import Enum + +class TargetReparametrization(Enum): + EPSILON = 'epsilon' + X0 = 'x0' + +class DiffusionCore(WarpCore): + @dataclass(frozen=True) + class Config(WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + grad_accum_steps: int = EXPECTED_TRAIN + batch_size: int = EXPECTED_TRAIN + updates: int = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + save_every: int = 500 + backup_every: int = 20000 + use_fsdp: bool = True + + # EMA UPDATE + ema_start_iters: int = None + ema_iters: int = None + ema_beta: float = None + + # GDF setting + gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0 + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(WarpCore.Info): + ema_loss: float = None + + @dataclass(frozen=True) + class Models(WarpCore.Models): + generator : nn.Module = EXPECTED + generator_ema : nn.Module = None # optional + + @dataclass(frozen=True) + class Optimizers(WarpCore.Optimizers): + generator : any = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + + # -------------------------------------------- + info: Info + config: Config + + @abstractmethod + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_path(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_filters(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def webdataset_preprocessors(self, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + raise NotImplementedError("This method needs to be overriden") + # ------------- + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.webdataset_path(extras) + preprocessors = self.webdataset_preprocessors(extras) + filters = self.webdataset_filters(extras) + + handler = warn_and_continue # None + # handler = None + dataset = wds.WebDataset( + dataset_path, resampled=True, handler=handler + ).select(filters).shuffle(690, handler=handler).decode( + "pilrgb", handler=handler + ).to_tuple( + *[p[0] for p in preprocessors], handler=handler + ).map_tuple( + *[p[1] for p in preprocessors], handler=handler + ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True + ) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + # FORWARD PASS + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: + pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss + target = noise + elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: + pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss + target = latents + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + return loss, loss_adjusted + + def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + start_iter = self.info.iter+1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP + models.generator.train() + for i in pbar: + # FORWARD PASS + loss, loss_adjusted = self.forward_pass(data, extras, models) + + # BACKWARD PASS + if i % self.config.grad_accum_steps == 0 or i == max_iters: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + schedulers_dict[k].step() + models.generator.zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + with models.generator.no_sync(): + loss_adjusted.backward() + self.info.iter = i + + # UPDATE EMA + if models.generator_ema is not None and i % self.config.ema_iters == 0: + update_weights_ema( + models.generator_ema, models.generator, + beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) + ) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + wandb.alert( + title=f"NaN value encountered in training run {self.info.wandb_run_id}", + text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", + wait_duration=60*30 + ) + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'raw_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'], + 'total_steps': self.info.total_steps, + } + + pbar.set_postfix(logs) + if self.config.wandb_project is not None: + wandb.log(logs) + + if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + tqdm.write("Skipping sampling & checkpoint because the loss is NaN") + wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") + else: + self.save_checkpoints(models, optimizers) + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + self.sample(models, data, extras) + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): + barrier() + suffix = '' if suffix is None else suffix + self.save_info(self.info, suffix=suffix) + models_dict = models.to_dict() + optimizers_dict = optimizers.to_dict() + for key in self.models_to_save(): + model = models_dict[key] + if model is not None: + self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) + for key in optimizers_dict: + optimizer = optimizers_dict[key] + if optimizer is not None: + self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) + if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") + torch.cuda.empty_cache() diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e71b37e8d1690a00ab1e0958320775bc822b6f5 --- /dev/null +++ b/core/utils/__init__.py @@ -0,0 +1,9 @@ +from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN +from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail + +# MOVE IT SOMERWHERE ELSE +def update_weights_ema(tgt_model, src_model, beta=0.999): + for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta) + for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta) \ No newline at end of file diff --git a/core/utils/__pycache__/__init__.cpython-310.pyc b/core/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63c0a7e0fbf358f557d6bea755a0f550b4010a48 Binary files /dev/null and b/core/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/core/utils/__pycache__/__init__.cpython-39.pyc b/core/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f18d6921da3c9d93087c1b6d8eacd7a5e46a8e5 Binary files /dev/null and b/core/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/core/utils/__pycache__/base_dto.cpython-310.pyc b/core/utils/__pycache__/base_dto.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de093eb65813d4abf69edfbb6923f2cabab21ad7 Binary files /dev/null and b/core/utils/__pycache__/base_dto.cpython-310.pyc differ diff --git a/core/utils/__pycache__/base_dto.cpython-39.pyc b/core/utils/__pycache__/base_dto.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b80d348c7959338709ec24c3ac24dfc4f6dab3dc Binary files /dev/null and b/core/utils/__pycache__/base_dto.cpython-39.pyc differ diff --git a/core/utils/__pycache__/save_and_load.cpython-310.pyc b/core/utils/__pycache__/save_and_load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7a0f63ac8bbaf073dcd8a046ed112cec181d33a Binary files /dev/null and b/core/utils/__pycache__/save_and_load.cpython-310.pyc differ diff --git a/core/utils/__pycache__/save_and_load.cpython-39.pyc b/core/utils/__pycache__/save_and_load.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec04e9aba6f83ab76f0bbc243bb95fda07ad8d16 Binary files /dev/null and b/core/utils/__pycache__/save_and_load.cpython-39.pyc differ diff --git a/core/utils/base_dto.py b/core/utils/base_dto.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf185f00e5c6f56d23774cec8591b8d4554971e --- /dev/null +++ b/core/utils/base_dto.py @@ -0,0 +1,56 @@ +import dataclasses +from dataclasses import dataclass, _MISSING_TYPE +from munch import Munch + +EXPECTED = "___REQUIRED___" +EXPECTED_TRAIN = "___REQUIRED_TRAIN___" + +# pylint: disable=invalid-field-call +def nested_dto(x, raw=False): + return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) + +@dataclass(frozen=True) +class Base: + training: bool = None + def __new__(cls, **kwargs): + training = kwargs.get('training', True) + setteable_fields = cls.setteable_fields(**kwargs) + mandatory_fields = cls.mandatory_fields(**kwargs) + invalid_kwargs = [ + {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) + ] + print(mandatory_fields) + assert ( + len(invalid_kwargs) == 0 + ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." + missing_kwargs = [f for f in mandatory_fields if f not in kwargs] + assert ( + len(missing_kwargs) == 0 + ), f"Required fields missing initializing this DTO: {missing_kwargs}." + return object.__new__(cls) + + + @classmethod + def setteable_fields(cls, **kwargs): + return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] + + @classmethod + def mandatory_fields(cls, **kwargs): + training = kwargs.get('training', True) + return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] + + @classmethod + def from_dict(cls, kwargs): + for k in kwargs: + if isinstance(kwargs[k], (dict, list, tuple)): + kwargs[k] = Munch.fromDict(kwargs[k]) + return cls(**kwargs) + + def to_dict(self): + # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes + selfdict = {} + for k in dataclasses.fields(self): + selfdict[k.name] = getattr(self, k.name) + if isinstance(selfdict[k.name], Munch): + selfdict[k.name] = selfdict[k.name].toDict() + return selfdict diff --git a/core/utils/save_and_load.py b/core/utils/save_and_load.py new file mode 100644 index 0000000000000000000000000000000000000000..0215f664f5a8e738147d0828b6a7e65b9c3a8507 --- /dev/null +++ b/core/utils/save_and_load.py @@ -0,0 +1,59 @@ +import os +import torch +import json +from pathlib import Path +import safetensors +import wandb + + +def create_folder_if_necessary(path): + path = "/".join(path.split("/")[:-1]) + Path(path).mkdir(parents=True, exist_ok=True) + + +def safe_save(ckpt, path): + try: + os.remove(f"{path}.bak") + except OSError: + pass + try: + os.rename(path, f"{path}.bak") + except OSError: + pass + if path.endswith(".pt") or path.endswith(".ckpt"): + torch.save(ckpt, path) + elif path.endswith(".json"): + with open(path, "w", encoding="utf-8") as f: + json.dump(ckpt, f, indent=4) + elif path.endswith(".safetensors"): + safetensors.torch.save_file(ckpt, path) + else: + raise ValueError(f"File extension not supported: {path}") + + +def load_or_fail(path, wandb_run_id=None): + accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] + try: + assert any( + [path.endswith(ext) for ext in accepted_extensions] + ), f"Automatic loading not supported for this extension: {path}" + if not os.path.exists(path): + checkpoint = None + elif path.endswith(".pt") or path.endswith(".ckpt"): + checkpoint = torch.load(path, map_location="cpu") + elif path.endswith(".json"): + with open(path, "r", encoding="utf-8") as f: + checkpoint = json.load(f) + elif path.endswith(".safetensors"): + checkpoint = {} + with safetensors.safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + return checkpoint + except Exception as e: + if wandb_run_id is not None: + wandb.alert( + title=f"Corrupt checkpoint for run {wandb_run_id}", + text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", + ) + raise e diff --git a/gdf/__init__.py b/gdf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..753b52e2e07e2540385594627a6faf4f6091b0a0 --- /dev/null +++ b/gdf/__init__.py @@ -0,0 +1,205 @@ +import torch +from .scalers import * +from .targets import * +from .schedulers import * +from .noise_conditions import * +from .loss_weights import * +from .samplers import * +import torch.nn.functional as F +import math +class GDF(): + def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0): + self.schedule = schedule + self.input_scaler = input_scaler + self.target = target + self.noise_cond = noise_cond + self.loss_weight = loss_weight + self.offset_noise = offset_noise + + def setup_limits(self, stretch_max=True, stretch_min=True, shift=1): + stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift) + return stretched_limits + + def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None): + if epsilon is None: + epsilon = torch.randn_like(x0) + if self.offset_noise > 0: + if offset is None: + offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device) + epsilon = epsilon + offset * self.offset_noise + logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device) + a, b = self.input_scaler(logSNR) # B + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW + #print('in line 33 a b', a.shape, b.shape, x0.shape, logSNR.shape, logSNR, self.noise_cond(logSNR)) + target = self.target(x0, epsilon, logSNR, a, b) + + # noised, noise, logSNR, t_cond + #noised, noise, target, logSNR, noise_cond, loss_weight + return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift) + + def undiffuse(self, x, logSNR, pred): + a, b = self.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1)) + return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b) + + def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): + sampler_params = {} if sampler_params is None else sampler_params + if sampler is None: + sampler = DDPMSampler(self) + r_range = torch.linspace(t_start, t_end, timesteps+1) + schedule = self.schedule if schedule is None else schedule + logSNR_range = schedule(r_range, shift=shift)[:, None].expand( + -1, shape[0] if x_init is None else x_init.size(0) + ).to(device) + + x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() + + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = { + k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) + else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) + else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) + else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) + } + + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) + + pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2) + + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x, noise_cond, **model_inputs) + x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) + x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) + #print('in line 86', x0.shape, x.shape, i, ) + altered_vars = yield (x0, x, pred) + + # Update some running variables if the user wants + if altered_vars is not None: + cfg = altered_vars.get('cfg', cfg) + cfg_rho = altered_vars.get('cfg_rho', cfg_rho) + sampler = altered_vars.get('sampler', sampler) + model_inputs = altered_vars.get('model_inputs', model_inputs) + x = altered_vars.get('x', x) + x_init = altered_vars.get('x_init', x_init) + +class GDF_dual_fixlrt(GDF): + def ref_noise(self, noised, x0, logSNR): + a, b = self.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) + #print('in line 210', a.shape, b.shape, x0.shape, noised.shape) + return self.target.noise_givenx0_noised(x0, noised, logSNR, a, b) + + def sample(self, model, model_inputs, shape, shape_lr, unconditional_inputs=None, sampler=None, + schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, + cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): + sampler_params = {} if sampler_params is None else sampler_params + if sampler is None: + sampler = DDPMSampler(self) + r_range = torch.linspace(t_start, t_end, timesteps+1) + schedule = self.schedule if schedule is None else schedule + logSNR_range = schedule(r_range, shift=shift)[:, None].expand( + -1, shape[0] if x_init is None else x_init.size(0) + ).to(device) + + x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() + x_lr = sampler.init_x(shape_lr).to(device) if x_init is None else x_init.clone() + if cfg is not None: + if unconditional_inputs is None: + unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} + model_inputs = { + k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) + else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) + else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) + else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) + } + + ###############################################lr sampling + + guide_feas = [None] * timesteps + + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) + + + + if i == timesteps -1 : + output, guide_lr_enc, guide_lr_dec = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs) + guide_feas[i] = ([f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_enc], [f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_dec]) + else: + output, _, _ = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs) + + pred, pred_unconditional = output.chunk(2) + + + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x_lr, noise_cond, **model_inputs) + x0_lr, epsilon_lr = self.undiffuse(x_lr, logSNR_range[i], pred) + x_lr = sampler(x_lr, x0_lr, epsilon_lr, logSNR_range[i], logSNR_range[i+1], **sampler_params) + + ###############################################hr HR sampling + for i in range(0, timesteps): + noise_cond = self.noise_cond(logSNR_range[i]) + if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): + cfg_val = cfg + if isinstance(cfg_val, (list, tuple)): + assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" + cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) + + out_pred, t_emb = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), \ + lr_guide=guide_feas[timesteps -1] if i <=19 else None , **model_inputs, require_t=True, guide_weight=1 - i/timesteps) + pred, pred_unconditional = out_pred.chunk(2) + pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) + if cfg_rho > 0: + std_pos, std_cfg = pred.std(), pred_cfg.std() + pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) + else: + pred = pred_cfg + else: + pred = model(x, noise_cond, guide_lr=(guide_lr_enc, guide_lr_dec), **model_inputs) + x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) + + x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) + altered_vars = yield (x0, x, pred, x_lr) + + + + # Update some running variables if the user wants + if altered_vars is not None: + cfg = altered_vars.get('cfg', cfg) + cfg_rho = altered_vars.get('cfg_rho', cfg_rho) + sampler = altered_vars.get('sampler', sampler) + model_inputs = altered_vars.get('model_inputs', model_inputs) + x = altered_vars.get('x', x) + x_init = altered_vars.get('x_init', x_init) + + + + diff --git a/gdf/loss_weights.py b/gdf/loss_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..d14ddaadeeb3f8de6c68aea4c364d9b852f2f15c --- /dev/null +++ b/gdf/loss_weights.py @@ -0,0 +1,101 @@ +import torch +import numpy as np + +# --- Loss Weighting +class BaseLossWeight(): + def weight(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + if shift != 1: + logSNR = logSNR.clone() + 2 * np.log(shift) + return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) + +class ComposedLossWeight(BaseLossWeight): + def __init__(self, div, mul): + self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul + self.div = [div] if isinstance(div, BaseLossWeight) else div + + def weight(self, logSNR): + prod, div = 1, 1 + for m in self.mul: + prod *= m.weight(logSNR) + for d in self.div: + div *= d.weight(logSNR) + return prod/div + +class ConstantLossWeight(BaseLossWeight): + def __init__(self, v=1): + self.v = v + + def weight(self, logSNR): + return torch.ones_like(logSNR) * self.v + +class SNRLossWeight(BaseLossWeight): + def weight(self, logSNR): + return logSNR.exp() + +class P2LossWeight(BaseLossWeight): + def __init__(self, k=1.0, gamma=1.0, s=1.0): + self.k, self.gamma, self.s = k, gamma, s + + def weight(self, logSNR): + return (self.k + (logSNR * self.s).exp()) ** -self.gamma + +class SNRPlusOneLossWeight(BaseLossWeight): + def weight(self, logSNR): + return logSNR.exp() + 1 + +class MinSNRLossWeight(BaseLossWeight): + def __init__(self, max_snr=5): + self.max_snr = max_snr + + def weight(self, logSNR): + return logSNR.exp().clamp(max=self.max_snr) + +class MinSNRPlusOneLossWeight(BaseLossWeight): + def __init__(self, max_snr=5): + self.max_snr = max_snr + + def weight(self, logSNR): + return (logSNR.exp() + 1).clamp(max=self.max_snr) + +class TruncatedSNRLossWeight(BaseLossWeight): + def __init__(self, min_snr=1): + self.min_snr = min_snr + + def weight(self, logSNR): + return logSNR.exp().clamp(min=self.min_snr) + +class SechLossWeight(BaseLossWeight): + def __init__(self, div=2): + self.div = div + + def weight(self, logSNR): + return 1/(logSNR/self.div).cosh() + +class DebiasedLossWeight(BaseLossWeight): + def weight(self, logSNR): + return 1/logSNR.exp().sqrt() + +class SigmoidLossWeight(BaseLossWeight): + def __init__(self, s=1): + self.s = s + + def weight(self, logSNR): + return (logSNR * self.s).sigmoid() + +class AdaptiveLossWeight(BaseLossWeight): + def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): + self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1) + self.bucket_losses = torch.ones(buckets) + self.weight_range = weight_range + + def weight(self, logSNR): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) + return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) + + def update_buckets(self, logSNR, loss, beta=0.99): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() + self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) diff --git a/gdf/noise_conditions.py b/gdf/noise_conditions.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2791f50a6f63eff8f9bed9b827f87517cc0be8 --- /dev/null +++ b/gdf/noise_conditions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + +class BaseNoiseCond(): + def __init__(self, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + self.shift = shift + self.clamp_range = clamp_range + self.setup(*args, **kwargs) + + def setup(self, *args, **kwargs): + pass # this method is optional, override it if required + + def cond(self, logSNR): + raise NotImplementedError("this method needs to be overriden") + + def __call__(self, logSNR): + if self.shift != 1: + logSNR = logSNR.clone() + 2 * np.log(self.shift) + return self.cond(logSNR).clamp(*self.clamp_range) + +class CosineTNoiseCond(BaseNoiseCond): + def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def cond(self, logSNR): + var = logSNR.sigmoid() + var = var.clamp(*self.clamp_range) + s, min_var = self.s.to(var.device), self.min_var.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + +class EDMNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return -logSNR/8 + +class SigmoidNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return (-logSNR).sigmoid() + +class LogSNRNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + return logSNR + +class EDMSigmaNoiseCond(BaseNoiseCond): + def setup(self, sigma_data=1): + self.sigma_data = sigma_data + + def cond(self, logSNR): + return torch.exp(-logSNR / 2) * self.sigma_data + +class RectifiedFlowsNoiseCond(BaseNoiseCond): + def cond(self, logSNR): + _a = logSNR.exp() - 1 + _a[_a == 0] = 1e-3 # Avoid division by zero + a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) + return a + +# Any NoiseCond that cannot be described easily as a continuous function of t +# It needs to define self.x and self.y in the setup() method +class PiecewiseLinearNoiseCond(BaseNoiseCond): + def setup(self): + self.x = None + self.y = None + + def piecewise_linear(self, y, xs, ys): + indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y) + x_min, x_max = xs[indices], xs[indices+1] + y_min, y_max = ys[indices], ys[indices+1] + x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min) + return x + + def cond(self, logSNR): + var = logSNR.sigmoid() + t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0) + return t + +class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond): + def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): + self.total_steps = total_steps + linear_range_sqrt = [r**0.5 for r in linear_range] + self.x = torch.linspace(0, 1, total_steps+1) + + alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 + self.y = alphas.cumprod(dim=-1) + + def cond(self, logSNR): + return super().cond(logSNR).clamp(0, 1) + +class DiscreteNoiseCond(BaseNoiseCond): + def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]): + self.noise_cond = noise_cond + self.steps = steps + self.continuous_range = continuous_range + + def cond(self, logSNR): + cond = self.noise_cond(logSNR) + cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0]) + return cond.mul(self.steps).long() + \ No newline at end of file diff --git a/gdf/readme.md b/gdf/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..9a63691513c9da6804fba53e36acc8e0cd7f5d7f --- /dev/null +++ b/gdf/readme.md @@ -0,0 +1,86 @@ +# Generic Diffusion Framework (GDF) + +# Basic usage +GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM +, EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different +frameworks + +Using GDF is very straighforward, first of all just define an instance of the GDF class: + +```python +from gdf import GDF +from gdf import CosineSchedule +from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight + +gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=P2LossWeight(), +) +``` + +You need to define the following components: +* **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. +* **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. +* **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) +* **Target**: What the target is during training, usually: epsilon, x0 or v +* **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` +* **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use + +All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: +```python +class VPScaler(): + def __call__(self, logSNR): + a_squared = logSNR.sigmoid() + a = a_squared.sqrt() + b = (1-a_squared).sqrt() + return a, b + +``` + +So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... + +### Training + +When you define your training loop you can get all you need by just doing: +```python +shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution +for inputs, extra_conditions in dataloader_iterator: + noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) + pred = diffusion_model(noised, noise_cond, extra_conditions) + + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() + + loss_adjusted.backward() + optimizer.step() + optimizer.zero_grad(set_to_none=True) +``` + +And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the +training from the GDF class. + +### Sampling + +The other important part is sampling, when you want to use this framework to sample you can just do the following: + +```python +from gdf import DDPMSampler + +shift = 1 +sampling_configs = { + "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, + "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) +} + +*_, (sampled, _, _) = gdf.sample( + diffusion_model, {"cond": extra_conditions}, latents.shape, + unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, + device=device, **sampling_configs +) +``` + +# Available modules + +TODO diff --git a/gdf/samplers.py b/gdf/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b6048c86a261d53d0440a3b2c1591a03d9978c4f --- /dev/null +++ b/gdf/samplers.py @@ -0,0 +1,43 @@ +import torch + +class SimpleSampler(): + def __init__(self, gdf): + self.gdf = gdf + self.current_step = -1 + + def __call__(self, *args, **kwargs): + self.current_step += 1 + return self.step(*args, **kwargs) + + def init_x(self, shape): + return torch.randn(*shape) + + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + raise NotImplementedError("You should override the 'apply' function.") + +class DDIMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): + a, b = self.gdf.input_scaler(logSNR) + if len(a.shape) == 1: + a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) + + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) + + sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 + # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) + return x + +class DDPMSampler(DDIMSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): + return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) + +class LCMSampler(SimpleSampler): + def step(self, x, x0, epsilon, logSNR, logSNR_prev): + a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) + if len(a_prev.shape) == 1: + a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) + return x0 * a_prev + torch.randn_like(epsilon) * b_prev + \ No newline at end of file diff --git a/gdf/scalers.py b/gdf/scalers.py new file mode 100644 index 0000000000000000000000000000000000000000..b1adb8b0269667f3d006c7d7d17cbf2b7ef56ca9 --- /dev/null +++ b/gdf/scalers.py @@ -0,0 +1,42 @@ +import torch + +class BaseScaler(): + def __init__(self): + self.stretched_limits = None + + def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): + min_logSNR = schedule(torch.ones(1), shift=shift) + max_logSNR = schedule(torch.zeros(1), shift=shift) + + min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] + max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] + self.stretched_limits = [min_a, max_a, min_b, max_b] + return self.stretched_limits + + def stretch_limits(self, a, b): + min_a, max_a, min_b, max_b = self.stretched_limits + return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) + + def scalers(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR): + a, b = self.scalers(logSNR) + if self.stretched_limits is not None: + a, b = self.stretch_limits(a, b) + return a, b + +class VPScaler(BaseScaler): + def scalers(self, logSNR): + a_squared = logSNR.sigmoid() + a = a_squared.sqrt() + b = (1-a_squared).sqrt() + return a, b + +class LERPScaler(BaseScaler): + def scalers(self, logSNR): + _a = logSNR.exp() - 1 + _a[_a == 0] = 1e-3 # Avoid division by zero + a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) + b = 1-a + return a, b diff --git a/gdf/schedulers.py b/gdf/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..caa6e174da1d766ea5828616bb8113865106b628 --- /dev/null +++ b/gdf/schedulers.py @@ -0,0 +1,200 @@ +import torch +import numpy as np + +class BaseSchedule(): + def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): + self.setup(*args, **kwargs) + self.limits = None + self.discrete_steps = discrete_steps + self.shift = shift + if force_limits: + self.reset_limits() + + def reset_limits(self, shift=1, disable=False): + try: + self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max + return self.limits + except Exception: + print("WARNING: this schedule doesn't support t and will be unbounded") + return None + + def setup(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overriden") + + def schedule(self, *args, **kwargs): + raise NotImplementedError("this method needs to be overriden") + + def __call__(self, t, *args, shift=1, **kwargs): + if isinstance(t, torch.Tensor): + batch_size = None + if self.discrete_steps is not None: + if t.dtype != torch.long: + t = (t * (self.discrete_steps-1)).round().long() + t = t / (self.discrete_steps-1) + t = t.clamp(0, 1) + else: + batch_size = t + t = None + logSNR = self.schedule(t, batch_size, *args, **kwargs) + if shift*self.shift != 1: + logSNR += 2 * np.log(1/(shift*self.shift)) + if self.limits is not None: + logSNR = logSNR.clamp(*self.limits) + return logSNR + +class CosineSchedule(BaseSchedule): + def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): + self.s = torch.tensor([s]) + self.clamp_range = clamp_range + self.norm_instead = norm_instead + self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 + + def schedule(self, t, batch_size): + if t is None: + t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) + s, min_var = self.s.to(t.device), self.min_var.to(t.device) + var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var + if self.norm_instead: + var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] + else: + var = var.clamp(*self.clamp_range) + logSNR = (var/(1-var)).log() + return logSNR + +class CosineSchedule2(BaseSchedule): + def setup(self, logsnr_range=[-15, 15]): + self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1])) + self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0])) + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log() + +class SqrtSchedule(BaseSchedule): + def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False): + self.s = s + self.clamp_range = clamp_range + self.norm_instead = norm_instead + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + var = 1 - (t + self.s)**0.5 + if self.norm_instead: + var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] + else: + var = var.clamp(*self.clamp_range) + logSNR = (var/(1-var)).log() + return logSNR + +class RectifiedFlowsSchedule(BaseSchedule): + def setup(self, logsnr_range=[-15, 15]): + self.logsnr_range = logsnr_range + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + logSNR = (((1-t)**2)/(t**2)).log() + logSNR = logSNR.clamp(*self.logsnr_range) + return logSNR + +class EDMSampleSchedule(BaseSchedule): + def setup(self, sigma_range=[0.002, 80], p=7): + self.sigma_range = sigma_range + self.p = p + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + smin, smax, p = *self.sigma_range, self.p + sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p + logSNR = (1/sigma**2).log() + return logSNR + +class EDMTrainSchedule(BaseSchedule): + def setup(self, mu=-1.2, std=1.2): + self.mu = mu + self.std = std + + def schedule(self, t, batch_size): + if t is not None: + raise Exception("EDMTrainSchedule doesn't support passing timesteps: t") + logSNR = -2*(torch.randn(batch_size) * self.std - self.mu) + return logSNR + +class LinearSchedule(BaseSchedule): + def setup(self, logsnr_range=[-10, 10]): + self.logsnr_range = logsnr_range + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1] + return logSNR + +# Any schedule that cannot be described easily as a continuous function of t +# It needs to define self.x and self.y in the setup() method +class PiecewiseLinearSchedule(BaseSchedule): + def setup(self): + self.x = None + self.y = None + + def piecewise_linear(self, x, xs, ys): + indices = torch.searchsorted(xs[:-1], x) - 1 + x_min, x_max = xs[indices], xs[indices+1] + y_min, y_max = ys[indices], ys[indices+1] + var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min) + return var + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device)) + logSNR = (var/(1-var)).log() + return logSNR + +class StableDiffusionSchedule(PiecewiseLinearSchedule): + def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): + linear_range_sqrt = [r**0.5 for r in linear_range] + self.x = torch.linspace(0, 1, total_steps+1) + + alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 + self.y = alphas.cumprod(dim=-1) + +class AdaptiveTrainSchedule(BaseSchedule): + def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0): + th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1) + self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)]) + self.bucket_probs = torch.ones(buckets) + self.min_probs = min_probs + + def schedule(self, t, batch_size): + if t is not None: + raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t") + norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum()) + buckets = torch.multinomial(norm_probs, batch_size, replacement=True) + ranges = self.bucket_ranges[buckets] + logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0] + return logSNR + + def update_buckets(self, logSNR, loss, beta=0.99): + range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device) + range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float() + range_idx = range_mask.argmax(-1).cpu() + self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta) + +class InterpolatedSchedule(BaseSchedule): + def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]): + self.scheduler1 = scheduler1 + self.scheduler2 = scheduler2 + self.shifts = shifts + + def schedule(self, t, batch_size): + if t is None: + t = 1-torch.rand(batch_size) + t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan + low_logSNR = self.scheduler1(t, shift=self.shifts[0]) + high_logSNR = self.scheduler2(t, shift=self.shifts[1]) + return low_logSNR * t + high_logSNR * (1-t) + diff --git a/gdf/targets.py b/gdf/targets.py new file mode 100644 index 0000000000000000000000000000000000000000..115062b6001f93082fa836e1f3742723e5972efe --- /dev/null +++ b/gdf/targets.py @@ -0,0 +1,46 @@ +class EpsilonTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return epsilon + + def x0(self, noised, pred, logSNR, a, b): + return (noised - pred * b) / a + + def epsilon(self, noised, pred, logSNR, a, b): + return pred + def noise_givenx0_noised(self, x0, noised , logSNR, a, b): + return (noised - a * x0) / b + def xt(self, x0, noise, logSNR, a, b): + + return x0 * a + noise*b +class X0Target(): + def __call__(self, x0, epsilon, logSNR, a, b): + return x0 + + def x0(self, noised, pred, logSNR, a, b): + return pred + + def epsilon(self, noised, pred, logSNR, a, b): + return (noised - pred * a) / b + +class VTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return a * epsilon - b * x0 + + def x0(self, noised, pred, logSNR, a, b): + squared_sum = a**2 + b**2 + return a/squared_sum * noised - b/squared_sum * pred + + def epsilon(self, noised, pred, logSNR, a, b): + squared_sum = a**2 + b**2 + return b/squared_sum * noised + a/squared_sum * pred + +class RectifiedFlowsTarget(): + def __call__(self, x0, epsilon, logSNR, a, b): + return epsilon - x0 + + def x0(self, noised, pred, logSNR, a, b): + return noised - pred * b + + def epsilon(self, noised, pred, logSNR, a, b): + return noised + pred * a + \ No newline at end of file diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/test_controlnet.py b/inference/test_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..250578262d2a118ece8b5a706aba1cd8115c62f5 --- /dev/null +++ b/inference/test_controlnet.py @@ -0,0 +1,166 @@ +import os +import yaml +import torch +import torchvision +from tqdm import tqdm +import sys +sys.path.append(os.path.abspath('./')) + +from inference.utils import * +from core.utils import load_or_fail +from train import WurstCore_control_lrguide, WurstCoreB +from PIL import Image +from core.utils import load_or_fail +import math +import argparse +import time +import random +import numpy as np +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( '--height', type=int, default=3840, help='image height') + parser.add_argument('--width', type=int, default=2160, help='image width') + parser.add_argument('--control_weight', type=float, default=0.70, help='[ 0.3, 0.8]') + parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') + parser.add_argument('--seed', type=int, default=123, help='random seed') + parser.add_argument('--config_c', type=str, + default='configs/training/cfg_control_lr.yaml' ,help='config file for stage c, latent generation') + parser.add_argument('--config_b', type=str, + default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') + parser.add_argument( '--prompt', type=str, + default='A peaceful lake surrounded by mountain, white cloud in the sky, high quality,', help='text prompt') + parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') + parser.add_argument( '--output_dir', type=str, default='figures/controlnet_results/', help='output directory for generated image') + parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') + parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') + parser.add_argument( '--canny_source_url', type=str, default="figures/California_000490.jpg", help='image used to extract canny edge map') + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + + args = parse_args() + width = args.width + height = args.height + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float + + + # SETUP STAGE C + with open(args.config_c, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + core = WurstCore_control_lrguide(config_dict=loaded_config, device=device, training=False) + + # SETUP STAGE B + with open(args.config_b, "r", encoding="utf-8") as file: + config_file_b = yaml.safe_load(file) + + core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) + + extras = core.setup_extras_pre() + models = core.setup_models(extras) + models.generator.eval().requires_grad_(False) + print("CONTROLNET READY") + + extras_b = core_b.setup_extras_pre() + models_b = core_b.setup_models(extras_b, skip_clip=True) + models_b = WurstCoreB.Models( + **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} + ) + models_b.generator.eval().requires_grad_(False) + print("STAGE B READY") + + batch_size = 1 + save_dir = args.output_dir + url = args.canny_source_url + images = resize_image(Image.open(url).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1) + batch = {'images': images} + + + + + + + cnet_multiplier = args.control_weight # 0.8 0.6 0.3 control strength + caption_list = [args.prompt] * args.num_image + height_lr, width_lr = get_target_lr_size(height / width, std_size=32) + stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) + stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) + + + + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + + sdd = torch.load(args.pretrained_path, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + models.train_norm.load_state_dict(collect_sd, strict=True) + + + + + models.controlnet.load_state_dict(load_or_fail(core.config.controlnet_checkpoint_path), strict=True) + # Stage C Parameters + extras.sampling_configs['cfg'] = 1 + extras.sampling_configs['shift'] = 2 + extras.sampling_configs['timesteps'] = 20 + extras.sampling_configs['t_start'] = 1.0 + + # Stage B Parameters + extras_b.sampling_configs['cfg'] = 1.1 + extras_b.sampling_configs['shift'] = 1 + extras_b.sampling_configs['timesteps'] = 10 + extras_b.sampling_configs['t_start'] = 1.0 + + # PREPARE CONDITIONS + + + + + for out_cnt, caption in enumerate(caption_list): + with torch.no_grad(): + + batch['captions'] = [caption + ' high quality'] * batch_size + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + cnet, cnet_input = core.get_cnet(batch, models, extras) + cnet_uncond = cnet + conditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet] + unconditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet_uncond] + edge_images = show_images(cnet_input) + models.generator.cuda() + for idx, img in enumerate(edge_images): + img.save(os.path.join(save_dir, f"edge_{url.split('/')[-1]}")) + + + print('STAGE C GENERATION***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions, unconditions) + models.generator.cpu() + torch.cuda.empty_cache() + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + conditions_b['effnet'] = sampled_c + unconditions_b['effnet'] = torch.zeros_like(sampled_c) + print('STAGE B + A DECODING***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) + + torch.cuda.empty_cache() + imgs = show_images(sampled) + + for idx, img in enumerate(imgs): + img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(out_cnt).zfill(5) + '.jpg')) + print('finished! Results at ', save_dir ) diff --git a/inference/test_personalized.py b/inference/test_personalized.py new file mode 100644 index 0000000000000000000000000000000000000000..840d52d0ef3b026e73c34f715b7b18ec3537e62a --- /dev/null +++ b/inference/test_personalized.py @@ -0,0 +1,180 @@ + +import os +import yaml +import torch +from tqdm import tqdm +import sys +sys.path.append(os.path.abspath('./')) +from inference.utils import * +from train import WurstCoreB +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from train import WurstCore_personalized as WurstCoreC +import torch.nn.functional as F +import numpy as np +import random +import math +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( '--height', type=int, default=3072, help='image height') + parser.add_argument('--width', type=int, default=4096, help='image width') + parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') + parser.add_argument('--seed', type=int, default=23, help='random seed') + parser.add_argument('--config_c', type=str, + default="configs/training/lora_personalization.yaml" ,help='config file for stage c, latent generation') + parser.add_argument('--config_b', type=str, + default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') + parser.add_argument( '--prompt', type=str, + default='A photo of cat [roubaobao] with sunglasses, Time Square in the background, high quality, detail rich, 8k', help='text prompt') + parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') + parser.add_argument( '--output_dir', type=str, default='figures/personalized/', help='output directory for generated image') + parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') + parser.add_argument( '--pretrained_path_lora', type=str, default='models/lora_cat.safetensors',help='pretrained path of personalized lora parameter') + parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float + + + # SETUP STAGE C + with open(args.config_c, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + core = WurstCoreC(config_dict=loaded_config, device=device, training=False) + + # SETUP STAGE B + with open(args.config_b, "r", encoding="utf-8") as file: + config_file_b = yaml.safe_load(file) + core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) + + extras = core.setup_extras_pre() + models = core.setup_models(extras) + models.generator.eval().requires_grad_(False) + print("STAGE C READY") + + extras_b = core_b.setup_extras_pre() + models_b = core_b.setup_models(extras_b, skip_clip=True) + models_b = WurstCoreB.Models( + **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} + ) + models_b.generator.bfloat16().eval().requires_grad_(False) + print("STAGE B READY") + + + batch_size = 1 + captions = [args.prompt] * args.num_image + height, width = args.height, args.width + save_dir = args.output_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + + pretrained_pth = args.pretrained_path + sdd = torch.load(pretrained_pth, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + + models.train_norm.load_state_dict(collect_sd) + + + pretrained_pth_lora = args.pretrained_path_lora + sdd = torch.load(pretrained_pth_lora, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + + models.train_lora.load_state_dict(collect_sd) + + + models.generator.eval() + models.train_norm.eval() + + + height_lr, width_lr = get_target_lr_size(height / width, std_size=32) + stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) + stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) + + # Stage C Parameters + + extras.sampling_configs['cfg'] = 4 + extras.sampling_configs['shift'] = 1 + extras.sampling_configs['timesteps'] = 20 + extras.sampling_configs['t_start'] = 1.0 + extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) + + + + # Stage B Parameters + + extras_b.sampling_configs['cfg'] = 1.1 + extras_b.sampling_configs['shift'] = 1 + extras_b.sampling_configs['timesteps'] = 10 + extras_b.sampling_configs['t_start'] = 1.0 + + + for cnt, caption in enumerate(captions): + + batch = {'captions': [caption] * batch_size} + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + + + + for cnt, caption in enumerate(captions): + + + batch = {'captions': [caption] * batch_size} + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + + with torch.no_grad(): + + + models.generator.cuda() + print('STAGE C GENERATION***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) + + + + models.generator.cpu() + torch.cuda.empty_cache() + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + conditions_b['effnet'] = sampled_c + unconditions_b['effnet'] = torch.zeros_like(sampled_c) + print('STAGE B + A DECODING***************************') + + with torch.cuda.amp.autocast(dtype=dtype): + sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) + + torch.cuda.empty_cache() + imgs = show_images(sampled) + for idx, img in enumerate(imgs): + print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) + img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) + + + print('finished! Results at ', save_dir ) + + + diff --git a/inference/test_t2i.py b/inference/test_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..3478f95e4c706d88a8c73688ed4e990adc9ea8d4 --- /dev/null +++ b/inference/test_t2i.py @@ -0,0 +1,170 @@ + +import os +import yaml +import torch +from tqdm import tqdm +import sys +sys.path.append(os.path.abspath('./')) +from inference.utils import * +from core.utils import load_or_fail +from train import WurstCoreB +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from train import WurstCore_t2i as WurstCoreC +import torch.nn.functional as F +from core.utils import load_or_fail +import numpy as np +import random +import math +import argparse +from einops import rearrange +import math +#inrfft_3b_strc_WurstCore +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( '--height', type=int, default=2560, help='image height') + parser.add_argument('--width', type=int, default=5120, help='image width') + parser.add_argument('--seed', type=int, default=123, help='random seed') + parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') + parser.add_argument('--config_c', type=str, + default='configs/training/t2i.yaml' ,help='config file for stage c, latent generation') + parser.add_argument('--config_b', type=str, + default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') + parser.add_argument( '--prompt', type=str, + default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt') + parser.add_argument( '--num_image', type=int, default=10, help='how many images generated') + parser.add_argument( '--output_dir', type=str, default='figures/output_results/', help='output directory for generated image') + parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') + parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') + args = parser.parse_args() + return args + + + +if __name__ == "__main__": + + args = parse_args() + print(args) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(device) + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float + #gdf = gdf_refine( + # schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + # input_scaler=VPScaler(), target=EpsilonTarget(), + # noise_cond=CosineTNoiseCond(), + # loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + # ) + # SETUP STAGE C + config_file = args.config_c + with open(config_file, "r", encoding="utf-8") as file: + loaded_config = yaml.safe_load(file) + + core = WurstCoreC(config_dict=loaded_config, device=device, training=False) + + # SETUP STAGE B + config_file_b = args.config_b + with open(config_file_b, "r", encoding="utf-8") as file: + config_file_b = yaml.safe_load(file) + + core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) + + extras = core.setup_extras_pre() + models = core.setup_models(extras) + models.generator.eval().requires_grad_(False) + print("STAGE C READY") + + extras_b = core_b.setup_extras_pre() + models_b = core_b.setup_models(extras_b, skip_clip=True) + models_b = WurstCoreB.Models( + **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} + ) + models_b.generator.bfloat16().eval().requires_grad_(False) + print("STAGE B READY") + + captions = [args.prompt] * args.num_image + + + height, width = args.height, args.width + save_dir = args.output_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pretrained_path = args.pretrained_path + sdd = torch.load(pretrained_path, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + + models.train_norm.load_state_dict(collect_sd) + + + models.generator.eval() + models.train_norm.eval() + + batch_size=1 + height_lr, width_lr = get_target_lr_size(height / width, std_size=32) + stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) + stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) + + # Stage C Parameters + extras.sampling_configs['cfg'] = 4 + extras.sampling_configs['shift'] = 1 + extras.sampling_configs['timesteps'] = 20 + extras.sampling_configs['t_start'] = 1.0 + extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) + + + + # Stage B Parameters + extras_b.sampling_configs['cfg'] = 1.1 + extras_b.sampling_configs['shift'] = 1 + extras_b.sampling_configs['timesteps'] = 10 + extras_b.sampling_configs['t_start'] = 1.0 + + + + + for cnt, caption in enumerate(captions): + + + batch = {'captions': [caption] * batch_size} + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + + + with torch.no_grad(): + + + models.generator.cuda() + print('STAGE C GENERATION***************************') + with torch.cuda.amp.autocast(dtype=dtype): + sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) + + + + models.generator.cpu() + torch.cuda.empty_cache() + + conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) + unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) + conditions_b['effnet'] = sampled_c + unconditions_b['effnet'] = torch.zeros_like(sampled_c) + print('STAGE B + A DECODING***************************') + + with torch.cuda.amp.autocast(dtype=dtype): + sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) + + torch.cuda.empty_cache() + imgs = show_images(sampled) + for idx, img in enumerate(imgs): + print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) + img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) + + + print('finished! Results at ', save_dir ) diff --git a/inference/utils.py b/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5af277069ec7803d53ff8f5fa29bed41fde29b --- /dev/null +++ b/inference/utils.py @@ -0,0 +1,131 @@ +import PIL +import torch +import requests +import torchvision +from math import ceil +from io import BytesIO +import matplotlib.pyplot as plt +import torchvision.transforms.functional as F +import math +from tqdm import tqdm +def download_image(url): + return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") + + +def resize_image(image, size=768): + tensor_image = F.to_tensor(image) + resized_image = F.resize(tensor_image, size, antialias=True) + return resized_image + + +def downscale_images(images, factor=3/4): + scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32) + scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + return scaled_image + + + +def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): + resolution_multiple = 42.67 + latent_height = ceil(height / compression_factor_b) + latent_width = ceil(width / compression_factor_b) + stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) + + latent_height = ceil(height / compression_factor_a) + latent_width = ceil(width / compression_factor_a) + stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) + + return stage_c_latent_shape, stage_b_latent_shape + + +def get_views(H, W, window_size=64, stride=16): + ''' + - H, W: height and width of the latent + ''' + num_blocks_height = (H - window_size) // stride + 1 + num_blocks_width = (W - window_size) // stride + 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + views.append((h_start, h_end, w_start, w_end)) + return views + + + +def show_images(images, rows=None, cols=None, **kwargs): + if images.size(1) == 1: + images = images.repeat(1, 3, 1, 1) + elif images.size(1) > 3: + images = images[:, :3] + + if rows is None: + rows = 1 + if cols is None: + cols = images.size(0) // rows + + _, _, h, w = images.shape + + imgs = [] + for i, img in enumerate(images): + imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1))) + + return imgs + + + +def decode_b(conditions_b, unconditions_b, models_b, bshape, extras_b, device, \ + stage_a_tiled=False, num_instance=4, patch_size=256, stride=24): + + + sampling_b = extras_b.gdf.sample( + models_b.generator.half(), conditions_b, bshape, + unconditions_b, device=device, + **extras_b.sampling_configs, + ) + models_b.generator.cuda() + for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']): + sampled_b = sampled_b + models_b.generator.cpu() + torch.cuda.empty_cache() + if stage_a_tiled: + with torch.cuda.amp.autocast(dtype=torch.float16): + padding = (stride*2, stride*2, stride*2, stride*2) + sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect') + count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) + sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) + views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride) + + for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))): + + sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float() + count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1 + sampled /= count + sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2] + else: + + sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled) + + return sampled.float() + + +def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None): + if conditions is None: + conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + if unconditions is None: + unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + sampling_c = extras.gdf.sample( + models.generator, conditions, stage_c_latent_shape, stage_c_latent_shape_lr, + unconditions, device=device, **extras.sampling_configs, + ) + for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])): + sampled_c = sampled_c + return sampled_c + +def get_target_lr_size(ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w *32 ) + diff --git a/models/models_checklist.txt b/models/models_checklist.txt new file mode 100644 index 0000000000000000000000000000000000000000..2fdec27a72db473c51893abc64826514b1d9d065 --- /dev/null +++ b/models/models_checklist.txt @@ -0,0 +1,7 @@ +https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors +https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors +https://huggingface.co/roubaofeipi/UltraPixel/blob/main/ultrapixel_t2i.safetensors +https://huggingface.co/roubaofeipi/UltraPixel/blob/main/lora_cat.safetensors (only required for personalization) \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fcf5aa2a39061c3f4f82dde6ff063411223cb3 --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,6 @@ +from .effnet import EfficientNetEncoder +from .stage_c import StageC +from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from .previewer import Previewer +from .controlnet import ControlNet, ControlNetDeliverer +from . import controlnet as controlnet_filters diff --git a/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c74bb92cb0db0876acda8aa3d102141526fd428 Binary files /dev/null and b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc differ diff --git a/modules/cnet_modules/face_id/arcface.py b/modules/cnet_modules/face_id/arcface.py new file mode 100644 index 0000000000000000000000000000000000000000..64e918bb90437f6f193a7ec384bea1fcd73c7abb --- /dev/null +++ b/modules/cnet_modules/face_id/arcface.py @@ -0,0 +1,276 @@ +import numpy as np +import onnx, onnx2torch, cv2 +import torch +from insightface.utils import face_align + + +class ArcFaceRecognizer: + def __init__(self, model_file=None, device='cpu', dtype=torch.float32): + assert model_file is not None + self.model_file = model_file + + self.device = device + self.dtype = dtype + self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + self.input_mean = 127.5 + self.input_std = 127.5 + self.input_size = (112, 112) + self.input_shape = ['None', 3, 112, 112] + + def get(self, img, face): + aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) + face.embedding = self.get_feat(aimg).flatten() + return face.embedding + + def compute_sim(self, feat1, feat2): + from numpy.linalg import norm + feat1 = feat1.ravel() + feat2 = feat2.ravel() + sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) + return sim + + def get_feat(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.input_size + + blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + + blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) + net_out = self.model(blob_torch) + return net_out[0].float().cpu() + + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i % 2] + distance[:, i] + py = points[:, i % 2 + 1] + distance[:, i + 1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + + +class FaceDetector: + def __init__(self, model_file=None, dtype=torch.float32, device='cuda'): + self.model_file = model_file + self.taskname = 'detection' + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + + self.device = device + self.dtype = dtype + self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) + for param in self.model.parameters(): + param.requires_grad = False + self.model.eval() + + input_shape = (320, 320) + self.input_size = input_shape + self.input_shape = input_shape + + self.input_mean = 127.5 + self.input_std = 128.0 + self._anchor_ratio = 1.0 + self._num_anchors = 1 + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + + self.det_thresh = 0.5 + self.nms_thresh = 0.4 + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) + net_outs_torch = self.model(blob_torch) + # print(list(map(lambda x: x.shape, net_outs_torch))) + net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch)) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx + fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2] * stride + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + # solution-1, c style: + # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + # for i in range(height): + # anchor_centers[i, :, 1] = i + # for i in range(width): + # anchor_centers[:, i, 0] = i + + # solution-2: + # ax = np.arange(width, dtype=np.float32) + # ay = np.arange(height, dtype=np.float32) + # xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + # solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + # print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape((-1, 2)) + if self._num_anchors > 1: + anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2)) + if len(self.center_cache) < 100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores >= threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + # kpss = kps_preds + kpss = kpss.reshape((kpss.shape[0], -1, 2)) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size=None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio > model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order, :, :] + kpss = kpss[keep, :, :] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric == 'max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8200104d6d66a1084685c76373c38d752ed9c3d4 Binary files /dev/null and b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc differ diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca432e5c5eed7ba17fc6cafb06a3ebe16002f67e Binary files /dev/null and b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc differ diff --git a/modules/cnet_modules/inpainting/saliency_model.pt b/modules/cnet_modules/inpainting/saliency_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..e1b02cc60b2999a8f9ff90557182e3dafab63db7 --- /dev/null +++ b/modules/cnet_modules/inpainting/saliency_model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:225a602e1f2a5d159424be011a63b27d83b56343a4379a90710eca9a26bab920 +size 451123 diff --git a/modules/cnet_modules/inpainting/saliency_model.py b/modules/cnet_modules/inpainting/saliency_model.py new file mode 100644 index 0000000000000000000000000000000000000000..82355a02baead47f50fe643e57b81f8caca78f79 --- /dev/null +++ b/modules/cnet_modules/inpainting/saliency_model.py @@ -0,0 +1,81 @@ +import torch +import torchvision +from torch import nn +from PIL import Image +import numpy as np +import os + + +# MICRO RESNET +class ResBlock(nn.Module): + def __init__(self, channels): + super(ResBlock, self).__init__() + + self.resblock = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(channels, channels, kernel_size=3), + nn.InstanceNorm2d(channels, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(channels, channels, kernel_size=3), + nn.InstanceNorm2d(channels, affine=True), + ) + + def forward(self, x): + out = self.resblock(x) + return out + x + + +class Upsample2d(nn.Module): + def __init__(self, scale_factor): + super(Upsample2d, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') + return x + + +class MicroResNet(nn.Module): + def __init__(self): + super(MicroResNet, self).__init__() + + self.downsampler = nn.Sequential( + nn.ReflectionPad2d(4), + nn.Conv2d(3, 8, kernel_size=9, stride=4), + nn.InstanceNorm2d(8, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(8, 16, kernel_size=3, stride=2), + nn.InstanceNorm2d(16, affine=True), + nn.ReLU(), + nn.ReflectionPad2d(1), + nn.Conv2d(16, 32, kernel_size=3, stride=2), + nn.InstanceNorm2d(32, affine=True), + nn.ReLU(), + ) + + self.residual = nn.Sequential( + ResBlock(32), + nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), + ResBlock(64), + ) + + self.segmentator = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(64, 16, kernel_size=3), + nn.InstanceNorm2d(16, affine=True), + nn.ReLU(), + Upsample2d(scale_factor=2), + nn.ReflectionPad2d(4), + nn.Conv2d(16, 1, kernel_size=9), + nn.Sigmoid() + ) + + def forward(self, x): + out = self.downsampler(x) + out = self.residual(out) + out = self.segmentator(out) + return out diff --git a/modules/cnet_modules/pidinet/__init__.py b/modules/cnet_modules/pidinet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b4625bf915cc6c4053b7d7861a22ff371bc641 --- /dev/null +++ b/modules/cnet_modules/pidinet/__init__.py @@ -0,0 +1,37 @@ +# Pidinet +# https://github.com/hellozhuo/pidinet + +import os +import torch +import numpy as np +from einops import rearrange +from .model import pidinet +from .util import annotator_ckpts_path, safe_step + + +class PidiNetDetector: + def __init__(self, device): + remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" + modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") + if not os.path.exists(modelpath): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + self.netNetwork = pidinet() + self.netNetwork.load_state_dict( + {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) + self.netNetwork.to(device).eval().requires_grad_(False) + + def __call__(self, input_image): # , safe=False): + return self.netNetwork(input_image)[-1] + # assert input_image.ndim == 3 + # input_image = input_image[:, :, ::-1].copy() + # with torch.no_grad(): + # image_pidi = torch.from_numpy(input_image).float().cuda() + # image_pidi = image_pidi / 255.0 + # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') + # edge = self.netNetwork(image_pidi)[-1] + + # if safe: + # edge = safe_step(edge) + # edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + # return edge[0][0] diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07fca0abb9c90b7b40746b4044c4000ae69e00c7 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a060aa2baa87a3670aa0bf8276e2f34bafe9451 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2243c853d18e2a404ced3eb4ac6a95a7a9ee6874 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f70342fc64759bc7459abf0f7986ee3b7fd2126 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2e7ab031924860f1262f4d44bf2eaf57ca78edd Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc differ diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4da8564d03f99caa7a45d9ccb1358cb282cd2711 Binary files /dev/null and b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc differ diff --git a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth new file mode 100644 index 0000000000000000000000000000000000000000..1ceba1de87e7bb3c81961b80acbb3a106ca249c0 --- /dev/null +++ b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80860ac267258b5f27486e0ef152a211d0b08120f62aeb185a050acc30da486c +size 2871148 diff --git a/modules/cnet_modules/pidinet/model.py b/modules/cnet_modules/pidinet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..26644c6f6174c3b5407bd10c914045758cbadefe --- /dev/null +++ b/modules/cnet_modules/pidinet/model.py @@ -0,0 +1,654 @@ +""" +Author: Zhuo Su, Wenzhe Liu +Date: Feb 18, 2021 +""" + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +nets = { + 'baseline': { + 'layer0': 'cv', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'c-v15': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'a-v15': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'r-v15': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cvvv4': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'avvv4': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'rvvv4': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cccv4': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cv', + }, + 'aaav4': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'cv', + }, + 'rrrv4': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'cv', + }, + 'c16': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cd', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cd', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cd', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cd', + }, + 'a16': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'ad', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'ad', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'ad', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'ad', + }, + 'r16': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'rd', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'rd', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'rd', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'rd', + }, + 'carv4': { + 'layer0': 'cd', + 'layer1': 'ad', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'ad', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'ad', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'ad', + 'layer14': 'rd', + 'layer15': 'cv', + }, +} + + +def createConvFunc(op_type): + assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) + if op_type == 'cv': + return F.conv2d + + if op_type == 'cd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' + assert padding == dilation, 'padding for cd_conv set wrong' + + weights_c = weights.sum(dim=[2, 3], keepdim=True) + yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) + y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y - yc + + return func + elif op_type == 'ad': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' + assert padding == dilation, 'padding for ad_conv set wrong' + + shape = weights.shape + weights = weights.view(shape[0], shape[1], -1) + weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise + y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + + return func + elif op_type == 'rd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' + padding = 2 * dilation + + shape = weights.shape + if weights.is_cuda: + buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) + else: + buffer = torch.zeros(shape[0], shape[1], 5 * 5) + weights = weights.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] + buffer[:, :, 12] = 0 + buffer = buffer.view(shape[0], shape[1], 5, 5) + y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + + return func + else: + print('impossible to be here unless you force that') + return None + + +class Conv2d(nn.Module): + def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, + bias=False): + super(Conv2d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.pdc = pdc + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + + return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class CSAM(nn.Module): + """ + Compact Spatial Attention Module + """ + + def __init__(self, channels): + super(CSAM, self).__init__() + + mid_channels = 4 + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + y = self.relu1(x) + y = self.conv1(y) + y = self.conv2(y) + y = self.sigmoid(y) + + return x * y + + +class CDCM(nn.Module): + """ + Compact Dilation Convolution based Module + """ + + def __init__(self, in_channels, out_channels): + super(CDCM, self).__init__() + + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) + self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) + self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) + self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + x = self.relu1(x) + x = self.conv1(x) + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x3 = self.conv2_3(x) + x4 = self.conv2_4(x) + return x1 + x2 + x3 + x4 + + +class MapReduce(nn.Module): + """ + Reduce feature maps into a single edge map + """ + + def __init__(self, channels): + super(MapReduce, self).__init__() + self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x) + + +class PDCBlock(nn.Module): + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock, self).__init__() + self.stride = stride + + self.stride = stride + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PDCBlock_converted(nn.Module): + """ + CPDC, APDC can be converted to vanilla 3x3 convolution + RPDC can be converted to vanilla 5x5 convolution + """ + + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock_converted, self).__init__() + self.stride = stride + + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + if pdc == 'rd': + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) + else: + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + + +class PiDiNet(nn.Module): + def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): + super(PiDiNet, self).__init__() + self.sa = sa + if dil is not None: + assert isinstance(dil, int), 'dil should be an int' + self.dil = dil + + self.fuseplanes = [] + + self.inplane = inplane + if convert: + if pdcs[0] == 'rd': + init_kernel_size = 5 + init_padding = 2 + else: + init_kernel_size = 3 + init_padding = 1 + self.init_block = nn.Conv2d(3, self.inplane, + kernel_size=init_kernel_size, padding=init_padding, bias=False) + block_class = PDCBlock_converted + else: + self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) + block_class = PDCBlock + + self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) + self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) + self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) + self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) + self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) + self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 2C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) + self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) + self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) + self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) + self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) + self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) + self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.conv_reduces = nn.ModuleList() + if self.sa and self.dil is not None: + self.attentions = nn.ModuleList() + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.attentions.append(CSAM(self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + elif self.sa: + self.attentions = nn.ModuleList() + for i in range(4): + self.attentions.append(CSAM(self.fuseplanes[i])) + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + elif self.dil is not None: + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + else: + for i in range(4): + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + + self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias + nn.init.constant_(self.classifier.weight, 0.25) + nn.init.constant_(self.classifier.bias, 0) + + # print('initialization done') + + def get_weights(self): + conv_weights = [] + bn_weights = [] + relu_weights = [] + for pname, p in self.named_parameters(): + if 'bn' in pname: + bn_weights.append(p) + elif 'relu' in pname: + relu_weights.append(p) + else: + conv_weights.append(p) + + return conv_weights, bn_weights, relu_weights + + def forward(self, x): + H, W = x.size()[2:] + + x = self.init_block(x) + + x1 = self.block1_1(x) + x1 = self.block1_2(x1) + x1 = self.block1_3(x1) + + x2 = self.block2_1(x1) + x2 = self.block2_2(x2) + x2 = self.block2_3(x2) + x2 = self.block2_4(x2) + + x3 = self.block3_1(x2) + x3 = self.block3_2(x3) + x3 = self.block3_3(x3) + x3 = self.block3_4(x3) + + x4 = self.block4_1(x3) + x4 = self.block4_2(x4) + x4 = self.block4_3(x4) + x4 = self.block4_4(x4) + + x_fuses = [] + if self.sa and self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](self.dilations[i](xi))) + elif self.sa: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](xi)) + elif self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.dilations[i](xi)) + else: + x_fuses = [x1, x2, x3, x4] + + e1 = self.conv_reduces[0](x_fuses[0]) + e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) + + e2 = self.conv_reduces[1](x_fuses[1]) + e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) + + e3 = self.conv_reduces[2](x_fuses[2]) + e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) + + e4 = self.conv_reduces[3](x_fuses[3]) + e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) + + outputs = [e1, e2, e3, e4] + + output = self.classifier(torch.cat(outputs, dim=1)) + # if not self.training: + # return torch.sigmoid(output) + + outputs.append(output) + outputs = [torch.sigmoid(r) for r in outputs] + return outputs + + +def config_model(model): + model_options = list(nets.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + # print(str(nets[model])) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = nets[model][layer_name] + pdcs.append(createConvFunc(op)) + + return pdcs + + +def pidinet(): + pdcs = config_model('carv4') + dil = 24 # if args.dil else None + return PiDiNet(60, pdcs, dil=dil, sa=True) diff --git a/modules/cnet_modules/pidinet/util.py b/modules/cnet_modules/pidinet/util.py new file mode 100644 index 0000000000000000000000000000000000000000..aec00770c7706f95abf3a0b9b02dbe3232930596 --- /dev/null +++ b/modules/cnet_modules/pidinet/util.py @@ -0,0 +1,97 @@ +import random + +import numpy as np +import cv2 +import os + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) diff --git a/modules/common.py b/modules/common.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4ad71649f60f2dd38947c9ebc23bc51db2b544 --- /dev/null +++ b/modules/common.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from einops import rearrange +import torch.fft as fft +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + kv = torch.cat([x, kv], dim=1) + #if x.shape[1] >= 72 * 72: + # x = x * math.sqrt(math.log(64*64, 24*24)) + + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + res = self.attention(self.norm(x), kv, self_attn=self.self_attn) + + #print(torch.unique(res), torch.unique(x), self.self_attn) + #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24)) + x = x + res + + return x + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0): + super().__init__() + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca']): + super().__init__() + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/modules/common_ckpt.py b/modules/common_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..f64cf11790bdd2a83ca0744629336d81464b3ed0 --- /dev/null +++ b/modules/common_ckpt.py @@ -0,0 +1,360 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from einops import rearrange +from modules.speed_util import checkpoint +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + +class AttnBlock_lrfuse_backup(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + self.fuse_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + self.use_checkpoint = use_checkpoint + + def forward(self, hr, lr): + return checkpoint(self._forward, (hr, lr), self.paramters(), self.use_checkpoint) + def _forward(self, hr, lr): + res = hr + hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c')) + lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr + + lr_fuse = self.fuse_mapper(rearrange(lr_fuse, 'b c h w -> b (h w ) c')) + hr = self.attention(self.norm(res), lr_fuse, self_attn=False) + res + return hr + + +class AttnBlock_lrfuse(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, kernel_size=3, use_checkpoint=True): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + + + self.depthwise = Conv2d(c, c , kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + + self.channelwise = nn.Sequential( + Linear(c + c, c ), + nn.GELU(), + GlobalResponseNorm(c ), + nn.Dropout(dropout), + Linear(c , c) + ) + self.use_checkpoint = use_checkpoint + + + def forward(self, hr, lr): + return checkpoint(self._forward, (hr, lr), self.parameters(), self.use_checkpoint) + + def _forward(self, hr, lr): + res = hr + hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c')) + lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr + + lr_fuse = torch.nn.functional.interpolate(lr_fuse.float(), res.shape[2:]) + #print('in line 65', lr_fuse.shape, res.shape) + media = torch.cat((self.depthwise(lr_fuse), res), dim=1) + out = self.channelwise(media.permute(0,2,3,1)).permute(0,3,1,2) + res + + return out + + + + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + + kv = torch.cat([x, kv], dim=1) + #if x.shape[1] > 48 * 48 and not self.training: + # x = x * math.sqrt(math.log(x.shape[1] , 24*24)) + + x = self.attn(x, kv, kv, need_weights=False)[0] + x = x.permute(0, 2, 1).view(*orig_shape) + return x +class Attention2D_splitpatch(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + + #x = rearrange(x, 'b c h w -> b c (nh wh) (nw ww)', wh=24, ww=24, nh=orig_shape[-2] // 24, nh=orig_shape[-1] // 24,) + x = rearrange(x, 'b c (nh wh) (nw ww) -> (b nh nw) (wh ww) c', wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24,) + #print('in line 168', x.shape) + #x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + num = (orig_shape[-2] // 24) * (orig_shape[-1] // 24) + kv = torch.cat([x, kv.repeat(num, 1, 1)], dim=1) + #if x.shape[1] > 48 * 48 and not self.training: + # x = x * math.sqrt(math.log(x.shape[1] / math.sqrt(16), 24*24)) + + x = self.attn(x, kv, kv, need_weights=False)[0] + x = rearrange(x, ' (b nh nw) (wh ww) c -> b c (nh wh) (nw ww)', b=orig_shape[0], wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24) + #x = x.permute(0, 2, 1).view(*orig_shape) + + return x +class Attention2D_extra(nn.Module): + def __init__(self, c, nhead, dropout=0.0): + super().__init__() + self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) + + def forward(self, x, kv, extra_emb=None, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + num_x = x.shape[1] + + + if extra_emb is not None: + ori_extra_shape = extra_emb.shape + extra_emb = extra_emb.view(extra_emb.size(0), extra_emb.size(1), -1).permute(0, 2, 1) + x = torch.cat((x, extra_emb), dim=1) + if self_attn: + #print('in line 23 algong self att ', kv.shape, x.shape) + kv = torch.cat([x, kv], dim=1) + x = self.attn(x, kv, kv, need_weights=False)[0] + img = x[:, :num_x, :].permute(0, 2, 1).view(*orig_shape) + if extra_emb is not None: + fix = x[:, num_x:, :].permute(0, 2, 1).view(*ori_extra_shape) + return img, fix + else: + return img +class AttnBlock_extraq(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + #self.norm2 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D_extra(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + # norm2 initialization in generator in init extra parameter + def forward(self, x, kv, extra_emb=None): + #print('in line 84', x.shape, kv.shape, self.self_attn, extra_emb if extra_emb is None else extra_emb.shape) + #in line 84 torch.Size([1, 1536, 32, 32]) torch.Size([1, 85, 1536]) True None + #if extra_emb is not None: + + kv = self.kv_mapper(kv) + if extra_emb is not None: + res_x, res_extra = self.attention(self.norm(x), kv, extra_emb=self.norm2(extra_emb), self_attn=self.self_attn) + x = x + res_x + extra_emb = extra_emb + res_extra + return x, extra_emb + else: + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x +class AttnBlock_latent2ex(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + + def forward(self, x, kv): + #print('in line 84', x.shape, kv.shape, self.self_attn) + kv = F.interpolate(kv.float(), x.shape[2:]) + kv = kv.view(kv.size(0), kv.size(1), -1).permute(0, 2, 1) + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) +class AttnBlock_crossbranch(nn.Module): + def __init__(self, attnmodule, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.attn = AttnBlock(c, c_cond, nhead, self_attn, dropout) + #print('in line 108', attnmodule.device) + self.attn.load_state_dict(attnmodule.state_dict()) + self.norm1 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + + self.channelwise1 = nn.Sequential( + Linear(c *2, c ), + nn.GELU(), + GlobalResponseNorm(c ), + nn.Dropout(dropout), + Linear(c, c) + ) + self.channelwise2 = nn.Sequential( + Linear(c *2, c ), + nn.GELU(), + GlobalResponseNorm(c ), + nn.Dropout(dropout), + Linear(c, c) + ) + self.c = c + def forward(self, x, kv, main_x): + #print('in line 84', x.shape, kv.shape, main_x.shape, self.c) + + x = self.channelwise1(torch.cat((x, F.interpolate(main_x.float(), x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x + x = self.attn(x, kv) + main_x = self.channelwise2(torch.cat((main_x, F.interpolate(x.float(), main_x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + main_x + return main_x, x + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, use_checkpoint =True): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + self.use_checkpoint = use_checkpoint + def forward(self, x, x_skip=None): + + if x_skip is not None: + return checkpoint(self._forward_skip, (x, x_skip), self.parameters(), self.use_checkpoint) + else: + #print('in line 298', x.shape) + return checkpoint(self._forward_woskip, (x, ), self.parameters(), self.use_checkpoint) + + + + def _forward_skip(self, x, x_skip): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + def _forward_woskip(self, x): + x_res = x + x = self.norm(self.depthwise(x)) + + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + Linear(c_cond, c) + ) + self.use_checkpoint = use_checkpoint + def forward(self, x, kv): + return checkpoint(self._forward, (x, kv), self.parameters(), self.use_checkpoint) + def _forward(self, x, kv): + kv = self.kv_mapper(kv) + res = self.attention(self.norm(x), kv, self_attn=self.self_attn) + + #print(torch.unique(res), torch.unique(x), self.self_attn) + #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24)) + x = x + res + + return x +class AttnBlock_mytest(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention2D(c, nhead, dropout) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + nn.Linear(c_cond, c) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0): + super().__init__() + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + Linear(c, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + Linear(c * 4, c) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca'], use_checkpoint=True): + super().__init__() + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + self.use_checkpoint = use_checkpoint + def forward(self, x, t): + return checkpoint(self._forward, (x, t), self.parameters(), self.use_checkpoint) + + def _forward(self, x, t): + #print('in line 284', x.shape, t.shape, self.conds) + #in line 284 torch.Size([4, 2048, 19, 29]) torch.Size([4, 192]) ['sca', 'crp'] + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/modules/controlnet.py b/modules/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c187aecb725e00e19924ae308e3aac401acfdf06 --- /dev/null +++ b/modules/controlnet.py @@ -0,0 +1,349 @@ +import torchvision +import torch +from torch import nn +import numpy as np +import kornia +import cv2 +from core.utils import load_or_fail +#from insightface.app.common import Face +from .effnet import EfficientNetEncoder +from .cnet_modules.pidinet import PidiNetDetector +from .cnet_modules.inpainting.saliency_model import MicroResNet +#from .cnet_modules.face_id.arcface import FaceDetector, ArcFaceRecognizer +from .common import LayerNorm2d + + +class CNetResBlock(nn.Module): + def __init__(self, c): + super().__init__() + self.blocks = nn.Sequential( + LayerNorm2d(c), + nn.GELU(), + nn.Conv2d(c, c, kernel_size=3, padding=1), + LayerNorm2d(c), + nn.GELU(), + nn.Conv2d(c, c, kernel_size=3, padding=1), + ) + + def forward(self, x): + return x + self.blocks(x) + + +class ControlNet(nn.Module): + def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None): + super().__init__() + if bottleneck_mode is None: + bottleneck_mode = 'effnet' + self.proj_blocks = proj_blocks + if bottleneck_mode == 'effnet': + embd_channels = 1280 + #self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + if c_in != 3: + in_weights = self.backbone[0][0].weight.data + self.backbone[0][0] = nn.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False) + if c_in > 3: + nn.init.constant_(self.backbone[0][0].weight, 0) + self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() + else: + self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() + elif bottleneck_mode == 'simple': + embd_channels = c_in + self.backbone = nn.Sequential( + nn.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1), + ) + elif bottleneck_mode == 'large': + self.backbone = nn.Sequential( + nn.Conv2d(c_in, 4096 * 4, kernel_size=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(4096 * 4, 1024, kernel_size=1), + *[CNetResBlock(1024) for _ in range(8)], + nn.Conv2d(1024, 1280, kernel_size=1), + ) + embd_channels = 1280 + else: + raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') + self.projections = nn.ModuleList() + for _ in range(len(proj_blocks)): + self.projections.append(nn.Sequential( + nn.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False), + )) + nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection + + def forward(self, x): + x = self.backbone(x) + proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] + for i, idx in enumerate(self.proj_blocks): + proj_outputs[idx] = self.projections[i](x) + return proj_outputs + + +class ControlNetDeliverer(): + def __init__(self, controlnet_projections): + self.controlnet_projections = controlnet_projections + self.restart() + + def restart(self): + self.idx = 0 + return self + + def __call__(self): + if self.idx < len(self.controlnet_projections): + output = self.controlnet_projections[self.idx] + else: + output = None + self.idx += 1 + return output + + +# CONTROLNET FILTERS ---------------------------------------------------- + +class BaseFilter(): + def __init__(self, device): + self.device = device + + def num_channels(self): + return 3 + + def __call__(self, x): + return x + + +class CannyFilter(BaseFilter): + def __init__(self, device, resize=224): + super().__init__(device) + self.resize = resize + + def num_channels(self): + return 1 + + def __call__(self, x): + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + edges = [cv2.Canny(x[i].mul(255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), 100, 200) for i in range(len(x))] + edges = torch.stack([torch.tensor(e).div(255).unsqueeze(0) for e in edges], dim=0) + if self.resize is not None: + edges = nn.functional.interpolate(edges, size=orig_size, mode='bilinear') + return edges + + +class QRFilter(BaseFilter): + def __init__(self, device, resize=224, blobify=True, dilation_kernels=[3, 5, 7], blur_kernels=[15]): + super().__init__(device) + self.resize = resize + self.blobify = blobify + self.dilation_kernels = dilation_kernels + self.blur_kernels = blur_kernels + + def num_channels(self): + return 1 + + def __call__(self, x): + x = x.to(self.device) + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + + x = kornia.color.rgb_to_hsv(x)[:, -1:] + # blobify + if self.blobify: + d_kernel = np.random.choice(self.dilation_kernels) + d_blur = np.random.choice(self.blur_kernels) + if d_blur > 0: + x = torchvision.transforms.GaussianBlur(d_blur)(x) + if d_kernel > 0: + blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, + d_kernel).pow(2)[:, + None]) < 0.3).float().to(self.device) + x = kornia.morphology.dilation(x, blob_mask) + x = kornia.morphology.erosion(x, blob_mask) + # mask + vmax, vmin = x.amax(dim=[2, 3], keepdim=True)[0], x.amin(dim=[2, 3], keepdim=True)[0] + th = (vmax - vmin) * 0.33 + high_brightness, low_brightness = (x > (vmax - th)).float(), (x < (vmin + th)).float() + mask = (torch.ones_like(x) - low_brightness + high_brightness) * 0.5 + + if self.resize is not None: + mask = nn.functional.interpolate(mask, size=orig_size, mode='bilinear') + return mask.cpu() + + +class PidiFilter(BaseFilter): + def __init__(self, device, resize=224, dilation_kernels=[0, 3, 5, 7, 9], binarize=True): + super().__init__(device) + self.resize = resize + self.model = PidiNetDetector(device) + self.dilation_kernels = dilation_kernels + self.binarize = binarize + + def num_channels(self): + return 1 + + def __call__(self, x): + x = x.to(self.device) + orig_size = x.shape[-2:] + if self.resize is not None: + x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') + + x = self.model(x) + d_kernel = np.random.choice(self.dilation_kernels) + if d_kernel > 0: + blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, d_kernel).pow(2)[ + :, None]) < 0.3).float().to(self.device) + x = kornia.morphology.dilation(x, blob_mask) + if self.binarize: + th = np.random.uniform(0.05, 0.7) + x = (x > th).float() + + if self.resize is not None: + x = nn.functional.interpolate(x, size=orig_size, mode='bilinear') + return x.cpu() + + +class SRFilter(BaseFilter): + def __init__(self, device, scale_factor=1 / 4): + super().__init__(device) + self.scale_factor = scale_factor + + def num_channels(self): + return 3 + + def __call__(self, x): + x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") + return torch.nn.functional.interpolate(x, scale_factor=1 / self.scale_factor, mode="nearest") + + +class SREffnetFilter(BaseFilter): + def __init__(self, device, scale_factor=1/2): + super().__init__(device) + self.scale_factor = scale_factor + + self.effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + self.effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail("models/effnet_encoder.safetensors") + self.effnet.load_state_dict(effnet_checkpoint) + self.effnet.eval().requires_grad_(False) + + def num_channels(self): + return 16 + + def __call__(self, x): + x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") + with torch.no_grad(): + effnet_embedding = self.effnet(self.effnet_preprocess(x.to(self.device))).cpu() + effnet_embedding = torch.nn.functional.interpolate(effnet_embedding, scale_factor=1/self.scale_factor, mode="nearest") + upscaled_image = torch.nn.functional.interpolate(x, scale_factor=1/self.scale_factor, mode="nearest") + return effnet_embedding, upscaled_image + + +class InpaintFilter(BaseFilter): + def __init__(self, device, thresold=[0.04, 0.4], p_outpaint=0.4): + super().__init__(device) + self.saliency_model = MicroResNet().eval().requires_grad_(False).to(device) + self.saliency_model.load_state_dict(load_or_fail("modules/cnet_modules/inpainting/saliency_model.pt")) + self.thresold = thresold + self.p_outpaint = p_outpaint + + def num_channels(self): + return 4 + + def __call__(self, x, mask=None, threshold=None, outpaint=None): + x = x.to(self.device) + resized_x = torchvision.transforms.functional.resize(x, 240, antialias=True) + if threshold is None: + threshold = np.random.uniform(self.thresold[0], self.thresold[1]) + if mask is None: + saliency_map = self.saliency_model(resized_x) > threshold + if outpaint is None: + if np.random.rand() < self.p_outpaint: + saliency_map = ~saliency_map + else: + if outpaint: + saliency_map = ~saliency_map + interpolated_saliency_map = torch.nn.functional.interpolate(saliency_map.float(), size=x.shape[2:], mode="nearest") + saliency_map = torchvision.transforms.functional.gaussian_blur(interpolated_saliency_map, 141) > 0.5 + inpainted_images = torch.where(saliency_map, torch.ones_like(x), x) + mask = torch.nn.functional.interpolate(saliency_map.float(), size=inpainted_images.shape[2:], mode="nearest") + else: + mask = mask.to(self.device) + inpainted_images = torch.where(mask, torch.ones_like(x), x) + c_inpaint = torch.cat([inpainted_images, mask], dim=1) + return c_inpaint.cpu() + + +# IDENTITY +''' +class IdentityFilter(BaseFilter): + def __init__(self, device, max_faces=4, p_drop=0.05, p_full=0.3): + detector_path = 'modules/cnet_modules/face_id/models/buffalo_l/det_10g.onnx' + recognizer_path = 'modules/cnet_modules/face_id/models/buffalo_l/w600k_r50.onnx' + + super().__init__(device) + self.max_faces = max_faces + self.p_drop = p_drop + self.p_full = p_full + + self.detector = FaceDetector(detector_path, device=device) + self.recognizer = ArcFaceRecognizer(recognizer_path, device=device) + + self.id_colors = torch.tensor([ + [1.0, 0.0, 0.0], # RED + [0.0, 1.0, 0.0], # GREEN + [0.0, 0.0, 1.0], # BLUE + [1.0, 0.0, 1.0], # PURPLE + [0.0, 1.0, 1.0], # CYAN + [1.0, 1.0, 0.0], # YELLOW + [0.5, 0.0, 0.0], # DARK RED + [0.0, 0.5, 0.0], # DARK GREEN + [0.0, 0.0, 0.5], # DARK BLUE + [0.5, 0.0, 0.5], # DARK PURPLE + [0.0, 0.5, 0.5], # DARK CYAN + [0.5, 0.5, 0.0], # DARK YELLOW + ]) + + def num_channels(self): + return 512 + + def get_faces(self, image): + npimg = image.permute(1, 2, 0).mul(255).to(device="cpu", dtype=torch.uint8).cpu().numpy() + bgr = cv2.cvtColor(npimg, cv2.COLOR_RGB2BGR) + bboxes, kpss = self.detector.detect(bgr, max_num=self.max_faces) + N = len(bboxes) + ids = torch.zeros((N, 512), dtype=torch.float32) + for i in range(N): + face = Face(bbox=bboxes[i, :4], kps=kpss[i], det_score=bboxes[i, 4]) + ids[i, :] = self.recognizer.get(bgr, face) + tbboxes = torch.tensor(bboxes[:, :4], dtype=torch.int) + + ids = ids / torch.linalg.norm(ids, dim=1, keepdim=True) + return tbboxes, ids # returns bounding boxes (N x 4) and ID vectors (N x 512) + + def __call__(self, x): + visual_aid = x.clone().cpu() + face_mtx = torch.zeros(x.size(0), 512, x.size(-2) // 32, x.size(-1) // 32) + + for i in range(x.size(0)): + bounding_boxes, ids = self.get_faces(x[i]) + for j in range(bounding_boxes.size(0)): + if np.random.rand() > self.p_drop: + sx, sy, ex, ey = (bounding_boxes[j] / 32).clamp(min=0).round().int().tolist() + ex, ey = max(ex, sx + 1), max(ey, sy + 1) + if bounding_boxes.size(0) == 1 and np.random.rand() < self.p_full: + sx, sy, ex, ey = 0, 0, x.size(-1) // 32, x.size(-2) // 32 + face_mtx[i, :, sy:ey, sx:ex] = ids[j:j + 1, :, None, None] + visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] += self.id_colors[j % 13, :, + None, None] + visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] *= 0.5 + + return face_mtx.to(x.device), visual_aid.to(x.device) +''' diff --git a/modules/effnet.py b/modules/effnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb2690c2547c8c7553aec8a9f9e838241f8f61c --- /dev/null +++ b/modules/effnet.py @@ -0,0 +1,17 @@ +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) + diff --git a/modules/inr_fea_res_lite.py b/modules/inr_fea_res_lite.py new file mode 100644 index 0000000000000000000000000000000000000000..41ddfb09937f26e2c7d0193b4a65607efabde5e5 --- /dev/null +++ b/modules/inr_fea_res_lite.py @@ -0,0 +1,435 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops +import numpy as np +import models +from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d +#from modules.common_ckpt import AttnBlock, +from einops import rearrange +import torch.fft as fft +from modules.speed_util import checkpoint +def batched_linear_mm(x, wb): + # x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2) + one = torch.ones(*x.shape[:-1], 1, device=x.device) + return torch.matmul(torch.cat([x, one], dim=-1), wb) +def make_coord_grid(shape, range, device=None): + """ + Args: + shape: tuple + range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim + Returns: + grid: shape (*shape, ) + """ + l_lst = [] + for i, s in enumerate(shape): + l = (0.5 + torch.arange(s, device=device)) / s + if isinstance(range[0], list) or isinstance(range[0], tuple): + minv, maxv = range[i] + else: + minv, maxv = range + l = minv + (maxv - minv) * l + l_lst.append(l) + grid = torch.meshgrid(*l_lst, indexing='ij') + grid = torch.stack(grid, dim=-1) + return grid +def init_wb(shape): + weight = torch.empty(shape[1], shape[0] - 1) + nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + bias = torch.empty(shape[1], 1) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(bias, -bound, bound) + + return torch.cat([weight, bias], dim=1).t().detach() + +def init_wb_rewrite(shape): + weight = torch.empty(shape[1], shape[0] - 1) + + torch.nn.init.xavier_uniform_(weight) + + bias = torch.empty(shape[1], 1) + torch.nn.init.xavier_uniform_(bias) + + + return torch.cat([weight, bias], dim=1).t().detach() +class HypoMlp(nn.Module): + + def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024): + super().__init__() + self.use_pe = use_pe + self.pe_dim = pe_dim + self.pe_sigma = pe_sigma + self.depth = depth + self.param_shapes = dict() + if use_pe: + last_dim = in_dim * pe_dim + else: + last_dim = in_dim + for i in range(depth): # for each layer the weight + cur_dim = hidden_dim if i < depth - 1 else out_dim + self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim) + last_dim = cur_dim + self.relu = nn.ReLU() + self.params = None + self.out_bias = out_bias + + def set_params(self, params): + self.params = params + + def convert_posenc(self, x): + w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device)) + x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1) + x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1) + return x + + def forward(self, x): + B, query_shape = x.shape[0], x.shape[1: -1] + x = x.view(B, -1, x.shape[-1]) + if self.use_pe: + x = self.convert_posenc(x) + #print('in line 79 after pos embedding', x.shape) + for i in range(self.depth): + x = batched_linear_mm(x, self.params[f'wb{i}']) + if i < self.depth - 1: + x = self.relu(x) + else: + x = x + self.out_bias + x = x.view(B, *query_shape, -1) + return x + + + +class Attention(nn.Module): + + def __init__(self, dim, n_head, head_dim, dropout=0.): + super().__init__() + self.n_head = n_head + inner_dim = n_head * head_dim + self.to_q = nn.Sequential( + nn.SiLU(), + Linear(dim, inner_dim )) + self.to_kv = nn.Sequential( + nn.SiLU(), + Linear(dim, inner_dim * 2)) + self.scale = head_dim ** -0.5 + # self.to_out = nn.Sequential( + # Linear(inner_dim, dim), + # nn.Dropout(dropout), + # ) + + def forward(self, fr, to=None): + if to is None: + to = fr + q = self.to_q(fr) + k, v = self.to_kv(to).chunk(2, dim=-1) + q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v]) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = F.softmax(dots, dim=-1) # b h n n + out = torch.matmul(attn, v) + out = einops.rearrange(out, 'b h n d -> b n (h d)') + return out + + +class FeedForward(nn.Module): + + def __init__(self, dim, ff_dim, dropout=0.): + super().__init__() + + self.net = nn.Sequential( + Linear(dim, ff_dim), + nn.GELU(), + #GlobalResponseNorm(ff_dim), + nn.Dropout(dropout), + Linear(ff_dim, dim) + ) + + def forward(self, x): + return self.net(x) + + +class PreNorm(nn.Module): + + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + + +#TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = []) +class TransformerEncoder(nn.Module): + + def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)), + PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)), + ])) + + def forward(self, x): + for norm_attn, norm_ff in self.layers: + x = x + norm_attn(x) + x = x + norm_ff(x) + return x +class ImgrecTokenizer(nn.Module): + + def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16): + super().__init__() + + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + if isinstance(padding, int): + padding = (padding, padding) + self.patch_size = patch_size + self.padding = padding + self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim) + + self.posemb = nn.Parameter(torch.randn(input_size, dim)) + + def forward(self, x): + #print(x.shape) + p = self.patch_size + x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L) + #print('in line 185 after unfoding', x.shape) + x = x.permute(0, 2, 1).contiguous() + ttt = self.prefc(x) + + x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0) + return x + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + +class TimestepBlock_res(nn.Module): + def __init__(self, c, c_timestep, conds=['sca']): + super().__init__() + + self.mapper = Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) + + + + + def forward(self, x, t): + #print(x.shape, t.shape, self.conds, 'in line 269') + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + + +class ScaleNormalize_res(nn.Module): + def __init__(self, c, scale_c, conds=['sca']): + super().__init__() + self.c_r = scale_c + self.mapping = TimestepBlock_res(c, scale_c, conds=conds) + self.t_conds = conds + self.alpha = nn.Conv2d(c, c, kernel_size=1) + self.gamma = nn.Conv2d(c, c, kernel_size=1) + self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) + + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + def forward(self, x, std_size=24*24): + scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size)) + scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val + scale_val_f = self.gen_r_embedding(scale_val) + for c in self.t_conds: + t_cond = torch.zeros_like(scale_val) + scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1) + + f = self.mapping(x, scale_val_f) + + return f + x + + +class TransInr_withnorm(nn.Module): + + def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]): + super().__init__() + self.input_layer= nn.Conv2d(ind, ch, 1) + self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch) + #self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) + #self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, ) + + self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) + self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim) + #self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = []) + self.base_params = nn.ParameterDict() + n_wtokens = 0 + self.wtoken_postfc = nn.ModuleDict() + self.wtoken_rng = dict() + for name, shape in self.hyponet.param_shapes.items(): + self.base_params[name] = nn.Parameter(init_wb(shape)) + g = min(n_groups, shape[1]) + assert shape[1] % g == 0 + self.wtoken_postfc[name] = nn.Sequential( + nn.LayerNorm(f_dim), + nn.Linear(f_dim, shape[0] - 1), + ) + self.wtoken_rng[name] = (n_wtokens, n_wtokens + g) + n_wtokens += g + self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim)) + self.output_layer= nn.Conv2d(ch, ind, 1) + + + self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds) + + + self.hr_norm = ScaleNormalize_res(ind, 64, conds=[]) + + self.normalize_final = nn.Sequential( + LayerNorm2d(ind, elementwise_affine=False, eps=1e-6), + ) + + self.toout = nn.Sequential( + Linear( ind*2, ind // 4), + nn.GELU(), + Linear( ind // 4, ind) + ) + self.apply(self._init_weights) + + mask = torch.zeros((1, 1, 32, 32)) + h, w = 32, 32 + center_h, center_w = h // 2, w // 2 + low_freq_h, low_freq_w = h // 4, w // 4 + mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1 + self.mask = mask + + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + #nn.init.constant_(self.last.weight, 0) + def adain(self, feature_a, feature_b): + norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True) + norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True) + #feature_a = F.interpolate(feature_a, feature_b.shape[2:]) + feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean + return feature_b + def forward(self, target_shape, target, dtokens, t_emb): + #print(target.shape, dtokens.shape, 'in line 290') + hlr, wlr = dtokens.shape[2:] + original = dtokens + + dtokens = self.input_layer(dtokens) + dtokens = self.tokenizer(dtokens) + B = dtokens.shape[0] + wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B) + #print(wtokens.shape, dtokens.shape) + trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1)) + trans_out = trans_out[:, -len(self.wtokens):, :] + + params = dict() + for name, shape in self.hyponet.param_shapes.items(): + wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B) + w, b = wb[:, :-1, :], wb[:, -1:, :] + + l, r = self.wtoken_rng[name] + x = self.wtoken_postfc[name](trans_out[:, l: r, :]) + x = x.transpose(-1, -2) # (B, shape[0] - 1, g) + w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1) + + wb = torch.cat([w, b], dim=1) + params[name] = wb + coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device) + coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0]) + self.hyponet.set_params(params) + ori_up = F.interpolate(original.float(), target_shape[2:]) + hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up + #print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537') + + output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + #print(output.shape, 'in line 540') + #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3 + output = self.mapp_t(output, t_emb) + output = self.normalize_final(output) + output = self.hr_norm(output) + #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + #output = self.mapp_t(output, t_emb) + #output = self.weight(output) * output + + return output + + + + + + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + + +if __name__ == '__main__': + #ef __init__(self, ch, n_head, head_dim, n_groups): + trans_inr = TransInr(16, 24, 32, 64).cuda() + input = torch.randn((1, 16, 24, 24)).cuda() + source = torch.randn((1, 16, 16, 16)).cuda() + t = torch.randn((1, 128)).cuda() + output, hr = trans_inr(input, t, source) + + total_up = sum([ param.nelement() for param in trans_inr.parameters()]) + print(output.shape, hr.shape, total_up /1e6 ) + diff --git a/modules/lora.py b/modules/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0a2bd797f3669a465f6c2c4255b52fe1bda7a7 --- /dev/null +++ b/modules/lora.py @@ -0,0 +1,71 @@ +import torch +from torch import nn + + +class LoRA(nn.Module): + def __init__(self, layer, name='weight', rank=16, alpha=1): + super().__init__() + weight = getattr(layer, name) + self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1)))) + self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank))) + nn.init.normal_(self.lora_up, mean=0, std=1) + + self.scale = alpha / rank + self.enabled = True + + def forward(self, original_weights): + if self.enabled: + lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2) + lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale + return original_weights + lora_weights + else: + return original_weights + + +def apply_lora(model, filters=None, rank=16): + def check_parameter(module, name): + return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( + getattr(module, name), nn.Parameter) + + for name, module in model.named_modules(): + if filters is None or any([f in name for f in filters]): + if check_parameter(module, "weight"): + device, dtype = module.weight.device, module.weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device)) + elif check_parameter(module, "in_proj_weight"): + device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device)) + + +class ReToken(nn.Module): + def __init__(self, indices=None): + super().__init__() + assert indices is not None + self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280)) + self.register_buffer('indices', torch.tensor(indices)) + self.enabled = True + + def forward(self, embeddings): + if self.enabled: + embeddings = embeddings.clone() + for i, idx in enumerate(self.indices): + embeddings[idx] += self.embeddings[i] + return embeddings + + +def apply_retoken(module, indices=None): + def check_parameter(module, name): + return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( + getattr(module, name), nn.Parameter) + + if check_parameter(module, "weight"): + device, dtype = module.weight.device, module.weight.dtype + torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device)) + + +def remove_lora(model, leave_parametrized=True): + for module in model.modules(): + if torch.nn.utils.parametrize.is_parametrized(module, "weight"): + nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized) + elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): + nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized) diff --git a/modules/model_4stage_lite.py b/modules/model_4stage_lite.py new file mode 100644 index 0000000000000000000000000000000000000000..e77cc5d73ccda882774f447f5a8bb86fe71fe755 --- /dev/null +++ b/modules/model_4stage_lite.py @@ -0,0 +1,458 @@ +import torch +from torch import nn +import numpy as np +import math +from modules.common_ckpt import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock +from .controlnet import ControlNetDeliverer +import torch.nn.functional as F +from modules.inr_fea_res_lite import TransInr_withnorm as TransInr +from modules.inr_fea_res_lite import ScaleNormalize_res +from einops import rearrange +import torch.fft as fft +import random +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = nn.Conv2d(c_in, c_out, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x.float()) + return x +def ada_in(a, b): + mean_a = torch.mean(a, dim=(2, 3), keepdim=True) + std_a = torch.std(a, dim=(2, 3), keepdim=True) + + mean_b = torch.mean(b, dim=(2, 3), keepdim=True) + std_b = torch.std(b, dim=(2, 3), keepdim=True) + + return (b - mean_b) / (1e-8 + std_b) * std_a + mean_a +def feature_dist_loss(x1, x2): + mu1 = torch.mean(x1, dim=(2, 3)) + mu2 = torch.mean(x2, dim=(2, 3)) + + std1 = torch.std(x1, dim=(2, 3)) + std2 = torch.std(x2, dim=(2, 3)) + std_loss = torch.mean(torch.abs(torch.log(std1+ 1e-8) - torch.log(std2+ 1e-8))) + mean_loss = torch.mean(torch.abs(mu1 - mu2)) + #print('in line 36', std_loss, mean_loss) + return std_loss + mean_loss*0.1 +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], + lr_h=24, lr_w=24): + super().__init__() + + self.lr_h, self.lr_w = lr_h, lr_w + self.block_repeat = block_repeat + self.c_in = c_in + self.c_cond = c_cond + self.patch_size = patch_size + self.c_hidden = c_hidden + self.nhead = nhead + self.blocks = blocks + self.level_config = level_config + self.kernel_size = kernel_size + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + self.self_attn = self_attn + self.dropout = dropout + self.switch_level = switch_level + # CONDITIONING + self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) + self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) + self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + + + #extra down blocks + + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + def _init_extra_parameter(self): + + + + self.agg_net = nn.ModuleList() + for _ in range(2): + + self.agg_net.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) # + + self.agg_net_up = nn.ModuleList() + for _ in range(2): + + self.agg_net_up.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) # + + + + + + self.norm_down_blocks = nn.ModuleList() + for i in range(len(self.c_hidden)): + + up_blocks = nn.ModuleList() + for j in range(self.blocks[0][i]): + if j % 4 == 0: + up_blocks.append( + ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[])) + self.norm_down_blocks.append(up_blocks) + + + self.norm_up_blocks = nn.ModuleList() + for i in reversed(range(len(self.c_hidden))): + + up_block = nn.ModuleList() + for j in range(self.blocks[1][::-1][i]): + if j % 4 == 0: + up_block.append(ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[])) + self.norm_up_blocks.append(up_block) + + + + + self.agg_net.apply(self._init_weights) + self.agg_net_up.apply(self._init_weights) + self.norm_up_blocks.apply(self._init_weights) + self.norm_down_blocks.apply(self._init_weights) + for block in self.agg_net + self.agg_net_up: + #for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + + + + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None, require_q=False, lr_guide=None, r_emb_lite=None, guide_weight=1): + level_outputs = [] + if require_q: + qs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for stage_cnt, (down_block, downscaler, repmap) in enumerate(block_group): + x = downscaler(x) + for i in range(len(repmap) + 1): + for inner_cnt, block in enumerate(down_block): + + + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None and lr_guide is None: + #if cnet is not None : + next_cnet = cnet() + if next_cnet is not None: + + x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + + x = block(x, clip) + if require_q and (inner_cnt == 2 ): + qs.append(x.clone()) + if lr_guide is not None and (inner_cnt == 2 ) : + + guide = self.agg_net[stage_cnt](x.shape, x, lr_guide[stage_cnt], r_emb_lite) + x = x + guide + + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) # 0 indicate last output + if require_q: + return level_outputs, qs + return level_outputs + + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None, require_ff=False, agg_f=None, r_emb_lite=None, guide_weight=1): + if require_ff: + agg_feas = [] + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + + + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + + if cnet is not None and agg_f is None: + next_cnet = cnet() + if next_cnet is not None: + + x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear', + align_corners=True) + + + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + + + x = block(x, clip) + if require_ff and (k == 2 ): + agg_feas.append(x.clone()) + if agg_f is not None and (k == 2 ) : + + guide = self.agg_net_up[i](x.shape, x, agg_f[i], r_emb_lite) # training 1 test 4k 0.8 2k 0.7 + if not self.training: + hw = x.shape[-2] * x.shape[-1] + if hw >= 96*96: + guide = 0.7*guide + + else: + + if hw >= 72*72: + guide = 0.5* guide + else: + + guide = 0.3* guide + + x = x + guide + + + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + #if require_ff: + # agg_feas.append(x.clone()) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + + + if require_ff: + return x, agg_feas + + return x + + + + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, lr_guide=None, reuire_f=False, cnet=None, require_t=False, guide_weight=0.5, **kwargs): + + r_embed = self.gen_r_embedding(r) + + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + + x = self.embedding(x) + + + + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + + if not reuire_f: + level_outputs = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, \ + require_q=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight) + x = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, \ + require_ff=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight) + else: + level_outputs, lr_enc = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, require_q=True) + x, lr_dec = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, require_ff=True) + + if reuire_f and require_t: + return self.clf(x), r_embed, lr_enc, lr_dec + if reuire_f: + return self.clf(x), lr_enc, lr_dec + if require_t: + return self.clf(x), r_embed + return self.clf(x) + + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) + + + +if __name__ == '__main__': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + total_ori = sum([ param.nelement() for param in generator.parameters()]) + generator._init_extra_parameter() + generator = generator.cuda() + total = sum([ param.nelement() for param in generator.parameters()]) + total_down = sum([ param.nelement() for param in generator.down_blocks.parameters()]) + + total_up = sum([ param.nelement() for param in generator.up_blocks.parameters()]) + total_pro = sum([ param.nelement() for param in generator.project.parameters()]) + + + print(total_ori / 1e6, total / 1e6, total_up / 1e6, total_down / 1e6, total_pro / 1e6) + + # for name, module in generator.down_blocks.named_modules(): + # print(name, module) + output, out_lr = generator( + x=torch.randn(1, 16, 24, 24).cuda(), + x_lr=torch.randn(1, 16, 16, 16).cuda(), + r=torch.tensor([0.7056]).cuda(), + clip_text=torch.randn(1, 77, 1280).cuda(), + clip_text_pooled = torch.randn(1, 1, 1280).cuda(), + clip_img = torch.randn(1, 1, 768).cuda() + ) + print(output.shape, out_lr.shape) + # cnt diff --git a/modules/previewer.py b/modules/previewer.py new file mode 100644 index 0000000000000000000000000000000000000000..51ab24292d8ac0da8d24b17d8fc0ac9e1419a3d7 --- /dev/null +++ b/modules/previewer.py @@ -0,0 +1,45 @@ +from torch import nn + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return self.blocks(x) diff --git a/modules/resnet.py b/modules/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..460a808942be147d76b8b1f3baf29fec1e2a7b8d --- /dev/null +++ b/modules/resnet.py @@ -0,0 +1,415 @@ +import torch +from torch import nn +import torch.nn.functional as F +#import fvcore.nn.weight_init as weight_init + +""" +Functions for building the BottleneckBlock from Detectron2. +# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py +""" + +def get_norm(norm, out_channels, num_norm_groups=32): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "GN": lambda channels: nn.GroupNorm(num_norm_groups, channels), + }[norm] + return norm(out_channels) + +class Conv2d(nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + x = F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + +class CNNBlockBase(nn.Module): + """ + A CNN block is assumed to have input channels, output channels and a stride. + The input and output of `forward()` method must be NCHW tensors. + The method can perform arbitrary computation but must match the given + channels and stride specification. + Attribute: + in_channels (int): + out_channels (int): + stride (int): + """ + + def __init__(self, in_channels, out_channels, stride): + """ + The `__init__` method of any subclass should also contain these arguments. + Args: + in_channels (int): + out_channels (int): + stride (int): + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block used by ResNet-50, 101 and 152 + defined in :paper:`ResNet`. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1, and a projection shortcut if needed. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="GN", + stride_in_1x1=False, + dilation=1, + num_norm_groups=32 + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels, num_norm_groups), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels, num_norm_groups), + ) + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + norm=get_norm(norm, bottleneck_channels, num_norm_groups), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels, num_norm_groups), + ) + + #for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + # if layer is not None: # shortcut can be None + # weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient ¦Ã is initialized + # to be 1, except for each residual block's last BN + # where ¦Ã is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + out = self.conv2(out) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + +class ResNet(nn.Module): + """ + Implement :paper:`ResNet`. + """ + + def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stage_names, self.stages = [], [] + + if out_features is not None: + # Avoid keeping unused layers in this module. They consume extra memory + # and may cause allreduce to fail + num_stages = max( + [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features] + ) + stages = stages[:num_stages] + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "res" + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stage_names.append(name) + self.stages.append(stage) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + self.stage_names = tuple(self.stage_names) # Make it static for scripting + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + self.freeze(freeze_at) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for name, stage in zip(self.stage_names, self.stages): + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def freeze(self, freeze_at=0): + """ + Freeze the first several stages of the ResNet. Commonly used in + fine-tuning. + Layers that produce the same feature map spatial size are defined as one + "stage" by :paper:`FPN`. + Args: + freeze_at (int): number of stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + one residual stage, etc. + Returns: + nn.Module: this ResNet itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + @staticmethod + def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks of the same type that forms one ResNet stage. + Args: + block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this + stage. A module of this type must not change spatial resolution of inputs unless its + stride != 1. + num_blocks (int): number of blocks in this stage + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of + `block_class`. If the argument name is "xx_per_block", the + argument is a list of values to be passed to each block in the + stage. Otherwise, the same argument is passed to every block + in the stage. + Returns: + list[CNNBlockBase]: a list of block module. + Examples: + :: + stage = ResNet.make_stage( + BottleneckBlock, 3, in_channels=16, out_channels=64, + bottleneck_channels=16, num_groups=1, + stride_per_block=[2, 1, 1], + dilations_per_block=[1, 1, 2] + ) + Usually, layers that produce the same feature map spatial size are defined as one + "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should + all be 1. + """ + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " + f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append( + block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs) + ) + in_channels = out_channels + return blocks + + @staticmethod + def make_default_stages(depth, block_class=None, **kwargs): + """ + Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). + If it doesn't create the ResNet variant you need, please use :meth:`make_stage` + instead for fine-grained customization. + Args: + depth (int): depth of ResNet + block_class (type): the CNN block class. Has to accept + `bottleneck_channels` argument for depth > 50. + By default it is BasicBlock or BottleneckBlock, based on the + depth. + kwargs: + other arguments to pass to `make_stage`. Should not contain + stride and channels, as they are predefined for each depth. + Returns: + list[list[CNNBlockBase]]: modules in all stages; see arguments of + :class:`ResNet.__init__`. + """ + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + if block_class is None: + block_class = BasicBlock if depth < 50 else BottleneckBlock + if depth < 50: + in_channels = [64, 64, 128, 256] + out_channels = [64, 128, 256, 512] + else: + in_channels = [64, 256, 512, 1024] + out_channels = [256, 512, 1024, 2048] + ret = [] + for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels): + if depth >= 50: + kwargs["bottleneck_channels"] = o // 4 + ret.append( + ResNet.make_stage( + block_class=block_class, + num_blocks=n, + stride_per_block=[s] + [1] * (n - 1), + in_channels=i, + out_channels=o, + **kwargs, + ) + ) + return ret \ No newline at end of file diff --git a/modules/speed_util.py b/modules/speed_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9507c74833bec270b00bd252a3c76fcc09fab3 --- /dev/null +++ b/modules/speed_util.py @@ -0,0 +1,55 @@ +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) \ No newline at end of file diff --git a/modules/stage_a.py b/modules/stage_a.py new file mode 100644 index 0000000000000000000000000000000000000000..2840ef71d30e3da74954ab4a05e724fd7fef86cf --- /dev/null +++ b/modules/stage_a.py @@ -0,0 +1,183 @@ +import torch +from torch import nn +from torchtools.nn import VectorQuantize +from einops import rearrange +import torch.nn.functional as F +import math +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + + #x = x.to(torch.float64) + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +def extract_patches(tensor, patch_size, stride): + b, c, H, W = tensor.shape + pad_h = (patch_size - (H - patch_size) % stride) % stride + pad_w = (patch_size - (W - patch_size) % stride) % stride + tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect') + + + patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride) + patches = patches.contiguous().view(b, c, -1, patch_size, patch_size) + patches = patches.permute(0, 2, 1, 3, 4) + return patches, (H, W) + +def fuse_patches(patches, patch_size, stride, H, W): + + b, num_patches, c, _, _ = patches.shape + patches = patches.permute(0, 2, 1, 3, 4) + + + + pad_h = (patch_size - (H - patch_size) % stride) % stride + pad_w = (patch_size - (W - patch_size) % stride) % stride + out_h = H + pad_h + out_w = W + pad_w + patches = patches.contiguous().view(b, c , -1, patch_size*patch_size ).permute(0, 1, 3, 2) + patches = patches.contiguous().view(b, c*patch_size*patch_size, -1) + + tensor = F.fold(patches, output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) + overlap_cnt = F.fold(torch.ones_like(patches), output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) + tensor = tensor / overlap_cnt + print('end fuse patch', tensor.shape, (tensor.dtype)) + return tensor[:, :, :H, :W] + + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, + scale_factor=0.43): # 0.3764 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + else: + return x / self.scale_factor, None, None, None + + + + def decode(self, x, tiled_decoding=False): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/modules/stage_b.py b/modules/stage_b.py new file mode 100644 index 0000000000000000000000000000000000000000..f89b42d61327278820e164b1c093cbf8d1048ee1 --- /dev/null +++ b/modules/stage_b.py @@ -0,0 +1,239 @@ +import math +import numpy as np +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock + + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True, + t_conds=['sca']): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + self.pixels_mapper = nn.Sequential( + nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/modules/stage_c.py b/modules/stage_c.py new file mode 100644 index 0000000000000000000000000000000000000000..53b73d0197712b981ec1a154428c21af2149646a --- /dev/null +++ b/modules/stage_c.py @@ -0,0 +1,252 @@ +import torch +from torch import nn +import numpy as np +import math +from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock +#from .controlnet import ControlNetDeliverer + + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = nn.Conv2d(c_in, c_out, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x.float()) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False]): + super().__init__() + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) + self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) + self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) + self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) # General init + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + elif isinstance(block, TimestepBlock): + for layer in block.modules(): + if isinstance(layer, nn.Linear): + nn.init.constant_(layer.weight, 0) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + x = self.embedding(x) + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/prompt_list.txt b/prompt_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..27cd31b4750d2f15fdb6f2a3f4bdd117a7377267 --- /dev/null +++ b/prompt_list.txt @@ -0,0 +1,32 @@ +A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening +in the early morning light. + +A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a +clear blue sky. + +A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing +a vintage floral dress and standing in front of a blooming garden. + +The image features a snow-covered mountain range with a large, snow-covered mountain in the background. +The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the +winter season, with snow covering the ground and the trees. + +Crocodile in a sweater. + +A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school +uniform, standing under a cherry blossom tree with petals falling around her. The background shows a +traditional Japanese school with cherry blossoms in full bloom. + +A playful Labrador retriever puppy with a shiny, golden coat, chasing a red ball in a spacious backyard, with +green grass and a wooden fence. + +A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm +lights glowing from the windows, and a path of footprints leading to the front door. + +A highly detailed, high-quality image of the Banff National Park in Canada. The turquoise waters of Lake +Louise are surrounded by snow-capped mountains and dense pine forests. A wooden canoe is docked at the +edge of the lake. The sky is a clear, bright blue, and the air is crisp and fresh. + +A highly detailed, high-quality image of a Shih Tzu receiving a bath in a home bathroom. The dog is standing +in a tub, covered in suds, with a slightly wet and adorable look. The background includes bathroom fixtures, +towels, and a clean, tiled floor. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7647fc562eb544209951a1163bdd619dbbed29b1..0a2397e70942601b5d37bd0b370cb842e1feb7bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,19 @@ -huggingface_hub==0.22.2 -scikit-build-core -https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.81-cu124/llama_cpp_python-0.2.81-cp310-cp310-linux_x86_64.whl \ No newline at end of file +--find-links https://download.pytorch.org/whl/torch_stable.html +accelerate>=0.25.0 +torch==2.1.2+cu118 +torchvision==0.16.2+cu118 +transformers>=4.30.0 +numpy>=1.23.5 +kornia>=0.7.0 +insightface>=0.7.3 +opencv-python>=4.8.1.78 +tqdm>=4.66.1 +matplotlib>=3.7.4 +webdataset>=0.2.79 +wandb>=0.16.2 +munch>=4.0.0 +onnxruntime>=1.16.3 +einops>=0.7.0 +onnx2torch>=1.5.13 +warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git +torchtools @ git+https://github.com/pabloppp/pytorch-tools diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1331f6b933f63c99a6bdf074201fdb4b8f78c2 --- /dev/null +++ b/train/__init__.py @@ -0,0 +1,5 @@ +from .train_b import WurstCore as WurstCoreB +from .train_c import WurstCore as WurstCoreC +from .train_t2i import WurstCore as WurstCore_t2i +from .train_ultrapixel_control import WurstCore as WurstCore_control_lrguide +from .train_personalized import WurstCore as WurstCore_personalized \ No newline at end of file diff --git a/train/base.py b/train/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8a6ef306e40da8c9d8db33ceba2f8b2982a9a9 --- /dev/null +++ b/train/base.py @@ -0,0 +1,402 @@ +import yaml +import json +import torch +import wandb +import torchvision +import numpy as np +from torch import nn +from tqdm import tqdm +from abc import abstractmethod +from fractions import Fraction +import matplotlib.pyplot as plt +from dataclasses import dataclass +from torch.distributed import barrier +from torch.utils.data import DataLoader + +from gdf import GDF +from gdf import AdaptiveLossWeight + +from core import WarpCore +from core.data import setup_webdataset_path, MultiGetter, MultiFilter, Bucketeer +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary + +import webdataset as wds +from webdataset.handlers import warn_and_continue + +import transformers +transformers.utils.logging.set_verbosity_error() + + +class DataCore(WarpCore): + @dataclass(frozen=True) + class Config(WarpCore.Config): + image_size: int = EXPECTED_TRAIN + webdataset_path: str = EXPECTED_TRAIN + grad_accum_steps: int = EXPECTED_TRAIN + batch_size: int = EXPECTED_TRAIN + multi_aspect_ratio: list = None + + captions_getter: list = None + dataset_filters: list = None + + bucketeer_random_ratio: float = 0.05 + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + transforms: torchvision.transforms.Compose = EXPECTED + clip_preprocess: torchvision.transforms.Compose = EXPECTED + + @dataclass(frozen=True) + class Models(WarpCore.Models): + tokenizer: nn.Module = EXPECTED + text_model: nn.Module = EXPECTED + image_model: nn.Module = None + + config: Config + + def webdataset_path(self): + if isinstance(self.config.webdataset_path, str) and (self.config.webdataset_path.strip().startswith( + 'pipe:') or self.config.webdataset_path.strip().startswith('file:')): + return self.config.webdataset_path + else: + dataset_path = self.config.webdataset_path + if isinstance(self.config.webdataset_path, str) and self.config.webdataset_path.strip().endswith('.yml'): + with open(self.config.webdataset_path, 'r', encoding='utf-8') as file: + dataset_path = yaml.safe_load(file) + return setup_webdataset_path(dataset_path, cache_path=f"{self.config.experiment_id}_webdataset_cache.yml") + + def webdataset_preprocessors(self, extras: Extras): + def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x + + # CUSTOM CAPTIONS GETTER ----- + def get_caption(oc, c, p_og=0.05): # cog_contexual, cog_caption + if p_og > 0 and np.random.rand() < p_og and len(oc) > 0: + return identity(oc) + else: + return identity(c) + + captions_getter = MultiGetter(rules={ + ('old_caption', 'caption'): lambda oc, c: get_caption(json.loads(oc)['og_caption'], c, p_og=0.05) + }) + + return [ + ('jpg;png', + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None else extras.transforms, + 'images'), + ('txt', identity, 'captions') if self.config.captions_getter is None else ( + self.config.captions_getter[0], eval(self.config.captions_getter[1]), 'captions'), + ] + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.webdataset_path() + preprocessors = self.webdataset_preprocessors(extras) + + handler = warn_and_continue + dataset = wds.WebDataset( + dataset_path, resampled=True, handler=handler + ).select( + MultiFilter(rules={ + f[0]: eval(f[1]) for f in self.config.dataset_filters + }) if self.config.dataset_filters is not None else lambda _: True + ).shuffle(690, handler=handler).decode( + "pilrgb", handler=handler + ).to_tuple( + *[p[0] for p in preprocessors], handler=handler + ).map_tuple( + *[p[1] for p in preprocessors], handler=handler + ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)}) + + def identity(x): + return x + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=self.config.image_size ** 2, factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + if return_fields is None: + return_fields = ['clip_text', 'clip_text_pooled', 'clip_img'] + + captions = batch.get('captions', None) + images = batch.get('images', None) + batch_size = len(captions) + + text_embeddings = None + text_pooled_embeddings = None + if 'clip_text' in return_fields or 'clip_text_pooled' in return_fields: + if is_eval: + if is_unconditional: + captions_unpooled = ["" for _ in range(batch_size)] + else: + captions_unpooled = captions + else: + rand_idx = np.random.rand(batch_size) > 0.05 + captions_unpooled = [str(c) if keep else "" for c, keep in zip(captions, rand_idx)] + clip_tokens_unpooled = models.tokenizer(captions_unpooled, truncation=True, padding="max_length", + max_length=models.tokenizer.model_max_length, + return_tensors="pt").to(self.device) + text_encoder_output = models.text_model(**clip_tokens_unpooled, output_hidden_states=True) + if 'clip_text' in return_fields: + text_embeddings = text_encoder_output.hidden_states[-1] + if 'clip_text_pooled' in return_fields: + text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) + + image_embeddings = None + if 'clip_img' in return_fields: + image_embeddings = torch.zeros(batch_size, 768, device=self.device) + if images is not None: + images = images.to(self.device) + if is_eval: + if not is_unconditional and eval_image_embeds: + image_embeddings = models.image_model(extras.clip_preprocess(images)).image_embeds + else: + rand_idx = np.random.rand(batch_size) > 0.9 + if any(rand_idx): + image_embeddings[rand_idx] = models.image_model(extras.clip_preprocess(images[rand_idx])).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + return { + 'clip_text': text_embeddings, + 'clip_text_pooled': text_pooled_embeddings, + 'clip_img': image_embeddings + } + + +class TrainingCore(DataCore, WarpCore): + @dataclass(frozen=True) + class Config(DataCore.Config, WarpCore.Config): + updates: int = EXPECTED_TRAIN + backup_every: int = EXPECTED_TRAIN + save_every: int = EXPECTED_TRAIN + + # EMA UPDATE + ema_start_iters: int = None + ema_iters: int = None + ema_beta: float = None + + use_fsdp: bool = None + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(WarpCore.Info): + ema_loss: float = None + adaptive_loss: dict = None + + @dataclass(frozen=True) + class Models(WarpCore.Models): + generator: nn.Module = EXPECTED + generator_ema: nn.Module = None # optional + + @dataclass(frozen=True) + class Optimizers(WarpCore.Optimizers): + generator: any = EXPECTED + + @dataclass(frozen=True) + class Extras(WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + + info: Info + config: Config + + @abstractmethod + def forward_pass(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers, + schedulers: WarpCore.Schedulers): + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def models_to_save(self) -> list: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + @abstractmethod + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + raise NotImplementedError("This method needs to be overriden") + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + pbar = tqdm(range(start_iter, max_iters + 1)) if self.is_main_node else range(start_iter, + max_iters + 1) # <--- DDP + if 'generator' in self.models_to_save(): + models.generator.train() + for i in pbar: + # FORWARD PASS + loss, loss_adjusted = self.forward_pass(data, extras, models) + + # # BACKWARD PASS + grad_norm = self.backward_pass( + i % self.config.grad_accum_steps == 0 or i == max_iters, loss, loss_adjusted, + models, optimizers, schedulers + ) + self.info.iter = i + + # UPDATE EMA + if models.generator_ema is not None and i % self.config.ema_iters == 0: + update_weights_ema( + models.generator_ema, models.generator, + beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) + ) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan( + grad_norm.item()): + wandb.alert( + title=f"NaN value encountered in training run {self.info.wandb_run_id}", + text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", + wait_duration=60 * 30 + ) + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'raw_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + + pbar.set_postfix(logs) + if self.config.wandb_project is not None: + wandb.log(logs) + + if i == 1 or i % (self.config.save_every * self.config.grad_accum_steps) == 0 or i == max_iters: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + tqdm.write("Skipping sampling & checkpoint because the loss is NaN") + wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.wandb_run_id}", + text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + self.save_checkpoints(models, optimizers) + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + self.sample(models, data, extras) + + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): + barrier() + suffix = '' if suffix is None else suffix + self.save_info(self.info, suffix=suffix) + models_dict = models.to_dict() + optimizers_dict = optimizers.to_dict() + for key in self.models_to_save(): + model = models_dict[key] + if model is not None: + self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) + for key in optimizers_dict: + optimizer = optimizers_dict[key] + if optimizer is not None: + self.save_optimizer(optimizer, f'{key}_optim{suffix}', + fsdp_model=models_dict[key] if self.config.use_fsdp else None) + if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k") + torch.cuda.empty_cache() + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + if 'generator' in self.models_to_save(): + models.generator.eval() + with torch.no_grad(): + batch = next(data.iterator) + + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + latents = self.encode_latents(batch, models, extras) + noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + *_, (sampled, _, _) = extras.gdf.sample( + models.generator, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + *_, (sampled_ema, _, _) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, unconditions, device=self.device, **extras.sampling_configs + ) + else: + sampled_ema = sampled + + if self.is_main_node: + noised_images = torch.cat( + [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0) + pred_images = torch.cat( + [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0) + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in pred_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) + wandb.log({"Log": log_table}) + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) + plt.ylabel('Raw Loss') + plt.ylabel('LogSNR') + wandb.log({"Loss/LogSRN": plt}) + + if 'generator' in self.models_to_save(): + models.generator.train() diff --git a/train/dist_core.py b/train/dist_core.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4e9e670a3b853fac345618d3557d648d813902 --- /dev/null +++ b/train/dist_core.py @@ -0,0 +1,47 @@ +import os +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" diff --git a/train/train_b.py b/train/train_b.py new file mode 100644 index 0000000000000000000000000000000000000000..c3441a5841750a7c33b49756d2d60064a68d82d8 --- /dev/null +++ b/train/train_b.py @@ -0,0 +1,305 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import numpy as np + +import sys +import os +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_a import StageA + +from modules.stage_b import StageB +from modules.stage_b import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + shift: float = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3BB or 700M + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + stage_a_checkpoint_path: str = EXPECTED + effnet_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + stage_a: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 1.5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 10} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) if self.config.training else torchvision.transforms.CenterCrop(self.config.image_size) + ]) + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=None + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None): + images = batch.get('images', None) + + if images is not None: + images = images.to(self.device) + if is_eval and not is_unconditional: + effnet_embeddings = models.effnet(extras.effnet_preprocess(images)) + else: + if is_eval: + effnet_factor = 1 + else: + effnet_factor = np.random.uniform(0.5, 1) # f64 to f32 + effnet_height, effnet_width = int(((images.size(-2)*effnet_factor)//32)*32), int(((images.size(-1)*effnet_factor)//32)*32) + + effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height//32, effnet_width//32, device=self.device) + if not is_eval: + effnet_images = torchvision.transforms.functional.resize(images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) + rand_idx = np.random.rand(len(images)) <= 0.9 + if any(rand_idx): + effnet_embeddings[rand_idx] = models.effnet(extras.effnet_preprocess(effnet_images[rand_idx])) + else: + effnet_embeddings = None + + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text_pooled'] + ) + + return {'effnet': effnet_embeddings, 'clip': conditions['clip_text_pooled']} + + def setup_models(self, extras: Extras, skip_clip: bool = False) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False) + del effnet_checkpoint + + # vqGAN + stage_a = StageA().to(self.device) + stage_a_checkpoint = load_or_fail(self.config.stage_a_checkpoint_path) + stage_a.load_state_dict(stage_a_checkpoint if 'state_dict' not in stage_a_checkpoint else stage_a_checkpoint['state_dict']) + stage_a.eval().requires_grad_(False) + del stage_a_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3B': + generator = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) + if self.config.ema_start_iters is not None: + generator_ema = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) + elif self.config.model_version == '700M': + generator = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) + if self.config.ema_start_iters is not None: + generator_ema = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + if generator_ema is not None: + if loading_context is dummy_context: + generator_ema.load_state_dict(generator.state_dict()) + else: + for param_name, param in generator.state_dict().items(): + set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) + generator_ema = self.load_model(generator_ema, 'generator_ema') + generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) + + if self.config.use_fsdp: + fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) + generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + if generator_ema is not None: + generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + if skip_clip: + tokenizer = None + text_model = None + else: + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, stage_a=stage_a, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, + optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def _pyramid_noise(self, epsilon, size_range=None, levels=10, scale_mode='nearest'): + epsilon = epsilon.clone() + multipliers = [1] + for i in range(1, levels): + m = 0.75 ** i + h, w = epsilon.size(-2) // (2 ** i), epsilon.size(-2) // (2 ** i) + if size_range is None or (size_range[0] <= h <= size_range[1] or size_range[0] <= w <= size_range[1]): + offset = torch.randn(epsilon.size(0), epsilon.size(1), h, w, device=self.device) + epsilon = epsilon + torch.nn.functional.interpolate(offset, size=epsilon.shape[-2:], + mode=scale_mode) * m + multipliers.append(m) + if h <= 1 or w <= 1: + break + epsilon = epsilon / sum([m ** 2 for m in multipliers]) ** 0.5 + # epsilon = epsilon / epsilon.std() + return epsilon + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + epsilon = torch.randn_like(latents) + epsilon = self._pyramid_noise(epsilon, size_range=[1, 16]) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1, + epsilon=epsilon) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.stage_a.encode(images)[0] + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.stage_a.decode(latents.float()).clamp(0, 1) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_c.py b/train/train_c.py new file mode 100644 index 0000000000000000000000000000000000000000..c4490c6eebc3e1c5126dd13c53603872f1459a3e --- /dev/null +++ b/train/train_c.py @@ -0,0 +1,266 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler + +import sys +import os +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_c import StageC +from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: + generator_ema = StageC() + elif self.config.model_version == '1B': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + if self.config.ema_start_iters is not None: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + if generator_ema is not None: + if loading_context is dummy_context: + generator_ema.load_state_dict(generator.state_dict()) + else: + for param_name, param in generator.state_dict().items(): + set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) + generator_ema = self.load_model(generator_ema, 'generator_ema') + generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) + + if self.config.use_fsdp: + fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) + generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + if generator_ema is not None: + generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, previewer=previewer, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model, image_model=image_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + # Training loop -------------------------------- + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_c_lora.py b/train/train_c_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..8b83eee0f250e5359901d39b8d4052254cfff4fa --- /dev/null +++ b/train/train_c_lora.py @@ -0,0 +1,330 @@ +import torch +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler + +import sys +import os +import re +from dataclasses import dataclass + +from gdf import GDF, EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop + +from modules.effnet import EfficientNetEncoder +from modules.stage_c import StageC +from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken + +from train.base import DataCore, TrainingCore + +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +import functools +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager + + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + lora_checkpoint_path: str = None + + # LoRA STUFF + module_filters: list = EXPECTED + rank: int = EXPECTED + train_tokens: list = EXPECTED + + # gdf customization + adaptive_loss_weight: str = None + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + lora: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + lora: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED + class Info(TrainingCore.Info): + train_tokens: list = None + + @dataclass(frozen=True) + class Optimizers(TrainingCore.Optimizers, WarpCore.Optimizers): + generator: any = None + lora: any = EXPECTED + + # -------------------------------------------- + info: Info + config: Config + + # Extras: gdf, transforms and preprocessors -------------------------------- + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + # Data -------------------------------- + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + # Models, Optimizers & Schedulers setup -------------------------------- + def setup_models(self, extras: Extras) -> Models: + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 + + # EfficientNet encoder + effnet = EfficientNetEncoder().to(self.device) + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False) + del effnet_checkpoint + + # Previewer + previewer = Previewer().to(self.device) + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + with loading_context(): + # Diffusion models + if self.config.model_version == '3.6B': + generator = StageC() + elif self.config.model_version == '1B': + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + if self.config.generator_checkpoint_path is not None: + if loading_context is dummy_context: + generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + generator = generator.to(dtype).to(self.device) + generator = self.load_model(generator, 'generator') + + # if self.config.use_fsdp: + # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) + # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + # PREPARE LORA + update_tokens = [] + for tkn_regex, aggr_regex in self.config.train_tokens: + if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')): + # Insert new token + tokenizer.add_tokens([tkn_regex]) + # add new zeros embedding + new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1] + if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline + aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None] + if len(aggr_tokens) > 0: + new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True) + elif self.is_main_node: + print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.") + text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([ + text_model.text_model.embeddings.token_embedding.weight.data, new_embedding + ], dim=0) + selected_tokens = [len(tokenizer.vocab) - 1] + else: + selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None] + update_tokens += selected_tokens + update_tokens = list(set(update_tokens)) # remove duplicates + + apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens) + apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank) + text_model.text_model.to(self.device) + generator.to(self.device) + lora = nn.ModuleDict() + lora['embeddings'] = text_model.text_model.embeddings.token_embedding.parametrizations.weight[0] + lora['weights'] = nn.ModuleList() + for module in generator.modules(): + if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)): + lora['weights'].append(module) + + self.info.train_tokens = [(i, tokenizer.decode(i)) for i in update_tokens] + if self.is_main_node: + print("Updating tokens:", self.info.train_tokens) + print(f"LoRA training {len(lora['weights'])} layers") + + if self.config.lora_checkpoint_path is not None: + lora_checkpoint = load_or_fail(self.config.lora_checkpoint_path) + lora.load_state_dict(lora_checkpoint if 'state_dict' not in lora_checkpoint else lora_checkpoint['state_dict']) + + lora = self.load_model(lora, 'lora') + lora.to(self.device).train().requires_grad_(True) + if self.config.use_fsdp: + # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) + fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken]) + lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) + + return self.Models( + effnet=effnet, previewer=previewer, + generator=generator, generator_ema=None, + lora=lora, + tokenizer=tokenizer, text_model=text_model, image_model=image_model + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: + optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) + optimizer = self.load_optimizer(optimizer, 'lora_optim', + fsdp_model=models.lora if self.config.use_fsdp else None) + return self.Optimizers(generator=None, lora=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.lora, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(lora=scheduler) + + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + batch = next(data.iterator) + + conditions = self.get_conditions(batch, models, extras) + with torch.no_grad(): + latents = self.encode_latents(batch, models, extras) + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred = models.generator(noised, noise_cond, **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + if update: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(models.lora.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if optimizers_dict[k] is not None and k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['lora'] + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + models.lora.eval() + super().sample(models, data, extras) + models.lora.train(), models.generator.eval() + + def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + images = batch['images'].to(self.device) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + +if __name__ == '__main__': + print("Launching Script") + warpcore = WurstCore( + config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, + device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + ) + warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore() diff --git a/train/train_personalized.py b/train/train_personalized.py new file mode 100644 index 0000000000000000000000000000000000000000..5161b7c621a0eb9daf9d0f0566322bbeed646284 --- /dev/null +++ b/train/train_personalized.py @@ -0,0 +1,899 @@ +import torch +import json +import yaml +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import torch.multiprocessing as mp +import os +import numpy as np +import re +import sys +sys.path.append(os.path.abspath('./')) + +from dataclasses import dataclass +from torch.distributed import init_process_group, destroy_process_group, barrier +from gdf import GDF_dual_fixlrt as GDF +from gdf import EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop +from fractions import Fraction +from modules.effnet import EfficientNetEncoder +from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.common_ckpt import GlobalResponseNorm +from modules.previewer import Previewer +from core.data import Bucketeer +from train.base import DataCore, TrainingCore +from tqdm import tqdm +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager +from train.dist_core import * +import glob +from torch.utils.data import DataLoader, Dataset +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from PIL import Image +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from core.utils import Base +import torch.nn.functional as F +import functools +import math +import copy +import random +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken + +Image.MAX_IMAGE_PIXELS = None +torch.manual_seed(23) +random.seed(23) +np.random.seed(23) +#7978026 + +class Null_Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + pass + + + + +def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x +def check_nan_inmodel(model, meta=''): + for name, param in model.named_parameters(): + if torch.isnan(param).any(): + print(f"nan detected in {name}", meta) + return True + print('no nan', meta) + return False +class mydist_dataset(Dataset): + def __init__(self, rootpath, tmp_prompt, img_processor=None): + + self.img_pathlist = glob.glob(os.path.join(rootpath, '*.jpg')) + self.img_pathlist = self.img_pathlist * 100000 + self.img_processor = img_processor + self.length = len( self.img_pathlist) + self.caption = tmp_prompt + + + def __getitem__(self, idx): + + imgpath = self.img_pathlist[idx] + txt = self.caption + + + + + try: + img = Image.open(imgpath).convert('RGB') + w, h = img.size + if self.img_processor is not None: + img = self.img_processor(img) + + except: + print('exception', imgpath) + return self.__getitem__(random.randint(0, self.length -1 ) ) + return dict(captions=txt, images=img) + def __len__(self): + return self.length +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + generator_checkpoint_path: str = None + ultrapixel_path: str = EXPECTED + + # gdf customization + adaptive_loss_weight: str = None + + # LoRA STUFF + module_filters: list = EXPECTED + rank: int = EXPECTED + train_tokens: list = EXPECTED + use_ddp: bool=EXPECTED + tmp_prompt: str=EXPECTED + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + sampler: DistributedSampler = EXPECTED + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + train_norm: nn.Module = EXPECTED + train_lora: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: # configure model + + + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 + + # EfficientNet encoderin + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: # default setting + generator_ema = StageC() + elif self.config.model_version == '1B': + print('in line 155 1b light model', self.config.model_version ) + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + + if self.config.ema_start_iters is not None and self.config.training: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + + + if loading_context is dummy_context: + generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + + generator._init_extra_parameter() + generator = generator.to(torch.bfloat16).to(self.device) + + train_norm = nn.ModuleList() + + + cnt_norm = 0 + for mm in generator.modules(): + if isinstance(mm, GlobalResponseNorm): + + train_norm.append(Null_Model()) + cnt_norm += 1 + + + + + train_norm.append(generator.agg_net) + train_norm.append(generator.agg_net_up) + sdd = torch.load(self.config.ultrapixel_path, map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_norm.load_state_dict(collect_sd) + + + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + # PREPARE LORA + train_lora = nn.ModuleList() + update_tokens = [] + for tkn_regex, aggr_regex in self.config.train_tokens: + if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')): + # Insert new token + tokenizer.add_tokens([tkn_regex]) + # add new zeros embedding + new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1] + if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline + aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None] + if len(aggr_tokens) > 0: + new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True) + elif self.is_main_node: + print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.") + text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([ + text_model.text_model.embeddings.token_embedding.weight.data, new_embedding + ], dim=0) + selected_tokens = [len(tokenizer.vocab) - 1] + else: + selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None] + update_tokens += selected_tokens + update_tokens = list(set(update_tokens)) # remove duplicates + + apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens) + + apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank) + for module in generator.modules(): + if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)): + train_lora.append(module) + + + train_lora.append(text_model.text_model.embeddings.token_embedding.parametrizations.weight[0]) + + if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors')): + sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors'), map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_lora.load_state_dict(collect_sd, strict=True) + + + train_norm.to(self.device).train().requires_grad_(True) + + if generator_ema is not None: + + generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + generator_ema._init_extra_parameter() + pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') + if os.path.exists(pretrained_pth): + generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) + + generator_ema.eval().requires_grad_(False) + + check_nan_inmodel(generator, 'generator') + + + + if self.config.use_ddp and self.config.training: + + train_lora = DDP(train_lora, device_ids=[self.device], find_unused_parameters=True) + + + + return self.Models( + effnet=effnet, previewer=previewer, train_norm = train_norm, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model, image_model=image_model, + train_lora=train_lora + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + params = [] + params += list(models.train_lora.module.parameters()) + optimizer = optim.AdamW(params, lr=self.config.lr) + + return self.Optimizers(generator=optimizer) + + def ema_update(self, ema_model, source_model, beta): + for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): + param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) + + def sync_ema(self, ema_model): + print('sync ema', torch.distributed.get_world_size()) + for param in ema_model.parameters(): + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) + param.data /= torch.distributed.get_world_size() + def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + optimizer = optim.AdamW( + models.generator.up_blocks.parameters() , + lr=self.config.lr) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.config.webdataset_path + + + dataset = mydist_dataset(dataset_path, self.config.tmp_prompt, \ + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ + else extras.transforms) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + + sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None, + sampler = sampler + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) + + + + + + def setup_ddp(self, experiment_id, single_gpu=False, rank=0): + + if not single_gpu: + local_rank = rank + process_id = rank + world_size = get_world_size() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '14443' + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=local_rank, + world_size=world_size, + # init_method=init_method, + ) + print(f"[GPU {process_id}] READY") + else: + self.is_main_node = rank == 0 + self.process_id = rank + self.device = torch.device('cuda:0') + self.world_size = 1 + print("Running in single thread, DDP not enabled.") + # Training loop -------------------------------- + def get_target_lr_size(self, ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w * 32) + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + + batch = data + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + shape_lr = self.get_target_lr_size(ratio) + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) + + + + flag_lr = random.random() < 0.5 or self.info.iter <5000 + + if flag_lr: + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1) + else: + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + if not flag_lr: + noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = \ + extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + + if not flag_lr: + with torch.no_grad(): + _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) + + + pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if not flag_lr else None , **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + + loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps + + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + return loss, loss_adjusted + + def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + + if update: + + torch.distributed.barrier() + loss_adjusted.backward() + + grad_norm = nn.utils.clip_grad_norm_(models.train_lora.module.parameters(), 1.0) + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + + loss_adjusted.backward() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: + + images = batch['images'].to(self.device) + if target_size is not None: + images = F.interpolate(images, target_size) + + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): + + self.is_main_node = (rank == 0) + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) + self.info: self.Info = self.setup_info() + print('in line 292', self.config.experiment_id, rank, world_size <= 1) + p = [i for i in range( 2 * 768 // 32)] + p = [num / sum(p) for num in p] + self.rand_pro = p + self.res_list = [o for o in range(800, 2336, 32)] + + + + def __call__(self, single_gpu=False): + + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device ) + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + if self.config.wandb_project is not None: + wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") + + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + if 'generator' in self.models_to_save(): + models.generator.train() + + iter_cnt = 0 + epoch_cnt = 0 + models.train_norm.train() + while True: + epoch_cnt += 1 + if self.world_size > 1: + + data.sampler.set_epoch(epoch_cnt) + for ggg in range(len(data.dataloader)): + iter_cnt += 1 + # FORWARD PASS + + loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) + + + # # BACKWARD PASS + + grad_norm = self.backward_pass( + iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, + models, optimizers, schedulers + ) + + + + self.info.iter = iter_cnt + + + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + + if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'backward_loss': loss_adjusted.mean().item(), + + 'ema_loss': self.info.ema_loss, + 'raw_ori_loss': loss.mean().item(), + + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + + + print(iter_cnt, max_iters, logs, epoch_cnt, ) + + + + + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: + + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + + + if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: + print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) + torch.save(models.train_lora.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_lora.safetensors') + + + torch.save(models.train_lora.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_lora_{iter_cnt}.safetensors') + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: + + if self.is_main_node: + + self.sample(models, data, extras) + if False: + param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()} + threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10% + important_params = [name for name, change in param_changes.items() if change > threshold] + print(important_params, threshold, len(param_changes), self.process_id) + json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4) + + + if self.info.iter >= max_iters: + break + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + + + models.generator.eval() + models.train_norm.eval() + with torch.no_grad(): + batch = next(data.iterator) + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + + shape_lr = self.get_target_lr_size(ratio) + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) + + if self.is_main_node: + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + + if self.is_main_node: + print('sampling results hr latent shape ', latents.shape, 'lr latent shape', latents_lr.shape, ) + noised_images = torch.cat( + [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) + + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + noised_images_lr = torch.cat( + [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) + + sampled_images_lr = torch.cat( + [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) + sampled_images_ema_lr = torch.cat( + [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + collage_img_lr = torch.cat([ + torch.cat([i for i in images_lr.cpu()], dim=-1), + torch.cat([i for i in noised_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) + wandb.log({"Log": log_table}) + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) + plt.ylabel('Raw Loss') + plt.ylabel('LogSNR') + wandb.log({"Loss/LogSRN": plt}) + + + models.generator.train() + models.train_norm.train() + print('finish sampling') + + + + def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): + + + models.generator.eval() + models.trans_inr.eval() + with torch.no_grad(): + + if self.is_main_node: + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + hr_shape, lr_shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + + *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + + return sampled, sampled_lr +def main_worker(rank, cfg): + print("Launching Script in main worker") + warpcore = WurstCore( + config_file_path=cfg, rank=rank, world_size = get_world_size() + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore(get_world_size()==1) + +if __name__ == '__main__': + + if get_master_ip() == "127.0.0.1": + + mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) + else: + main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, ) diff --git a/train/train_t2i.py b/train/train_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..456ca4b0dd1fe8e1fc18e3e5c940797439071d1f --- /dev/null +++ b/train/train_t2i.py @@ -0,0 +1,807 @@ +import torch +import json +import yaml +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import torch.multiprocessing as mp +import numpy as np +import os +import sys +sys.path.append(os.path.abspath('./')) +from dataclasses import dataclass +from torch.distributed import init_process_group, destroy_process_group, barrier +from gdf import GDF_dual_fixlrt as GDF +from gdf import EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop +from fractions import Fraction +from modules.effnet import EfficientNetEncoder + +from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.previewer import Previewer +from core.data import Bucketeer +from train.base import DataCore, TrainingCore +from tqdm import tqdm +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager +from train.dist_core import * +import glob +from torch.utils.data import DataLoader, Dataset +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from PIL import Image +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from core.utils import Base +from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm +import torch.nn.functional as F +import functools +import math +import copy +import random +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken +Image.MAX_IMAGE_PIXELS = None +torch.manual_seed(23) +random.seed(23) +np.random.seed(23) +#7978026 + +class Null_Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + pass + + + + +def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x +def check_nan_inmodel(model, meta=''): + for name, param in model.named_parameters(): + if torch.isnan(param).any(): + print(f"nan detected in {name}", meta) + return True + print('no nan', meta) + return False +class mydist_dataset(Dataset): + def __init__(self, rootpath, img_processor=None): + + self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg')) + self.img_processor = img_processor + self.length = len( self.img_pathlist) + + + + def __getitem__(self, idx): + + imgpath = self.img_pathlist[idx] + json_file = imgpath.replace('.jpg', '.json') + + with open(json_file, 'r') as file: + info = json.load(file) + txt = info['caption'] + if txt is None: + txt = ' ' + try: + img = Image.open(imgpath).convert('RGB') + w, h = img.size + if self.img_processor is not None: + img = self.img_processor(img) + + except: + print('exception', imgpath) + return self.__getitem__(random.randint(0, self.length -1 ) ) + return dict(captions=txt, images=img) + def __len__(self): + return self.length + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + + generator_checkpoint_path: str = None + + # gdf customization + adaptive_loss_weight: str = None + use_ddp: bool=EXPECTED + + + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + sampler: DistributedSampler = EXPECTED + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + train_norm: nn.Module = EXPECTED + + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess + ) + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: # configure model + + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 + + # EfficientNet encoderin + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: # default setting + generator_ema = StageC() + elif self.config.model_version == '1B': + print('in line 155 1b light model', self.config.model_version ) + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + + if self.config.ema_start_iters is not None and self.config.training: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + + + if loading_context is dummy_context: + generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + + generator._init_extra_parameter() + generator = generator.to(torch.bfloat16).to(self.device) + + + train_norm = nn.ModuleList() + cnt_norm = 0 + for mm in generator.modules(): + if isinstance(mm, GlobalResponseNorm): + + train_norm.append(Null_Model()) + cnt_norm += 1 + + train_norm.append(generator.agg_net) + train_norm.append(generator.agg_net_up) + total = sum([ param.nelement() for param in train_norm.parameters()]) + print('Trainable parameter', total / 1048576) + + if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')): + sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_norm.load_state_dict(collect_sd, strict=True) + + + train_norm.to(self.device).train().requires_grad_(True) + train_norm_ema = copy.deepcopy(train_norm) + train_norm_ema.to(self.device).eval().requires_grad_(False) + if generator_ema is not None: + + generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + generator_ema._init_extra_parameter() + + + pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') + if os.path.exists(pretrained_pth): + print(pretrained_pth, 'exists') + generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) + + + generator_ema.eval().requires_grad_(False) + + + + + check_nan_inmodel(generator, 'generator') + + + + if self.config.use_ddp and self.config.training: + + train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True) + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + return self.Models( + effnet=effnet, previewer=previewer, train_norm = train_norm, + generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model, + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + params = [] + params += list(models.train_norm.module.parameters()) + + optimizer = optim.AdamW(params, lr=self.config.lr) + + return self.Optimizers(generator=optimizer) + + def ema_update(self, ema_model, source_model, beta): + for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): + param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) + + def sync_ema(self, ema_model): + for param in ema_model.parameters(): + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) + param.data /= torch.distributed.get_world_size() + def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + optimizer = optim.AdamW( + models.generator.up_blocks.parameters() , + lr=self.config.lr) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.config.webdataset_path + dataset = mydist_dataset(dataset_path, \ + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ + else extras.transforms) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + + sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None, + sampler = sampler + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) + + + def models_to_save(self): + pass + def setup_ddp(self, experiment_id, single_gpu=False, rank=0): + + if not single_gpu: + local_rank = rank + process_id = rank + world_size = get_world_size() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '41443' + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=local_rank, + world_size=world_size, + ) + print(f"[GPU {process_id}] READY") + else: + self.is_main_node = rank == 0 + self.process_id = rank + self.device = torch.device('cuda:0') + self.world_size = 1 + print("Running in single thread, DDP not enabled.") + # Training loop -------------------------------- + def get_target_lr_size(self, ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w * 32) + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + #batch = next(data.iterator) + batch = data + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + shape_lr = self.get_target_lr_size(ratio) + #print('in line 485', shape_lr, ratio, batch['images'].shape) + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) + + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + # 768 1536 + require_cond = True + + with torch.no_grad(): + _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) + + + pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + + loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps + + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + + + if update: + + torch.distributed.barrier() + loss_adjusted.backward() + + grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0) + + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + + loss_adjusted.backward() + + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + + def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: + + images = batch['images'].to(self.device) + if target_size is not None: + images = F.interpolate(images, target_size) + + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): + + self.is_main_node = (rank == 0) + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) + self.info: self.Info = self.setup_info() + + + + def __call__(self, single_gpu=False): + + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + + + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + + models.generator.train() + + iter_cnt = 0 + epoch_cnt = 0 + models.train_norm.train() + while True: + epoch_cnt += 1 + if self.world_size > 1: + + data.sampler.set_epoch(epoch_cnt) + for ggg in range(len(data.dataloader)): + iter_cnt += 1 + loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) + grad_norm = self.backward_pass( + iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, + models, optimizers, schedulers + ) + + self.info.iter = iter_cnt + + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss) + if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + print(f" NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'backward_loss': loss_adjusted.mean().item(), + 'ema_loss': self.info.ema_loss, + 'raw_ori_loss': loss.mean().item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + if iter_cnt % (self.config.save_every) == 0: + + print(iter_cnt, max_iters, logs, epoch_cnt, ) + + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: + + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + + + + if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: + print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors') + + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors') + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: + + if self.is_main_node: + + self.sample(models, data, extras) + + + if self.info.iter >= max_iters: + break + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + + + models.generator.eval() + models.train_norm.eval() + with torch.no_grad(): + batch = next(data.iterator) + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + + shape_lr = self.get_target_lr_size(ratio) + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) + + + if self.is_main_node: + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + + + + if self.is_main_node: + print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, ) + noised_images = torch.cat( + [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) + + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) + + + noised_images_lr = torch.cat( + [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) + + sampled_images_lr = torch.cat( + [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + ], dim=-2) + + collage_img_lr = torch.cat([ + torch.cat([i for i in images_lr.cpu()], dim=-1), + torch.cat([i for i in noised_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') + + + models.generator.train() + models.train_norm.train() + print('finish sampling') + + + + def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): + + + models.generator.eval() + + with torch.no_grad(): + + if self.is_main_node: + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, conditions, + hr_shape, lr_shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + + *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( + models.generator_ema, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + return sampled, sampled_lr +def main_worker(rank, cfg): + print("Launching Script in main worker") + + warpcore = WurstCore( + config_file_path=cfg, rank=rank, world_size = get_world_size() + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore(get_world_size()==1) + +if __name__ == '__main__': + print('launch multi process') + # os.environ["OMP_NUM_THREADS"] = "1" + # os.environ["MKL_NUM_THREADS"] = "1" + #dist.init_process_group(backend="nccl") + #torch.backends.cudnn.benchmark = True +#train/train_c_my.py + #mp.set_sharing_strategy('file_system') + + if get_master_ip() == "127.0.0.1": + # manually launch distributed processes + mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) + else: + main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, ) diff --git a/train/train_ultrapixel_control.py b/train/train_ultrapixel_control.py new file mode 100644 index 0000000000000000000000000000000000000000..cd67965973a85ed1d72c164dd0e8970f8b5ce277 --- /dev/null +++ b/train/train_ultrapixel_control.py @@ -0,0 +1,928 @@ +import torch +import json +import yaml +import torchvision +from torch import nn, optim +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection +from warmup_scheduler import GradualWarmupScheduler +import torch.multiprocessing as mp +import numpy as np +import sys + +import os +from dataclasses import dataclass +from torch.distributed import init_process_group, destroy_process_group, barrier +from gdf import GDF_dual_fixlrt as GDF +from gdf import EpsilonTarget, CosineSchedule +from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight +from torchtools.transforms import SmartCrop +from fractions import Fraction +from modules.effnet import EfficientNetEncoder + +from modules.model_4stage_lite import StageC + +from modules.model_4stage_lite import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock +from modules.common_ckpt import GlobalResponseNorm +from modules.previewer import Previewer +from core.data import Bucketeer +from train.base import DataCore, TrainingCore +from tqdm import tqdm +from core import WarpCore +from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail +from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from contextlib import contextmanager +from train.dist_core import * +import glob +from torch.utils.data import DataLoader, Dataset +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from PIL import Image +from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary +from core.utils import Base +from modules.common import LayerNorm2d +import torch.nn.functional as F +import functools +import math +import copy +import random +from modules.lora import apply_lora, apply_retoken, LoRA, ReToken +from modules import ControlNet, ControlNetDeliverer +from modules import controlnet_filters + +Image.MAX_IMAGE_PIXELS = None +torch.manual_seed(8432) +random.seed(8432) +np.random.seed(8432) +#7978026 + +class Null_Model(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + pass + + +def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x +def check_nan_inmodel(model, meta=''): + for name, param in model.named_parameters(): + if torch.isnan(param).any(): + print(f"nan detected in {name}", meta) + return True + print('no nan', meta) + return False + + +class WurstCore(TrainingCore, DataCore, WarpCore): + @dataclass(frozen=True) + class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): + # TRAINING PARAMS + lr: float = EXPECTED_TRAIN + warmup_updates: int = EXPECTED_TRAIN + dtype: str = None + + # MODEL VERSION + model_version: str = EXPECTED # 3.6B or 1B + clip_image_model_name: str = 'openai/clip-vit-large-patch14' + clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' + + # CHECKPOINT PATHS + effnet_checkpoint_path: str = EXPECTED + previewer_checkpoint_path: str = EXPECTED + #trans_inr_ckpt: str = EXPECTED + generator_checkpoint_path: str = None + controlnet_checkpoint_path: str = EXPECTED + + # controlnet settings + controlnet_blocks: list = EXPECTED + controlnet_filter: str = EXPECTED + controlnet_filter_params: dict = None + controlnet_bottleneck_mode: str = None + + + # gdf customization + adaptive_loss_weight: str = None + + #module_filters: list = EXPECTED + #rank: int = EXPECTED + @dataclass(frozen=True) + class Data(Base): + dataset: Dataset = EXPECTED + dataloader: DataLoader = EXPECTED + iterator: any = EXPECTED + sampler: DistributedSampler = EXPECTED + + @dataclass(frozen=True) + class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): + effnet: nn.Module = EXPECTED + previewer: nn.Module = EXPECTED + train_norm: nn.Module = EXPECTED + train_norm_ema: nn.Module = EXPECTED + controlnet: nn.Module = EXPECTED + + @dataclass(frozen=True) + class Schedulers(WarpCore.Schedulers): + generator: any = None + + @dataclass(frozen=True) + class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): + gdf: GDF = EXPECTED + sampling_configs: dict = EXPECTED + effnet_preprocess: torchvision.transforms.Compose = EXPECTED + controlnet_filter: controlnet_filters.BaseFilter = EXPECTED + + info: TrainingCore.Info + config: Config + + def setup_extras_pre(self) -> Extras: + gdf = GDF( + schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=VPScaler(), target=EpsilonTarget(), + noise_cond=CosineTNoiseCond(), + loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), + ) + sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} + + if self.info.adaptive_loss is not None: + gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + effnet_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ) + ]) + + clip_preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) + ) + ]) + + if self.config.training: + transforms = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), + SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) + ]) + else: + transforms = None + controlnet_filter = getattr(controlnet_filters, self.config.controlnet_filter)( + self.device, + **(self.config.controlnet_filter_params if self.config.controlnet_filter_params is not None else {}) + ) + + return self.Extras( + gdf=gdf, + sampling_configs=sampling_configs, + transforms=transforms, + effnet_preprocess=effnet_preprocess, + clip_preprocess=clip_preprocess, + controlnet_filter=controlnet_filter + ) + def get_cnet(self, batch: dict, models: Models, extras: Extras, cnet_input=None, target_size=None, **kwargs): + images = batch['images'] + if target_size is not None: + images = Image.resize(images, target_size) + with torch.no_grad(): + if cnet_input is None: + cnet_input = extras.controlnet_filter(images, **kwargs) + if isinstance(cnet_input, tuple): + cnet_input, cnet_input_preview = cnet_input + else: + cnet_input_preview = cnet_input + cnet_input, cnet_input_preview = cnet_input.to(self.device), cnet_input_preview.to(self.device) + cnet = models.controlnet(cnet_input) + return cnet, cnet_input_preview + + def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, + eval_image_embeds=False, return_fields=None): + conditions = super().get_conditions( + batch, models, extras, is_eval, is_unconditional, + eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] + ) + return conditions + + def setup_models(self, extras: Extras) -> Models: # configure model + + + dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 + + # EfficientNet encoderin + effnet = EfficientNetEncoder() + effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) + effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) + effnet.eval().requires_grad_(False).to(self.device) + del effnet_checkpoint + + # Previewer + previewer = Previewer() + previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) + previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) + previewer.eval().requires_grad_(False).to(self.device) + del previewer_checkpoint + + @contextmanager + def dummy_context(): + yield None + + loading_context = dummy_context if self.config.training else init_empty_weights + + # Diffusion models + with loading_context(): + generator_ema = None + if self.config.model_version == '3.6B': + generator = StageC() + if self.config.ema_start_iters is not None: # default setting + generator_ema = StageC() + elif self.config.model_version == '1B': + + generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + + if self.config.ema_start_iters is not None and self.config.training: + generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) + else: + raise ValueError(f"Unknown model version {self.config.model_version}") + + + + if loading_context is dummy_context: + generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) + else: + for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): + set_module_tensor_to_device(generator, param_name, "cpu", value=param) + + generator._init_extra_parameter() + + + + + generator = generator.to(torch.bfloat16).to(self.device) + + train_norm = nn.ModuleList() + + + cnt_norm = 0 + for mm in generator.modules(): + if isinstance(mm, GlobalResponseNorm): + + train_norm.append(Null_Model()) + cnt_norm += 1 + + + + + train_norm.append(generator.agg_net) + train_norm.append(generator.agg_net_up) + + + + + if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')): + sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu') + collect_sd = {} + for k, v in sdd.items(): + collect_sd[k[7:]] = v + train_norm.load_state_dict(collect_sd, strict=True) + + + train_norm.to(self.device).train().requires_grad_(True) + train_norm_ema = copy.deepcopy(train_norm) + train_norm_ema.to(self.device).eval().requires_grad_(False) + if generator_ema is not None: + + generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) + generator_ema._init_extra_parameter() + + pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') + if os.path.exists(pretrained_pth): + print(pretrained_pth, 'exists') + generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) + + generator_ema.eval().requires_grad_(False) + + check_nan_inmodel(generator, 'generator') + + + + if self.config.use_fsdp and self.config.training: + train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True) + + + # CLIP encoders + tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) + text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) + image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) + + controlnet = ControlNet( + c_in=extras.controlnet_filter.num_channels(), + proj_blocks=self.config.controlnet_blocks, + bottleneck_mode=self.config.controlnet_bottleneck_mode + ) + controlnet = controlnet.to(dtype).to(self.device) + controlnet = self.load_model(controlnet, 'controlnet') + controlnet.backbone.eval().requires_grad_(True) + + + return self.Models( + effnet=effnet, previewer=previewer, train_norm = train_norm, + generator=generator, generator_ema=generator_ema, + tokenizer=tokenizer, text_model=text_model, image_model=image_model, + train_norm_ema=train_norm_ema, controlnet =controlnet + ) + + def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + +# + + params = [] + params += list(models.train_norm.module.parameters()) + + optimizer = optim.AdamW(params, lr=self.config.lr) + + return self.Optimizers(generator=optimizer) + + def ema_update(self, ema_model, source_model, beta): + for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): + param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) + + def sync_ema(self, ema_model): + print('sync ema', torch.distributed.get_world_size()) + for param in ema_model.parameters(): + torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) + param.data /= torch.distributed.get_world_size() + def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: + + + optimizer = optim.AdamW( + models.generator.up_blocks.parameters() , + lr=self.config.lr) + optimizer = self.load_optimizer(optimizer, 'generator_optim', + fsdp_model=models.generator if self.config.use_fsdp else None) + return self.Optimizers(generator=optimizer) + + def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: + scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) + scheduler.last_epoch = self.info.total_steps + return self.Schedulers(generator=scheduler) + + def setup_data(self, extras: Extras) -> WarpCore.Data: + # SETUP DATASET + dataset_path = self.config.webdataset_path + print('in line 96', dataset_path, type(dataset_path)) + + dataset = mydist_dataset(dataset_path, \ + torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ + else extras.transforms) + + # SETUP DATALOADER + real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) + print('in line 119', self.process_id, real_batch_size) + sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True, + collate_fn=identity if self.config.multi_aspect_ratio is not None else None, + sampler = sampler + ) + if self.is_main_node: + print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") + + if self.config.multi_aspect_ratio is not None: + aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] + dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, + ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, + interpolate_nearest=False) # , use_smartcrop=True) + else: + + dataloader_iterator = iter(dataloader) + + return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) + + + + + + def setup_ddp(self, experiment_id, single_gpu=False, rank=0): + + if not single_gpu: + local_rank = rank + process_id = rank + world_size = get_world_size() + + self.process_id = process_id + self.is_main_node = process_id == 0 + self.device = torch.device(local_rank) + self.world_size = world_size + + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '41443' + torch.cuda.set_device(local_rank) + init_process_group( + backend="nccl", + rank=local_rank, + world_size=world_size, + # init_method=init_method, + ) + print(f"[GPU {process_id}] READY") + else: + self.is_main_node = rank == 0 + self.process_id = rank + self.device = torch.device('cuda:0') + self.world_size = 1 + print("Running in single thread, DDP not enabled.") + # Training loop -------------------------------- + def get_target_lr_size(self, ratio, std_size=24): + w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) + return (h * 32 , w * 32) + def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): + #batch = next(data.iterator) + batch = data + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + shape_lr = self.get_target_lr_size(ratio) + + with torch.no_grad(): + conditions = self.get_conditions(batch, models, extras) + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) + + noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) + noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + + require_cond = True + + with torch.no_grad(): + _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) + + + pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions) + loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) + + loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps + # + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + extras.gdf.loss_weight.update_buckets(logSNR, loss) + + return loss, loss_adjusted + + def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): + + if update: + + torch.distributed.barrier() + loss_adjusted.backward() + + + grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0) + + optimizers_dict = optimizers.to_dict() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].step() + schedulers_dict = schedulers.to_dict() + for k in schedulers_dict: + if k != 'training': + schedulers_dict[k].step() + for k in optimizers_dict: + if k != 'training': + optimizers_dict[k].zero_grad(set_to_none=True) + self.info.total_steps += 1 + else: + #print('in line 457', loss_adjusted) + loss_adjusted.backward() + #torch.distributed.barrier() + grad_norm = torch.tensor(0.0).to(self.device) + + return grad_norm + + def models_to_save(self): + return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema'] + + def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: + + images = batch['images'].to(self.device) + if target_size is not None: + images = F.interpolate(images, target_size) + #images = apply_degradations(images) + return models.effnet(extras.effnet_preprocess(images)) + + def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: + return models.previewer(latents) + + def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): + # Temporary setup, will be overriden by setup_ddp if required + # self.device = device + # self.process_id = 0 + # self.is_main_node = True + # self.world_size = 1 + # ---- + # self.world_size = world_size + # self.process_id = rank + # self.device=device + self.is_main_node = (rank == 0) + self.config: self.Config = self.setup_config(config_file_path, config_dict, training) + self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) + self.info: self.Info = self.setup_info() + print('in line 292', self.config.experiment_id, rank, world_size <= 1) + p = [i for i in range( 2 * 768 // 32)] + p = [num / sum(p) for num in p] + self.rand_pro = p + self.res_list = [o for o in range(800, 2336, 32)] + + #[32, 40, 48] + #in line 292 stage_c_3b_finetuning False + + def __call__(self, single_gpu=False): + # this will change the device to the CUDA rank + #self.setup_wandb() + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + if self.is_main_node: + print() + print("**STARTIG JOB WITH CONFIG:**") + print(yaml.dump(self.config.to_dict(), default_flow_style=False)) + print("------------------------------------") + print() + print("**INFO:**") + print(yaml.dump(vars(self.info), default_flow_style=False)) + print("------------------------------------") + print() + print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device ) + # SETUP STUFF + extras = self.setup_extras_pre() + assert extras is not None, "setup_extras_pre() must return a DTO" + + + + data = self.setup_data(extras) + assert data is not None, "setup_data() must return a DTO" + if self.is_main_node: + print("**DATA:**") + print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + models = self.setup_models(extras) + assert models is not None, "setup_models() must return a DTO" + if self.is_main_node: + print("**MODELS:**") + print(yaml.dump({ + k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() + }, default_flow_style=False)) + print("------------------------------------") + print() + + + + optimizers = self.setup_optimizers(extras, models) + assert optimizers is not None, "setup_optimizers() must return a DTO" + if self.is_main_node: + print("**OPTIMIZERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + schedulers = self.setup_schedulers(extras, models, optimizers) + assert schedulers is not None, "setup_schedulers() must return a DTO" + if self.is_main_node: + print("**SCHEDULERS:**") + print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + + post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) + assert post_extras is not None, "setup_extras_post() must return a DTO" + extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) + if self.is_main_node: + print("**EXTRAS:**") + print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) + print("------------------------------------") + print() + # ------- + + # TRAIN + if self.is_main_node: + print("**TRAINING STARTING...**") + self.train(data, extras, models, optimizers, schedulers) + + if single_gpu is False: + barrier() + destroy_process_group() + if self.is_main_node: + print() + print("------------------------------------") + print() + print("**TRAINING COMPLETE**") + if self.config.wandb_project is not None: + wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") + + + def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, + schedulers: WarpCore.Schedulers): + start_iter = self.info.iter + 1 + max_iters = self.config.updates * self.config.grad_accum_steps + if self.is_main_node: + print(f"STARTING AT STEP: {start_iter}/{max_iters}") + + + if self.is_main_node: + create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') + if 'generator' in self.models_to_save(): + models.generator.train() + #initial_params = {name: param.clone() for name, param in models.train_norm.named_parameters()} + iter_cnt = 0 + epoch_cnt = 0 + models.train_norm.train() + while True: + epoch_cnt += 1 + if self.world_size > 1: + print('sampler set epoch', epoch_cnt) + data.sampler.set_epoch(epoch_cnt) + for ggg in range(len(data.dataloader)): + iter_cnt += 1 + # FORWARD PASS + #print('in line 414 before forward', iter_cnt, batch['captions'][0], self.process_id) + #loss, loss_adjusted, loss_extra = self.forward_pass(batch, extras, models) + loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) + + #print('in line 416', loss, iter_cnt) + # # BACKWARD PASS + + grad_norm = self.backward_pass( + iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, + models, optimizers, schedulers + ) + + + + self.info.iter = iter_cnt + + # UPDATE EMA + if iter_cnt % self.config.ema_iters == 0: + + with torch.no_grad(): + print('in line 890 ema update', self.config.ema_iters, iter_cnt) + self.ema_update(models.train_norm_ema, models.train_norm, self.config.ema_beta) + #generator.module.agg_net. + #self.ema_update(models.generator_ema.agg_net, models.generator.module.agg_net, self.config.ema_beta) + #self.ema_update(models.generator_ema.agg_net_up, models.generator.module.agg_net_up, self.config.ema_beta) + + # UPDATE LOSS METRICS + self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 + + #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss) + if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): + print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + if self.is_main_node: + logs = { + 'loss': self.info.ema_loss, + 'backward_loss': loss_adjusted.mean().item(), + #'raw_extra_loss': loss_extra.mean().item(), + 'ema_loss': self.info.ema_loss, + 'raw_ori_loss': loss.mean().item(), + #'raw_rec_loss': loss_rec.mean().item(), + #'raw_lr_loss': loss_lr.mean().item(), + #'reg_loss':loss_reg.item(), + 'grad_norm': grad_norm.item(), + 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, + 'total_steps': self.info.total_steps, + } + if iter_cnt % (self.config.save_every) == 0: + + print(iter_cnt, max_iters, logs, epoch_cnt, ) + #pbar.set_postfix(logs) + + + #if iter_cnt % 10 == 0: + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: + #if True: + # SAVE AND CHECKPOINT STUFF + if np.isnan(loss.mean().item()): + if self.is_main_node and self.config.wandb_project is not None: + print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ + f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") + + else: + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + self.info.adaptive_loss = { + 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), + 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), + } + #self.save_checkpoints(models, optimizers) + + #torch.save(models.trans_inr.module.state_dict(), \ + #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr.safetensors') + #torch.save(models.trans_inr_ema.state_dict(), \ + #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr_ema.safetensors') + + + if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: + print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors') + + #self.sync_ema(models.train_norm_ema) + torch.save(models.train_norm_ema.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm_ema.safetensors') + #if self.is_main_node and iter_cnt % (4 * self.config.save_every * self.config.grad_accum_steps) == 0: + torch.save(models.train_norm.state_dict(), \ + f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors') + + + if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: + + if self.is_main_node: + #check_nan_inmodel(models.generator, 'generator') + #check_nan_inmodel(models.generator_ema, 'generator_ema') + self.sample(models, data, extras) + if False: + param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()} + threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10% + important_params = [name for name, change in param_changes.items() if change > threshold] + print(important_params, threshold, len(param_changes), self.process_id) + json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4) + + + if self.info.iter >= max_iters: + break + + def sample(self, models: Models, data: WarpCore.Data, extras: Extras): + + #if 'generator' in self.models_to_save(): + models.generator.eval() + models.train_norm.eval() + with torch.no_grad(): + batch = next(data.iterator) + ratio = batch['images'].shape[-2] / batch['images'].shape[-1] + #batch['images'] = batch['images'].to(torch.float16) + shape_lr = self.get_target_lr_size(ratio) + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + cnet, cnet_input = self.get_cnet(batch, models, extras) + conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet} + + latents = self.encode_latents(batch, models, extras) + latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) + + if self.is_main_node: + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + #print('in line 366 on v100 switch to tf16') + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, models.trans_inr, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + + + #else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + + + if self.is_main_node: + print('sampling results', latents.shape, latents_lr.shape, ) + noised_images = torch.cat( + [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) + + sampled_images = torch.cat( + [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) + sampled_images_ema = torch.cat( + [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))], + dim=0) + + noised_images_lr = torch.cat( + [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) + + sampled_images_lr = torch.cat( + [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) + sampled_images_ema_lr = torch.cat( + [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))], + dim=0) + + images = batch['images'] + if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): + images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') + images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') + + collage_img = torch.cat([ + torch.cat([i for i in images.cpu()], dim=-1), + torch.cat([i for i in noised_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + ], dim=-2) + + collage_img_lr = torch.cat([ + torch.cat([i for i in images_lr.cpu()], dim=-1), + torch.cat([i for i in noised_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), + torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1), + ], dim=-2) + + torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') + torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') + #torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') + + captions = batch['captions'] + if self.config.wandb_project is not None: + log_data = [ + [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ + wandb.Image(images[i])] for i in range(len(images))] + log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) + wandb.log({"Log": log_table}) + + if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): + plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) + plt.ylabel('Raw Loss') + plt.ylabel('LogSNR') + wandb.log({"Loss/LogSRN": plt}) + + #if 'generator' in self.models_to_save(): + models.generator.train() + models.train_norm.train() + print('finishe sampling in line 901') + + + + def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): + + #if 'generator' in self.models_to_save(): + models.generator.eval() + models.trans_inr.eval() + models.controlnet.eval() + with torch.no_grad(): + + if self.is_main_node: + conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) + unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) + cnet, cnet_input = self.get_cnet(batch, models, extras, target_size = lr_shape) + conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet} + + #print('in line 885', self.is_main_node) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + #print('in line 366 on v100 switch to tf16') + *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( + models.generator, models.trans_inr, conditions, + hr_shape, lr_shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + if models.generator_ema is not None: + + *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( + models.generator_ema, models.trans_inr_ema, conditions, + latents.shape, latents_lr.shape, + unconditions, device=self.device, **extras.sampling_configs + ) + + else: + sampled_ema = sampled + sampled_ema_lr = sampled_lr + #x0, x, epsilon, x0_lr, x_lr, pred_lr) + #sampled, _ = models.trans_inr(sampled, None, sampled) + #sampled_lr, _ = models.trans_inr(sampled, None, sampled_lr) + + return sampled, sampled_lr +def main_worker(rank, cfg): + print("Launching Script in main worker") + print('in line 467', rank) + warpcore = WurstCore( + config_file_path=cfg, rank=rank, world_size = get_world_size() + ) + # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD + + # RUN TRAINING + warpcore(get_world_size()==1) + +if __name__ == '__main__': + print('launch multi process') + # os.environ["OMP_NUM_THREADS"] = "1" + # os.environ["MKL_NUM_THREADS"] = "1" + #dist.init_process_group(backend="nccl") + #torch.backends.cudnn.benchmark = True +#train/train_c_my.py + #mp.set_sharing_strategy('file_system') + print('in line 481', sys.argv[1] if len(sys.argv) > 1 else None) + print('in line 481',get_master_ip(), get_world_size() ) + print('in line 484', get_world_size()) + if get_master_ip() == "127.0.0.1": + # manually launch distributed processes + mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) + else: + main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )