diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..25f24893d9cb3e67b090e2c7b2bd2c712c25e393
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitattributes b/.gitattributes
index 64f23e0770da589d2949e1c24149405f5eda3d68..fd430f0759307b1fadfdf816996d98e99acdf80e 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -14,6 +14,7 @@
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
@@ -21,7 +22,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zstandard filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7db0e1756c662bd53b44122ca8566ee0ed24048e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+web.sh
+*__pycache__
+test_512_old/
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..3251a81ed3f9f4889d6334cd4282602f776fe2b5
--- /dev/null
+++ b/app.py
@@ -0,0 +1,146 @@
+from typing import Tuple
+import dnnlib
+from PIL import Image
+import numpy as np
+import torch
+import legacy
+import cv2
+import paddlehub as hub
+
+u2net = hub.Module(name='U2Net')
+
+# gradio app imports
+import gradio as gr
+from torchvision.transforms import ToTensor, ToPILImage
+image_to_tensor = ToTensor()
+tensor_to_image = ToPILImage()
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+class_idx = None
+truncation_psi = 0.1
+
+def create_model(network_pkl):
+ print('Loading networks from "%s"...' % network_pkl)
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'] # type: ignore
+
+ G = G.eval().to(device)
+ netG_params = sum(p.numel() for p in G.parameters())
+ print("Generator Params: {} M".format(netG_params/1e6))
+ return G
+
+def fcf_inpaint(G, org_img, erased_img, mask):
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ if class_idx is None:
+ ValueError("class_idx can't be None.")
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print ('warn: --class=lbl ignored when running on an unconditional network')
+
+ pred_img = G(img=torch.cat([0.5 - mask, erased_img], dim=1), c=label, truncation_psi=truncation_psi, noise_mode='const')
+ comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
+ return comp_img
+
+def show_images(img):
+ """ Display a batch of images inline. """
+ return Image.fromarray(img)
+
+def denorm(img):
+ img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
+ img = (img +1) * 127.5
+ img = np.rint(img).clip(0, 255).astype(np.uint8)
+ return img
+
+def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
+ img = np.array(pil_img)
+ return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
+
+def inpaint(input_img, mask, option):
+ width, height = input_img.size
+
+ if option == "Automatic":
+ result = u2net.Segmentation(
+ images=[cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)],
+ paths=None,
+ batch_size=1,
+ input_size=320,
+ output_dir='output',
+ visualization=True)
+ mask = Image.fromarray(result[0]['mask'])
+ else:
+ mask = mask.resize((width,height))
+
+ mask = mask.convert('L')
+ mask = np.array(mask) / 255.
+ mask = cv2.resize(mask,
+ (512, 512), interpolation=cv2.INTER_NEAREST)
+ mask_tensor = torch.from_numpy(mask).to(torch.float32)
+ mask_tensor = mask_tensor.unsqueeze(0)
+ mask_tensor = mask_tensor.unsqueeze(0).to(device)
+
+ rgb = input_img.convert('RGB')
+ rgb = np.array(rgb)
+ rgb = cv2.resize(rgb,
+ (512, 512), interpolation=cv2.INTER_AREA)
+ rgb = rgb.transpose(2,0,1)
+ rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
+ rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
+ rgb_erased = rgb.clone()
+ rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
+ rgb_erased = rgb_erased.to(torch.float32)
+
+ # model = create_model("models/places_512.pkl")
+ # comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
+ rgb_erased = denorm(rgb_erased)
+ # comp_img = denorm(comp_img)
+
+ return show_images(rgb_erased), show_images(rgb_erased)
+
+gradio_inputs = [gr.inputs.Image(type='pil',
+ tool=None,
+ label="Input Image"),
+ gr.inputs.Image(type='pil',source="canvas", label="Mask", invert_colors=True),
+ gr.inputs.Radio(choices=["Automatic", "Manual"], type="value", default="Manual", label="Masking Choice")
+ # gr.inputs.Image(type='pil',
+ # tool=None,
+ # label="Mask")]
+ ]
+
+# gradio_outputs = [gr.outputs.Image(label='Auto-Detected Mask (From drawn black pixels)')]
+
+gradio_outputs = [gr.outputs.Image(label='Image with Hole'),
+ gr.outputs.Image(label='Inpainted Image')]
+
+examples = [['test_512/person512.png', 'test_512/mask_auto.png', 'Automatic'],
+ ['test_512/a_org.png', 'test_512/a_mask.png', 'Manual'],
+ ['test_512/c_org.png', 'test_512/b_mask.png', 'Manual'],
+ ['test_512/b_org.png', 'test_512/c_mask.png', 'Manual'],
+ ['test_512/d_org.png', 'test_512/d_mask.png', 'Manual'],
+ ['test_512/e_org.png', 'test_512/e_mask.png', 'Manual'],
+ ['test_512/f_org.png', 'test_512/f_mask.png', 'Manual'],
+ ['test_512/g_org.png', 'test_512/g_mask.png', 'Manual'],
+ ['test_512/h_org.png', 'test_512/h_mask.png', 'Manual'],
+ ['test_512/i_org.png', 'test_512/i_mask.png', 'Manual']]
+
+title = "FcF-Inpainting"
+description = "[Note: Queue time may take upto 20 seconds! The image and mask are resized to 512x512 before inpainting.] To use FcF-Inpainting: \n \
+ (1) Upload an Image; \n \
+ (2) Draw (Manual) a Mask on the White Canvas or Generate a mask using U2Net by selecting the Automatic option; \n \
+ (3) Click on Submit and witness the MAGIC! 🪄 ✨ ✨"
+article = "
Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand | Github Repo
"
+
+css = ".image-preview {height: 32rem; width: auto;} .output-image {height: 32rem; width: auto;} .panel-buttons { display: flex; flex-direction: row;}"
+
+iface = gr.Interface(fn=inpaint, inputs=gradio_inputs,
+ outputs=gradio_outputs,
+ css=css,
+ layout="vertical",
+ examples_per_page=5,
+ thumbnail="fcf_gan.png",
+ allow_flagging="never",
+ examples=examples, title=title,
+ description=description, article=article)
+iface.launch(enable_queue=True,
+ share=True, server_name="0.0.0.0")
\ No newline at end of file
diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..2f08cf36f11f9b0fd94c1b7caeadf69b98375b04
--- /dev/null
+++ b/dnnlib/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/dnnlib/util.py b/dnnlib/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..76725336d01e75e1c68daa88be47f4fde0bbc63b
--- /dev/null
+++ b/dnnlib/util.py
@@ -0,0 +1,477 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/fcf_gan.png b/fcf_gan.png
new file mode 100644
index 0000000000000000000000000000000000000000..5336259b8b11ef8cf4a6b340a7f211b0aa9abaf4
Binary files /dev/null and b/fcf_gan.png differ
diff --git a/legacy.py b/legacy.py
new file mode 100755
index 0000000000000000000000000000000000000000..9387d79f23224642ca316399de2f0258f72de79b
--- /dev/null
+++ b/legacy.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import click
+import pickle
+import re
+import copy
+import numpy as np
+import torch
+import dnnlib
+from torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+def load_network_pkl(f, force_fp16=False):
+ data = _LegacyUnpickler(f).load()
+
+ # Legacy TensorFlow pickle => convert.
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
+ tf_G, tf_D, tf_Gs = data
+ G = convert_tf_generator(tf_G)
+ D = convert_tf_discriminator(tf_D)
+ G_ema = convert_tf_generator(tf_Gs)
+ data = dict(G=G, D=D, G_ema=G_ema)
+
+ # Add missing fields.
+ if 'training_set_kwargs' not in data:
+ data['training_set_kwargs'] = None
+ if 'augment_pipe' not in data:
+ data['augment_pipe'] = None
+
+ # Validate contents.
+ assert isinstance(data['G'], torch.nn.Module)
+ assert isinstance(data['D'], torch.nn.Module)
+ assert isinstance(data['G_ema'], torch.nn.Module)
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
+
+ # Force FP16.
+ if force_fp16:
+ for key in ['G', 'D', 'G_ema']:
+ old = data[key]
+ kwargs = copy.deepcopy(old.init_kwargs)
+ if key.startswith('G'):
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
+ kwargs.synthesis_kwargs.num_fp16_res = 4
+ kwargs.synthesis_kwargs.conv_clamp = 256
+ if key.startswith('D'):
+ kwargs.num_fp16_res = 4
+ kwargs.conv_clamp = 256
+ if kwargs != old.init_kwargs:
+ new = type(old)(**kwargs).eval().requires_grad_(False)
+ misc.copy_params_and_buffers(old, new, require_all=True)
+ data[key] = new
+ return data
+
+#----------------------------------------------------------------------------
+
+class _TFNetworkStub(dnnlib.EasyDict):
+ pass
+
+class _LegacyUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'dnnlib.tflib.network' and name == 'Network':
+ return _TFNetworkStub
+ return super().find_class(module, name)
+
+#----------------------------------------------------------------------------
+
+def _collect_tf_params(tf_net):
+ # pylint: disable=protected-access
+ tf_params = dict()
+ def recurse(prefix, tf_net):
+ for name, value in tf_net.variables:
+ tf_params[prefix + name] = value
+ for name, comp in tf_net.components.items():
+ recurse(prefix + name + '/', comp)
+ recurse('', tf_net)
+ return tf_params
+
+#----------------------------------------------------------------------------
+
+def _populate_module_params(module, *patterns):
+ for name, tensor in misc.named_params_and_buffers(module):
+ found = False
+ value = None
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
+ match = re.fullmatch(pattern, name)
+ if match:
+ found = True
+ if value_fn is not None:
+ value = value_fn(*match.groups())
+ break
+ try:
+ assert found
+ if value is not None:
+ tensor.copy_(torch.from_numpy(np.array(value)))
+ except:
+ print(name, list(tensor.shape))
+ raise
+
+#----------------------------------------------------------------------------
+
+def convert_tf_generator(tf_G):
+ if tf_G.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_G.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None, none=None):
+ known_kwargs.add(tf_name)
+ val = tf_kwargs.get(tf_name, default)
+ return val if val is not None else none
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ z_dim = kwarg('latent_size', 512),
+ c_dim = kwarg('label_size', 0),
+ w_dim = kwarg('dlatent_size', 512),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 8),
+ embed_features = kwarg('label_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
+ ),
+ synthesis_kwargs = dnnlib.EasyDict(
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ architecture = kwarg('architecture', 'skip'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ use_noise = kwarg('use_noise', True),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('truncation_psi')
+ kwarg('truncation_cutoff')
+ kwarg('style_mixing_prob')
+ kwarg('structure')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_G)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
+ kwargs.synthesis.kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ from training import networks
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ _populate_module_params(G,
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'.*\.resample_filter', None,
+ )
+ return G
+
+#----------------------------------------------------------------------------
+
+def convert_tf_discriminator(tf_D):
+ if tf_D.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_D.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None):
+ known_kwargs.add(tf_name)
+ return tf_kwargs.get(tf_name, default)
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ c_dim = kwarg('label_size', 0),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ architecture = kwarg('architecture', 'resnet'),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ cmap_dim = kwarg('mapping_fmaps', None),
+ block_kwargs = dnnlib.EasyDict(
+ activation = kwarg('nonlinearity', 'lrelu'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ freeze_layers = kwarg('freeze_layers', 0),
+ ),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 0),
+ embed_features = kwarg('mapping_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
+ ),
+ epilogue_kwargs = dnnlib.EasyDict(
+ mbstd_group_size = kwarg('mbstd_group_size', None),
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('structure')
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_D)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
+ kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ from training import networks
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ _populate_module_params(D,
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
+ r'.*\.resample_filter', None,
+ )
+ return D
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--source', help='Input pickle', required=True, metavar='PATH')
+@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
+@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
+def convert_network_pickle(source, dest, force_fp16):
+ """Convert legacy network pickle into the native PyTorch format.
+
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
+
+ Example:
+
+ \b
+ python legacy.py \\
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
+ --dest=stylegan2-cat-config-f.pkl
+ """
+ print(f'Loading "{source}"...')
+ with dnnlib.util.open_url(source) as f:
+ data = load_network_pkl(f, force_fp16=force_fp16)
+ print(f'Saving "{dest}"...')
+ with open(dest, 'wb') as f:
+ pickle.dump(data, f)
+ print('Done.')
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/output/result_0.png b/output/result_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..6565e454a4ed3f81b9178fd2261c0b7497ab05c8
Binary files /dev/null and b/output/result_0.png differ
diff --git a/output/result_mask_0.png b/output/result_mask_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..16e6ecb56e817f7a9d20bac64b8b898c234c9d96
Binary files /dev/null and b/output/result_mask_0.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..31edb20b905059f6593f3c91d0e118869294a885
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+icecream
+psutil
+click
+requests
+matplotlib
+tqdm
+ninja
+imageio-ffmpeg==0.4.3
+scipy
+termcolor>=1.1
+colorama
+cvbase
+opencv-python
+etaprogress
+scikit-learn
+pandas
+tensorboard
+pydrive2
+pandas
+easydict
+kornia==0.5.0
+gradio
+ipython
+Jinja2
+paddlepaddle
+paddlehub
\ No newline at end of file
diff --git a/setup.sh b/setup.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c952da9c6fdf6129894a605d1071bf12a58d9ea2
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,7 @@
+#!/bin/sh
+eval "$(conda shell.bash hook)"
+conda create --name fcf -y python=3.7
+conda activate fcf
+conda env list
+conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
+pip3 install -r requirements.txt
\ No newline at end of file
diff --git a/test_512/.DS_Store b/test_512/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..dd758c4dc70ad3bf9d644894c55d0a49a93b4af8
Binary files /dev/null and b/test_512/.DS_Store differ
diff --git a/test_512/a_mask.png b/test_512/a_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..64b0b252c80d5243dc68d3e37ec79b7989c7ac88
Binary files /dev/null and b/test_512/a_mask.png differ
diff --git a/test_512/a_org.png b/test_512/a_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..a641309ca6e79f41288f1550dcdc4ea78aa4c585
Binary files /dev/null and b/test_512/a_org.png differ
diff --git a/test_512/b_mask.png b/test_512/b_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d69f8e97a458fba76498be14c68887f6097c1bf
Binary files /dev/null and b/test_512/b_mask.png differ
diff --git a/test_512/b_org.png b/test_512/b_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c8a3da5680a45b507be6c868ff6619ab1ea9865
Binary files /dev/null and b/test_512/b_org.png differ
diff --git a/test_512/c_mask.png b/test_512/c_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9d1add616423eb857567c8026cf9b97f0d295da
Binary files /dev/null and b/test_512/c_mask.png differ
diff --git a/test_512/c_org.png b/test_512/c_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..3be3809f8e375d270dcc1753b917b77c17f5ce9d
Binary files /dev/null and b/test_512/c_org.png differ
diff --git a/test_512/d_mask.png b/test_512/d_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..c17e4be34dfaa781f24b399cb397dc5f712a438c
Binary files /dev/null and b/test_512/d_mask.png differ
diff --git a/test_512/d_org.png b/test_512/d_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..b322cdb1e0077e71d9b418160efe1d1e33d07deb
Binary files /dev/null and b/test_512/d_org.png differ
diff --git a/test_512/e_mask.png b/test_512/e_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..d3523ee2924f64db40fdb36f9e1704980a16a018
Binary files /dev/null and b/test_512/e_mask.png differ
diff --git a/test_512/e_org.png b/test_512/e_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..be2131112df4a19f69c156ce1bee0342385391b3
Binary files /dev/null and b/test_512/e_org.png differ
diff --git a/test_512/f_mask.png b/test_512/f_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..4192c238e8207276734c8eaeb691d5c379b85288
Binary files /dev/null and b/test_512/f_mask.png differ
diff --git a/test_512/f_org.png b/test_512/f_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..0fb021a223a70a6f74c6f5be0eaf65e0ffda8fec
Binary files /dev/null and b/test_512/f_org.png differ
diff --git a/test_512/g_mask.png b/test_512/g_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..4d994d2b017d4eff45235ec839557af2704e104b
Binary files /dev/null and b/test_512/g_mask.png differ
diff --git a/test_512/g_org.png b/test_512/g_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..05abfb6ad712ac1b0ae856480822c82041c5a1e3
Binary files /dev/null and b/test_512/g_org.png differ
diff --git a/test_512/h_mask.png b/test_512/h_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..f072f74862ac0f8be820c139f19b070badf68264
Binary files /dev/null and b/test_512/h_mask.png differ
diff --git a/test_512/h_org.png b/test_512/h_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..5fc126793f7baf8d5d1912440eed5bcbab860f36
Binary files /dev/null and b/test_512/h_org.png differ
diff --git a/test_512/i_mask.png b/test_512/i_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..40a9e843b45802aed38ad3e3f16df693beaef933
Binary files /dev/null and b/test_512/i_mask.png differ
diff --git a/test_512/i_org.png b/test_512/i_org.png
new file mode 100644
index 0000000000000000000000000000000000000000..14f98eedfb5b789c863bee3c8f421c5ac28ed2f2
Binary files /dev/null and b/test_512/i_org.png differ
diff --git a/test_512/mask_auto.png b/test_512/mask_auto.png
new file mode 100644
index 0000000000000000000000000000000000000000..f033fc6e8a618fd940c3a7a871fc9a7585a13fc5
Binary files /dev/null and b/test_512/mask_auto.png differ
diff --git a/test_512/person512.png b/test_512/person512.png
new file mode 100644
index 0000000000000000000000000000000000000000..57af615ac3124800f124a72be78a8cbca1997110
Binary files /dev/null and b/test_512/person512.png differ
diff --git a/torch_utils/__init__.py b/torch_utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9
--- /dev/null
+++ b/torch_utils/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/torch_utils/custom_ops.py b/torch_utils/custom_ops.py
new file mode 100755
index 0000000000000000000000000000000000000000..4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e
--- /dev/null
+++ b/torch_utils/custom_ops.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import glob
+import torch
+import torch.utils.cpp_extension
+import importlib
+import hashlib
+import shutil
+from pathlib import Path
+
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Compile and load.
+ verbose_build = (verbosity == 'full')
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
+
+ # Compute a combined hash digest for all source files in the same
+ # custom op directory (usually .cu, .cpp, .py and .h files).
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
+
+ if not os.path.isdir(digest_build_dir):
+ os.makedirs(digest_build_dir, exist_ok=True)
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
+ if baton.try_acquire():
+ try:
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
+ finally:
+ baton.release()
+ else:
+ # Someone else is copying source files under the digest dir,
+ # wait until done and continue.
+ baton.wait()
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/misc.py b/torch_utils/misc.py
new file mode 100755
index 0000000000000000000000000000000000000000..c3508142a959da461ae91feb60abc775ee4fcc35
--- /dev/null
+++ b/torch_utils/misc.py
@@ -0,0 +1,263 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ input = input.to(torch.float32)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to suppress known warnings in torch.jit.trace().
+
+class suppress_tracer_warnings(warnings.catch_warnings):
+ def __enter__(self):
+ super().__enter__()
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
+ return self
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/ops/__init__.py b/torch_utils/ops/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9
--- /dev/null
+++ b/torch_utils/ops/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/torch_utils/ops/bias_act.cpp b/torch_utils/ops/bias_act.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330
--- /dev/null
+++ b/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,99 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/torch_utils/ops/bias_act.cu b/torch_utils/ops/bias_act.cu
new file mode 100755
index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880
--- /dev/null
+++ b/torch_utils/ops/bias_act.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/torch_utils/ops/bias_act.h b/torch_utils/ops/bias_act.h
new file mode 100755
index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4
--- /dev/null
+++ b/torch_utils/ops/bias_act.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/torch_utils/ops/bias_act.py b/torch_utils/ops/bias_act.py
new file mode 100755
index 0000000000000000000000000000000000000000..4bcb409a89ccf6c6f6ecfca5962683df2d280b1f
--- /dev/null
+++ b/torch_utils/ops/bias_act.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import warnings
+import numpy as np
+import torch
+import dnnlib
+import traceback
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ _inited = True
+ sources = ['bias_act.cpp', 'bias_act.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py
new file mode 100755
index 0000000000000000000000000000000000000000..e95e10d0b1d0315a63a76446fd4c5c293c8bbc6d
--- /dev/null
+++ b/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import warnings
+import contextlib
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+@contextlib.contextmanager
+def no_weight_gradients():
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
+ return True
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
+ return False
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ if not transpose:
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+ else: # transpose
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ ctx.save_for_backward(input, weight)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
+ assert grad_input.shape == input.shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input):
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
+ assert grad_weight.shape == weight_shape
+ ctx.save_for_backward(grad_output, input)
+ return grad_weight
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output.shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input.shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/ops/conv2d_resample.py b/torch_utils/ops/conv2d_resample.py
new file mode 100755
index 0000000000000000000000000000000000000000..cd4750744c83354bab78704d4ef51ad1070fcc4a
--- /dev/null
+++ b/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,156 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ w = w.flip([2, 3])
+
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
+ if out_channels <= 4 and groups == 1:
+ in_shape = x.shape
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
+ else:
+ x = x.to(memory_format=torch.contiguous_format)
+ w = w.to(memory_format=torch.contiguous_format)
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
+ return x.to(memory_format=torch.channels_last)
+
+ # Otherwise => execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/ops/fma.py b/torch_utils/ops/fma.py
new file mode 100755
index 0000000000000000000000000000000000000000..2eeac58a626c49231e04122b93e321ada954c5d3
--- /dev/null
+++ b/torch_utils/ops/fma.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+
+#----------------------------------------------------------------------------
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+#----------------------------------------------------------------------------
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+#----------------------------------------------------------------------------
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
+ assert x.shape == shape
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca6b3413ea72a734703c34382c023b84523601fd
--- /dev/null
+++ b/torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import warnings
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ if not enabled:
+ return False
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
+ return True
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
+ return False
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/ops/upfirdn2d.cpp b/torch_utils/ops/upfirdn2d.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a
--- /dev/null
+++ b/torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,103 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/torch_utils/ops/upfirdn2d.cu b/torch_utils/ops/upfirdn2d.cu
new file mode 100755
index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916
--- /dev/null
+++ b/torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,350 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ }
+ return spec;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/torch_utils/ops/upfirdn2d.h b/torch_utils/ops/upfirdn2d.h
new file mode 100755
index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd
--- /dev/null
+++ b/torch_utils/ops/upfirdn2d.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct upfirdn2d_kernel_params
+{
+ const void* x;
+ const float* f;
+ void* y;
+
+ int2 up;
+ int2 down;
+ int2 pad0;
+ int flip;
+ float gain;
+
+ int4 inSize; // [width, height, channel, batch]
+ int4 inStride;
+ int2 filterSize; // [width, height]
+ int2 filterStride;
+ int4 outSize; // [width, height, channel, batch]
+ int4 outStride;
+ int sizeMinor;
+ int sizeMajor;
+
+ int loopMinor;
+ int loopMajor;
+ int loopX;
+ int launchMinor;
+ int launchMajor;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct upfirdn2d_kernel_spec
+{
+ void* kernel;
+ int tileOutW;
+ int tileOutH;
+ int loopMinor;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/torch_utils/ops/upfirdn2d.py b/torch_utils/ops/upfirdn2d.py
new file mode 100755
index 0000000000000000000000000000000000000000..ceeac2b9834e33b7c601c28bf27f32aa91c69256
--- /dev/null
+++ b/torch_utils/ops/upfirdn2d.py
@@ -0,0 +1,384 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient resampling of 2D images."""
+
+import os
+import warnings
+import numpy as np
+import torch
+import traceback
+
+from .. import custom_ops
+from .. import misc
+from . import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+def _parse_scaling(scaling):
+ if isinstance(scaling, int):
+ scaling = [scaling, scaling]
+ assert isinstance(scaling, (list, tuple))
+ assert all(isinstance(x, int) for x in scaling)
+ sx, sy = scaling
+ assert sx >= 1 and sy >= 1
+ return sx, sy
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, int) for x in padding)
+ if len(padding) == 2:
+ padx, pady = padding
+ padding = [padx, padx, pady, pady]
+ padx0, padx1, pady0, pady1 = padding
+ return padx0, padx1, pady0, pady1
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ fw = f.shape[-1]
+ fh = f.shape[0]
+ with misc.suppress_tracer_warnings():
+ fw = int(fw)
+ fh = int(fh)
+ misc.assert_shape(f, [fh, fw][:f.ndim])
+ assert fw >= 1 and fh >= 1
+ return fw, fh
+
+#----------------------------------------------------------------------------
+
+def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
+
+ Args:
+ f: Torch tensor, numpy array, or python list of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable),
+ `[]` (impulse), or
+ `None` (identity).
+ device: Result device (default: cpu).
+ normalize: Normalize the filter so that it retains the magnitude
+ for constant input signal (DC)? (default: True).
+ flip_filter: Flip the filter? (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ separable: Return a separable filter? (default: select automatically).
+
+ Returns:
+ Float32 tensor of the shape
+ `[filter_height, filter_width]` (non-separable) or
+ `[filter_taps]` (separable).
+ """
+ # Validate.
+ if f is None:
+ f = 1
+ f = torch.as_tensor(f, dtype=torch.float32)
+ assert f.ndim in [0, 1, 2]
+ assert f.numel() > 0
+ if f.ndim == 0:
+ f = f[np.newaxis]
+
+ # Separable?
+ if separable is None:
+ separable = (f.ndim == 1 and f.numel() >= 8)
+ if f.ndim == 1 and not separable:
+ f = f.ger(f)
+ assert f.ndim == (1 if separable else 2)
+
+ # Apply normalize, flip, gain, and device.
+ if normalize:
+ f /= f.sum()
+ if flip_filter:
+ f = f.flip(list(range(f.ndim)))
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(device=device)
+ return f
+
+#----------------------------------------------------------------------------
+
+def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 4. Downsample the image by keeping every Nth pixel (`down`).
+
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ assert f.dtype == torch.float32 and not f.requires_grad
+ batch_size, num_channels, in_height, in_width = x.shape
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Upsample by inserting zeros.
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
+
+ # Pad or crop.
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
+
+ # Setup filter.
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(x.dtype)
+ if not flip_filter:
+ f = f.flip(list(range(f.ndim)))
+
+ # Convolve with the filter.
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
+ if f.ndim == 4:
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
+ else:
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
+
+ # Downsample by throwing away pixels.
+ x = x[:, :, ::downy, ::downx]
+ return x
+
+#----------------------------------------------------------------------------
+
+_upfirdn2d_cuda_cache = dict()
+
+def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
+ """
+ # Parse arguments.
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Lookup from cache.
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ if key in _upfirdn2d_cuda_cache:
+ return _upfirdn2d_cuda_cache[key]
+
+ # Forward op.
+ class Upfirdn2dCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ y = x
+ if f.ndim == 2:
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ else:
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
+ ctx.save_for_backward(f)
+ ctx.x_shape = x.shape
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ f, = ctx.saved_tensors
+ _, _, ih, iw = ctx.x_shape
+ _, _, oh, ow = dy.shape
+ fw, fh = _get_filter_size(f)
+ p = [
+ fw - padx0 - 1,
+ iw * upx - ow * downx + padx0 - upx + 1,
+ fh - pady0 - 1,
+ ih * upy - oh * downy + pady0 - upy + 1,
+ ]
+ dx = None
+ df = None
+
+ if ctx.needs_input_grad[0]:
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
+
+ assert not ctx.needs_input_grad[1]
+ return dx, df
+
+ # Add to cache.
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
+ return Upfirdn2dCuda
+
+#----------------------------------------------------------------------------
+
+def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape matches the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + fw // 2,
+ padx1 + (fw - 1) // 2,
+ pady0 + fh // 2,
+ pady1 + (fh - 1) // 2,
+ ]
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a multiple of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ upx, upy = _parse_scaling(up)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw + upx - 1) // 2,
+ padx1 + (fw - upx) // 2,
+ pady0 + (fh + upy - 1) // 2,
+ pady1 + (fh - upy) // 2,
+ ]
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a fraction of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the input. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw - downx + 1) // 2,
+ padx1 + (fw - downx) // 2,
+ pady0 + (fh - downy + 1) // 2,
+ pady1 + (fh - downy) // 2,
+ ]
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/persistence.py b/torch_utils/persistence.py
new file mode 100755
index 0000000000000000000000000000000000000000..0186cfd97bca0fcb397a7b73643520c1d1105a02
--- /dev/null
+++ b/torch_utils/persistence.py
@@ -0,0 +1,251 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Facilities for pickling Python code alongside other data.
+
+The pickled code is automatically imported into a separate Python module
+during unpickling. This way, any previously exported pickles will remain
+usable even if the original code is no longer available, or if the current
+version of the code is not consistent with what was originally pickled."""
+
+import sys
+import pickle
+import io
+import inspect
+import copy
+import uuid
+import types
+import dnnlib
+
+#----------------------------------------------------------------------------
+
+_version = 6 # internal version number
+_decorators = set() # {decorator_class, ...}
+_import_hooks = [] # [hook_function, ...]
+_module_to_src_dict = dict() # {module: src, ...}
+_src_to_module_dict = dict() # {src: module, ...}
+
+#----------------------------------------------------------------------------
+
+def persistent_class(orig_class):
+ r"""Class decorator that extends a given class to save its source code
+ when pickled.
+
+ Example:
+
+ from torch_utils import persistence
+
+ @persistence.persistent_class
+ class MyNetwork(torch.nn.Module):
+ def __init__(self, num_inputs, num_outputs):
+ super().__init__()
+ self.fc = MyLayer(num_inputs, num_outputs)
+ ...
+
+ @persistence.persistent_class
+ class MyLayer(torch.nn.Module):
+ ...
+
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
+ source code alongside other internal state (e.g., parameters, buffers,
+ and submodules). This way, any previously exported pickle will remain
+ usable even if the class definitions have been modified or are no
+ longer available.
+
+ The decorator saves the source code of the entire Python module
+ containing the decorated class. It does *not* save the source code of
+ any imported modules. Thus, the imported modules must be available
+ during unpickling, also including `torch_utils.persistence` itself.
+
+ It is ok to call functions defined in the same module from the
+ decorated class. However, if the decorated class depends on other
+ classes defined in the same module, they must be decorated as well.
+ This is illustrated in the above example in the case of `MyLayer`.
+
+ It is also possible to employ the decorator just-in-time before
+ calling the constructor. For example:
+
+ cls = MyLayer
+ if want_to_make_it_persistent:
+ cls = persistence.persistent_class(cls)
+ layer = cls(num_inputs, num_outputs)
+
+ As an additional feature, the decorator also keeps track of the
+ arguments that were used to construct each instance of the decorated
+ class. The arguments can be queried via `obj.init_args` and
+ `obj.init_kwargs`, and they are automatically pickled alongside other
+ object state. A typical use case is to first unpickle a previous
+ instance of a persistent class, and then upgrade it to use the latest
+ version of the source code:
+
+ with open('old_pickle.pkl', 'rb') as f:
+ old_net = pickle.load(f)
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
+ """
+ assert isinstance(orig_class, type)
+ if is_persistent(orig_class):
+ return orig_class
+
+ assert orig_class.__module__ in sys.modules
+ orig_module = sys.modules[orig_class.__module__]
+ orig_module_src = _module_to_src(orig_module)
+
+ class Decorator(orig_class):
+ _orig_module_src = orig_module_src
+ _orig_class_name = orig_class.__name__
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._init_args = copy.deepcopy(args)
+ self._init_kwargs = copy.deepcopy(kwargs)
+ assert orig_class.__name__ in orig_module.__dict__
+ _check_pickleable(self.__reduce__())
+
+ @property
+ def init_args(self):
+ return copy.deepcopy(self._init_args)
+
+ @property
+ def init_kwargs(self):
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
+
+ def __reduce__(self):
+ fields = list(super().__reduce__())
+ fields += [None] * max(3 - len(fields), 0)
+ if fields[0] is not _reconstruct_persistent_obj:
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
+ fields[1] = (meta,) # reconstruct args
+ fields[2] = None # state dict
+ return tuple(fields)
+
+ Decorator.__name__ = orig_class.__name__
+ _decorators.add(Decorator)
+ return Decorator
+
+#----------------------------------------------------------------------------
+
+def is_persistent(obj):
+ r"""Test whether the given object or class is persistent, i.e.,
+ whether it will save its source code when pickled.
+ """
+ try:
+ if obj in _decorators:
+ return True
+ except TypeError:
+ pass
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
+
+#----------------------------------------------------------------------------
+
+def import_hook(hook):
+ r"""Register an import hook that is called whenever a persistent object
+ is being unpickled. A typical use case is to patch the pickled source
+ code to avoid errors and inconsistencies when the API of some imported
+ module has changed.
+
+ The hook should have the following signature:
+
+ hook(meta) -> modified meta
+
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
+
+ type: Type of the persistent object, e.g. `'class'`.
+ version: Internal version number of `torch_utils.persistence`.
+ module_src Original source code of the Python module.
+ class_name: Class name in the original Python module.
+ state: Internal state of the object.
+
+ Example:
+
+ @persistence.import_hook
+ def wreck_my_network(meta):
+ if meta.class_name == 'MyNetwork':
+ print('MyNetwork is being imported. I will wreck it!')
+ meta.module_src = meta.module_src.replace("True", "False")
+ return meta
+ """
+ assert callable(hook)
+ _import_hooks.append(hook)
+
+#----------------------------------------------------------------------------
+
+def _reconstruct_persistent_obj(meta):
+ r"""Hook that is called internally by the `pickle` module to unpickle
+ a persistent object.
+ """
+ meta = dnnlib.EasyDict(meta)
+ meta.state = dnnlib.EasyDict(meta.state)
+ for hook in _import_hooks:
+ meta = hook(meta)
+ assert meta is not None
+
+ assert meta.version == _version
+ module = _src_to_module(meta.module_src)
+
+ assert meta.type == 'class'
+ orig_class = module.__dict__[meta.class_name]
+ decorator_class = persistent_class(orig_class)
+ obj = decorator_class.__new__(decorator_class)
+
+ setstate = getattr(obj, '__setstate__', None)
+ if callable(setstate):
+ setstate(meta.state) # pylint: disable=not-callable
+ else:
+ obj.__dict__.update(meta.state)
+ return obj
+
+#----------------------------------------------------------------------------
+
+def _module_to_src(module):
+ r"""Query the source code of a given Python module.
+ """
+ src = _module_to_src_dict.get(module, None)
+ if src is None:
+ src = inspect.getsource(module)
+ _module_to_src_dict[module] = src
+ _src_to_module_dict[src] = module
+ return src
+
+def _src_to_module(src):
+ r"""Get or create a Python module for the given source code.
+ """
+ module = _src_to_module_dict.get(src, None)
+ if module is None:
+ module_name = "_imported_module_" + uuid.uuid4().hex
+ module = types.ModuleType(module_name)
+ sys.modules[module_name] = module
+ _module_to_src_dict[module] = src
+ _src_to_module_dict[src] = module
+ exec(src, module.__dict__) # pylint: disable=exec-used
+ return module
+
+#----------------------------------------------------------------------------
+
+def _check_pickleable(obj):
+ r"""Check that the given object is pickleable, raising an exception if
+ it is not. This function is expected to be considerably more efficient
+ than actually pickling the object.
+ """
+ def recurse(obj):
+ if isinstance(obj, (list, tuple, set)):
+ return [recurse(x) for x in obj]
+ if isinstance(obj, dict):
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
+ return None # Python primitive types are pickleable.
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
+ return None # NumPy arrays and PyTorch tensors are pickleable.
+ if is_persistent(obj):
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
+ return obj
+ with io.BytesIO() as f:
+ pickle.dump(recurse(obj), f)
+
+#----------------------------------------------------------------------------
diff --git a/torch_utils/training_stats.py b/torch_utils/training_stats.py
new file mode 100755
index 0000000000000000000000000000000000000000..26f467f9eaa074ee13de1cf2625cd7da44880847
--- /dev/null
+++ b/torch_utils/training_stats.py
@@ -0,0 +1,268 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Facilities for reporting and collecting training statistics across
+multiple processes and devices. The interface is designed to minimize
+synchronization overhead as well as the amount of boilerplate in user
+code."""
+
+import re
+import numpy as np
+import torch
+import dnnlib
+
+from . import misc
+
+#----------------------------------------------------------------------------
+
+_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
+_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
+_counter_dtype = torch.float64 # Data type to use for the internal counters.
+_rank = 0 # Rank of the current process.
+_sync_device = None # Device to use for multiprocess communication. None = single-process.
+_sync_called = False # Has _sync() been called yet?
+_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
+_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
+
+#----------------------------------------------------------------------------
+
+def init_multiprocessing(rank, sync_device):
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
+ across multiple processes.
+
+ This function must be called after
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
+ The call is not necessary if multi-process collection is not needed.
+
+ Args:
+ rank: Rank of the current process.
+ sync_device: PyTorch device to use for inter-process
+ communication, or None to disable multi-process
+ collection. Typically `torch.device('cuda', rank)`.
+ """
+ global _rank, _sync_device
+ assert not _sync_called
+ _rank = rank
+ _sync_device = sync_device
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def report(name, value):
+ r"""Broadcasts the given set of scalars to all interested instances of
+ `Collector`, across device and process boundaries.
+
+ This function is expected to be extremely cheap and can be safely
+ called from anywhere in the training loop, loss function, or inside a
+ `torch.nn.Module`.
+
+ Warning: The current implementation expects the set of unique names to
+ be consistent across processes. Please make sure that `report()` is
+ called at least once for each unique name by each process, and in the
+ same order. If a given process has no scalars to broadcast, it can do
+ `report(name, [])` (empty list).
+
+ Args:
+ name: Arbitrary string specifying the name of the statistic.
+ Averages are accumulated separately for each unique name.
+ value: Arbitrary set of scalars. Can be a list, tuple,
+ NumPy array, PyTorch tensor, or Python scalar.
+
+ Returns:
+ The same `value` that was passed in.
+ """
+ if name not in _counters:
+ _counters[name] = dict()
+
+ elems = torch.as_tensor(value)
+ if elems.numel() == 0:
+ return value
+
+ elems = elems.detach().flatten().to(_reduce_dtype)
+ moments = torch.stack([
+ torch.ones_like(elems).sum(),
+ elems.sum(),
+ elems.square().sum(),
+ ])
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
+ moments = moments.to(_counter_dtype)
+
+ device = moments.device
+ if device not in _counters[name]:
+ _counters[name][device] = torch.zeros_like(moments)
+ _counters[name][device].add_(moments)
+ return value
+
+#----------------------------------------------------------------------------
+
+def report0(name, value):
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
+ but ignores any scalars provided by the other processes.
+ See `report()` for further details.
+ """
+ report(name, value if _rank == 0 else [])
+ return value
+
+#----------------------------------------------------------------------------
+
+class Collector:
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
+ computes their long-term averages (mean and standard deviation) over
+ user-defined periods of time.
+
+ The averages are first collected into internal counters that are not
+ directly visible to the user. They are then copied to the user-visible
+ state as a result of calling `update()` and can then be queried using
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
+ internal counters for the next round, so that the user-visible state
+ effectively reflects averages collected between the last two calls to
+ `update()`.
+
+ Args:
+ regex: Regular expression defining which statistics to
+ collect. The default is to collect everything.
+ keep_previous: Whether to retain the previous averages if no
+ scalars were collected on a given round
+ (default: True).
+ """
+ def __init__(self, regex='.*', keep_previous=True):
+ self._regex = re.compile(regex)
+ self._keep_previous = keep_previous
+ self._cumulative = dict()
+ self._moments = dict()
+ self.update()
+ self._moments.clear()
+
+ def names(self):
+ r"""Returns the names of all statistics broadcasted so far that
+ match the regular expression specified at construction time.
+ """
+ return [name for name in _counters if self._regex.fullmatch(name)]
+
+ def update(self):
+ r"""Copies current values of the internal counters to the
+ user-visible state and resets them for the next round.
+
+ If `keep_previous=True` was specified at construction time, the
+ operation is skipped for statistics that have received no scalars
+ since the last update, retaining their previous averages.
+
+ This method performs a number of GPU-to-CPU transfers and one
+ `torch.distributed.all_reduce()`. It is intended to be called
+ periodically in the main training loop, typically once every
+ N training steps.
+ """
+ if not self._keep_previous:
+ self._moments.clear()
+ for name, cumulative in _sync(self.names()):
+ if name not in self._cumulative:
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ delta = cumulative - self._cumulative[name]
+ self._cumulative[name].copy_(cumulative)
+ if float(delta[0]) != 0:
+ self._moments[name] = delta
+
+ def _get_delta(self, name):
+ r"""Returns the raw moments that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ assert self._regex.fullmatch(name)
+ if name not in self._moments:
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ return self._moments[name]
+
+ def num(self, name):
+ r"""Returns the number of scalars that were accumulated for the given
+ statistic between the last two calls to `update()`, or zero if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ return int(delta[0])
+
+ def mean(self, name):
+ r"""Returns the mean of the scalars that were accumulated for the
+ given statistic between the last two calls to `update()`, or NaN if
+ no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0:
+ return float('nan')
+ return float(delta[1] / delta[0])
+
+ def std(self, name):
+ r"""Returns the standard deviation of the scalars that were
+ accumulated for the given statistic between the last two calls to
+ `update()`, or NaN if no scalars were collected.
+ """
+ delta = self._get_delta(name)
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
+ return float('nan')
+ if int(delta[0]) == 1:
+ return float(0)
+ mean = float(delta[1] / delta[0])
+ raw_var = float(delta[2] / delta[0])
+ return np.sqrt(max(raw_var - np.square(mean), 0))
+
+ def as_dict(self):
+ r"""Returns the averages accumulated between the last two calls to
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
+
+ dnnlib.EasyDict(
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
+ ...
+ )
+ """
+ stats = dnnlib.EasyDict()
+ for name in self.names():
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
+ return stats
+
+ def __getitem__(self, name):
+ r"""Convenience getter.
+ `collector[name]` is a synonym for `collector.mean(name)`.
+ """
+ return self.mean(name)
+
+#----------------------------------------------------------------------------
+
+def _sync(names):
+ r"""Synchronize the global cumulative counters across devices and
+ processes. Called internally by `Collector.update()`.
+ """
+ if len(names) == 0:
+ return []
+ global _sync_called
+ _sync_called = True
+
+ # Collect deltas within current rank.
+ deltas = []
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
+ for name in names:
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
+ for counter in _counters[name].values():
+ delta.add_(counter.to(device))
+ counter.copy_(torch.zeros_like(counter))
+ deltas.append(delta)
+ deltas = torch.stack(deltas)
+
+ # Sum deltas across ranks.
+ if _sync_device is not None:
+ torch.distributed.all_reduce(deltas)
+
+ # Update cumulative values.
+ deltas = deltas.cpu()
+ for idx, name in enumerate(names):
+ if name not in _cumulative:
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
+ _cumulative[name].add_(deltas[idx])
+
+ # Return name-value pairs.
+ return [(name, _cumulative[name]) for name in names]
+
+#----------------------------------------------------------------------------
diff --git a/training/__init__.py b/training/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e1e1a5ba99e56a56ecaa14f7d4fa41777789c0cf
--- /dev/null
+++ b/training/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/training/data/augment.py b/training/data/augment.py
new file mode 100755
index 0000000000000000000000000000000000000000..22d4ff22e3c01de0b05d5fdf9503be555263b92b
--- /dev/null
+++ b/training/data/augment.py
@@ -0,0 +1,423 @@
+import numpy as np
+import scipy.signal
+import torch
+from torch_utils import persistence
+from torch_utils import misc
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import grid_sample_gradfix
+from torch_utils.ops import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+# Coefficients of various wavelet decomposition low-pass filters.
+
+wavelets = {
+ 'haar': [0.7071067811865476, 0.7071067811865476],
+ 'db1': [0.7071067811865476, 0.7071067811865476],
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
+}
+
+#----------------------------------------------------------------------------
+# Helpers for constructing transformation matrices.
+
+def matrix(*rows, device=None):
+ assert all(len(row) == len(rows[0]) for row in rows)
+ elems = [x for row in rows for x in row]
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
+ if len(ref) == 0:
+ return misc.constant(np.asarray(rows), device=device)
+ assert device is None or device == ref[0].device
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
+
+def translate2d(tx, ty, **kwargs):
+ return matrix(
+ [1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1],
+ **kwargs)
+
+def translate3d(tx, ty, tz, **kwargs):
+ return matrix(
+ [1, 0, 0, tx],
+ [0, 1, 0, ty],
+ [0, 0, 1, tz],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def scale2d(sx, sy, **kwargs):
+ return matrix(
+ [sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1],
+ **kwargs)
+
+def scale3d(sx, sy, sz, **kwargs):
+ return matrix(
+ [sx, 0, 0, 0],
+ [0, sy, 0, 0],
+ [0, 0, sz, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def rotate2d(theta, **kwargs):
+ return matrix(
+ [torch.cos(theta), torch.sin(-theta), 0],
+ [torch.sin(theta), torch.cos(theta), 0],
+ [0, 0, 1],
+ **kwargs)
+
+def rotate3d(v, theta, **kwargs):
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
+ return matrix(
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+def translate2d_inv(tx, ty, **kwargs):
+ return translate2d(-tx, -ty, **kwargs)
+
+def scale2d_inv(sx, sy, **kwargs):
+ return scale2d(1 / sx, 1 / sy, **kwargs)
+
+def rotate2d_inv(theta, **kwargs):
+ return rotate2d(-theta, **kwargs)
+
+#----------------------------------------------------------------------------
+# Versatile image augmentation pipeline from the paper
+# "Training Generative Adversarial Networks with Limited Data".
+#
+# All augmentations are disabled by default; individual augmentations can
+# be enabled by setting their probability multipliers to 1.
+
+@persistence.persistent_class
+class AugmentPipe(torch.nn.Module):
+ def __init__(self,
+ xflip=0, rotate90=0, xint=0, xint_max=0.125,
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
+ noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
+ ):
+ super().__init__()
+ self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
+
+ # Pixel blitting.
+ self.xflip = float(xflip) # Probability multiplier for x-flip.
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
+ self.xint = float(xint) # Probability multiplier for integer translation.
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
+
+ # General geometric transformations.
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
+
+ # Color transformations.
+ self.brightness = float(brightness) # Probability multiplier for brightness.
+ self.contrast = float(contrast) # Probability multiplier for contrast.
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
+ self.hue = float(hue) # Probability multiplier for hue rotation.
+ self.saturation = float(saturation) # Probability multiplier for saturation.
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
+
+ # Image-space filtering.
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
+
+ # Image-space corruptions.
+ self.noise = float(noise) # Probability multiplier for additive RGB noise.
+ self.cutout = float(cutout) # Probability multiplier for cutout.
+ self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
+ self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
+
+ # Setup orthogonal lowpass filter for geometric augmentations.
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
+
+ # Construct filter bank for image-space filtering.
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
+ for i in range(1, Hz_fbank.shape[0]):
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
+
+ def forward(self, images, debug_percentile=None):
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
+ batch_size, num_channels, height, width = images.shape
+ device = images.device
+ if debug_percentile is not None:
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
+
+ # -------------------------------------
+ # Select parameters for pixel blitting.
+ # -------------------------------------
+
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
+ I_3 = torch.eye(3, device=device)
+ G_inv = I_3
+
+ # Apply x-flip with probability (xflip * strength).
+ if self.xflip > 0:
+ i = torch.floor(torch.rand([batch_size], device=device) * 2)
+ i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
+ G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
+
+ # Apply 90 degree rotations with probability (rotate90 * strength).
+ if self.rotate90 > 0:
+ i = torch.floor(torch.rand([batch_size], device=device) * 4)
+ i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 4))
+ G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
+
+ # Apply integer translation with probability (xint * strength).
+ if self.xint > 0:
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
+
+ # --------------------------------------------------------
+ # Select parameters for general geometric transformations.
+ # --------------------------------------------------------
+
+ # Apply isotropic scaling with probability (scale * strength).
+ if self.scale > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
+ G_inv = G_inv @ scale2d_inv(s, s)
+
+ # Apply pre-rotation with probability p_rot.
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
+
+ # Apply anisotropic scaling with probability (aniso * strength).
+ if self.aniso > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
+
+ # Apply post-rotation with probability p_rot.
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.zeros_like(theta)
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
+
+ # Apply fractional translation with probability (xfrac * strength).
+ if self.xfrac > 0:
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
+
+ # ----------------------------------
+ # Execute geometric transformations.
+ # ----------------------------------
+
+ # Execute if the transform is not identity.
+ if G_inv is not I_3:
+
+ # Calculate padding.
+ cx = (width - 1) / 2
+ cy = (height - 1) / 2
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
+ Hz_pad = self.Hz_geom.shape[0] // 4
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
+
+ # Pad image and adjust origin.
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
+
+ # Upsample.
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
+
+ # Execute transformation.
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
+ images = grid_sample_gradfix.grid_sample(images, grid)
+
+ # Downsample and crop.
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
+
+ # --------------------------------------------
+ # Select parameters for color transformations.
+ # --------------------------------------------
+
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
+ I_4 = torch.eye(4, device=device)
+ C = I_4
+
+ # Apply brightness with probability (brightness * strength).
+ if self.brightness > 0:
+ b = torch.randn([batch_size], device=device) * self.brightness_std
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
+ if debug_percentile is not None:
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
+ C = translate3d(b, b, b) @ C
+
+ # Apply contrast with probability (contrast * strength).
+ if self.contrast > 0:
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
+ if debug_percentile is not None:
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
+ C = scale3d(c, c, c) @ C
+
+ # Apply luma flip with probability (lumaflip * strength).
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
+ if self.lumaflip > 0:
+ i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
+ i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
+ if debug_percentile is not None:
+ i = torch.full_like(i, torch.floor(debug_percentile * 2))
+ C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
+
+ # Apply hue rotation with probability (hue * strength).
+ if self.hue > 0 and num_channels > 1:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
+ C = rotate3d(v, theta) @ C # Rotate around v.
+
+ # Apply saturation with probability (saturation * strength).
+ if self.saturation > 0 and num_channels > 1:
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
+
+ # ------------------------------
+ # Execute color transformations.
+ # ------------------------------
+
+ # Execute if the transform is not identity.
+ if C is not I_4:
+ images = images.reshape([batch_size, num_channels, height * width])
+ if num_channels == 3:
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
+ elif num_channels == 1:
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
+ else:
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ----------------------
+ # Image-space filtering.
+ # ----------------------
+
+ if self.imgfilter > 0:
+ num_bands = self.Hz_fbank.shape[0]
+ assert len(self.imgfilter_bands) == num_bands
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
+
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
+ for i, band_strength in enumerate(self.imgfilter_bands):
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
+ if debug_percentile is not None:
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
+ t[:, i] = t_i # Replace i'th element.
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
+ g = g * t # Accumulate into global gain.
+
+ # Construct combined amplification filter.
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
+
+ # Apply filter.
+ p = self.Hz_fbank.shape[1] // 2
+ images = images.reshape([1, batch_size * num_channels, height, width])
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ------------------------
+ # Image-space corruptions.
+ # ------------------------
+
+ # Apply additive RGB noise with probability (noise * strength).
+ if self.noise > 0:
+ sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
+ sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
+ if debug_percentile is not None:
+ sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
+ images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
+
+ # Apply cutout with probability (cutout * strength).
+ if self.cutout > 0:
+ size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
+ size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
+ center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
+ if debug_percentile is not None:
+ size = torch.full_like(size, self.cutout_size)
+ center = torch.full_like(center, debug_percentile)
+ coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
+ coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
+ mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
+ mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
+ mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
+ images = images * mask
+
+ return images
+
+#----------------------------------------------------------------------------
diff --git a/training/data/configs/medium_256.yaml b/training/data/configs/medium_256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..02e8be55bf9ae56a3316f882432eb50b16abdb93
--- /dev/null
+++ b/training/data/configs/medium_256.yaml
@@ -0,0 +1,24 @@
+kind: 'mixed'
+
+mask_gen_kwargs:
+ irregular_proba: 1
+ hole_range:
+ - 0.0
+ - 0.7
+ irregular_kwargs:
+ min_times: 4
+ max_times: 5
+ max_width: 50
+ max_angle: 4
+ max_len: 100
+
+ box_proba: 0.3
+ box_kwargs:
+ margin: 0
+ bbox_min_size: 10
+ bbox_max_size: 50
+ max_times: 5
+ min_times: 1
+
+ segm_proba: 0
+ squares_proba: 0
\ No newline at end of file
diff --git a/training/data/configs/medium_512.yaml b/training/data/configs/medium_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7572770b6620af2c93780f0aeffb0730288f90e6
--- /dev/null
+++ b/training/data/configs/medium_512.yaml
@@ -0,0 +1,24 @@
+kind: 'mixed'
+
+mask_gen_kwargs:
+ irregular_proba: 1
+ hole_range:
+ - 0.0
+ - 0.7
+ irregular_kwargs:
+ min_times: 4
+ max_times: 10
+ max_width: 100
+ max_angle: 4
+ max_len: 200
+
+ box_proba: 0.3
+ box_kwargs:
+ margin: 0
+ bbox_min_size: 30
+ bbox_max_size: 150
+ max_times: 5
+ min_times: 1
+
+ segm_proba: 0
+ squares_proba: 0
\ No newline at end of file
diff --git a/training/data/configs/segm_256.yaml b/training/data/configs/segm_256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6c617c80be626f9ba5ff488e6338d25cc99bf3da
--- /dev/null
+++ b/training/data/configs/segm_256.yaml
@@ -0,0 +1,14 @@
+kind: segmentation
+
+mask_gen_kwargs:
+ confidence_threshold: 0.5
+
+max_masks_per_image: 1
+
+cropping:
+ out_min_size: 256
+ handle_small_mode: upscale
+ out_square_crop: True
+ crop_min_overlap: 1
+
+max_tamper_area: 0.5
\ No newline at end of file
diff --git a/training/data/configs/segm_512.yaml b/training/data/configs/segm_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25070d739296c71adbb8f362bc64c6be74a8cbdb
--- /dev/null
+++ b/training/data/configs/segm_512.yaml
@@ -0,0 +1,14 @@
+kind: segmentation
+
+mask_gen_kwargs:
+ confidence_threshold: 0.5
+
+max_masks_per_image: 1
+
+cropping:
+ out_min_size: 512
+ handle_small_mode: upscale
+ out_square_crop: True
+ crop_min_overlap: 1
+
+max_tamper_area: 0.5
\ No newline at end of file
diff --git a/training/data/configs/thick_256.yaml b/training/data/configs/thick_256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e55f3412c56a3e686cef97ce555c567e1ff96243
--- /dev/null
+++ b/training/data/configs/thick_256.yaml
@@ -0,0 +1,24 @@
+kind: 'mixed'
+
+mask_gen_kwargs:
+ irregular_proba: 1
+ hole_range:
+ - 0.0
+ - 0.7
+ irregular_kwargs:
+ min_times: 1
+ max_times: 5
+ max_width: 100
+ max_angle: 4
+ max_len: 200
+
+ box_proba: 0.3
+ box_kwargs:
+ margin: 10
+ bbox_min_size: 30
+ bbox_max_size: 150
+ max_times: 3
+ min_times: 1
+
+ segm_proba: 0
+ squares_proba: 0
diff --git a/training/data/configs/thick_512.yaml b/training/data/configs/thick_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f04bf9853088c3be50891032cd83d11276c27423
--- /dev/null
+++ b/training/data/configs/thick_512.yaml
@@ -0,0 +1,24 @@
+kind: 'mixed'
+
+mask_gen_kwargs:
+ irregular_proba: 1
+ hole_range:
+ - 0.0
+ - 0.7
+ irregular_kwargs:
+ min_times: 1
+ max_times: 5
+ max_width: 250
+ max_angle: 4
+ max_len: 450
+
+ box_proba: 0.3
+ box_kwargs:
+ margin: 10
+ bbox_min_size: 30
+ bbox_max_size: 300
+ max_times: 4
+ min_times: 1
+
+ segm_proba: 0
+ squares_proba: 0
diff --git a/training/data/configs/thin_256.yaml b/training/data/configs/thin_256.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4a91a3826f6e63fcfd4705ec9da66e7d96f86e5
--- /dev/null
+++ b/training/data/configs/thin_256.yaml
@@ -0,0 +1,20 @@
+kind: 'mixed'
+
+mask_gen_kwargs:
+ irregular_proba: 1
+ hole_range:
+ - 0.0
+ - 0.7
+ irregular_kwargs:
+ min_times: 4
+ max_times: 50
+ max_width: 10
+ max_angle: 4
+ max_len: 40
+
+ box_proba: 0
+ segm_proba: 0
+ squares_proba: 0
+
+ segm_proba: 0
+ squares_proba: 0
diff --git a/training/data/configs/thin_512.yaml b/training/data/configs/thin_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3659dc2a61dbaf2db3896b7dfa14e1814c51a92b
--- /dev/null
+++ b/training/data/configs/thin_512.yaml
@@ -0,0 +1,23 @@
+kind: 'mixed'
+
+mask_gen_kwargs:
+ irregular_proba: 1
+ hole_range:
+ - 0.0
+ - 0.7
+ irregular_kwargs:
+ min_times: 4
+ max_times: 70
+ max_width: 20
+ max_angle: 4
+ max_len: 100
+ box_proba: 0
+ segm_proba: 0
+ squares_proba: 0
+
+ box_proba: 0
+ segm_proba: 0
+ squares_proba: 0
+
+ segm_proba: 0
+ squares_proba: 0
\ No newline at end of file
diff --git a/training/data/dataset.py b/training/data/dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..eb5deb6fc27b1b068ba4cb6d6e632a08b3d558f4
--- /dev/null
+++ b/training/data/dataset.py
@@ -0,0 +1,242 @@
+import os
+import numpy as np
+import PIL.Image
+import json
+import torch
+import dnnlib
+import dnnlib
+import cv2
+from icecream import ic
+from . import mask_generator
+import os.path as osp
+import matplotlib.pyplot as plt
+from icecream import ic
+import matplotlib.cm as cm
+import copy
+import albumentations as A
+try:
+ import pyspng
+except ImportError:
+ pyspng = None
+
+#----------------------------------------------------------------------------
+
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self,
+ name, # Name of the dataset.
+ raw_shape, # Shape of the raw image data (NCHW).
+ max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
+ use_labels = False, # Enable conditioning labels? False = label dimension is zero.
+ xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
+ random_seed = 0, # Random seed to use when applying max_size.
+ ):
+ self._name = name
+ self._raw_shape = list(raw_shape)
+ self._use_labels = use_labels
+ self._raw_labels = None
+ self._label_shape = None
+
+ # Apply max_size.
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
+ if (max_size is not None) and (self._raw_idx.size > max_size):
+ np.random.RandomState(random_seed).shuffle(self._raw_idx)
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
+
+ # Apply xflip.
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
+ if xflip:
+ self._raw_idx = np.tile(self._raw_idx, 2)
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
+
+ def _get_raw_labels(self):
+ if self._raw_labels is None:
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
+ if self._raw_labels is None:
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
+ assert isinstance(self._raw_labels, np.ndarray)
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
+ assert self._raw_labels.dtype in [np.float32, np.int64]
+ if self._raw_labels.dtype == np.int64:
+ assert self._raw_labels.ndim == 1
+ assert np.all(self._raw_labels >= 0)
+ return self._raw_labels
+
+ def close(self): # to be overridden by subclass
+ pass
+
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
+ raise NotImplementedError
+
+ def _load_raw_labels(self): # to be overridden by subclass
+ raise NotImplementedError
+
+ def __getstate__(self):
+ return dict(self.__dict__, _raw_labels=None)
+
+ def __del__(self):
+ try:
+ self.close()
+ except:
+ pass
+
+ def __len__(self):
+ return self._raw_idx.size
+
+ def __getitem__(self, idx):
+ image = self._load_raw_image(self._raw_idx[idx])
+ assert isinstance(image, np.ndarray)
+ assert list(image.shape) == self.image_shape
+ assert image.dtype == np.uint8
+ if self._xflip[idx]:
+ assert image.ndim == 3 # CHW
+ image = image[:, :, ::-1]
+ return image.copy(), self.get_label(idx)
+
+ def get_label(self, idx):
+ label = self._get_raw_labels()[self._raw_idx[idx]]
+ if label.dtype == np.int64:
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
+ onehot[label] = 1
+ label = onehot
+ return label.copy()
+
+ def get_details(self, idx):
+ d = dnnlib.EasyDict()
+ d.raw_idx = int(self._raw_idx[idx])
+ d.xflip = (int(self._xflip[idx]) != 0)
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
+ return d
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def image_shape(self):
+ return list(self._raw_shape[1:])
+
+ @property
+ def num_channels(self):
+ assert len(self.image_shape) == 3 # CHW
+ return self.image_shape[0]
+
+ @property
+ def resolution(self):
+ assert len(self.image_shape) == 3 # CHW
+ assert self.image_shape[1] == self.image_shape[2]
+ return self.image_shape[1]
+
+ @property
+ def label_shape(self):
+ if self._label_shape is None:
+ raw_labels = self._get_raw_labels()
+ if raw_labels.dtype == np.int64:
+ self._label_shape = [int(np.max(raw_labels)) + 1]
+ else:
+ self._label_shape = raw_labels.shape[1:]
+ return list(self._label_shape)
+
+ @property
+ def label_dim(self):
+ assert len(self.label_shape) == 1
+ return self.label_shape[0]
+
+ @property
+ def has_labels(self):
+ return any(x != 0 for x in self.label_shape)
+
+ @property
+ def has_onehot_labels(self):
+ return self._get_raw_labels().dtype == np.int64
+
+#----------------------------------------------------------------------------
+
+class ImageDataset(Dataset):
+
+ def __init__(self,
+ img_path, # Path to images.
+ resolution = None, # Ensure specific resolution, None = highest available.
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self.sz = resolution
+ self.img_path = img_path
+ self._type = 'dir'
+ self.files = []
+
+ self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self.img_path) for root, _dirs, files in os.walk(self.img_path) for fname in files]
+ PIL.Image.init()
+ self._image_fnames = sorted(os.path.join(self.img_path,fname) for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
+ if len(self._image_fnames) == 0:
+ raise IOError('No image files found in the specified path')
+
+ self.files = []
+
+ for f in self._image_fnames:
+ if not '_mask' in f:
+ self.files.append(f)
+
+ self.files = sorted(self.files)
+
+ self.transform = A.Compose([
+ A.PadIfNeeded(min_height=self.sz, min_width=self.sz),
+ A.OpticalDistortion(),
+ A.RandomCrop(height=self.sz, width=self.sz),
+ A.HorizontalFlip(),
+ A.CLAHE(),
+ A.ToFloat()
+ ])
+
+ name = os.path.splitext(os.path.basename(self.img_path))[0]
+ raw_shape = [len(self.files)] + list(self._load_raw_image(0).shape)
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
+ raise IOError('Image files do not match the specified resolution')
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
+
+ def __len__(self):
+ return len(self.files)
+
+ def _load_image(self, fn):
+ return PIL.Image.open(fn).convert('RGB')
+
+ @staticmethod
+ def _file_ext(fname):
+ return os.path.splitext(fname)[1].lower()
+
+ def _load_raw_image(self, raw_idx):
+ fname = self.files[raw_idx]
+ image = np.array(PIL.Image.open(fname).convert('RGB'))
+ image = self.transform(image=image)['image']
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis] # HW => HWC
+ image = image.transpose(2, 0, 1) # HWC => CHW
+ return image
+
+ def _load_raw_labels(self):
+ fname = 'dataset.json'
+ if fname not in self._all_fnames:
+ return None
+ with self._open_file(fname) as f:
+ labels = json.load(f)['labels']
+ if labels is None:
+ return None
+ labels = dict(labels)
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
+ labels = np.array(labels)
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
+ return labels
+
+ def _get_image(self, idx):
+ fname = self.files[idx]
+ mask = mask_generator.generate_random_mask(s=self.sz, hole_range=[0.1,0.7])
+
+ rgb = np.array(self._load_image(fname)) # uint8
+ rgb = self.transform(image=rgb)['image']
+ rgb = np.rint(rgb * 255).clip(0, 255).astype(np.uint8)
+
+ return rgb, mask
+
+ def __getitem__(self, idx):
+ rgb, mask = self._get_image(idx) # modal, uint8 {0, 1}
+ rgb = rgb.transpose(2,0,1)
+
+ return rgb, mask, super().get_label(idx)
\ No newline at end of file
diff --git a/training/data/demo_loader.py b/training/data/demo_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a5de4b15e90e9024a1976cbd57de2df046ca2d5
--- /dev/null
+++ b/training/data/demo_loader.py
@@ -0,0 +1,109 @@
+from tabnanny import filename_only
+import numpy as np
+import cv2
+import os
+import PIL
+import torch
+from .dataset import Dataset
+
+class ImageDataset(Dataset):
+
+ def __init__(self,
+ img_path, # Path to images.
+ resolution = None, # Ensure specific resolution, None = highest available.
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self.sz = resolution
+ self.img_path = img_path
+ self._type = 'dir'
+ self.files = []
+ self.idx = 0
+
+ self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self.img_path) for root, _dirs, files in os.walk(self.img_path) for fname in files]
+ PIL.Image.init()
+ self._image_fnames = sorted(os.path.join(self.img_path,fname) for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
+ if len(self._image_fnames) == 0:
+ raise IOError('No image files found in the specified path')
+
+ self.files = []
+
+ for f in self._image_fnames:
+ if not '_mask' in f:
+ self.files.append(f)
+
+ self.files = sorted(self.files)
+
+ def __len__(self):
+ return len(self.files)
+
+ @staticmethod
+ def _file_ext(fname):
+ return os.path.splitext(fname)[1].lower()
+
+ def _load_image(self, fn):
+ return PIL.Image.open(fn).convert('RGB')
+
+ def _get_image(self, idx):
+ # imgfn, seg_map, img_id = self.data_reader.get_image(idx)
+
+ fname = self.files[idx]
+ ext = self._file_ext(fname)
+
+ mask = np.array(self._load_image(fname.replace(ext, f'_mask{ext}')).convert('L')) / 255
+ mask = cv2.resize(mask,
+ (self.sz, self.sz), interpolation=cv2.INTER_NEAREST)
+
+ rgb = np.array(self._load_image(fname)) # uint8
+ rgb = cv2.resize(rgb,
+ (self.sz, self.sz), interpolation=cv2.INTER_AREA)
+
+ return rgb, fname.split('/')[-1].replace(ext, ''), mask
+
+ def __getitem__(self, idx):
+ rgb, fname, mask = self._get_image(idx) # modal, uint8 {0, 1}
+ rgb = rgb.transpose(2,0,1)
+
+ mask_tensor = torch.from_numpy(mask).to(torch.float32)
+ mask_tensor = mask_tensor.unsqueeze(0)
+ rgb = torch.from_numpy(rgb.astype(np.float32))
+ rgb = (rgb.to(torch.float32) / 127.5 - 1)
+ rgb_erased = rgb.clone()
+ rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
+ rgb_erased = rgb_erased.to(torch.float32)
+
+ return rgb, rgb_erased, mask_tensor, fname
+
+def collate_fn(data):
+ """Creates mini-batch tensors from the list of images.
+
+ We should build custom collate_fn rather than using default collate_fn,
+ because merging caption (including padding) is not supported in default.
+ Args:
+ data: list
+ - image: torch tensor of shape (3, 256, 256).
+
+ Returns:
+ images: torch tensor of shape (batch_size, 3, 256, 256).
+
+ """
+
+ rgbs, rgbs_erased, mask_tensors, fnames = zip(*data)
+
+ rgbs = list(rgbs)
+ rgbs_erased = list(rgbs_erased)
+ mask_tensors = list(mask_tensors)
+ fnames = list(fnames)
+
+ return torch.stack(rgbs, dim=0), torch.stack(rgbs_erased, dim=0), torch.stack(mask_tensors, dim=0), fnames
+
+def get_loader(img_path, resolution):
+ """Returns torch.utils.data.DataLoader for custom coco dataset."""
+
+ ds = ImageDataset(img_path=img_path, resolution=resolution)
+
+ data_loader = torch.utils.data.DataLoader(dataset=ds,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ collate_fn=collate_fn)
+ return data_loader
\ No newline at end of file
diff --git a/training/data/gen_loader.py b/training/data/gen_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b2771ec970da4aeae217b63d8f8aee40504d25
--- /dev/null
+++ b/training/data/gen_loader.py
@@ -0,0 +1,119 @@
+import numpy as np
+import cv2
+import os
+import PIL
+import torch
+from .dataset import Dataset
+from . import mask_generator
+from . import lama_mask_generator_test as lama_mask_generator
+import os.path as osp
+
+class ImageDataset(Dataset):
+
+ def __init__(self,
+ img_path, # Path to images.
+ resolution = 256, # Ensure specific resolution, None = highest available.
+ msk_ratio = None, # Masked ratio for freeform masks
+ lama_cfg = None, # Lama masks config file
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self.sz = resolution
+ self.img_path = img_path
+ self._type = 'dir'
+ self.files = []
+ self.idx = 0
+ self.is_comod = msk_ratio is not None
+ self.mask_ratio = msk_ratio
+
+ if not self.is_comod:
+ self.lama_mask_generator = lama_mask_generator.get_mask_generator(kind=lama_cfg['kind'], cfg=lama_cfg['mask_gen_kwargs'])
+ self.iter = 0
+
+ self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self.img_path) for root, _dirs, files in os.walk(self.img_path) for fname in files]
+ PIL.Image.init()
+ self._image_fnames = sorted(os.path.join(self.img_path,fname) for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
+ if len(self._image_fnames) == 0:
+ raise IOError('No image files found in the specified path')
+
+ self.files = []
+
+ for f in self._image_fnames:
+ if not '_mask' in f:
+ self.files.append(f)
+
+ self.files = sorted(self.files)
+
+ def __len__(self):
+ return len(self.files)
+
+ @staticmethod
+ def _file_ext(fname):
+ return os.path.splitext(fname)[1].lower()
+
+ def _load_image(self, fn):
+ return PIL.Image.open(fn).convert('RGB')
+
+ def _get_image(self, idx):
+
+ fname = self.files[idx]
+ ext = self._file_ext(fname)
+
+ rgb = np.array(self._load_image(fname)) # uint8
+ rgb = cv2.resize(rgb,
+ (self.sz, self.sz), interpolation=cv2.INTER_AREA)
+
+ if self.is_comod:
+ mask = mask_generator.generate_random_mask(s=self.sz, hole_range=self.mask_ratio)
+ else:
+ mask = self.lama_mask_generator(shape=(self.sz, self.sz), iter_i=self.iter)
+ self.iter += 1
+
+ return rgb, fname.split('/')[-1].replace(ext, ''), mask
+
+ def __getitem__(self, idx):
+ rgb, fname, mask = self._get_image(idx) # modal, uint8 {0, 1}
+ rgb = rgb.transpose(2,0,1)
+
+ mask_tensor = torch.from_numpy(mask).to(torch.float32)
+ rgb = torch.from_numpy(rgb.astype(np.float32))
+ rgb = (rgb.to(torch.float32) / 127.5 - 1)
+ rgb_erased = rgb.clone()
+ rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
+ rgb_erased = rgb_erased.to(torch.float32)
+
+ return rgb, rgb_erased, mask_tensor, fname
+
+def collate_fn(data):
+ """Creates mini-batch tensors from the list of images.
+
+ We should build custom collate_fn rather than using default collate_fn,
+ because merging caption (including padding) is not supported in default.
+ Args:
+ data: list
+ - image: torch tensor of shape (3, 256, 256).
+
+ Returns:
+ images: torch tensor of shape (batch_size, 3, 256, 256).
+
+ """
+
+ rgbs, rgbs_erased, mask_tensors, fnames = zip(*data)
+
+ rgbs = list(rgbs)
+ rgbs_erased = list(rgbs_erased)
+ mask_tensors = list(mask_tensors)
+ fnames = list(fnames)
+
+ return torch.stack(rgbs, dim=0), torch.stack(rgbs_erased, dim=0), torch.stack(mask_tensors, dim=0), fnames
+
+def get_loader(img_path, resolution, msk_ratio, lama_cfg):
+ """Returns torch.utils.data.DataLoader for custom coco dataset."""
+
+ ds = ImageDataset(img_path=img_path, resolution=resolution, msk_ratio=msk_ratio, lama_cfg=lama_cfg)
+
+ data_loader = torch.utils.data.DataLoader(dataset=ds,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ collate_fn=collate_fn)
+ return data_loader
\ No newline at end of file
diff --git a/training/data/lama_mask_generator_test.py b/training/data/lama_mask_generator_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc00c757ca49686514ce5c75c7f2a4420697e503
--- /dev/null
+++ b/training/data/lama_mask_generator_test.py
@@ -0,0 +1,307 @@
+import math
+import random
+import hashlib
+import logging
+from enum import Enum
+
+import cv2
+import numpy as np
+
+from utils.data_utils import LinearRamp
+from metrics.evaluation.masks.mask import SegmentationMask
+
+LOGGER = logging.getLogger(__name__)
+
+
+class DrawMethod(Enum):
+ LINE = 'line'
+ CIRCLE = 'circle'
+ SQUARE = 'square'
+
+
+def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
+ draw_method=DrawMethod.LINE):
+ draw_method = DrawMethod(draw_method)
+
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ start_x = np.random.randint(width)
+ start_y = np.random.randint(height)
+ for j in range(1 + np.random.randint(5)):
+ angle = 0.01 + np.random.randint(max_angle)
+ if i % 2 == 0:
+ angle = 2 * 3.1415926 - angle
+ length = 10 + np.random.randint(max_len)
+ brush_w = 5 + np.random.randint(max_width)
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
+ if draw_method == DrawMethod.LINE:
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
+ elif draw_method == DrawMethod.CIRCLE:
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
+ elif draw_method == DrawMethod.SQUARE:
+ radius = brush_w // 2
+ mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
+ start_x, start_y = end_x, end_y
+ return mask[None, ...]
+
+
+class RandomIrregularMaskGenerator:
+ def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
+ draw_method=DrawMethod.LINE):
+ self.max_angle = max_angle
+ self.max_len = max_len
+ self.max_width = max_width
+ self.min_times = min_times
+ self.max_times = max_times
+ self.draw_method = draw_method
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, shape, iter_i=None, raw_image=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_max_len = int(max(1, self.max_len * coef))
+ cur_max_width = int(max(1, self.max_width * coef))
+ cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
+ return make_random_irregular_mask(shape, max_angle=self.max_angle, max_len=cur_max_len,
+ max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
+ draw_method=self.draw_method)
+
+
+def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
+ mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
+ return mask[None, ...]
+
+
+class RandomRectangleMaskGenerator:
+ def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
+ self.margin = margin
+ self.bbox_min_size = bbox_min_size
+ self.bbox_max_size = bbox_max_size
+ self.min_times = min_times
+ self.max_times = max_times
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, shape, iter_i=None, raw_image=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
+ cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
+ return make_random_rectangle_mask(shape, margin=self.margin, bbox_min_size=self.bbox_min_size,
+ bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
+ max_times=cur_max_times)
+
+
+def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ step_x = np.random.randint(min_step, max_step + 1)
+ width_x = np.random.randint(min_width, min(step_x, max_width + 1))
+ offset_x = np.random.randint(0, step_x)
+
+ step_y = np.random.randint(min_step, max_step + 1)
+ width_y = np.random.randint(min_width, min(step_y, max_width + 1))
+ offset_y = np.random.randint(0, step_y)
+
+ for dy in range(width_y):
+ mask[offset_y + dy::step_y] = 1
+ for dx in range(width_x):
+ mask[:, offset_x + dx::step_x] = 1
+ return mask[None, ...]
+
+
+class RandomSuperresMaskGenerator:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+
+ def __call__(self, shape, iter_i=None):
+ return make_random_superres_mask(shape, **self.kwargs)
+
+
+class MixedMaskGenerator:
+ def __init__(self, irregular_proba=1/3, hole_range=[0,0,0.7], irregular_kwargs=None,
+ box_proba=1/3, box_kwargs=None,
+ segm_proba=1/3, segm_kwargs=None,
+ squares_proba=0, squares_kwargs=None,
+ superres_proba=0, superres_kwargs=None,
+ outpainting_proba=0, outpainting_kwargs=None,
+ invert_proba=0):
+ self.probas = []
+ self.gens = []
+ self.hole_range = hole_range
+
+ if irregular_proba > 0:
+ self.probas.append(irregular_proba)
+ if irregular_kwargs is None:
+ irregular_kwargs = {}
+ else:
+ irregular_kwargs = dict(irregular_kwargs)
+ irregular_kwargs['draw_method'] = DrawMethod.LINE
+ self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
+
+ if box_proba > 0:
+ self.probas.append(box_proba)
+ if box_kwargs is None:
+ box_kwargs = {}
+ self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
+
+ if squares_proba > 0:
+ self.probas.append(squares_proba)
+ if squares_kwargs is None:
+ squares_kwargs = {}
+ else:
+ squares_kwargs = dict(squares_kwargs)
+ squares_kwargs['draw_method'] = DrawMethod.SQUARE
+ self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
+
+ if superres_proba > 0:
+ self.probas.append(superres_proba)
+ if superres_kwargs is None:
+ superres_kwargs = {}
+ self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
+
+ self.probas = np.array(self.probas, dtype='float32')
+ self.probas /= self.probas.sum()
+ self.invert_proba = invert_proba
+
+ def __call__(self, shape, iter_i=None, raw_image=None):
+ kind = np.random.choice(len(self.probas), p=self.probas)
+ gen = self.gens[kind]
+ result = gen(shape, iter_i=iter_i, raw_image=raw_image)
+ if self.invert_proba > 0 and random.random() < self.invert_proba:
+ result = 1 - result
+ if np.mean(result) <= self.hole_range[0] or np.mean(result) >= self.hole_range[1]:
+ return self.__call__(shape, iter_i=iter_i, raw_image=raw_image)
+ else:
+ return result
+
+
+class RandomSegmentationMaskGenerator:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.impl = SegmentationMask(**self.kwargs)
+
+ def __call__(self, img, iter_i=None, raw_image=None, hole_range=[0.0, 0.3]):
+
+ masks = self.impl.get_masks(img)
+ fil_masks = []
+ for cur_mask in masks:
+ if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > hole_range[1]:
+ continue
+ fil_masks.append(cur_mask)
+
+ mask_index = np.random.choice(len(fil_masks),
+ size=1,
+ replace=False)
+ mask = fil_masks[mask_index]
+
+ return mask
+
+
+class SegMaskGenerator:
+ def __init__(self, hole_range=[0.1, 0.2], segm_kwargs=None):
+ if segm_kwargs is None:
+ segm_kwargs = {}
+ self.gen = RandomSegmentationMaskGenerator(**segm_kwargs)
+ self.hole_range = hole_range
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ result = self.gen(img=img, iter_i=iter_i, raw_image=raw_image, hole_range=self.hole_range)
+ return result
+
+class FGSegmentationMaskGenerator:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.impl = SegmentationMask(**self.kwargs)
+
+ def __call__(self, img, iter_i=None, raw_image=None, hole_range=[0.0, 0.3]):
+
+ masks = self.impl.get_masks(img)
+ mask = masks[0]
+ for cur_mask in masks:
+ if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > hole_range[1]:
+ continue
+ mask += cur_mask
+
+ mask = mask > 0
+ return mask
+
+class SegBGMaskGenerator:
+ def __init__(self, hole_range=[0.1, 0.2], segm_kwargs=None):
+ if segm_kwargs is None:
+ segm_kwargs = {}
+ self.gen = FGSegmentationMaskGenerator(**segm_kwargs)
+ self.hole_range = hole_range
+ self.cfg = {
+ 'irregular_proba': 1,
+ 'hole_range': [0.0, 1.0],
+ 'irregular_kwargs': {
+ 'max_angle': 4,
+ 'max_len': 250,
+ 'max_width': 150,
+ 'max_times': 3,
+ 'min_times': 1,
+ },
+ 'box_proba': 0,
+ 'box_kwargs': {
+ 'margin': 10,
+ 'bbox_min_size': 30,
+ 'bbox_max_size': 150,
+ 'max_times': 4,
+ 'min_times': 1,
+ }
+ }
+ self.bg_mask_gen = MixedMaskGenerator(**self.cfg)
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ shape = img.shape[:2]
+ mask_fg = self.gen(img=img, iter_i=iter_i, raw_image=raw_image, hole_range=self.hole_range)
+ bg_ratio = 1 - np.mean(mask_fg)
+ result = self.bg_mask_gen(shape, iter_i=iter_i, raw_image=raw_image)
+ result = result - mask_fg
+ if np.mean(result) <= self.hole_range[0]*bg_ratio or np.mean(result) >= self.hole_range[1]*bg_ratio:
+ return self.__call__(shape, iter_i=iter_i, raw_image=raw_image)
+ return result
+
+
+def get_mask_generator(kind, cfg=None):
+ if kind is None:
+ kind = "mixed"
+
+ if cfg is None:
+ cfg = {
+ 'irregular_proba': 1,
+ 'hole_range': [0.0, 0.7],
+ 'irregular_kwargs': {
+ 'max_angle': 4,
+ 'max_len': 200,
+ 'max_width': 100,
+ 'max_times': 5,
+ 'min_times': 1,
+ },
+ 'box_proba': 1,
+ 'box_kwargs': {
+ 'margin': 10,
+ 'bbox_min_size': 30,
+ 'bbox_max_size': 150,
+ 'max_times': 4,
+ 'min_times': 1,
+ },
+ 'segm_proba': 0,}
+
+ if kind == "mixed":
+ cl = MixedMaskGenerator
+ elif kind =="segmentation":
+ cl = SegBGMaskGenerator
+ else:
+ raise NotImplementedError(f"No such generator kind = {kind}")
+ return cl(**cfg)
\ No newline at end of file
diff --git a/training/data/mask_generator.py b/training/data/mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dd426a7068d950d12584d008a55bec948878a78
--- /dev/null
+++ b/training/data/mask_generator.py
@@ -0,0 +1,80 @@
+import numpy as np
+from PIL import Image, ImageDraw
+import math
+
+def RandomBrush(
+ max_tries,
+ s,
+ min_num_vertex = 4,
+ max_num_vertex = 18,
+ mean_angle = 2*math.pi / 5,
+ angle_range = 2*math.pi / 15,
+ min_width = 12,
+ max_width = 48):
+ H, W = s, s
+ average_radius = math.sqrt(H*H+W*W) / 8
+ mask = Image.new('L', (W, H), 0)
+ for _ in range(np.random.randint(max_tries)):
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
+ angles = []
+ vertex = []
+ for i in range(num_vertex):
+ if i % 2 == 0:
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
+ else:
+ angles.append(np.random.uniform(angle_min, angle_max))
+
+ h, w = mask.size
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
+ for i in range(num_vertex):
+ r = np.clip(
+ np.random.normal(loc=average_radius, scale=average_radius//2),
+ 0, 2*average_radius)
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
+ vertex.append((int(new_x), int(new_y)))
+
+ draw = ImageDraw.Draw(mask)
+ width = int(np.random.uniform(min_width, max_width))
+ draw.line(vertex, fill=1, width=width)
+ for v in vertex:
+ draw.ellipse((v[0] - width//2,
+ v[1] - width//2,
+ v[0] + width//2,
+ v[1] + width//2),
+ fill=1)
+ if np.random.random() > 0.5:
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
+ if np.random.random() > 0.5:
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
+ mask = np.asarray(mask, np.uint8)
+ if np.random.random() > 0.5:
+ mask = np.flip(mask, 0)
+ if np.random.random() > 0.5:
+ mask = np.flip(mask, 1)
+ return mask
+
+def RandomMask(s, hole_range=[0,1]):
+ coef = min(hole_range[0] + hole_range[1], 1.0)
+ while True:
+ mask = np.ones((s, s), np.uint8)
+ def Fill(max_size):
+ w, h = np.random.randint(max_size), np.random.randint(max_size)
+ ww, hh = w // 2, h // 2
+ x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
+ mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
+ def MultiFill(max_tries, max_size):
+ for _ in range(np.random.randint(max_tries)):
+ Fill(max_size)
+ MultiFill(int(10 * coef), s // 2)
+ MultiFill(int(5 * coef), s)
+ mask = 1 - np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
+ hole_ratio = np.mean(mask)
+ if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
+ continue
+ return mask[np.newaxis, ...].astype(np.float32)
+
+def generate_random_mask(s=256, hole_range=[0.1,1]):
+ return RandomMask(s, hole_range)
\ No newline at end of file
diff --git a/training/data/pred_loader.py b/training/data/pred_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..703ba042f8e268124186a47b65196bb51095e73d
--- /dev/null
+++ b/training/data/pred_loader.py
@@ -0,0 +1,104 @@
+from tabnanny import filename_only
+import numpy as np
+import cv2
+import os
+import PIL
+import torch
+from .dataset import Dataset
+
+class ImageDataset(Dataset):
+
+ def __init__(self,
+ img_path, # Path to images.
+ resolution = None, # Ensure specific resolution, None = highest available.
+ **super_kwargs, # Additional arguments for the Dataset base class.
+ ):
+ self.sz = resolution
+ self.img_path = img_path
+ self._type = 'dir'
+ self.files = []
+ self.idx = 0
+
+ self._all_fnames = [os.path.relpath(os.path.join(root, fname), start=self.img_path) for root, _dirs, files in os.walk(self.img_path) for fname in files]
+ PIL.Image.init()
+ self._image_fnames = sorted(os.path.join(self.img_path,fname) for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
+ if len(self._image_fnames) == 0:
+ raise IOError('No image files found in the specified path')
+
+ self.files = []
+
+ for f in self._image_fnames:
+ if not '_mask' in f:
+ self.files.append(f)
+
+ self.files = sorted(self.files)
+
+ def __len__(self):
+ return len(self.files)
+
+ @staticmethod
+ def _file_ext(fname):
+ return os.path.splitext(fname)[1].lower()
+
+ def _load_image(self, fn):
+ return PIL.Image.open(fn).convert('RGB')
+
+ def _get_image(self, idx):
+ # imgfn, seg_map, img_id = self.data_reader.get_image(idx)
+
+ fname = self.files[idx]
+ ext = self._file_ext(fname)
+
+ mask = np.array(self._load_image(fname.replace(ext, f'_mask000{ext}')).convert('L')) / 255
+ rgb = np.array(self._load_image(fname)) # uint8
+
+ return rgb, fname.split('/')[-1].replace(ext, ''), mask
+
+ def __getitem__(self, idx):
+ rgb, fname, mask = self._get_image(idx) # modal, uint8 {0, 1}
+ rgb = rgb.transpose(2,0,1)
+
+ mask_tensor = torch.from_numpy(mask).to(torch.float32)
+ mask_tensor = mask_tensor.unsqueeze(0)
+ rgb = torch.from_numpy(rgb.astype(np.float32))
+ rgb = (rgb.to(torch.float32) / 127.5 - 1)
+ rgb_erased = rgb.clone()
+ rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
+ rgb_erased = rgb_erased.to(torch.float32)
+
+ return rgb, rgb_erased, mask_tensor, fname
+
+def collate_fn(data):
+ """Creates mini-batch tensors from the list of images.
+
+ We should build custom collate_fn rather than using default collate_fn,
+ because merging caption (including padding) is not supported in default.
+ Args:
+ data: list
+ - image: torch tensor of shape (3, 256, 256).
+
+ Returns:
+ images: torch tensor of shape (batch_size, 3, 256, 256).
+
+ """
+
+ rgbs, rgbs_erased, mask_tensors, fnames = zip(*data)
+
+ rgbs = list(rgbs)
+ rgbs_erased = list(rgbs_erased)
+ mask_tensors = list(mask_tensors)
+ fnames = list(fnames)
+
+ return torch.stack(rgbs, dim=0), torch.stack(rgbs_erased, dim=0), torch.stack(mask_tensors, dim=0), fnames
+
+def get_loader(img_path, resolution):
+ """Returns torch.utils.data.DataLoader for custom coco dataset."""
+
+ ds = ImageDataset(img_path=img_path, resolution=resolution)
+
+ data_loader = torch.utils.data.DataLoader(dataset=ds,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ collate_fn=collate_fn)
+ return data_loader
\ No newline at end of file
diff --git a/training/ffc.py b/training/ffc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6bbd3f430f1665bac3f43b933be61892a8b0ddf
--- /dev/null
+++ b/training/ffc.py
@@ -0,0 +1,380 @@
+# Fast Fourier Convolution NeurIPS 2020
+# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
+# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from kornia.geometry.transform import rotate
+import torch.fft as fft
+from icecream import ic
+import PIL
+
+def save_image_grid(feats, fname, gridsize):
+ gw, gh = gridsize
+ idx = gw * gh
+
+ max_num = torch.max(feats[:idx]).item()
+ min_num = torch.min(feats[:idx]).item()
+ feats = feats[:idx].cpu() * 255 / (max_num - min_num)
+ feats = np.asarray(feats, dtype=np.float32)
+ feats = np.rint(feats).clip(0, 255).astype(np.uint8)
+
+ C, H, W = feats.shape
+
+ feats = feats.reshape(gh, gw, 1, H, W)
+ feats = feats.transpose(0, 3, 1, 4, 2)
+ feats = feats.reshape(gh * H, gw * W, 1)
+ feats = np.stack([feats]*3, axis=2).squeeze() * 10
+ feats = np.rint(feats).clip(0, 255).astype(np.uint8)
+
+ from icecream import ic
+ ic(feats.shape)
+
+ feats = PIL.Image.fromarray(feats)
+ feats.save(fname + '.png')
+
+def _conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ return F.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+class LearnableSpatialTransformWrapper(nn.Module):
+ def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
+ super().__init__()
+ self.impl = impl
+ self.angle = torch.rand(1) * angle_init_range
+ if train_angle:
+ self.angle = nn.Parameter(self.angle, requires_grad=True)
+ self.pad_coef = pad_coef
+
+ def forward(self, x):
+ if torch.is_tensor(x):
+ return self.inverse_transform(self.impl(self.transform(x)), x)
+ elif isinstance(x, tuple):
+ x_trans = tuple(self.transform(elem) for elem in x)
+ y_trans = self.impl(x_trans)
+ return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
+ else:
+ raise ValueError(f'Unexpected input type {type(x)}')
+
+ def transform(self, x):
+ height, width = x.shape[2:]
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
+ x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
+ x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
+ return x_padded_rotated
+
+ def inverse_transform(self, y_padded_rotated, orig_x):
+ height, width = orig_x.shape[2:]
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
+
+ y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
+ y_height, y_width = y_padded.shape[2:]
+ y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
+ return y
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=False),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ res = x * y.expand_as(x)
+ return res
+
+
+class FourierUnit(nn.Module):
+
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
+ # bn_layer not used
+ super(FourierUnit, self).__init__()
+ self.groups = groups
+
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
+ out_channels=out_channels * 2,
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
+ self.relu = torch.nn.ReLU(inplace=False)
+
+ # squeeze and excitation block
+ self.use_se = use_se
+ if use_se:
+ if se_kwargs is None:
+ se_kwargs = {}
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
+
+ self.spatial_scale_factor = spatial_scale_factor
+ self.spatial_scale_mode = spatial_scale_mode
+ self.spectral_pos_encoding = spectral_pos_encoding
+ self.ffc3d = ffc3d
+ self.fft_norm = fft_norm
+
+ def forward(self, x):
+ batch = x.shape[0]
+
+ if self.spatial_scale_factor is not None:
+ orig_size = x.shape[-2:]
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
+
+ r_size = x.size()
+ # (batch, c, h, w/2+1, 2)
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
+ ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
+
+ if self.spectral_pos_encoding:
+ height, width = ffted.shape[-2:]
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
+
+ if self.use_se:
+ ffted = self.se(ffted)
+
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
+ ffted = self.relu(ffted)
+
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
+
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
+
+ if self.spatial_scale_factor is not None:
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
+
+ return output
+
+
+class SpectralTransform(nn.Module):
+
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
+ # bn_layer not used
+ super(SpectralTransform, self).__init__()
+ self.enable_lfu = enable_lfu
+ if stride == 2:
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
+ else:
+ self.downsample = nn.Identity()
+
+ self.stride = stride
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels //
+ 2, kernel_size=1, groups=groups, bias=False),
+ # nn.BatchNorm2d(out_channels // 2),
+ nn.ReLU(inplace=True)
+ )
+ self.fu = FourierUnit(
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
+ if self.enable_lfu:
+ self.lfu = FourierUnit(
+ out_channels // 2, out_channels // 2, groups)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
+
+ def forward(self, x):
+
+ x = self.downsample(x)
+ x = self.conv1(x)
+ output = self.fu(x)
+
+ if self.enable_lfu:
+ n, c, h, w = x.shape
+ split_no = 2
+ split_s = h // split_no
+ xs = torch.cat(torch.split(
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
+ dim=1).contiguous()
+ xs = self.lfu(xs)
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
+ else:
+ xs = 0
+
+ output = self.conv2(x + output + xs)
+
+ return output
+
+class FFC(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ ratio_gin, ratio_gout, stride=1, padding=0,
+ dilation=1, groups=1, bias=False, enable_lfu=True,
+ padding_type='reflect', gated=False, **spectral_kwargs):
+ super(FFC, self).__init__()
+
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
+ self.stride = stride
+
+ in_cg = int(in_channels * ratio_gin)
+ in_cl = in_channels - in_cg
+ out_cg = int(out_channels * ratio_gout)
+ out_cl = out_channels - out_cg
+ #groups_g = 1 if groups == 1 else int(groups * ratio_gout)
+ #groups_l = 1 if groups == 1 else groups - groups_g
+
+ self.ratio_gin = ratio_gin
+ self.ratio_gout = ratio_gout
+ self.global_in_num = in_cg
+
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
+ self.convl2l = module(in_cl, out_cl, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
+ self.convl2g = module(in_cl, out_cg, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
+ self.convg2l = module(in_cg, out_cl, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
+ self.convg2g = module(
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
+
+ self.gated = gated
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
+ self.gate = module(in_channels, 2, 1)
+
+ def forward(self, x, fname=None):
+ x_l, x_g = x if type(x) is tuple else (x, 0)
+ out_xl, out_xg = 0, 0
+
+ if self.gated:
+ total_input_parts = [x_l]
+ if torch.is_tensor(x_g):
+ total_input_parts.append(x_g)
+ total_input = torch.cat(total_input_parts, dim=1)
+
+ gates = torch.sigmoid(self.gate(total_input))
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
+ else:
+ g2l_gate, l2g_gate = 1, 1
+
+ # for i in range(x_g.shape[0]):
+ # c, h, w = x_g[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(x_g[i].detach(), f'vis/{fname}_xg_{h}', (gh, gw))
+
+ # for i in range(x_l.shape[0]):
+ # c, h, w = x_l[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(x_l[i].detach(), f'vis/{fname}_xl_{h}', (gh, gw))
+
+ spec_x = self.convg2g(x_g)
+
+ # for i in range(spec_x.shape[0]):
+ # c, h, w = spec_x[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(spec_x[i].detach(), f'vis/{fname}_spec_x_{h}', (gh, gw))
+
+ if self.ratio_gout != 1:
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
+ if self.ratio_gout != 0:
+ out_xg = self.convl2g(x_l) * l2g_gate + spec_x
+
+ # for i in range(out_xg.shape[0]):
+ # c, h, w = out_xg[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(out_xg[i].detach(), f'vis/{fname}_outg_{h}', (gh, gw))
+
+ # for i in range(out_xl.shape[0]):
+ # c, h, w = out_xl[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(out_xl[i].detach(), f'vis/{fname}_outl_{h}', (gh, gw))
+
+ return out_xl, out_xg
+
+class FFC_BN_ACT(nn.Module):
+
+ def __init__(self, in_channels, out_channels,
+ kernel_size, ratio_gin, ratio_gout,
+ stride=1, padding=0, dilation=1, groups=1, bias=False,
+ norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity,
+ padding_type='reflect',
+ enable_lfu=True, **kwargs):
+ super(FFC_BN_ACT, self).__init__()
+ self.ffc = FFC(in_channels, out_channels, kernel_size,
+ ratio_gin, ratio_gout, stride, padding, dilation,
+ groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
+ global_channels = int(out_channels * ratio_gout)
+ # self.bn_l = lnorm(out_channels - global_channels)
+ # self.bn_g = gnorm(global_channels)
+
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
+ self.act_l = lact(inplace=True)
+ self.act_g = gact(inplace=True)
+
+ def forward(self, x, fname=None):
+ x_l, x_g = self.ffc(x, fname=fname,)
+ x_l = self.act_l(x_l)
+ x_g = self.act_g(x_g)
+ return x_l, x_g
+
+
+class FFCResnetBlock(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
+ spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75):
+ super().__init__()
+ self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ padding_type=padding_type,
+ ratio_gin=ratio_gin, ratio_gout=ratio_gout)
+ self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ padding_type=padding_type,
+ ratio_gin=ratio_gin, ratio_gout=ratio_gout)
+ if spatial_transform_kwargs is not None:
+ self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
+ self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
+ self.inline = inline
+
+ def forward(self, x, fname=None):
+ if self.inline:
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
+ else:
+ x_l, x_g = x if type(x) is tuple else (x, 0)
+
+ id_l, id_g = x_l, x_g
+
+ x_l, x_g = self.conv1((x_l, x_g), fname=fname)
+ x_l, x_g = self.conv2((x_l, x_g), fname=fname)
+
+ x_l, x_g = id_l + x_l, id_g + x_g
+ out = x_l, x_g
+ if self.inline:
+ out = torch.cat(out, dim=1)
+ return out
+
+class ConcatTupleLayer(nn.Module):
+ def forward(self, x):
+ assert isinstance(x, tuple)
+ x_l, x_g = x
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
+ if not torch.is_tensor(x_g):
+ return x_l
+ return torch.cat(x, dim=1)
diff --git a/training/losses/ade20k/__init__.py b/training/losses/ade20k/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..773cfc4664eef45a4f6fe05bd3fe2aa2143fdb5c
--- /dev/null
+++ b/training/losses/ade20k/__init__.py
@@ -0,0 +1 @@
+from .base import *
\ No newline at end of file
diff --git a/training/losses/ade20k/base.py b/training/losses/ade20k/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cdbe2d3e7dbadf4ed5e5a7cf2d248761ef25d9c
--- /dev/null
+++ b/training/losses/ade20k/base.py
@@ -0,0 +1,627 @@
+"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
+
+import os
+
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.io import loadmat
+from torch.nn.modules import BatchNorm2d
+
+from . import resnet
+from . import mobilenet
+
+
+NUM_CLASS = 150
+base_path = os.path.dirname(os.path.abspath(__file__)) # current file path
+colors_path = os.path.join(base_path, 'color150.mat')
+classes_path = os.path.join(base_path, 'object150_info.csv')
+
+segm_options = dict(colors=loadmat(colors_path)['colors'],
+ classes=pd.read_csv(classes_path),)
+
+
+class NormalizeTensor:
+ def __init__(self, mean, std, inplace=False):
+ """Normalize a tensor image with mean and standard deviation.
+ .. note::
+ This transform acts out of place by default, i.e., it does not mutates the input tensor.
+ See :class:`~torchvision.transforms.Normalize` for more details.
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+ inplace(bool,optional): Bool to make this operation inplace.
+ Returns:
+ Tensor: Normalized Tensor image.
+ """
+
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+
+ def __call__(self, tensor):
+ if not self.inplace:
+ tensor = tensor.clone()
+
+ dtype = tensor.dtype
+ mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device)
+ std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device)
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
+ return tensor
+
+
+# Model Builder
+class ModelBuilder:
+ # custom weights initialization
+ @staticmethod
+ def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.kaiming_normal_(m.weight.data)
+ elif classname.find('BatchNorm') != -1:
+ m.weight.data.fill_(1.)
+ m.bias.data.fill_(1e-4)
+
+ @staticmethod
+ def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
+ pretrained = True if len(weights) == 0 else False
+ arch = arch.lower()
+ if arch == 'mobilenetv2dilated':
+ orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
+ net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
+ elif arch == 'resnet18':
+ orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
+ net_encoder = Resnet(orig_resnet)
+ elif arch == 'resnet18dilated':
+ orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
+ elif arch == 'resnet50dilated':
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
+ elif arch == 'resnet50':
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
+ net_encoder = Resnet(orig_resnet)
+ else:
+ raise Exception('Architecture undefined!')
+
+ # encoders are usually pretrained
+ # net_encoder.apply(ModelBuilder.weights_init)
+ if len(weights) > 0:
+ print('Loading weights for net_encoder')
+ net_encoder.load_state_dict(
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
+ return net_encoder
+
+ @staticmethod
+ def build_decoder(arch='ppm_deepsup',
+ fc_dim=512, num_class=NUM_CLASS,
+ weights='', use_softmax=False, drop_last_conv=False):
+ arch = arch.lower()
+ if arch == 'ppm_deepsup':
+ net_decoder = PPMDeepsup(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ use_softmax=use_softmax,
+ drop_last_conv=drop_last_conv)
+ elif arch == 'c1_deepsup':
+ net_decoder = C1DeepSup(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ use_softmax=use_softmax,
+ drop_last_conv=drop_last_conv)
+ else:
+ raise Exception('Architecture undefined!')
+
+ net_decoder.apply(ModelBuilder.weights_init)
+ if len(weights) > 0:
+ print('Loading weights for net_decoder')
+ net_decoder.load_state_dict(
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
+ return net_decoder
+
+ @staticmethod
+ def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs):
+ path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth')
+ return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv)
+
+ @staticmethod
+ def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation,
+ *arts, **kwargs):
+ if segmentation:
+ path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth')
+ else:
+ path = ''
+ return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path)
+
+
+def conv3x3_bn_relu(in_planes, out_planes, stride=1):
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
+ BatchNorm2d(out_planes),
+ nn.ReLU(inplace=True),
+ )
+
+
+class SegmentationModule(nn.Module):
+ def __init__(self,
+ weights_path,
+ num_classes=150,
+ arch_encoder="resnet50dilated",
+ drop_last_conv=False,
+ net_enc=None, # None for Default encoder
+ net_dec=None, # None for Default decoder
+ encode=None, # {None, 'binary', 'color', 'sky'}
+ use_default_normalization=False,
+ return_feature_maps=False,
+ return_feature_maps_level=3, # {0, 1, 2, 3}
+ return_feature_maps_only=True,
+ **kwargs,
+ ):
+ super().__init__()
+ self.weights_path = weights_path
+ self.drop_last_conv = drop_last_conv
+ self.arch_encoder = arch_encoder
+ if self.arch_encoder == "resnet50dilated":
+ self.arch_decoder = "ppm_deepsup"
+ self.fc_dim = 2048
+ elif self.arch_encoder == "mobilenetv2dilated":
+ self.arch_decoder = "c1_deepsup"
+ self.fc_dim = 320
+ else:
+ raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}")
+ model_builder_kwargs = dict(arch_encoder=self.arch_encoder,
+ arch_decoder=self.arch_decoder,
+ fc_dim=self.fc_dim,
+ drop_last_conv=drop_last_conv,
+ weights_path=self.weights_path)
+
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc
+ self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec
+ self.use_default_normalization = use_default_normalization
+ self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+ self.encode = encode
+
+ self.return_feature_maps = return_feature_maps
+
+ assert 0 <= return_feature_maps_level <= 3
+ self.return_feature_maps_level = return_feature_maps_level
+
+ def normalize_input(self, tensor):
+ if tensor.min() < 0 or tensor.max() > 1:
+ raise ValueError("Tensor should be 0..1 before using normalize_input")
+ return self.default_normalization(tensor)
+
+ @property
+ def feature_maps_channels(self):
+ return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048
+
+ def forward(self, img_data, segSize=None):
+ if segSize is None:
+ raise NotImplementedError("Please pass segSize param. By default: (300, 300)")
+
+ fmaps = self.encoder(img_data, return_feature_maps=True)
+ pred = self.decoder(fmaps, segSize=segSize)
+
+ if self.return_feature_maps:
+ return pred, fmaps
+ # print("BINARY", img_data.shape, pred.shape)
+ return pred
+
+ def multi_mask_from_multiclass(self, pred, classes):
+ def isin(ar1, ar2):
+ return (ar1[..., None] == ar2).any(-1).float()
+ return isin(pred, torch.LongTensor(classes).to(self.device))
+
+ @staticmethod
+ def multi_mask_from_multiclass_probs(scores, classes):
+ res = None
+ for c in classes:
+ if res is None:
+ res = scores[:, c]
+ else:
+ res += scores[:, c]
+ return res
+
+ def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600)
+ segSize=None):
+ """Entry-point for segmentation. Use this methods instead of forward
+ Arguments:
+ tensor {torch.Tensor} -- BCHW
+ Keyword Arguments:
+ imgSizes {tuple or list} -- imgSizes for segmentation input.
+ default: (300, 450)
+ original implementation: (300, 375, 450, 525, 600)
+
+ """
+ if segSize is None:
+ segSize = tensor.shape[-2:]
+ segSize = (tensor.shape[2], tensor.shape[3])
+ with torch.no_grad():
+ if self.use_default_normalization:
+ tensor = self.normalize_input(tensor)
+ scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device)
+ features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device)
+
+ result = []
+ for img_size in imgSizes:
+ if img_size != -1:
+ img_data = F.interpolate(tensor.clone(), size=img_size)
+ else:
+ img_data = tensor.clone()
+
+ if self.return_feature_maps:
+ pred_current, fmaps = self.forward(img_data, segSize=segSize)
+ else:
+ pred_current = self.forward(img_data, segSize=segSize)
+
+
+ result.append(pred_current)
+ scores = scores + pred_current / len(imgSizes)
+
+ # Disclaimer: We use and aggregate only last fmaps: fmaps[3]
+ if self.return_feature_maps:
+ features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes)
+
+ _, pred = torch.max(scores, dim=1)
+
+ if self.return_feature_maps:
+ return features
+
+ return pred, result
+
+ def get_edges(self, t):
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
+ edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
+ edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
+ edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
+ edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
+
+ if True:
+ return edge.half()
+ return edge.float()
+
+
+# pyramid pooling, deep supervision
+class PPMDeepsup(nn.Module):
+ def __init__(self, num_class=NUM_CLASS, fc_dim=4096,
+ use_softmax=False, pool_scales=(1, 2, 3, 6),
+ drop_last_conv=False):
+ super().__init__()
+ self.use_softmax = use_softmax
+ self.drop_last_conv = drop_last_conv
+
+ self.ppm = []
+ for scale in pool_scales:
+ self.ppm.append(nn.Sequential(
+ nn.AdaptiveAvgPool2d(scale),
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
+ BatchNorm2d(512),
+ nn.ReLU(inplace=True)
+ ))
+ self.ppm = nn.ModuleList(self.ppm)
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
+
+ self.conv_last = nn.Sequential(
+ nn.Conv2d(fc_dim + len(pool_scales) * 512, 512,
+ kernel_size=3, padding=1, bias=False),
+ BatchNorm2d(512),
+ nn.ReLU(inplace=True),
+ nn.Dropout2d(0.1),
+ nn.Conv2d(512, num_class, kernel_size=1)
+ )
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+ self.dropout_deepsup = nn.Dropout2d(0.1)
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+
+ input_size = conv5.size()
+ ppm_out = [conv5]
+ for pool_scale in self.ppm:
+ ppm_out.append(nn.functional.interpolate(
+ pool_scale(conv5),
+ (input_size[2], input_size[3]),
+ mode='bilinear', align_corners=False))
+ ppm_out = torch.cat(ppm_out, 1)
+
+ if self.drop_last_conv:
+ return ppm_out
+ else:
+ x = self.conv_last(ppm_out)
+
+ if self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ x = nn.functional.softmax(x, dim=1)
+ return x
+
+ # deep sup
+ conv4 = conv_out[-2]
+ _ = self.cbr_deepsup(conv4)
+ _ = self.dropout_deepsup(_)
+ _ = self.conv_last_deepsup(_)
+
+ x = nn.functional.log_softmax(x, dim=1)
+ _ = nn.functional.log_softmax(_, dim=1)
+
+ return (x, _)
+
+
+class Resnet(nn.Module):
+ def __init__(self, orig_resnet):
+ super(Resnet, self).__init__()
+
+ # take pretrained resnet, except AvgPool and FC
+ self.conv1 = orig_resnet.conv1
+ self.bn1 = orig_resnet.bn1
+ self.relu1 = orig_resnet.relu1
+ self.conv2 = orig_resnet.conv2
+ self.bn2 = orig_resnet.bn2
+ self.relu2 = orig_resnet.relu2
+ self.conv3 = orig_resnet.conv3
+ self.bn3 = orig_resnet.bn3
+ self.relu3 = orig_resnet.relu3
+ self.maxpool = orig_resnet.maxpool
+ self.layer1 = orig_resnet.layer1
+ self.layer2 = orig_resnet.layer2
+ self.layer3 = orig_resnet.layer3
+ self.layer4 = orig_resnet.layer4
+
+ def forward(self, x, return_feature_maps=False):
+ conv_out = []
+
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x); conv_out.append(x);
+ x = self.layer2(x); conv_out.append(x);
+ x = self.layer3(x); conv_out.append(x);
+ x = self.layer4(x); conv_out.append(x);
+
+ if return_feature_maps:
+ return conv_out
+ return [x]
+
+# Resnet Dilated
+class ResnetDilated(nn.Module):
+ def __init__(self, orig_resnet, dilate_scale=8):
+ super().__init__()
+ from functools import partial
+
+ if dilate_scale == 8:
+ orig_resnet.layer3.apply(
+ partial(self._nostride_dilate, dilate=2))
+ orig_resnet.layer4.apply(
+ partial(self._nostride_dilate, dilate=4))
+ elif dilate_scale == 16:
+ orig_resnet.layer4.apply(
+ partial(self._nostride_dilate, dilate=2))
+
+ # take pretrained resnet, except AvgPool and FC
+ self.conv1 = orig_resnet.conv1
+ self.bn1 = orig_resnet.bn1
+ self.relu1 = orig_resnet.relu1
+ self.conv2 = orig_resnet.conv2
+ self.bn2 = orig_resnet.bn2
+ self.relu2 = orig_resnet.relu2
+ self.conv3 = orig_resnet.conv3
+ self.bn3 = orig_resnet.bn3
+ self.relu3 = orig_resnet.relu3
+ self.maxpool = orig_resnet.maxpool
+ self.layer1 = orig_resnet.layer1
+ self.layer2 = orig_resnet.layer2
+ self.layer3 = orig_resnet.layer3
+ self.layer4 = orig_resnet.layer4
+
+ def _nostride_dilate(self, m, dilate):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ # the convolution with stride
+ if m.stride == (2, 2):
+ m.stride = (1, 1)
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate // 2, dilate // 2)
+ m.padding = (dilate // 2, dilate // 2)
+ # other convoluions
+ else:
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate, dilate)
+ m.padding = (dilate, dilate)
+
+ def forward(self, x, return_feature_maps=False):
+ conv_out = []
+
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ conv_out.append(x)
+ x = self.layer2(x)
+ conv_out.append(x)
+ x = self.layer3(x)
+ conv_out.append(x)
+ x = self.layer4(x)
+ conv_out.append(x)
+
+ if return_feature_maps:
+ return conv_out
+ return [x]
+
+class MobileNetV2Dilated(nn.Module):
+ def __init__(self, orig_net, dilate_scale=8):
+ super(MobileNetV2Dilated, self).__init__()
+ from functools import partial
+
+ # take pretrained mobilenet features
+ self.features = orig_net.features[:-1]
+
+ self.total_idx = len(self.features)
+ self.down_idx = [2, 4, 7, 14]
+
+ if dilate_scale == 8:
+ for i in range(self.down_idx[-2], self.down_idx[-1]):
+ self.features[i].apply(
+ partial(self._nostride_dilate, dilate=2)
+ )
+ for i in range(self.down_idx[-1], self.total_idx):
+ self.features[i].apply(
+ partial(self._nostride_dilate, dilate=4)
+ )
+ elif dilate_scale == 16:
+ for i in range(self.down_idx[-1], self.total_idx):
+ self.features[i].apply(
+ partial(self._nostride_dilate, dilate=2)
+ )
+
+ def _nostride_dilate(self, m, dilate):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ # the convolution with stride
+ if m.stride == (2, 2):
+ m.stride = (1, 1)
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate//2, dilate//2)
+ m.padding = (dilate//2, dilate//2)
+ # other convoluions
+ else:
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate, dilate)
+ m.padding = (dilate, dilate)
+
+ def forward(self, x, return_feature_maps=False):
+ if return_feature_maps:
+ conv_out = []
+ for i in range(self.total_idx):
+ x = self.features[i](x)
+ if i in self.down_idx:
+ conv_out.append(x)
+ conv_out.append(x)
+ return conv_out
+
+ else:
+ return [self.features(x)]
+
+
+# last conv, deep supervision
+class C1DeepSup(nn.Module):
+ def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False):
+ super(C1DeepSup, self).__init__()
+ self.use_softmax = use_softmax
+ self.drop_last_conv = drop_last_conv
+
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
+
+ # last conv
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+
+ x = self.cbr(conv5)
+
+ if self.drop_last_conv:
+ return x
+ else:
+ x = self.conv_last(x)
+
+ if self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ x = nn.functional.softmax(x, dim=1)
+ return x
+
+ # deep sup
+ conv4 = conv_out[-2]
+ _ = self.cbr_deepsup(conv4)
+ _ = self.conv_last_deepsup(_)
+
+ x = nn.functional.log_softmax(x, dim=1)
+ _ = nn.functional.log_softmax(_, dim=1)
+
+ return (x, _)
+
+
+# last conv
+class C1(nn.Module):
+ def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
+ super(C1, self).__init__()
+ self.use_softmax = use_softmax
+
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
+
+ # last conv
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+ x = self.cbr(conv5)
+ x = self.conv_last(x)
+
+ if self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ x = nn.functional.softmax(x, dim=1)
+ else:
+ x = nn.functional.log_softmax(x, dim=1)
+
+ return x
+
+
+# pyramid pooling
+class PPM(nn.Module):
+ def __init__(self, num_class=150, fc_dim=4096,
+ use_softmax=False, pool_scales=(1, 2, 3, 6)):
+ super(PPM, self).__init__()
+ self.use_softmax = use_softmax
+
+ self.ppm = []
+ for scale in pool_scales:
+ self.ppm.append(nn.Sequential(
+ nn.AdaptiveAvgPool2d(scale),
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
+ BatchNorm2d(512),
+ nn.ReLU(inplace=True)
+ ))
+ self.ppm = nn.ModuleList(self.ppm)
+
+ self.conv_last = nn.Sequential(
+ nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
+ kernel_size=3, padding=1, bias=False),
+ BatchNorm2d(512),
+ nn.ReLU(inplace=True),
+ nn.Dropout2d(0.1),
+ nn.Conv2d(512, num_class, kernel_size=1)
+ )
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+
+ input_size = conv5.size()
+ ppm_out = [conv5]
+ for pool_scale in self.ppm:
+ ppm_out.append(nn.functional.interpolate(
+ pool_scale(conv5),
+ (input_size[2], input_size[3]),
+ mode='bilinear', align_corners=False))
+ ppm_out = torch.cat(ppm_out, 1)
+
+ x = self.conv_last(ppm_out)
+
+ if self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ x = nn.functional.softmax(x, dim=1)
+ else:
+ x = nn.functional.log_softmax(x, dim=1)
+ return x
diff --git a/training/losses/ade20k/color150.mat b/training/losses/ade20k/color150.mat
new file mode 100644
index 0000000000000000000000000000000000000000..c518b64fbbe899d4a8b2705f012eeba795339892
Binary files /dev/null and b/training/losses/ade20k/color150.mat differ
diff --git a/training/losses/ade20k/mobilenet.py b/training/losses/ade20k/mobilenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f501266e56ee71cdf455744020f8fc1a58ec9fff
--- /dev/null
+++ b/training/losses/ade20k/mobilenet.py
@@ -0,0 +1,154 @@
+"""
+This MobileNetV2 implementation is modified from the following repository:
+https://github.com/tonylins/pytorch-mobilenet-v2
+"""
+
+import torch.nn as nn
+import math
+from .utils import load_url
+from .segm_lib.nn import SynchronizedBatchNorm2d
+
+BatchNorm2d = SynchronizedBatchNorm2d
+
+
+__all__ = ['mobilenetv2']
+
+
+model_urls = {
+ 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
+}
+
+
+def conv_bn(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+def conv_1x1_bn(inp, oup):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = round(inp * expand_ratio)
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ if expand_ratio == 1:
+ self.conv = nn.Sequential(
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ BatchNorm2d(oup),
+ )
+ else:
+ self.conv = nn.Sequential(
+ # pw
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+ BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ BatchNorm2d(oup),
+ )
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
+ super(MobileNetV2, self).__init__()
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ interverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ [6, 320, 1, 1],
+ ]
+
+ # building first layer
+ assert input_size % 32 == 0
+ input_channel = int(input_channel * width_mult)
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
+ self.features = [conv_bn(3, input_channel, 2)]
+ # building inverted residual blocks
+ for t, c, n, s in interverted_residual_setting:
+ output_channel = int(c * width_mult)
+ for i in range(n):
+ if i == 0:
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
+ else:
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
+ input_channel = output_channel
+ # building last several layers
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
+ # make it nn.Sequential
+ self.features = nn.Sequential(*self.features)
+
+ # building classifier
+ self.classifier = nn.Sequential(
+ nn.Dropout(0.2),
+ nn.Linear(self.last_channel, n_class),
+ )
+
+ self._initialize_weights()
+
+ def forward(self, x):
+ x = self.features(x)
+ x = x.mean(3).mean(2)
+ x = self.classifier(x)
+ return x
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ n = m.weight.size(1)
+ m.weight.data.normal_(0, 0.01)
+ m.bias.data.zero_()
+
+
+def mobilenetv2(pretrained=False, **kwargs):
+ """Constructs a MobileNet_V2 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = MobileNetV2(n_class=1000, **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
+ return model
\ No newline at end of file
diff --git a/training/losses/ade20k/object150_info.csv b/training/losses/ade20k/object150_info.csv
new file mode 100644
index 0000000000000000000000000000000000000000..8b34d8f3874a38b96894863c5458a7c3c2b0e2e6
--- /dev/null
+++ b/training/losses/ade20k/object150_info.csv
@@ -0,0 +1,151 @@
+Idx,Ratio,Train,Val,Stuff,Name
+1,0.1576,11664,1172,1,wall
+2,0.1072,6046,612,1,building;edifice
+3,0.0878,8265,796,1,sky
+4,0.0621,9336,917,1,floor;flooring
+5,0.0480,6678,641,0,tree
+6,0.0450,6604,643,1,ceiling
+7,0.0398,4023,408,1,road;route
+8,0.0231,1906,199,0,bed
+9,0.0198,4688,460,0,windowpane;window
+10,0.0183,2423,225,1,grass
+11,0.0181,2874,294,0,cabinet
+12,0.0166,3068,310,1,sidewalk;pavement
+13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
+14,0.0151,1804,190,1,earth;ground
+15,0.0118,6666,796,0,door;double;door
+16,0.0110,4269,411,0,table
+17,0.0109,1691,160,1,mountain;mount
+18,0.0104,3999,441,0,plant;flora;plant;life
+19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
+20,0.0103,3261,318,0,chair
+21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
+22,0.0074,709,75,1,water
+23,0.0067,3296,315,0,painting;picture
+24,0.0065,1191,106,0,sofa;couch;lounge
+25,0.0061,1516,162,0,shelf
+26,0.0060,667,69,1,house
+27,0.0053,651,57,1,sea
+28,0.0052,1847,224,0,mirror
+29,0.0046,1158,128,1,rug;carpet;carpeting
+30,0.0044,480,44,1,field
+31,0.0044,1172,98,0,armchair
+32,0.0044,1292,184,0,seat
+33,0.0033,1386,138,0,fence;fencing
+34,0.0031,698,61,0,desk
+35,0.0030,781,73,0,rock;stone
+36,0.0027,380,43,0,wardrobe;closet;press
+37,0.0026,3089,302,0,lamp
+38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
+39,0.0024,804,99,0,railing;rail
+40,0.0023,1453,153,0,cushion
+41,0.0023,411,37,0,base;pedestal;stand
+42,0.0022,1440,162,0,box
+43,0.0022,800,77,0,column;pillar
+44,0.0020,2650,298,0,signboard;sign
+45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
+46,0.0019,367,36,0,counter
+47,0.0018,311,30,1,sand
+48,0.0018,1181,122,0,sink
+49,0.0018,287,23,1,skyscraper
+50,0.0018,468,38,0,fireplace;hearth;open;fireplace
+51,0.0018,402,43,0,refrigerator;icebox
+52,0.0018,130,12,1,grandstand;covered;stand
+53,0.0018,561,64,1,path
+54,0.0017,880,102,0,stairs;steps
+55,0.0017,86,12,1,runway
+56,0.0017,172,11,0,case;display;case;showcase;vitrine
+57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
+58,0.0017,930,109,0,pillow
+59,0.0015,139,18,0,screen;door;screen
+60,0.0015,564,52,1,stairway;staircase
+61,0.0015,320,26,1,river
+62,0.0015,261,29,1,bridge;span
+63,0.0014,275,22,0,bookcase
+64,0.0014,335,60,0,blind;screen
+65,0.0014,792,75,0,coffee;table;cocktail;table
+66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
+67,0.0014,1309,138,0,flower
+68,0.0013,1112,113,0,book
+69,0.0013,266,27,1,hill
+70,0.0013,659,66,0,bench
+71,0.0012,331,31,0,countertop
+72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
+73,0.0012,369,36,0,palm;palm;tree
+74,0.0012,144,9,0,kitchen;island
+75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
+76,0.0010,324,33,0,swivel;chair
+77,0.0009,304,27,0,boat
+78,0.0009,170,20,0,bar
+79,0.0009,68,6,0,arcade;machine
+80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
+81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
+82,0.0008,492,49,0,towel
+83,0.0008,2510,269,0,light;light;source
+84,0.0008,440,39,0,truck;motortruck
+85,0.0008,147,18,1,tower
+86,0.0008,583,56,0,chandelier;pendant;pendent
+87,0.0007,533,61,0,awning;sunshade;sunblind
+88,0.0007,1989,239,0,streetlight;street;lamp
+89,0.0007,71,5,0,booth;cubicle;stall;kiosk
+90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
+91,0.0007,135,12,0,airplane;aeroplane;plane
+92,0.0007,83,5,1,dirt;track
+93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
+94,0.0006,1003,104,0,pole
+95,0.0006,182,12,1,land;ground;soil
+96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
+97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
+98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
+99,0.0006,965,114,0,bottle
+100,0.0006,117,13,0,buffet;counter;sideboard
+101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
+102,0.0006,108,9,1,stage
+103,0.0006,557,55,0,van
+104,0.0006,52,4,0,ship
+105,0.0005,99,5,0,fountain
+106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
+107,0.0005,292,31,0,canopy
+108,0.0005,77,9,0,washer;automatic;washer;washing;machine
+109,0.0005,340,38,0,plaything;toy
+110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
+111,0.0005,465,49,0,stool
+112,0.0005,50,4,0,barrel;cask
+113,0.0005,622,75,0,basket;handbasket
+114,0.0005,80,9,1,waterfall;falls
+115,0.0005,59,3,0,tent;collapsible;shelter
+116,0.0005,531,72,0,bag
+117,0.0005,282,30,0,minibike;motorbike
+118,0.0005,73,7,0,cradle
+119,0.0005,435,44,0,oven
+120,0.0005,136,25,0,ball
+121,0.0005,116,24,0,food;solid;food
+122,0.0004,266,31,0,step;stair
+123,0.0004,58,12,0,tank;storage;tank
+124,0.0004,418,83,0,trade;name;brand;name;brand;marque
+125,0.0004,319,43,0,microwave;microwave;oven
+126,0.0004,1193,139,0,pot;flowerpot
+127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
+128,0.0004,347,36,0,bicycle;bike;wheel;cycle
+129,0.0004,52,5,1,lake
+130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
+131,0.0004,108,13,0,screen;silver;screen;projection;screen
+132,0.0004,201,30,0,blanket;cover
+133,0.0004,285,21,0,sculpture
+134,0.0004,268,27,0,hood;exhaust;hood
+135,0.0003,1020,108,0,sconce
+136,0.0003,1282,122,0,vase
+137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
+138,0.0003,453,57,0,tray
+139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
+140,0.0003,397,44,0,fan
+141,0.0003,92,8,1,pier;wharf;wharfage;dock
+142,0.0003,228,18,0,crt;screen
+143,0.0003,570,59,0,plate
+144,0.0003,217,22,0,monitor;monitoring;device
+145,0.0003,206,19,0,bulletin;board;notice;board
+146,0.0003,130,14,0,shower
+147,0.0003,178,28,0,radiator
+148,0.0002,504,57,0,glass;drinking;glass
+149,0.0002,775,96,0,clock
+150,0.0002,421,56,0,flag
diff --git a/training/losses/ade20k/resnet.py b/training/losses/ade20k/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e1d521f171c984cf6a7ff3dcebd96f8c5faf908
--- /dev/null
+++ b/training/losses/ade20k/resnet.py
@@ -0,0 +1,181 @@
+"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
+
+import math
+
+import torch.nn as nn
+from torch.nn import BatchNorm2d
+
+from .utils import load_url
+
+__all__ = ['ResNet', 'resnet50']
+
+
+model_urls = {
+ 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000):
+ self.inplanes = 128
+ super(ResNet, self).__init__()
+ self.conv1 = conv3x3(3, 64, stride=2)
+ self.bn1 = BatchNorm2d(64)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(64, 64)
+ self.bn2 = BatchNorm2d(64)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = conv3x3(64, 128)
+ self.bn3 = BatchNorm2d(128)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+def resnet50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
+ return model
+
+
+def resnet18(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet18']))
+ return model
\ No newline at end of file
diff --git a/training/losses/ade20k/segm_lib.zip b/training/losses/ade20k/segm_lib.zip
new file mode 100644
index 0000000000000000000000000000000000000000..8743b4225eb3e6910f101ba27b1ba461adf55a77
--- /dev/null
+++ b/training/losses/ade20k/segm_lib.zip
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe01e5fde56d2adf7690c11eaed3a306f87af69314de6278dd814b770f8f666f
+size 32158
diff --git a/training/losses/ade20k/segm_lib/nn/__init__.py b/training/losses/ade20k/segm_lib/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98a96370ef04570f516052bb73f568d0ebc346c3
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/__init__.py
@@ -0,0 +1,2 @@
+from .modules import *
+from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
diff --git a/training/losses/ade20k/segm_lib/nn/modules/__init__.py b/training/losses/ade20k/segm_lib/nn/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/modules/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# File : __init__.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
+from .replicate import DataParallelWithCallback, patch_replication_callback
diff --git a/training/losses/ade20k/segm_lib/nn/modules/batchnorm.py b/training/losses/ade20k/segm_lib/nn/modules/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..18318965335b37cc671004a6aceda3229dc7b477
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/modules/batchnorm.py
@@ -0,0 +1,329 @@
+# -*- coding: utf-8 -*-
+# File : batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import collections
+
+import torch
+import torch.nn.functional as F
+
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
+
+from .comm import SyncMaster
+
+__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
+
+
+def _sum_ft(tensor):
+ """sum over the first and last dimention"""
+ return tensor.sum(dim=0).sum(dim=-1)
+
+
+def _unsqueeze_ft(tensor):
+ """add new dementions at the front and the tail"""
+ return tensor.unsqueeze(0).unsqueeze(-1)
+
+
+_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
+
+
+class _SynchronizedBatchNorm(_BatchNorm):
+ def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+
+ self._sync_master = SyncMaster(self._data_parallel_master)
+
+ self._is_parallel = False
+ self._parallel_id = None
+ self._slave_pipe = None
+
+ # customed batch norm statistics
+ self._moving_average_fraction = 1. - momentum
+ self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
+ self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
+ self.register_buffer('_running_iter', torch.ones(1))
+ self._tmp_running_mean = self.running_mean.clone() * self._running_iter
+ self._tmp_running_var = self.running_var.clone() * self._running_iter
+
+ def forward(self, input):
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+ if not (self._is_parallel and self.training):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
+ # Resize the input to (B, C, -1).
+ input_shape = input.size()
+ input = input.view(input.size(0), self.num_features, -1)
+
+ # Compute the sum and square-sum.
+ sum_size = input.size(0) * input.size(2)
+ input_sum = _sum_ft(input)
+ input_ssum = _sum_ft(input ** 2)
+
+ # Reduce-and-broadcast the statistics.
+ if self._parallel_id == 0:
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
+ else:
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
+
+ # Compute the output.
+ if self.affine:
+ # MJY:: Fuse the multiplication for speed.
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
+ else:
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
+
+ # Reshape it.
+ return output.view(input_shape)
+
+ def __data_parallel_replicate__(self, ctx, copy_id):
+ self._is_parallel = True
+ self._parallel_id = copy_id
+
+ # parallel_id == 0 means master device.
+ if self._parallel_id == 0:
+ ctx.sync_master = self._sync_master
+ else:
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+ def _data_parallel_master(self, intermediates):
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+ to_reduce = [i[1][:2] for i in intermediates]
+ to_reduce = [j for i in to_reduce for j in i] # flatten
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+ sum_size = sum([i[1].sum_size for i in intermediates])
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+ outputs = []
+ for i, rec in enumerate(intermediates):
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+ return outputs
+
+ def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
+ """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
+ return dest * alpha + delta * beta + bias
+
+ def _compute_mean_std(self, sum_, ssum, size):
+ """Compute the mean and standard-deviation with sum and square-sum. This method
+ also maintains the moving average on the master device."""
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+ mean = sum_ / size
+ sumvar = ssum - sum_ * mean
+ unbias_var = sumvar / (size - 1)
+ bias_var = sumvar / size
+
+ self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
+ self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
+ self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
+
+ self.running_mean = self._tmp_running_mean / self._running_iter
+ self.running_var = self._tmp_running_var / self._running_iter
+
+ return mean, bias_var.clamp(self.eps) ** -0.5
+
+
+class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
+ mini-batch.
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of size
+ `batch_size x num_features [x width]`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 2 and input.dim() != 3:
+ raise ValueError('expected 2D or 3D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
+ of 3d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
+ of 4d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
+ or Spatio-temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x depth x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
diff --git a/training/losses/ade20k/segm_lib/nn/modules/comm.py b/training/losses/ade20k/segm_lib/nn/modules/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b64bf6ba3b3e7abbab375c6dd4a87d8239e62138
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/modules/comm.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# File : comm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+ def __init__(self):
+ self._result = None
+ self._lock = threading.Lock()
+ self._cond = threading.Condition(self._lock)
+
+ def put(self, result):
+ with self._lock:
+ assert self._result is None, 'Previous result has\'t been fetched.'
+ self._result = result
+ self._cond.notify()
+
+ def get(self):
+ with self._lock:
+ if self._result is None:
+ self._cond.wait()
+
+ res = self._result
+ self._result = None
+ return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+ """Pipe for master-slave communication."""
+
+ def run_slave(self, msg):
+ self.queue.put((self.identifier, msg))
+ ret = self.result.get()
+ self.queue.put(True)
+ return ret
+
+
+class SyncMaster(object):
+ """An abstract `SyncMaster` object.
+
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+ and passed to a registered callback.
+ - After receiving the messages, the master device should gather the information and determine to message passed
+ back to each slave devices.
+ """
+
+ def __init__(self, master_callback):
+ """
+
+ Args:
+ master_callback: a callback to be invoked after having collected messages from slave devices.
+ """
+ self._master_callback = master_callback
+ self._queue = queue.Queue()
+ self._registry = collections.OrderedDict()
+ self._activated = False
+
+ def register_slave(self, identifier):
+ """
+ Register an slave device.
+
+ Args:
+ identifier: an identifier, usually is the device id.
+
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+ """
+ if self._activated:
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
+ self._activated = False
+ self._registry.clear()
+ future = FutureResult()
+ self._registry[identifier] = _MasterRegistry(future)
+ return SlavePipe(identifier, self._queue, future)
+
+ def run_master(self, master_msg):
+ """
+ Main entry for the master device in each forward pass.
+ The messages were first collected from each devices (including the master device), and then
+ an callback will be invoked to compute the message to be sent back to each devices
+ (including the master device).
+
+ Args:
+ master_msg: the message that the master want to send to itself. This will be placed as the first
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+ Returns: the message to be sent back to the master device.
+
+ """
+ self._activated = True
+
+ intermediates = [(0, master_msg)]
+ for i in range(self.nr_slaves):
+ intermediates.append(self._queue.get())
+
+ results = self._master_callback(intermediates)
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+ for i, res in results:
+ if i == 0:
+ continue
+ self._registry[i].result.put(res)
+
+ for i in range(self.nr_slaves):
+ assert self._queue.get() is True
+
+ return results[0][1]
+
+ @property
+ def nr_slaves(self):
+ return len(self._registry)
diff --git a/training/losses/ade20k/segm_lib/nn/modules/replicate.py b/training/losses/ade20k/segm_lib/nn/modules/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/modules/replicate.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# File : replicate.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import functools
+
+from torch.nn.parallel.data_parallel import DataParallel
+
+__all__ = [
+ 'CallbackContext',
+ 'execute_replication_callbacks',
+ 'DataParallelWithCallback',
+ 'patch_replication_callback'
+]
+
+
+class CallbackContext(object):
+ pass
+
+
+def execute_replication_callbacks(modules):
+ """
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
+
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
+ (shared among multiple copies of this module on different devices).
+ Through this context, different copies can share some information.
+
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
+ of any slave copies.
+ """
+ master_copy = modules[0]
+ nr_modules = len(list(master_copy.modules()))
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+ for i, module in enumerate(modules):
+ for j, m in enumerate(module.modules()):
+ if hasattr(m, '__data_parallel_replicate__'):
+ m.__data_parallel_replicate__(ctxs[j], i)
+
+
+class DataParallelWithCallback(DataParallel):
+ """
+ Data Parallel with a replication callback.
+
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
+ original `replicate` function.
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ # sync_bn.__data_parallel_replicate__ will be invoked.
+ """
+
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+
+def patch_replication_callback(data_parallel):
+ """
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
+ Useful when you have customized `DataParallel` implementation.
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+ > patch_replication_callback(sync_bn)
+ # this is equivalent to
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ """
+
+ assert isinstance(data_parallel, DataParallel)
+
+ old_replicate = data_parallel.replicate
+
+ @functools.wraps(old_replicate)
+ def new_replicate(module, device_ids):
+ modules = old_replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+ data_parallel.replicate = new_replicate
diff --git a/training/losses/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py b/training/losses/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bd45a930d3dc84912e58659ee575be08e9038f0
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# File : test_numeric_batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+
+import unittest
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+
+from sync_batchnorm.unittest import TorchTestCase
+
+
+def handy_var(a, unbias=True):
+ n = a.size(0)
+ asum = a.sum(dim=0)
+ as_sum = (a ** 2).sum(dim=0) # a square sum
+ sumvar = as_sum - asum * asum / n
+ if unbias:
+ return sumvar / (n - 1)
+ else:
+ return sumvar / n
+
+
+class NumericTestCase(TorchTestCase):
+ def testNumericBatchNorm(self):
+ a = torch.rand(16, 10)
+ bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
+ bn.train()
+
+ a_var1 = Variable(a, requires_grad=True)
+ b_var1 = bn(a_var1)
+ loss1 = b_var1.sum()
+ loss1.backward()
+
+ a_var2 = Variable(a, requires_grad=True)
+ a_mean2 = a_var2.mean(dim=0, keepdim=True)
+ a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
+ # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
+ b_var2 = (a_var2 - a_mean2) / a_std2
+ loss2 = b_var2.sum()
+ loss2.backward()
+
+ self.assertTensorClose(bn.running_mean, a.mean(dim=0))
+ self.assertTensorClose(bn.running_var, handy_var(a))
+ self.assertTensorClose(a_var1.data, a_var2.data)
+ self.assertTensorClose(b_var1.data, b_var2.data)
+ self.assertTensorClose(a_var1.grad, a_var2.grad)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/training/losses/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py b/training/losses/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..45bb3c8cfd36d8f668e6fde756b17587eab72082
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+# File : test_sync_batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+
+import unittest
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+
+from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
+from sync_batchnorm.unittest import TorchTestCase
+
+
+def handy_var(a, unbias=True):
+ n = a.size(0)
+ asum = a.sum(dim=0)
+ as_sum = (a ** 2).sum(dim=0) # a square sum
+ sumvar = as_sum - asum * asum / n
+ if unbias:
+ return sumvar / (n - 1)
+ else:
+ return sumvar / n
+
+
+def _find_bn(module):
+ for m in module.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
+ return m
+
+
+class SyncTestCase(TorchTestCase):
+ def _syncParameters(self, bn1, bn2):
+ bn1.reset_parameters()
+ bn2.reset_parameters()
+ if bn1.affine and bn2.affine:
+ bn2.weight.data.copy_(bn1.weight.data)
+ bn2.bias.data.copy_(bn1.bias.data)
+
+ def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
+ """Check the forward and backward for the customized batch normalization."""
+ bn1.train(mode=is_train)
+ bn2.train(mode=is_train)
+
+ if cuda:
+ input = input.cuda()
+
+ self._syncParameters(_find_bn(bn1), _find_bn(bn2))
+
+ input1 = Variable(input, requires_grad=True)
+ output1 = bn1(input1)
+ output1.sum().backward()
+ input2 = Variable(input, requires_grad=True)
+ output2 = bn2(input2)
+ output2.sum().backward()
+
+ self.assertTensorClose(input1.data, input2.data)
+ self.assertTensorClose(output1.data, output2.data)
+ self.assertTensorClose(input1.grad, input2.grad)
+ self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
+ self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
+
+ def testSyncBatchNormNormalTrain(self):
+ bn = nn.BatchNorm1d(10)
+ sync_bn = SynchronizedBatchNorm1d(10)
+
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
+
+ def testSyncBatchNormNormalEval(self):
+ bn = nn.BatchNorm1d(10)
+ sync_bn = SynchronizedBatchNorm1d(10)
+
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
+
+ def testSyncBatchNormSyncTrain(self):
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+
+ bn.cuda()
+ sync_bn.cuda()
+
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
+
+ def testSyncBatchNormSyncEval(self):
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+
+ bn.cuda()
+ sync_bn.cuda()
+
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
+
+ def testSyncBatchNorm2DSyncTrain(self):
+ bn = nn.BatchNorm2d(10)
+ sync_bn = SynchronizedBatchNorm2d(10)
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+
+ bn.cuda()
+ sync_bn.cuda()
+
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/training/losses/ade20k/segm_lib/nn/modules/unittest.py b/training/losses/ade20k/segm_lib/nn/modules/unittest.py
new file mode 100644
index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/modules/unittest.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# File : unittest.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import unittest
+
+import numpy as np
+from torch.autograd import Variable
+
+
+def as_numpy(v):
+ if isinstance(v, Variable):
+ v = v.data
+ return v.cpu().numpy()
+
+
+class TorchTestCase(unittest.TestCase):
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
+ npa, npb = as_numpy(a), as_numpy(b)
+ self.assertTrue(
+ np.allclose(npa, npb, atol=atol),
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
+ )
diff --git a/training/losses/ade20k/segm_lib/nn/parallel/__init__.py b/training/losses/ade20k/segm_lib/nn/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b52f49cc0755562218a460483cbf02514ddd773
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/parallel/__init__.py
@@ -0,0 +1 @@
+from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
diff --git a/training/losses/ade20k/segm_lib/nn/parallel/data_parallel.py b/training/losses/ade20k/segm_lib/nn/parallel/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..376fc038919aa2a5bd696141e7bb6025d4981306
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/nn/parallel/data_parallel.py
@@ -0,0 +1,112 @@
+# -*- coding: utf8 -*-
+
+import torch.cuda as cuda
+import torch.nn as nn
+import torch
+import collections
+from torch.nn.parallel._functions import Gather
+
+
+__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
+
+
+def async_copy_to(obj, dev, main_stream=None):
+ if torch.is_tensor(obj):
+ v = obj.cuda(dev, non_blocking=True)
+ if main_stream is not None:
+ v.data.record_stream(main_stream)
+ return v
+ elif isinstance(obj, collections.Mapping):
+ return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
+ elif isinstance(obj, collections.Sequence):
+ return [async_copy_to(o, dev, main_stream) for o in obj]
+ else:
+ return obj
+
+
+def dict_gather(outputs, target_device, dim=0):
+ """
+ Gathers variables from different GPUs on a specified device
+ (-1 means the CPU), with dictionary support.
+ """
+ def gather_map(outputs):
+ out = outputs[0]
+ if torch.is_tensor(out):
+ # MJY(20180330) HACK:: force nr_dims > 0
+ if out.dim() == 0:
+ outputs = [o.unsqueeze(0) for o in outputs]
+ return Gather.apply(target_device, dim, *outputs)
+ elif out is None:
+ return None
+ elif isinstance(out, collections.Mapping):
+ return {k: gather_map([o[k] for o in outputs]) for k in out}
+ elif isinstance(out, collections.Sequence):
+ return type(out)(map(gather_map, zip(*outputs)))
+ return gather_map(outputs)
+
+
+class DictGatherDataParallel(nn.DataParallel):
+ def gather(self, outputs, output_device):
+ return dict_gather(outputs, output_device, dim=self.dim)
+
+
+class UserScatteredDataParallel(DictGatherDataParallel):
+ def scatter(self, inputs, kwargs, device_ids):
+ assert len(inputs) == 1
+ inputs = inputs[0]
+ inputs = _async_copy_stream(inputs, device_ids)
+ inputs = [[i] for i in inputs]
+ assert len(kwargs) == 0
+ kwargs = [{} for _ in range(len(inputs))]
+
+ return inputs, kwargs
+
+
+def user_scattered_collate(batch):
+ return batch
+
+
+def _async_copy(inputs, device_ids):
+ nr_devs = len(device_ids)
+ assert type(inputs) in (tuple, list)
+ assert len(inputs) == nr_devs
+
+ outputs = []
+ for i, dev in zip(inputs, device_ids):
+ with cuda.device(dev):
+ outputs.append(async_copy_to(i, dev))
+
+ return tuple(outputs)
+
+
+def _async_copy_stream(inputs, device_ids):
+ nr_devs = len(device_ids)
+ assert type(inputs) in (tuple, list)
+ assert len(inputs) == nr_devs
+
+ outputs = []
+ streams = [_get_stream(d) for d in device_ids]
+ for i, dev, stream in zip(inputs, device_ids, streams):
+ with cuda.device(dev):
+ main_stream = cuda.current_stream()
+ with cuda.stream(stream):
+ outputs.append(async_copy_to(i, dev, main_stream=main_stream))
+ main_stream.wait_stream(stream)
+
+ return outputs
+
+
+"""Adapted from: torch/nn/parallel/_functions.py"""
+# background streams used for copying
+_streams = None
+
+
+def _get_stream(device):
+ """Gets a background stream for copying between CPU and GPU"""
+ global _streams
+ if device == -1:
+ return None
+ if _streams is None:
+ _streams = [None] * cuda.device_count()
+ if _streams[device] is None: _streams[device] = cuda.Stream(device)
+ return _streams[device]
diff --git a/training/losses/ade20k/segm_lib/utils/__init__.py b/training/losses/ade20k/segm_lib/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abe3cbe49477fe37d4fc16249de8a10f4fb4a013
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/utils/__init__.py
@@ -0,0 +1 @@
+from .th import *
diff --git a/training/losses/ade20k/segm_lib/utils/data/__init__.py b/training/losses/ade20k/segm_lib/utils/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3b008fb13c5e8a84b1b785056e8c4f5226dc976
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/utils/data/__init__.py
@@ -0,0 +1,3 @@
+
+from .dataset import Dataset, TensorDataset, ConcatDataset
+from .dataloader import DataLoader
diff --git a/training/losses/ade20k/segm_lib/utils/data/dataloader.py b/training/losses/ade20k/segm_lib/utils/data/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..039b9ec3645b2a4626ff47c221e372f32a6ad339
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/utils/data/dataloader.py
@@ -0,0 +1,425 @@
+import torch
+import torch.multiprocessing as multiprocessing
+from torch._C import _set_worker_signal_handlers, \
+ _remove_worker_pids, _error_if_any_worker_fails
+try:
+ from torch._C import _set_worker_pids
+except:
+ from torch._C import _update_worker_pids as _set_worker_pids
+from .sampler import SequentialSampler, RandomSampler, BatchSampler
+import signal
+import collections
+import re
+import sys
+import threading
+import traceback
+from torch._six import string_classes, int_classes
+import numpy as np
+
+if sys.version_info[0] == 2:
+ import Queue as queue
+else:
+ import queue
+
+
+class ExceptionWrapper(object):
+ r"Wraps an exception plus traceback to communicate across threads"
+
+ def __init__(self, exc_info):
+ self.exc_type = exc_info[0]
+ self.exc_msg = "".join(traceback.format_exception(*exc_info))
+
+
+_use_shared_memory = False
+"""Whether to use shared memory in default_collate"""
+
+
+def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
+ global _use_shared_memory
+ _use_shared_memory = True
+
+ # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
+ # module's handlers are executed after Python returns from C low-level
+ # handlers, likely when the same fatal signal happened again already.
+ # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
+ _set_worker_signal_handlers()
+
+ torch.set_num_threads(1)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ if init_fn is not None:
+ init_fn(worker_id)
+
+ while True:
+ r = index_queue.get()
+ if r is None:
+ break
+ idx, batch_indices = r
+ try:
+ samples = collate_fn([dataset[i] for i in batch_indices])
+ except Exception:
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+ else:
+ data_queue.put((idx, samples))
+
+
+def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
+ if pin_memory:
+ torch.cuda.set_device(device_id)
+
+ while True:
+ try:
+ r = in_queue.get()
+ except Exception:
+ if done_event.is_set():
+ return
+ raise
+ if r is None:
+ break
+ if isinstance(r[1], ExceptionWrapper):
+ out_queue.put(r)
+ continue
+ idx, batch = r
+ try:
+ if pin_memory:
+ batch = pin_memory_batch(batch)
+ except Exception:
+ out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+ else:
+ out_queue.put((idx, batch))
+
+numpy_type_map = {
+ 'float64': torch.DoubleTensor,
+ 'float32': torch.FloatTensor,
+ 'float16': torch.HalfTensor,
+ 'int64': torch.LongTensor,
+ 'int32': torch.IntTensor,
+ 'int16': torch.ShortTensor,
+ 'int8': torch.CharTensor,
+ 'uint8': torch.ByteTensor,
+}
+
+
+def default_collate(batch):
+ "Puts each data field into a tensor with outer dimension batch size"
+
+ error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
+ elem_type = type(batch[0])
+ if torch.is_tensor(batch[0]):
+ out = None
+ if _use_shared_memory:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = batch[0].storage()._new_shared(numel)
+ out = batch[0].new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ elem = batch[0]
+ if elem_type.__name__ == 'ndarray':
+ # array of string classes and object
+ if re.search('[SaUO]', elem.dtype.str) is not None:
+ raise TypeError(error_msg.format(elem.dtype))
+
+ return torch.stack([torch.from_numpy(b) for b in batch], 0)
+ if elem.shape == (): # scalars
+ py_type = float if elem.dtype.name.startswith('float') else int
+ return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
+ elif isinstance(batch[0], int_classes):
+ return torch.LongTensor(batch)
+ elif isinstance(batch[0], float):
+ return torch.DoubleTensor(batch)
+ elif isinstance(batch[0], string_classes):
+ return batch
+ elif isinstance(batch[0], collections.Mapping):
+ return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
+ elif isinstance(batch[0], collections.Sequence):
+ transposed = zip(*batch)
+ return [default_collate(samples) for samples in transposed]
+
+ raise TypeError((error_msg.format(type(batch[0]))))
+
+
+def pin_memory_batch(batch):
+ if torch.is_tensor(batch):
+ return batch.pin_memory()
+ elif isinstance(batch, string_classes):
+ return batch
+ elif isinstance(batch, collections.Mapping):
+ return {k: pin_memory_batch(sample) for k, sample in batch.items()}
+ elif isinstance(batch, collections.Sequence):
+ return [pin_memory_batch(sample) for sample in batch]
+ else:
+ return batch
+
+
+_SIGCHLD_handler_set = False
+"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
+handler needs to be set for all DataLoaders in a process."""
+
+
+def _set_SIGCHLD_handler():
+ # Windows doesn't support SIGCHLD handler
+ if sys.platform == 'win32':
+ return
+ # can't set signal in child threads
+ if not isinstance(threading.current_thread(), threading._MainThread):
+ return
+ global _SIGCHLD_handler_set
+ if _SIGCHLD_handler_set:
+ return
+ previous_handler = signal.getsignal(signal.SIGCHLD)
+ if not callable(previous_handler):
+ previous_handler = None
+
+ def handler(signum, frame):
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
+ # Python can still get and update the process status successfully.
+ _error_if_any_worker_fails()
+ if previous_handler is not None:
+ previous_handler(signum, frame)
+
+ signal.signal(signal.SIGCHLD, handler)
+ _SIGCHLD_handler_set = True
+
+
+class DataLoaderIter(object):
+ "Iterates once over the DataLoader's dataset, as specified by the sampler"
+
+ def __init__(self, loader):
+ self.dataset = loader.dataset
+ self.collate_fn = loader.collate_fn
+ self.batch_sampler = loader.batch_sampler
+ self.num_workers = loader.num_workers
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
+ self.timeout = loader.timeout
+ self.done_event = threading.Event()
+
+ self.sample_iter = iter(self.batch_sampler)
+
+ if self.num_workers > 0:
+ self.worker_init_fn = loader.worker_init_fn
+ self.index_queue = multiprocessing.SimpleQueue()
+ self.worker_result_queue = multiprocessing.SimpleQueue()
+ self.batches_outstanding = 0
+ self.worker_pids_set = False
+ self.shutdown = False
+ self.send_idx = 0
+ self.rcvd_idx = 0
+ self.reorder_dict = {}
+
+ base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
+ self.workers = [
+ multiprocessing.Process(
+ target=_worker_loop,
+ args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
+ base_seed + i, self.worker_init_fn, i))
+ for i in range(self.num_workers)]
+
+ if self.pin_memory or self.timeout > 0:
+ self.data_queue = queue.Queue()
+ if self.pin_memory:
+ maybe_device_id = torch.cuda.current_device()
+ else:
+ # do not initialize cuda context if not necessary
+ maybe_device_id = None
+ self.worker_manager_thread = threading.Thread(
+ target=_worker_manager_loop,
+ args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
+ maybe_device_id))
+ self.worker_manager_thread.daemon = True
+ self.worker_manager_thread.start()
+ else:
+ self.data_queue = self.worker_result_queue
+
+ for w in self.workers:
+ w.daemon = True # ensure that the worker exits on process exit
+ w.start()
+
+ _set_worker_pids(id(self), tuple(w.pid for w in self.workers))
+ _set_SIGCHLD_handler()
+ self.worker_pids_set = True
+
+ # prime the prefetch loop
+ for _ in range(2 * self.num_workers):
+ self._put_indices()
+
+ def __len__(self):
+ return len(self.batch_sampler)
+
+ def _get_batch(self):
+ if self.timeout > 0:
+ try:
+ return self.data_queue.get(timeout=self.timeout)
+ except queue.Empty:
+ raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
+ else:
+ return self.data_queue.get()
+
+ def __next__(self):
+ if self.num_workers == 0: # same-process loading
+ indices = next(self.sample_iter) # may raise StopIteration
+ batch = self.collate_fn([self.dataset[i] for i in indices])
+ if self.pin_memory:
+ batch = pin_memory_batch(batch)
+ return batch
+
+ # check if the next sample has already been generated
+ if self.rcvd_idx in self.reorder_dict:
+ batch = self.reorder_dict.pop(self.rcvd_idx)
+ return self._process_next_batch(batch)
+
+ if self.batches_outstanding == 0:
+ self._shutdown_workers()
+ raise StopIteration
+
+ while True:
+ assert (not self.shutdown and self.batches_outstanding > 0)
+ idx, batch = self._get_batch()
+ self.batches_outstanding -= 1
+ if idx != self.rcvd_idx:
+ # store out-of-order samples
+ self.reorder_dict[idx] = batch
+ continue
+ return self._process_next_batch(batch)
+
+ next = __next__ # Python 2 compatibility
+
+ def __iter__(self):
+ return self
+
+ def _put_indices(self):
+ assert self.batches_outstanding < 2 * self.num_workers
+ indices = next(self.sample_iter, None)
+ if indices is None:
+ return
+ self.index_queue.put((self.send_idx, indices))
+ self.batches_outstanding += 1
+ self.send_idx += 1
+
+ def _process_next_batch(self, batch):
+ self.rcvd_idx += 1
+ self._put_indices()
+ if isinstance(batch, ExceptionWrapper):
+ raise batch.exc_type(batch.exc_msg)
+ return batch
+
+ def __getstate__(self):
+ # TODO: add limited pickling support for sharing an iterator
+ # across multiple threads for HOGWILD.
+ # Probably the best way to do this is by moving the sample pushing
+ # to a separate thread and then just sharing the data queue
+ # but signalling the end is tricky without a non-blocking API
+ raise NotImplementedError("DataLoaderIterator cannot be pickled")
+
+ def _shutdown_workers(self):
+ try:
+ if not self.shutdown:
+ self.shutdown = True
+ self.done_event.set()
+ # if worker_manager_thread is waiting to put
+ while not self.data_queue.empty():
+ self.data_queue.get()
+ for _ in self.workers:
+ self.index_queue.put(None)
+ # done_event should be sufficient to exit worker_manager_thread,
+ # but be safe here and put another None
+ self.worker_result_queue.put(None)
+ finally:
+ # removes pids no matter what
+ if self.worker_pids_set:
+ _remove_worker_pids(id(self))
+ self.worker_pids_set = False
+
+ def __del__(self):
+ if self.num_workers > 0:
+ self._shutdown_workers()
+
+
+class DataLoader(object):
+ """
+ Data loader. Combines a dataset and a sampler, and provides
+ single- or multi-process iterators over the dataset.
+
+ Arguments:
+ dataset (Dataset): dataset from which to load the data.
+ batch_size (int, optional): how many samples per batch to load
+ (default: 1).
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: False).
+ sampler (Sampler, optional): defines the strategy to draw samples from
+ the dataset. If specified, ``shuffle`` must be False.
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
+ indices at a time. Mutually exclusive with batch_size, shuffle,
+ sampler, and drop_last.
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. 0 means that the data will be loaded in the main process.
+ (default: 0)
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
+ into CUDA pinned memory before returning them.
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+ if the dataset size is not divisible by the batch size. If ``False`` and
+ the size of dataset is not divisible by the batch size, then the last batch
+ will be smaller. (default: False)
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
+ from workers. Should always be non-negative. (default: 0)
+ worker_init_fn (callable, optional): If not None, this will be called on each
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+ input, after seeding and before data loading. (default: None)
+
+ .. note:: By default, each worker will have its PyTorch seed set to
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
+ by main process using its RNG. You may use ``torch.initial_seed()`` to access
+ this value in :attr:`worker_init_fn`, which can be used to set other seeds
+ (e.g. NumPy) before data loading.
+
+ .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
+ unpicklable object, e.g., a lambda function.
+ """
+
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
+ num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
+ timeout=0, worker_init_fn=None):
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.collate_fn = collate_fn
+ self.pin_memory = pin_memory
+ self.drop_last = drop_last
+ self.timeout = timeout
+ self.worker_init_fn = worker_init_fn
+
+ if timeout < 0:
+ raise ValueError('timeout option should be non-negative')
+
+ if batch_sampler is not None:
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
+ raise ValueError('batch_sampler is mutually exclusive with '
+ 'batch_size, shuffle, sampler, and drop_last')
+
+ if sampler is not None and shuffle:
+ raise ValueError('sampler is mutually exclusive with shuffle')
+
+ if self.num_workers < 0:
+ raise ValueError('num_workers cannot be negative; '
+ 'use num_workers=0 to disable multiprocessing.')
+
+ if batch_sampler is None:
+ if sampler is None:
+ if shuffle:
+ sampler = RandomSampler(dataset)
+ else:
+ sampler = SequentialSampler(dataset)
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+ self.sampler = sampler
+ self.batch_sampler = batch_sampler
+
+ def __iter__(self):
+ return DataLoaderIter(self)
+
+ def __len__(self):
+ return len(self.batch_sampler)
diff --git a/training/losses/ade20k/segm_lib/utils/data/dataset.py b/training/losses/ade20k/segm_lib/utils/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..605aa877f7031a5cd2b98c0f831410aa80fddefa
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/utils/data/dataset.py
@@ -0,0 +1,118 @@
+import bisect
+import warnings
+
+from torch._utils import _accumulate
+from torch import randperm
+
+
+class Dataset(object):
+ """An abstract class representing a Dataset.
+
+ All other datasets should subclass it. All subclasses should override
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
+ supporting integer indexing in range from 0 to len(self) exclusive.
+ """
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def __add__(self, other):
+ return ConcatDataset([self, other])
+
+
+class TensorDataset(Dataset):
+ """Dataset wrapping data and target tensors.
+
+ Each sample will be retrieved by indexing both tensors along the first
+ dimension.
+
+ Arguments:
+ data_tensor (Tensor): contains sample data.
+ target_tensor (Tensor): contains sample targets (labels).
+ """
+
+ def __init__(self, data_tensor, target_tensor):
+ assert data_tensor.size(0) == target_tensor.size(0)
+ self.data_tensor = data_tensor
+ self.target_tensor = target_tensor
+
+ def __getitem__(self, index):
+ return self.data_tensor[index], self.target_tensor[index]
+
+ def __len__(self):
+ return self.data_tensor.size(0)
+
+
+class ConcatDataset(Dataset):
+ """
+ Dataset to concatenate multiple datasets.
+ Purpose: useful to assemble different existing datasets, possibly
+ large-scale datasets as the concatenation operation is done in an
+ on-the-fly manner.
+
+ Arguments:
+ datasets (iterable): List of datasets to be concatenated
+ """
+
+ @staticmethod
+ def cumsum(sequence):
+ r, s = [], 0
+ for e in sequence:
+ l = len(e)
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets):
+ super(ConcatDataset, self).__init__()
+ assert len(datasets) > 0, 'datasets should not be an empty iterable'
+ self.datasets = list(datasets)
+ self.cumulative_sizes = self.cumsum(self.datasets)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx]
+
+ @property
+ def cummulative_sizes(self):
+ warnings.warn("cummulative_sizes attribute is renamed to "
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
+ return self.cumulative_sizes
+
+
+class Subset(Dataset):
+ def __init__(self, dataset, indices):
+ self.dataset = dataset
+ self.indices = indices
+
+ def __getitem__(self, idx):
+ return self.dataset[self.indices[idx]]
+
+ def __len__(self):
+ return len(self.indices)
+
+
+def random_split(dataset, lengths):
+ """
+ Randomly split a dataset into non-overlapping new datasets of given lengths
+ ds
+
+ Arguments:
+ dataset (Dataset): Dataset to be split
+ lengths (iterable): lengths of splits to be produced
+ """
+ if sum(lengths) != len(dataset):
+ raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
+
+ indices = randperm(sum(lengths))
+ return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
diff --git a/training/losses/ade20k/segm_lib/utils/data/distributed.py b/training/losses/ade20k/segm_lib/utils/data/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3d890e28fd2b9e044bdd9494de4a43ad2471eed
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/utils/data/distributed.py
@@ -0,0 +1,58 @@
+import math
+import torch
+from .sampler import Sampler
+from torch.distributed import get_world_size, get_rank
+
+
+class DistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+
+ .. note::
+ Dataset is assumed to be of constant size.
+
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None):
+ if num_replicas is None:
+ num_replicas = get_world_size()
+ if rank is None:
+ rank = get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = list(torch.randperm(len(self.dataset), generator=g))
+
+ # add extra samples to make it evenly divisible
+ indices += indices[:(self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset:offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/training/losses/ade20k/segm_lib/utils/data/sampler.py b/training/losses/ade20k/segm_lib/utils/data/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..62a9a43bd1d4c21fbdcb262db7da8d4fe27b26de
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/utils/data/sampler.py
@@ -0,0 +1,131 @@
+import torch
+
+
+class Sampler(object):
+ """Base class for all Samplers.
+
+ Every Sampler subclass has to provide an __iter__ method, providing a way
+ to iterate over indices of dataset elements, and a __len__ method that
+ returns the length of the returned iterators.
+ """
+
+ def __init__(self, data_source):
+ pass
+
+ def __iter__(self):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+
+class SequentialSampler(Sampler):
+ """Samples elements sequentially, always in the same order.
+
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ """
+
+ def __init__(self, data_source):
+ self.data_source = data_source
+
+ def __iter__(self):
+ return iter(range(len(self.data_source)))
+
+ def __len__(self):
+ return len(self.data_source)
+
+
+class RandomSampler(Sampler):
+ """Samples elements randomly, without replacement.
+
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ """
+
+ def __init__(self, data_source):
+ self.data_source = data_source
+
+ def __iter__(self):
+ return iter(torch.randperm(len(self.data_source)).long())
+
+ def __len__(self):
+ return len(self.data_source)
+
+
+class SubsetRandomSampler(Sampler):
+ """Samples elements randomly from a given list of indices, without replacement.
+
+ Arguments:
+ indices (list): a list of indices
+ """
+
+ def __init__(self, indices):
+ self.indices = indices
+
+ def __iter__(self):
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
+
+ def __len__(self):
+ return len(self.indices)
+
+
+class WeightedRandomSampler(Sampler):
+ """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
+
+ Arguments:
+ weights (list) : a list of weights, not necessary summing up to one
+ num_samples (int): number of samples to draw
+ replacement (bool): if ``True``, samples are drawn with replacement.
+ If not, they are drawn without replacement, which means that when a
+ sample index is drawn for a row, it cannot be drawn again for that row.
+ """
+
+ def __init__(self, weights, num_samples, replacement=True):
+ self.weights = torch.DoubleTensor(weights)
+ self.num_samples = num_samples
+ self.replacement = replacement
+
+ def __iter__(self):
+ return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
+
+ def __len__(self):
+ return self.num_samples
+
+
+class BatchSampler(object):
+ """Wraps another sampler to yield a mini-batch of indices.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``
+
+ Example:
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
+ >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+ """
+
+ def __init__(self, sampler, batch_size, drop_last):
+ self.sampler = sampler
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ def __iter__(self):
+ batch = []
+ for idx in self.sampler:
+ batch.append(idx)
+ if len(batch) == self.batch_size:
+ yield batch
+ batch = []
+ if len(batch) > 0 and not self.drop_last:
+ yield batch
+
+ def __len__(self):
+ if self.drop_last:
+ return len(self.sampler) // self.batch_size
+ else:
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
diff --git a/training/losses/ade20k/segm_lib/utils/th.py b/training/losses/ade20k/segm_lib/utils/th.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca6ef9385e3b5c0a439579d3fd7aa73b5dc62758
--- /dev/null
+++ b/training/losses/ade20k/segm_lib/utils/th.py
@@ -0,0 +1,41 @@
+import torch
+from torch.autograd import Variable
+import numpy as np
+import collections
+
+__all__ = ['as_variable', 'as_numpy', 'mark_volatile']
+
+def as_variable(obj):
+ if isinstance(obj, Variable):
+ return obj
+ if isinstance(obj, collections.Sequence):
+ return [as_variable(v) for v in obj]
+ elif isinstance(obj, collections.Mapping):
+ return {k: as_variable(v) for k, v in obj.items()}
+ else:
+ return Variable(obj)
+
+def as_numpy(obj):
+ if isinstance(obj, collections.Sequence):
+ return [as_numpy(v) for v in obj]
+ elif isinstance(obj, collections.Mapping):
+ return {k: as_numpy(v) for k, v in obj.items()}
+ elif isinstance(obj, Variable):
+ return obj.data.cpu().numpy()
+ elif torch.is_tensor(obj):
+ return obj.cpu().numpy()
+ else:
+ return np.array(obj)
+
+def mark_volatile(obj):
+ if torch.is_tensor(obj):
+ obj = Variable(obj)
+ if isinstance(obj, Variable):
+ obj.no_grad = True
+ return obj
+ elif isinstance(obj, collections.Mapping):
+ return {k: mark_volatile(o) for k, o in obj.items()}
+ elif isinstance(obj, collections.Sequence):
+ return [mark_volatile(o) for o in obj]
+ else:
+ return obj
diff --git a/training/losses/ade20k/utils.py b/training/losses/ade20k/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f337db7db54c82be041698d694e1403e8918c4c0
--- /dev/null
+++ b/training/losses/ade20k/utils.py
@@ -0,0 +1,40 @@
+"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
+
+import os
+import sys
+
+import numpy as np
+import torch
+
+try:
+ from urllib import urlretrieve
+except ImportError:
+ from urllib.request import urlretrieve
+
+
+def load_url(url, model_dir='./pretrained', map_location=None):
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+ filename = url.split('/')[-1]
+ cached_file = os.path.join(model_dir, filename)
+ if not os.path.exists(cached_file):
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+ urlretrieve(url, cached_file)
+ return torch.load(cached_file, map_location=map_location)
+
+
+def color_encode(labelmap, colors, mode='RGB'):
+ labelmap = labelmap.astype('int')
+ labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
+ dtype=np.uint8)
+ for label in np.unique(labelmap):
+ if label < 0:
+ continue
+ labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
+ np.tile(colors[label],
+ (labelmap.shape[0], labelmap.shape[1], 1))
+
+ if mode == 'BGR':
+ return labelmap_rgb[:, :, ::-1]
+ else:
+ return labelmap_rgb
diff --git a/training/losses/high_receptive_pl.py b/training/losses/high_receptive_pl.py
new file mode 100644
index 0000000000000000000000000000000000000000..df48f2037d868c395b188adbfe1308b93ad7c139
--- /dev/null
+++ b/training/losses/high_receptive_pl.py
@@ -0,0 +1,43 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from training.losses.ade20k import ModelBuilder
+
+
+IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
+IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
+
+
+class HRFPL(nn.Module):
+ def __init__(self, weight=1,
+ weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
+ super().__init__()
+ self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
+ arch_encoder=arch_encoder,
+ arch_decoder='ppm_deepsup',
+ fc_dim=2048,
+ segmentation=segmentation)
+ self.impl.eval()
+ for w in self.impl.parameters():
+ w.requires_grad_(False)
+
+ self.weight = weight
+
+ def forward(self, pred, target):
+
+ target = (target + 1) / 2
+ pred = (pred + 1) / 2
+ pred = torch.clamp(pred, 0., 1.)
+
+ pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
+ target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
+
+ self.impl = self.impl.to(pred.device)
+ pred_feats = self.impl(pred, return_feature_maps=True)
+ target_feats = self.impl(target, return_feature_maps=True)
+
+ result = torch.stack([F.mse_loss(cur_pred, cur_target)
+ for cur_pred, cur_target
+ in zip(pred_feats, target_feats)]).sum() * self.weight
+ return result
\ No newline at end of file
diff --git a/training/losses/loss.py b/training/losses/loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..0f05c3a2705ce5e8fd33c2d4273c36a709ad843f
--- /dev/null
+++ b/training/losses/loss.py
@@ -0,0 +1,129 @@
+import numpy as np
+import torch
+from torch_utils import training_stats
+from torch_utils import misc
+from torch_utils.ops import conv2d_gradfix
+from icecream import ic
+from .high_receptive_pl import HRFPL
+import os
+
+#----------------------------------------------------------------------------
+
+class Loss:
+ def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
+ raise NotImplementedError()
+
+#----------------------------------------------------------------------------
+
+class StyleGAN2Loss(Loss):
+ def __init__(self, device, G_encoder, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
+ super().__init__()
+ self.device = device
+ self.G_encoder = G_encoder
+ self.G_mapping = G_mapping
+ self.G_synthesis = G_synthesis
+ self.D = D
+ self.augment_pipe = augment_pipe
+ self.style_mixing_prob = style_mixing_prob
+ self.r1_gamma = r1_gamma
+ self.pl_batch_shrink = pl_batch_shrink
+ self.pl_decay = pl_decay
+ self.pl_weight = pl_weight
+ self.pl_mean = torch.zeros([], device=device)
+ self.run_hrfpl = HRFPL(weight=5, weights_path=os.getcwd())
+
+ def run_G(self, r_img, c, sync):
+ with misc.ddp_sync(self.G_encoder, sync):
+ x_global, z, feats = self.G_encoder(r_img, c)
+ with misc.ddp_sync(self.G_mapping, sync):
+ ws = self.G_mapping(z, c)
+ if self.style_mixing_prob > 0:
+ with torch.autograd.profiler.record_function('style_mixing'):
+ cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
+ cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
+ ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
+ with misc.ddp_sync(self.G_synthesis, sync):
+ img = self.G_synthesis(x_global, feats, ws)
+ return img, ws
+
+ def run_D(self, img, c, sync):
+ with misc.ddp_sync(self.D, sync):
+ logits = self.D(img, c)
+ return logits
+
+
+ def accumulate_gradients(self, phase, erased_img, real_img, mask, real_c, gen_c, sync, gain):
+ assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
+ do_Gmain = (phase in ['Gmain', 'Gboth'])
+ do_Dmain = (phase in ['Dmain', 'Dboth'])
+ do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
+
+ # Gmain: Maximize logits for generated images.
+ if do_Gmain:
+ with torch.autograd.profiler.record_function('Gmain_forward'):
+ g_inputs = torch.cat([0.5 - mask, erased_img], dim=1)
+ gen_img, _ = self.run_G(g_inputs, gen_c, sync=sync) # May get synced by Gpl.
+ gen_img = gen_img * mask + real_img * (1 - mask)
+ loss_rec = 10 * torch.nn.functional.l1_loss(gen_img, real_img)
+ loss_pl = self.run_hrfpl(gen_img, real_img)
+
+ if self.augment_pipe is not None:
+ gen_img = self.augment_pipe(gen_img)
+ d_inputs = torch.cat([0.5 - mask, gen_img], dim=1)
+ gen_logits = self.run_D(d_inputs, gen_c, sync=False)
+
+ loss_G = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
+ loss_Gmain = loss_G.mean() + loss_rec + loss_pl
+ training_stats.report('Loss/G/loss', loss_G)
+ training_stats.report('Loss/G/rec_loss', loss_rec)
+ training_stats.report('Loss/G/main_loss', loss_Gmain)
+ training_stats.report('Loss/G/pl_loss', loss_pl)
+ with torch.autograd.profiler.record_function('Gmain_backward'):
+ loss_Gmain.mul(gain).backward()
+
+ # Dmain: Minimize logits for generated images.
+ loss_Dgen = 0
+ if do_Dmain:
+ with torch.autograd.profiler.record_function('Dgen_forward'):
+ g_inputs = torch.cat([0.5 - mask, erased_img], dim=1)
+ gen_img, _ = self.run_G(g_inputs, gen_c, sync=sync) # May get synced by Gpl.
+ gen_img = gen_img * mask + real_img * (1 - mask)
+ if self.augment_pipe is not None:
+ gen_img = self.augment_pipe(gen_img)
+ d_inputs = torch.cat([0.5 - mask, gen_img], dim=1)
+
+ gen_logits = self.run_D(d_inputs, gen_c, sync=False) # Gets synced by loss_Dreal.
+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
+
+ with torch.autograd.profiler.record_function('Dgen_backward'):
+ loss_Dgen.mean().mul(gain).backward()
+
+ # Dmain: Maximize logits for real images.
+ # Dr1: Apply R1 regularization.
+ if do_Dmain or do_Dr1:
+ name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
+ with torch.autograd.profiler.record_function(name + '_forward'):
+ real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
+ if self.augment_pipe is not None:
+ real_img_tmp = self.augment_pipe(real_img_tmp)
+ d_inputs = torch.cat([0.5 - mask, real_img_tmp], dim=1)
+ real_logits = self.run_D(d_inputs, real_c, sync=sync)
+
+ loss_Dreal = 0
+ if do_Dmain:
+ loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
+ training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
+
+ loss_Dr1 = 0
+ if do_Dr1:
+ with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
+ r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
+ r1_penalty = r1_grads.square().sum([1,2,3])
+ loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
+ training_stats.report('Loss/r1_penalty', r1_penalty)
+ training_stats.report('Loss/D/reg', loss_Dr1)
+
+ with torch.autograd.profiler.record_function(name + '_backward'):
+ (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
+
+#----------------------------------------------------------------------------
diff --git a/training/losses/perceptual.py b/training/losses/perceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..de22a11b1b3b20df1a835bb5c494a82b5b2972a2
--- /dev/null
+++ b/training/losses/perceptual.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .vgg import VGG19, VGG16
+
+class Perceptual16Loss(nn.Module):
+ def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
+ super(Perceptual16Loss, self).__init__()
+ self.vgg = VGG16()
+ self.criterion = torch.nn.L1Loss()
+ self.weights = weights
+
+ def calculate_pl(self, x, y):
+ feat_output = self.vgg(x)
+ feat_gt = self.vgg(y)
+
+ content_loss = 0.0
+
+ for i in range(3):
+ content_loss += self.criterion(feat_output[i], feat_gt[i])
+ return content_loss.to(device=x.device)
+
+ def compute_gram(self, x):
+ b, c, h, w = x.size()
+ f = x.view(b, c, w * h)
+ f_T = f.transpose(1, 2)
+ G = f.bmm(f_T) / (h * w * c)
+ return G
+
+ def calc_style(self, x, y):
+ feat_output = self.extractor(x)
+ feat_gt = self.extractor(y)
+
+ style_loss = 0.0
+
+ for i in range(3):
+ style_loss += self.criterion(
+ self.compute_gram(feat_output[i]), self.compute_gram(feat_gt[i]))
+ return style_loss
+
+class Perceptual19Loss(nn.Module):
+ def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
+ super(Perceptual19Loss, self).__init__()
+ self.vgg = VGG19()
+ self.criterion = torch.nn.L1Loss()
+ self.weights = weights
+
+ def calculate_pl(self, x, y):
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
+ content_loss = 0.0
+ prefix = [1, 2, 3, 4, 5]
+ for i in range(5):
+ content_loss += self.weights[i] * self.criterion(
+ x_vgg[f'relu{prefix[i]}_1'], y_vgg[f'relu{prefix[i]}_1'])
+ return content_loss.to(device=x.device)
+
+ def compute_gram(self, x):
+ b, c, h, w = x.size()
+ f = x.view(b, c, w * h)
+ f_T = f.transpose(1, 2)
+ G = f.bmm(f_T) / (h * w * c)
+ return G
+
+ def calc_style(self, x, y):
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
+ style_loss = 0.0
+ prefix = [2, 3, 4, 5]
+ posfix = [2, 4, 4, 2]
+ for pre, pos in list(zip(prefix, posfix)):
+ style_loss += self.criterion(
+ self.compute_gram(x_vgg[f'relu{pre}_{pos}']), self.compute_gram(y_vgg[f'relu{pre}_{pos}']))
+ return style_loss
+
\ No newline at end of file
diff --git a/training/losses/vgg.py b/training/losses/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddfbf3eb2956a9044ebc4d3d03bb4ec393bfe2ce
--- /dev/null
+++ b/training/losses/vgg.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as models
+from torch.nn.functional import conv2d
+
+
+class VGG16(nn.Module):
+ def __init__(self):
+ super(VGG16, self).__init__()
+ vgg16 = models.vgg16(pretrained=True)
+ self.enc_1 = nn.Sequential(*vgg16.features[:5])
+ self.enc_2 = nn.Sequential(*vgg16.features[5:10])
+ self.enc_3 = nn.Sequential(*vgg16.features[10:17])
+
+ # fix the encoder
+ for i in range(3):
+ for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
+ param.requires_grad = False
+
+ def forward(self, image):
+ results = [image]
+ for i in range(3):
+ func = getattr(self, 'enc_{:d}'.format(i + 1)).to(image.device)
+ results.append(func(results[-1]))
+ return results[1:]
+
+class VGG19(nn.Module):
+ def __init__(self, resize_input=False):
+ super(VGG19, self).__init__()
+ features = models.vgg19(pretrained=True).features
+
+ self.resize_input = resize_input
+ self.mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
+ self.std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
+ prefix = [1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]
+ posfix = [1, 2, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]
+ names = list(zip(prefix, posfix))
+ self.relus = []
+ for pre, pos in names:
+ self.relus.append('relu{}_{}'.format(pre, pos))
+ self.__setattr__('relu{}_{}'.format(
+ pre, pos), torch.nn.Sequential())
+
+ nums = [[0, 1], [2, 3], [4, 5, 6], [7, 8],
+ [9, 10, 11], [12, 13], [14, 15], [16, 17],
+ [18, 19, 20], [21, 22], [23, 24], [25, 26],
+ [27, 28, 29], [30, 31], [32, 33], [34, 35]]
+
+ for i, layer in enumerate(self.relus):
+ for num in nums[i]:
+ self.__getattr__(layer).add_module(str(num), features[num])
+
+ # don't need the gradients, just want the features
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ # resize and normalize input for pretrained vgg19
+ x = (x + 1.0) / 2.0
+ x = (x - self.mean.view(1, 3, 1, 1).to(x.device)) / (self.std.view(1, 3, 1, 1).to(x.device))
+ if self.resize_input:
+ x = F.interpolate(
+ x, size=(256, 256), mode='bilinear', align_corners=True)
+ features = []
+ for layer in self.relus:
+ x = self.__getattr__(layer).to(x.device)(x)
+ features.append(x)
+ out = {key: value for (key, value) in list(zip(self.relus, features))}
+ return out
\ No newline at end of file
diff --git a/training/models.py b/training/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7fa885df6fe55515b820bdf375515ba2eb1eee6
--- /dev/null
+++ b/training/models.py
@@ -0,0 +1,853 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+from numpy.lib.type_check import imag
+import torch
+import torch.nn as nn
+from torch_utils import misc
+from torch_utils import persistence
+from torch_utils.ops import conv2d_resample
+from torch_utils.ops import upfirdn2d
+from torch_utils.ops import bias_act
+from torch_utils.ops import fma
+from icecream import ic
+import torch.nn.functional as F
+from training.ffc import FFCResnetBlock, ConcatTupleLayer
+import matplotlib.pyplot as plt
+import PIL
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+def save_image_grid(feats, fname, gridsize):
+ gw, gh = gridsize
+ idx = gw * gh
+
+ max_num = torch.max(feats[:idx]).item()
+ min_num = torch.min(feats[:idx]).item()
+ feats = feats[:idx].cpu() * 255 / (max_num - min_num)
+ feats = np.asarray(feats, dtype=np.float32)
+ feats = np.rint(feats).clip(0, 255).astype(np.uint8)
+
+ C, H, W = feats.shape
+
+ feats = feats.reshape(gh, gw, 1, H, W)
+ feats = feats.transpose(0, 3, 1, 4, 2)
+ feats = feats.reshape(gh * H, gw * W, 1)
+ feats = np.stack([feats]*3, axis=2).squeeze() * 10
+ feats = np.rint(feats).clip(0, 255).astype(np.uint8)
+
+ from icecream import ic
+ ic(feats.shape)
+
+ feats = PIL.Image.fromarray(feats)
+ feats.save(fname + '.png')
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
+ noise = None, # Optional noise tensor to add to the output activations.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ padding = 0, # Padding with respect to the upsampled image.
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
+ demodulate = True, # Apply weight demodulation?
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
+):
+ batch_size = x.shape[0]
+ out_channels, in_channels, kh, kw = weight.shape
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs to avoid FP16 overflow.
+ if x.dtype == torch.float16 and demodulate:
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
+
+ # Calculate per-sample weights and demodulation coefficients.
+ w = None
+ dcoefs = None
+ if demodulate or fused_modconv:
+ w = weight.unsqueeze(0) # [NOIkk]
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
+ if demodulate and fused_modconv:
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
+ # Execute by scaling the activations before and after the convolution.
+ if not fused_modconv:
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
+ if demodulate and noise is not None:
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
+ elif demodulate:
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ elif noise is not None:
+ x = x.add_(noise.to(x.dtype))
+ return x
+
+ # Execute as one fused op using grouped convolution.
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ batch_size = int(batch_size)
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ if noise is not None:
+ x = x.add_(noise)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class Conv2dLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size, # Width and height of the convolution kernel.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
+ channels_last = False, # Expect the input to have memory_format=channels_last?
+ trainable = True, # Update the weights of this layer during training?
+ ):
+ super().__init__()
+ self.activation = activation
+ self.up = up
+ self.down = down
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
+ bias = torch.zeros([out_channels]) if bias else None
+ if trainable:
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
+ else:
+ self.register_buffer('weight', weight)
+ if bias is not None:
+ self.register_buffer('bias', bias)
+ else:
+ self.bias = None
+
+ def forward(self, x, gain=1):
+ w = self.weight * self.weight_gain
+ b = self.bias.to(x.dtype) if self.bias is not None else None
+ flip_weight = (self.up == 1) # slightly faster
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class FFCBlock(torch.nn.Module):
+ def __init__(self,
+ dim, # Number of output/input channels.
+ kernel_size, # Width and height of the convolution kernel.
+ padding,
+ ratio_gin=0.75,
+ ratio_gout=0.75,
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ ):
+ super().__init__()
+ if activation == 'linear':
+ self.activation = nn.Identity
+ else:
+ self.activation = nn.ReLU
+ self.padding = padding
+ self.kernel_size = kernel_size
+ self.ffc_block = FFCResnetBlock(dim=dim,
+ padding_type='reflect',
+ norm_layer=nn.SyncBatchNorm,
+ activation_layer=self.activation,
+ dilation=1,
+ ratio_gin=ratio_gin,
+ ratio_gout=ratio_gout)
+
+ self.concat_layer = ConcatTupleLayer()
+
+ def forward(self, gen_ft, mask, fname=None):
+ x = gen_ft.float()
+# x = mask*enc_ft + (1-mask)*gen_ft
+ x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:]
+
+ id_l, id_g = x_l, x_g
+
+ x_l, x_g = self.ffc_block((x_l, x_g), fname=fname)
+
+ x_l, x_g = id_l + x_l, id_g + x_g
+
+ x = self.concat_layer((x_l, x_g))
+ return x + gen_ft.float()
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class EncoderEpilogue(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
+ z_dim, # Output Latent (Z) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.cmap_dim = cmap_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.architecture = architecture
+
+ if architecture == 'skip':
+ self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation)
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation)
+ # self.out = FullyConnectedLayer(in_channels, z_dim)
+ self.dropout = torch.nn.Dropout(p=0.5)
+
+ def forward(self, x, cmap, force_fp32=False):
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
+ _ = force_fp32 # unused
+ dtype = torch.float32
+ memory_format = torch.contiguous_format
+
+ # FromRGB.
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.mbstd is not None:
+ x = self.mbstd(x)
+ const_e = self.conv(x)
+ x = self.fc(const_e.flatten(1))
+ # x = self.out(x)
+ x = self.dropout(x)
+
+ # Conditioning.
+ if self.cmap_dim > 0:
+ misc.assert_shape(cmap, [None, self.cmap_dim])
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ assert x.dtype == dtype
+ return x, const_e
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class EncoderBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ tmp_channels, # Number of intermediate channels.
+ out_channels, # Number of output channels.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ first_layer_idx, # Index of the first layer.
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
+ ):
+ assert in_channels in [0, tmp_channels]
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.resolution = resolution
+ self.img_channels = img_channels + 1
+ self.first_layer_idx = first_layer_idx
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+
+ self.num_layers = 0
+ def trainable_gen():
+ while True:
+ layer_idx = self.first_layer_idx + self.num_layers
+ trainable = (layer_idx >= freeze_layers)
+ self.num_layers += 1
+ yield trainable
+ trainable_iter = trainable_gen()
+
+ if in_channels == 0:
+ self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ if architecture == 'resnet':
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, force_fp32=False):
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+
+ # Input.
+ if x is not None:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # FromRGB.
+ if self.in_channels == 0:
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ y = self.fromrgb(img)
+ x = x + y if x is not None else y
+ img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
+
+ # Main layers.
+ if self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x)
+ feat = x.clone()
+ x = self.conv1(x, gain=np.sqrt(0.5))
+ x = y.add_(x)
+ else:
+ x = self.conv0(x)
+ feat = x.clone()
+ x = self.conv1(x)
+
+ assert x.dtype == dtype
+ return x, img, feat
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class SynthesisLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this layer.
+ kernel_size = 3, # Convolution kernel size.
+ up = 1, # Integer upsampling factor.
+ use_noise = True, # Enable noise input?
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ channels_last = False, # Use channels_last format for the weights?
+ ):
+ super().__init__()
+ self.resolution = resolution
+ self.up = up
+ self.use_noise = use_noise
+ self.activation = activation
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ if use_noise:
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+
+ def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
+ assert noise_mode in ['random', 'const', 'none']
+ in_resolution = self.resolution // self.up
+ misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
+ styles = self.affine(w)
+
+ noise = None
+ if self.use_noise and noise_mode == 'random':
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
+ if self.use_noise and noise_mode == 'const':
+ noise = self.noise_const * self.noise_strength
+
+ flip_weight = (self.up == 1) # slightly faster
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=False)
+ if act_gain != 1:
+ x = x * act_gain
+ if act_clamp is not None:
+ x = x.clamp(-act_clamp, act_clamp)
+ # x = bias_act.bias_act(x.clone(), self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class FFCSkipLayer(torch.nn.Module):
+ def __init__(self,
+ dim, # Number of input/output channels.
+ kernel_size = 3, # Convolution kernel size.
+ ratio_gin=0.75,
+ ratio_gout=0.75,
+ ):
+ super().__init__()
+ self.padding = kernel_size // 2
+
+ self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU,
+ padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout)
+
+ def forward(self, gen_ft, mask, fname=None):
+ x = self.ffc_act(gen_ft, mask, fname=fname)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class ToRGBLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
+ super().__init__()
+ self.conv_clamp = conv_clamp
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+
+ def forward(self, x, w, fused_modconv=True):
+ styles = self.affine(w) * self.weight_gain
+ x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class SynthesisBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+ self.res_ffc = {4:0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1}
+
+ if in_channels != 0 and resolution >= 8:
+ self.ffc_skip = nn.ModuleList()
+ for _ in range(self.res_ffc[resolution]):
+ self.ffc_skip.append(FFCSkipLayer(dim=out_channels))
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim*3, resolution=resolution, up=2,
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim*3, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim*3,
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
+ resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs):
+ # misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
+ # w_iter = iter(ws.unbind(dim=1))
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
+
+ # # Input.
+ # if self.in_channels == 0:
+ # ic(self.const.shape)
+ # x = self.const.to(dtype=dtype, memory_format=memory_format)
+ # x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ # ic(x.shape)
+ # else:
+ # misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
+ # x = x.to(dtype=dtype, memory_format=memory_format)
+ # ic(x.shape, 'ELSE')
+
+ x = x.to(dtype=dtype, memory_format=memory_format)
+ x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
+ if len(self.ffc_skip) > 0:
+ mask = F.interpolate(mask, size=x_skip.shape[2:],)
+ z = x + x_skip
+ for fres in self.ffc_skip:
+ z = fres(z, mask)
+ x = x + z
+ else:
+ x = x + x_skip
+ x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs)
+ if len(self.ffc_skip) > 0:
+ # for i in range(x.shape[0]):
+ # c, h, w = x[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(x[i].detach(), f'vis/{fname}_pre_{h}', (gh, gw))
+ mask = F.interpolate(mask, size=x_skip.shape[2:],)
+ z = x + x_skip
+ for fres in self.ffc_skip:
+ z = fres(z, mask)
+ # for i in range(z.shape[0]):
+ # c, h, w = z[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(z[i].detach(), f'vis/{fname}_ffc_{h}', (gh, gw))
+ x = x + z
+ # for i in range(x.shape[0]):
+ # c, h, w = x[i].shape
+ # gh = 3
+ # gw = 3
+ # save_image_grid(x[i].detach(), f'vis/{fname}_post_{h}', (gh, gw))
+ else:
+ x = x + x_skip
+ x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs)
+ # ToRGB.
+ if img is not None:
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ x = x.to(dtype=dtype)
+ assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class SynthesisForeword(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Output Latent (Z) dimensionality.
+ resolution, # Resolution of this block.
+ in_channels,
+ img_channels, # Number of input color channels.
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.z_dim = z_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.architecture = architecture
+
+ self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation)
+ self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4)
+
+ if architecture == 'skip':
+ self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim = (z_dim // 2) * 3)
+
+ def forward(self, x, ws, feats, img, force_fp32=False):
+ misc.assert_shape(x, [None, self.z_dim]) # [NC]
+ _ = force_fp32 # unused
+ dtype = torch.float32
+ memory_format = torch.contiguous_format
+
+ x_global = x.clone()
+ # ToRGB.
+ x = self.fc(x)
+ x = x.view(-1, self.z_dim // 2, 4, 4)
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ x_skip = feats[4].clone()
+ x = x + x_skip
+
+ mod_vector = []
+ mod_vector.append(ws[:, 0])
+ mod_vector.append(x_global.clone())
+ mod_vector = torch.cat(mod_vector, dim = 1)
+
+ x = self.conv(x, mod_vector)
+
+ mod_vector = []
+ mod_vector.append(ws[:, 2*2-3])
+ mod_vector.append(x_global.clone())
+ mod_vector = torch.cat(mod_vector, dim = 1)
+
+ if self.architecture == 'skip':
+ img = self.torgb(x, mod_vector)
+ img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+
+ assert x.dtype == dtype
+ return x, img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class DiscriminatorBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ tmp_channels, # Number of intermediate channels.
+ out_channels, # Number of output channels.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ first_layer_idx, # Index of the first layer.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
+ ):
+ assert in_channels in [0, tmp_channels]
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.resolution = resolution
+ self.img_channels = img_channels + 1
+ self.first_layer_idx = first_layer_idx
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+
+ self.num_layers = 0
+ def trainable_gen():
+ while True:
+ layer_idx = self.first_layer_idx + self.num_layers
+ trainable = (layer_idx >= freeze_layers)
+ self.num_layers += 1
+ yield trainable
+ trainable_iter = trainable_gen()
+
+ if in_channels == 0 or architecture == 'skip':
+ self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ if architecture == 'resnet':
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, force_fp32=False):
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+
+ # Input.
+ if x is not None:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # FromRGB.
+ if self.in_channels == 0 or self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ y = self.fromrgb(img)
+ x = x + y if x is not None else y
+ img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
+
+ # Main layers.
+ if self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x)
+ x = self.conv1(x, gain=np.sqrt(0.5))
+ x = y.add_(x)
+ else:
+ x = self.conv0(x)
+ x = self.conv1(x)
+
+ assert x.dtype == dtype
+ return x, img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class MinibatchStdLayer(torch.nn.Module):
+ def __init__(self, group_size, num_channels=1):
+ super().__init__()
+ self.group_size = group_size
+ self.num_channels = num_channels
+
+ def forward(self, x):
+ N, C, H, W = x.shape
+ with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
+ F = self.num_channels
+ c = C // F
+
+ y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
+ y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class DiscriminatorEpilogue(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.cmap_dim = cmap_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.architecture = architecture
+
+ if architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
+ self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
+
+ def forward(self, x, img, cmap, force_fp32=False):
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
+ _ = force_fp32 # unused
+ dtype = torch.float32
+ memory_format = torch.contiguous_format
+
+ # FromRGB.
+ x = x.to(dtype=dtype, memory_format=memory_format)
+ if self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ x = x + self.fromrgb(img)
+
+ # Main layers.
+ if self.mbstd is not None:
+ x = self.mbstd(x)
+ x = self.conv(x)
+ x = self.fc(x.flatten(1))
+ x = self.out(x)
+
+ # Conditioning.
+ if self.cmap_dim > 0:
+ misc.assert_shape(cmap, [None, self.cmap_dim])
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ assert x.dtype == dtype
+ return x
+
+#----------------------------------------------------------------------------
\ No newline at end of file
diff --git a/training/networks.py b/training/networks.py
new file mode 100755
index 0000000000000000000000000000000000000000..a1ff3e6071ede56a9cfa42238a9d03398b41682e
--- /dev/null
+++ b/training/networks.py
@@ -0,0 +1,321 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import torch
+from torch_utils import misc
+from torch_utils import persistence
+from training.models import *
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
+ num_layers = 8, # Number of mapping layers.
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ if embed_features is None:
+ embed_features = w_dim
+ if c_dim == 0:
+ embed_features = 0
+ if layer_features is None:
+ layer_features = w_dim
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
+
+ if c_dim > 0:
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
+ for idx in range(num_layers):
+ in_features = features_list[idx]
+ out_features = features_list[idx + 1]
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+
+ if num_ws is not None and w_avg_beta is not None:
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
+ # Embed, normalize, and concat inputs.
+ x = None
+ with torch.autograd.profiler.record_function('input'):
+ if self.z_dim > 0:
+ misc.assert_shape(z, [None, self.z_dim])
+ x = normalize_2nd_moment(z.to(torch.float32))
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Main layers.
+ for idx in range(self.num_layers):
+ layer = getattr(self, f'fc{idx}')
+ x = layer(x)
+
+ # Update moving average of W.
+ if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
+ with torch.autograd.profiler.record_function('update_w_avg'):
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast.
+ if self.num_ws is not None:
+ with torch.autograd.profiler.record_function('broadcast'):
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+ # Apply truncation.
+ if truncation_psi != 1:
+ with torch.autograd.profiler.record_function('truncate'):
+ assert self.w_avg_beta is not None
+ if self.num_ws is None or truncation_cutoff is None:
+ x = self.w_avg.lerp(x, truncation_psi)
+ else:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class EncoderNetwork(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ z_dim, # Input latent (Z) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'orig', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 16384, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for EncoderEpilogue.
+ ):
+ super().__init__()
+ self.c_dim = c_dim
+ self.z_dim = z_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, c, **block_kwargs):
+ x = None
+ feats = {}
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img, feat = block(x, img, **block_kwargs)
+ feats[res] = feat
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x, const_e = self.b4(x, cmap)
+ feats[4] = const_e
+
+ B, _ = x.shape
+ z = torch.randn((B, self.z_dim), requires_grad=False, dtype=x.dtype, device=x.device) ## Noise for Co-Modulation
+ return x, z, feats ## 1/2, 1/4, 1/8, 1/16, 1/32, 1/64
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ z_dim, # Output Latent (Z) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base = 16384, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions.
+ **block_kwargs, # Arguments for SynthesisBlock.
+ ):
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
+ super().__init__()
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int( np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max), z_dim=z_dim*2, resolution=4)
+
+ self.num_ws = self.img_resolution_log2 * 2 - 2
+ for res in self.block_resolutions:
+ if res // 2 in channels_dict.keys():
+ in_channels = channels_dict[res // 2] if res > 4 else 0
+ else:
+ in_channels = min(channel_base // (res // 2) , channel_max)
+ out_channels = channels_dict[res]
+ use_fp16 = (res >= fp16_resolution)
+ is_last = (res == self.img_resolution)
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
+ setattr(self, f'b{res}', block)
+
+ def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs):
+
+ img = None
+
+ x, img = self.foreword(x_global, ws, feats, img)
+
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ mod_vector0 = []
+ mod_vector0.append(ws[:, int(np.log2(res))*2-5])
+ mod_vector0.append(x_global.clone())
+ mod_vector0 = torch.cat(mod_vector0, dim = 1)
+
+ mod_vector1 = []
+ mod_vector1.append(ws[:, int(np.log2(res))*2-4])
+ mod_vector1.append(x_global.clone())
+ mod_vector1 = torch.cat(mod_vector1, dim = 1)
+
+ mod_vector_rgb = []
+ mod_vector_rgb.append(ws[:, int(np.log2(res))*2-3])
+ mod_vector_rgb.append(x_global.clone())
+ mod_vector_rgb = torch.cat(mod_vector_rgb, dim = 1)
+ # ic(x.shape)
+ x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs)
+ # ic(x.shape)
+ # ic('--------')
+ return img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class Generator(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ encoder_kwargs = {}, # Arguments for EncoderNetwork.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution, img_channels=img_channels, **encoder_kwargs)
+ self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+
+ def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
+ mask = img[:, -1].unsqueeze(1)
+ x_global, z, feats = self.encoder(img, c)
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
+ img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs)
+ # exit()
+ return img
+
+#----------------------------------------------------------------------------
+
+@persistence.persistent_class
+class Discriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 16384, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 0, # Use FP16 for the N highest resolutions.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, c, **block_kwargs):
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/training/training_loop.py b/training/training_loop.py
new file mode 100755
index 0000000000000000000000000000000000000000..2e3da0dd9234be2fa01461092a594ae72f0000d7
--- /dev/null
+++ b/training/training_loop.py
@@ -0,0 +1,487 @@
+import os
+import time
+import copy
+import json
+import pickle
+import psutil
+import PIL.Image
+import numpy as np
+import torch
+import dnnlib
+from torch_utils import misc
+from torch_utils import training_stats
+from torch_utils.ops import conv2d_gradfix
+from torch_utils.ops import grid_sample_gradfix
+
+import legacy
+import warnings
+warnings.filterwarnings("ignore")
+from colorama import init
+from colorama import Fore, Style
+from icecream import ic
+init(autoreset=True)
+from etaprogress.progress import ProgressBar
+import sys
+import matplotlib.pyplot as plt
+from evaluate import save_gen, create_folders
+
+from metrics.evaluation.data import PrecomputedInpaintingResultsDataset
+from metrics.evaluation.evaluator import InpaintingEvaluator
+from metrics.evaluation.losses.base_loss import FIDScore
+from metrics.evaluation.utils import load_yaml
+
+#----------------------------------------------------------------------------
+
+def setup_snapshot_image_grid(training_set, random_seed=0):
+ rnd = np.random.RandomState(random_seed)
+ gw = np.clip(5120 // training_set.image_shape[2], 0, 1)
+ gh = np.clip(5120 // training_set.image_shape[1], 10, 30)
+
+ # No labels => show random subset of training samples.
+ if not training_set.has_labels:
+ all_indices = list(range(len(training_set)))
+ rnd.shuffle(all_indices)
+ grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
+
+ else:
+ # Group training samples by label.
+ label_groups = dict() # label => [idx, ...]
+ for idx in range(len(training_set)):
+ label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
+ if label not in label_groups:
+ label_groups[label] = []
+ label_groups[label].append(idx)
+
+ # Reorder.
+ label_order = sorted(label_groups.keys())
+ for label in label_order:
+ rnd.shuffle(label_groups[label])
+
+ # Organize into grid.
+ grid_indices = []
+ for y in range(gh):
+ label = label_order[y % len(label_order)]
+ indices = label_groups[label]
+ grid_indices += [indices[x % len(indices)] for x in range(gw)]
+ label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
+
+ # Load data.
+ images, masks, labels = zip(*[training_set[i] for i in grid_indices])
+ return (gw, gh), np.stack(images), np.stack(masks), np.stack(labels)
+
+#----------------------------------------------------------------------------
+
+def save_image_grid(img, erased_img, inv_mask, pred_img, fname, drange, grid_size):
+ lo, hi = (0, 255)
+
+ model_lo, model_hi = drange
+
+ img = np.asarray(img, dtype=np.float32)
+ img = (img - lo) * (255 / (hi - lo))
+ img = np.rint(img).clip(0, 255).astype(np.uint8)
+
+ inv_mask = np.squeeze(np.stack([inv_mask]*3, axis=1))
+ inv_mask = np.asarray(inv_mask, dtype=np.float32)
+ inv_mask = np.rint(inv_mask).clip(0, 1).astype(np.uint8)
+
+ erased_img = np.asarray(erased_img, dtype=np.float32)
+ erased_img = (erased_img - lo) * (255 / (hi - lo))
+ erased_img = np.rint(erased_img).clip(0, 255).astype(np.uint8)
+
+ pred_img = np.asarray(pred_img, dtype=np.float32)
+ pred_img = (pred_img - model_lo) * (255 / (model_hi - model_lo))
+ pred_img = np.rint(pred_img).clip(0, 255).astype(np.uint8)
+
+ comp_img = img * (1 - inv_mask) + pred_img * inv_mask
+ f_img = np.concatenate((img, inv_mask * 255, erased_img, pred_img, comp_img), axis=1)
+
+ gw, gh = grid_size
+ gw *= f_img.shape[1] // 3
+ _N, C, H, W = img.shape
+ f_img = f_img.reshape(gh, gw, C, H, W)
+ f_img = f_img.transpose(0, 3, 1, 4, 2)
+ f_img = f_img.reshape(gh * H, gw * W, C)
+
+ assert C in [1, 3]
+ if C == 1:
+ PIL.Image.fromarray(f_img[:, :, 0], 'L').save(fname + '.png')
+ if C == 3:
+ PIL.Image.fromarray(f_img, 'RGB').save(fname + '.png')
+
+#----------------------------------------------------------------------------
+
+def training_loop(
+ run_dir = '.', # Output directory.
+ eval_img_data = None, # Evaluation Image data
+ resolution = 256, # Resolution of evaluation image
+ training_set_kwargs = {}, # Options for training set.
+ data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
+ G_kwargs = {}, # Options for generator network.
+ D_kwargs = {}, # Options for discriminator network.
+ G_opt_kwargs = {}, # Options for generator optimizer.
+ D_opt_kwargs = {}, # Options for discriminator optimizer.
+ augment_kwargs = None, # Options for augmentation pipeline. None = disable.
+ loss_kwargs = {}, # Options for loss function.
+ metrics = [], # Metrics to evaluate during training.
+ random_seed = 0, # Global random seed.
+ num_gpus = 1, # Number of GPUs participating in the training.
+ rank = 0, # Rank of the current process in [0, num_gpus[.
+ batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
+ batch_gpu = 4, # Number of samples processed at a time by one GPU.
+ ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
+ ema_rampup = None, # EMA ramp-up coefficient.
+ G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization.
+ D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
+ augment_p = 0, # Initial value of augmentation probability.
+ ada_target = None, # ADA target value. None = fixed p.
+ ada_interval = 4, # How often to perform ADA adjustment?
+ ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
+ total_kimg = 25000, # Total length of the training, measured in thousands of real images.
+ kimg_per_tick = 4, # Progress snapshot interval.
+ image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
+ network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
+ resume_pkl = None, # Network pickle to resume training from.
+ cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
+ allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
+ abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
+ progress_fn = None, # Callback function for updating training progress. Called for all ranks.
+):
+ # Initialize.
+ start_time = time.time()
+ device = torch.device('cuda', rank)
+ np.random.seed(random_seed * num_gpus + rank)
+ torch.manual_seed(random_seed * num_gpus + rank)
+ torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
+ torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
+ torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
+ conv2d_gradfix.enabled = True # Improves training speed.
+ grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
+
+ eval_config = load_yaml('metrics/configs/eval2_gpu.yaml')
+
+ # Load training set.
+ if rank == 0:
+ print(Fore.GREEN + 'Loading training set...')
+ training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
+ training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
+ training_loader = torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)
+
+ training_set_iterator = iter(training_loader)
+ if rank == 0:
+ print()
+ print(Fore.GREEN + 'Num images: ', len(training_set))
+ print(Fore.GREEN + 'Image shape:', training_set.image_shape)
+ print(Fore.GREEN + 'Label shape:', training_set.label_shape)
+ print()
+
+ # Construct networks.
+ if rank == 0:
+ print('Constructing networks...')
+ common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
+ G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+ D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Modul
+ G_ema = copy.deepcopy(G).eval()
+
+ # Resume from existing pickle.
+ if (resume_pkl is not None) and (rank == 0):
+ print(f'Resuming from "{resume_pkl}"')
+ with dnnlib.util.open_url(resume_pkl) as f:
+ resume_data = legacy.load_network_pkl(f)
+ for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
+ misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
+
+ # Print network parameters
+ if rank == 0:
+ netG_params = sum(p.numel() for p in G.parameters())
+ print(Fore.GREEN +"Generator Params: {} M".format(netG_params/1e6))
+
+ netD_params = sum(p.numel() for p in D.parameters())
+ print(Fore.GREEN +"Discriminator Params: {} M".format(netD_params/1e6))
+
+ # Setup augmentation.
+ if rank == 0:
+ print(Fore.YELLOW + 'Setting up augmentation...')
+ augment_pipe = None
+ ada_stats = None
+ if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
+ augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
+ augment_pipe.p.copy_(torch.as_tensor(augment_p))
+ if ada_target is not None:
+ ada_stats = training_stats.Collector(regex='Loss/signs/real')
+
+ # Distribute across GPUs.
+ if rank == 0:
+ print(Fore.CYAN + f'Distributing across {num_gpus} GPUs...')
+ ddp_modules = dict()
+ for name, module in [('G_encoder', G.encoder), ('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]:
+ if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
+ module.requires_grad_(True)
+ module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False, find_unused_parameters=True)
+ module.requires_grad_(False)
+ if name is not None:
+ ddp_modules[name] = module
+
+ # Setup training phases.
+ if rank == 0:
+ print('Setting up training phases...')
+ loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.losses.loss.Loss
+ phases = []
+ for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
+ if reg_interval is None:
+ opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
+ phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
+ else: # Lazy regularization.
+ mb_ratio = reg_interval / (reg_interval + 1)
+ opt_kwargs = dnnlib.EasyDict(opt_kwargs)
+ opt_kwargs.lr = opt_kwargs.lr * mb_ratio
+ opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
+ opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
+ phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
+ phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
+ for phase in phases:
+ phase.start_event = None
+ phase.end_event = None
+ if rank == 0:
+ phase.start_event = torch.cuda.Event(enable_timing=True)
+ phase.end_event = torch.cuda.Event(enable_timing=True)
+
+ # Export sample images.
+ grid_size = None
+ grid_c = None
+ if rank == 0:
+ print('Exporting sample images...')
+ grid_size, images, masks, labels = setup_snapshot_image_grid(training_set=training_set)
+ erased_images = images * (1 - masks)
+ grid_img = (torch.from_numpy(images).to(torch.float32) / 127.5 - 1).to(device)
+ grid_mask = torch.from_numpy(masks).to(torch.float32).to(device)
+ grid_erased_img = grid_img * (1 - grid_mask)
+ grid_img = grid_img.split(batch_gpu)
+ grid_mask = grid_mask.split(batch_gpu)
+ grid_erased_img = grid_erased_img.split(batch_gpu)
+ grid_c = torch.from_numpy(labels).to(torch.float32).to(device).split(batch_gpu)
+ pred_images = torch.cat([G_ema(img=torch.cat([0.5 - mask, erased_img], dim=1), c=c, noise_mode='const').cpu() for erased_img, mask, c in zip(grid_erased_img, grid_mask, grid_c)])
+ save_image_grid(images, erased_images, masks, pred_images.detach().numpy(), os.path.join(run_dir, 'run_init'), drange=[-1,1], grid_size=grid_size)
+
+ # Initialize logs.
+ if rank == 0:
+ print('Initializing logs...')
+ stats_collector = training_stats.Collector(regex='.*')
+ stats_metrics = dict()
+ stats_jsonl = None
+ stats_tfevents = None
+ if rank == 0:
+ stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
+ try:
+ import torch.utils.tensorboard as tensorboard
+ stats_tfevents = tensorboard.SummaryWriter(run_dir)
+ except ImportError as err:
+ print('Skipping tfevents export:', err)
+
+ # Train.
+ if rank == 0:
+ print(Fore.GREEN + Style.BRIGHT + f'Training for {total_kimg} kimg...')
+ print()
+ total = total_kimg * 1000
+ bar = ProgressBar(total, max_width=80)
+
+ cur_nimg = 0
+ cur_tick = 0
+ tick_start_nimg = cur_nimg
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - start_time
+ batch_idx = 0
+ if progress_fn is not None:
+ progress_fn(0, total_kimg)
+
+ while True:
+ # Fetch training data.
+ with torch.autograd.profiler.record_function('data_fetch'):
+ phase_real_imgs, phase_masks, phase_real_cs = next(training_set_iterator)
+ # phase_erased_img = ((phase_real_imgs * (1 - phase_masks)).to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
+ phase_real_img = (phase_real_imgs.to(device).to(torch.float32) / 127.5 - 1)
+ phase_inv_mask = (phase_masks.to(device).to(torch.float32))
+ phase_erased_img = phase_real_img * (1 - phase_inv_mask)
+ phase_erased_img = phase_erased_img.split(batch_gpu)
+ phase_real_img = phase_real_img.split(batch_gpu)
+ phase_inv_mask = phase_inv_mask.split(batch_gpu)
+ phase_real_c = phase_real_cs.to(device).split(batch_gpu)
+ all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
+ all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
+ all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
+
+ # Execute training phases.
+ for phase, phase_gen_c in zip(phases, all_gen_c):
+ if batch_idx % phase.interval != 0:
+ continue
+
+ # Initialize gradient accumulation.
+ if phase.start_event is not None:
+ phase.start_event.record(torch.cuda.current_stream(device))
+ phase.opt.zero_grad(set_to_none=True)
+ phase.module.requires_grad_(True)
+
+ # Accumulate gradients over multiple rounds.
+ for round_idx, (erased_img, real_img, mask, real_c, gen_c) in enumerate(zip(phase_erased_img, phase_real_img, phase_inv_mask, phase_real_c, phase_gen_c)):
+ sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
+ gain = phase.interval
+ loss.accumulate_gradients(phase=phase.name, erased_img=erased_img, real_img=real_img, mask=mask, real_c=real_c, gen_c=gen_c, sync=sync, gain=gain)
+
+ # Update weights.
+ phase.module.requires_grad_(False)
+ with torch.autograd.profiler.record_function(phase.name + '_opt'):
+ for param in phase.module.parameters():
+ if param.grad is not None:
+ misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
+ phase.opt.step()
+ if phase.end_event is not None:
+ phase.end_event.record(torch.cuda.current_stream(device))
+
+ # Update G_ema.
+ with torch.autograd.profiler.record_function('Gema'):
+ ema_nimg = ema_kimg * 1000
+ if ema_rampup is not None:
+ ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
+ ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
+ for p_ema, p in zip(G_ema.parameters(), G.parameters()):
+ p_ema.copy_(p.lerp(p_ema, ema_beta))
+ for b_ema, b in zip(G_ema.buffers(), G.buffers()):
+ b_ema.copy_(b)
+
+ # Update state.
+ cur_nimg += batch_size
+ batch_idx += 1
+
+ if rank == 0:
+ bar.numerator = cur_nimg
+ print(bar, end='\r')
+
+ # Execute ADA heuristic.
+ if (ada_stats is not None) and (batch_idx % ada_interval == 0):
+ ada_stats.update()
+ adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
+ augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))
+
+ # Perform maintenance tasks once per tick.
+ done = (cur_nimg >= total_kimg * 1000)
+ if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
+ continue
+
+ # Print status line, accumulating the same information in stats_collector.
+ tick_end_time = time.time()
+ fields = []
+ fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
+ fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
+ fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
+ fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
+ fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
+ fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
+ fields += [f"cpumem GB {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
+ fields += [f"gpumem GB {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
+ torch.cuda.reset_peak_memory_stats()
+ fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.4f}"]
+ training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
+ training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
+ if rank == 0:
+ print(Fore.CYAN + Style.BRIGHT + ' '.join(fields))
+
+ # Check for abort.
+ if (not done) and (abort_fn is not None) and abort_fn():
+ done = True
+ if rank == 0:
+ print()
+ print(Fore.RED + 'Aborting...')
+
+ # Save network snapshot.
+ snapshot_pkl = None
+ snapshot_data = None
+ if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0) and cur_tick is not 0:
+ snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
+ for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
+ if module is not None:
+ if num_gpus > 1:
+ misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
+ module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
+ snapshot_data[name] = module
+ del module # conserve memory
+ snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
+ if rank == 0:
+ with open(snapshot_pkl, 'wb') as f:
+ pickle.dump(snapshot_data, f)
+
+
+ if (snapshot_data is not None) and metrics and (done or cur_tick % network_snapshot_ticks == 0) and cur_tick is not 0:
+ msk_type = eval_img_data.split('/')[-1]
+ if rank == 0:
+ create_folders(msk_type)
+ label = torch.zeros([1, snapshot_data['G_ema'].c_dim]).to(device)
+ save_gen(snapshot_data['G_ema'], rank, num_gpus, device, eval_img_data, resolution, label, 1, msk_type)
+ if rank == 0:
+ eval_dataset = PrecomputedInpaintingResultsDataset(eval_img_data, f'fid_gens/{msk_type}', **eval_config.dataset_kwargs)
+ metrics = {
+ 'fid': FIDScore()
+ }
+ evaluator = InpaintingEvaluator(eval_dataset, scores=metrics, area_grouping=False,
+ integral_title='lpips_fid100_f1', integral_func=None,
+ **eval_config.evaluator_kwargs)
+ results = evaluator.dist_evaluate(device, num_gpus=1, rank=0)
+ fid_score = round(results[('fid', 'total')]['mean'], 5)
+ stats_metrics.update({'fid': fid_score})
+ print(Fore.GREEN + Style.BRIGHT + f' FID Score: {fid_score}')
+
+ del snapshot_data # conserve memory
+
+ # Save image snapshot.
+ if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
+ pred_images = torch.cat([G_ema(img=torch.cat([0.5 - mask, erased_img], dim=1), c=c, noise_mode='const').cpu() for erased_img, mask, c in zip(grid_erased_img, grid_mask, grid_c)])
+ save_image_grid(images, erased_images, masks, pred_images.detach().numpy(), os.path.join(run_dir, f'run_{cur_nimg//1000:06d}'), drange=[-1,1], grid_size=grid_size)
+
+ # Collect statistics.
+ for phase in phases:
+ value = []
+ if (phase.start_event is not None) and (phase.end_event is not None):
+ phase.end_event.synchronize()
+ value = phase.start_event.elapsed_time(phase.end_event)
+ training_stats.report0('Timing/' + phase.name, value)
+ stats_collector.update()
+ stats_dict = stats_collector.as_dict()
+
+ if rank == 0:
+ losses = []
+ for key in stats_dict.keys():
+ if 'Loss/D' in key or 'Loss/G' in key:
+ losses += [f"{key}: {(stats_dict[key]['mean']):<.4f}"]
+ print(Fore.MAGENTA + Style.BRIGHT + ' '.join(losses))
+
+ # Update logs.
+ timestamp = time.time()
+ if stats_jsonl is not None:
+ fields = dict(stats_dict, timestamp=timestamp)
+ stats_jsonl.write(json.dumps(fields) + '\n')
+ stats_jsonl.flush()
+ if stats_tfevents is not None:
+ global_step = int(cur_nimg / 1e3)
+ walltime = timestamp - start_time
+ for name, value in stats_dict.items():
+ stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
+ for name, value in stats_metrics.items():
+ stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
+ stats_tfevents.flush()
+ if progress_fn is not None:
+ progress_fn(cur_nimg // 1000, total_kimg)
+
+ # Update state.
+ cur_tick += 1
+ tick_start_nimg = cur_nimg
+ tick_start_time = time.time()
+ maintenance_time = tick_start_time - tick_end_time
+ if rank == 0:
+ sys.stdout.flush()
+ if done:
+ break
+
+ # Done.
+ if rank == 0:
+ print()
+ print(Fore.YELLOW + 'Exiting...')
+
+#----------------------------------------------------------------------------