diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..82a820c091f959748aa23460cce8a357b14c2217 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+stylegan3-fun/docs/stylegan3-teaser-1920x1006.png filter=lfs diff=lfs merge=lfs -text
+stylegan3-fun/docs/visualizer_screen0.png filter=lfs diff=lfs merge=lfs -text
+stylegan3-fun/out/seed0002.png filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..71d335bf262d32038b81369e1a84a0c734d75760
--- /dev/null
+++ b/app.py
@@ -0,0 +1,87 @@
+import os
+import sys
+import re
+from typing import List, Optional, Tuple, Union
+import random
+
+sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located
+
+import numpy as np
+import PIL.Image
+import torch
+import streamlit as st
+import dnnlib
+import legacy
+
+
+def parse_range(s: Union[str, List]) -> List[int]:
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
+
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
+ '''
+ if isinstance(s, list): return s
+ ranges = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for p in s.split(','):
+ m = range_re.match(p)
+ if m:
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
+ else:
+ ranges.append(int(p))
+ return ranges
+
+def make_transform(translate: Tuple[float,float], angle: float):
+ m = np.eye(3)
+ s = np.sin(angle/360.0*np.pi*2)
+ c = np.cos(angle/360.0*np.pi*2)
+ m[0][0] = c
+ m[0][1] = s
+ m[0][2] = translate[0]
+ m[1][0] = -s
+ m[1][1] = c
+ m[1][2] = translate[1]
+ return m
+
+def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]):
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with open(network_pkl, 'rb') as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ # Labels.
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ if class_idx is None:
+ raise Exception('Must specify class label when using a conditional network')
+ label[:, class_idx] = 1
+
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
+
+ if hasattr(G.synthesis, 'input'):
+ m = make_transform(translate, rotate)
+ m = np.linalg.inv(m)
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
+
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
+ return img
+
+def main():
+ st.title('Kpop Face Generator')
+
+ st.write('Press the button below to generate a new image:')
+ if st.button('Generate'):
+ network_pkl = 'kpopGG.pkl'
+ seed = random.randint(0, 99999)
+ truncation_psi = 0.45
+ noise_mode = 'const'
+ translate = (0.0, 0.0)
+ rotate = 0.0
+ class_idx = None
+
+ image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx)
+ st.image(image)
+
+if __name__ == "__main__":
+ main()
diff --git a/stylegan3-fun/.github/FUNDING.yml b/stylegan3-fun/.github/FUNDING.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bc60ae164bd42b907539fc10e7cd36a36b889c2e
--- /dev/null
+++ b/stylegan3-fun/.github/FUNDING.yml
@@ -0,0 +1,3 @@
+# These are supported funding model platforms
+
+github: PDillis # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
diff --git a/stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md b/stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000000000000000000000000000000000000..de2966434faf1502d9e3ac93b79c7cf0b646fc7c
--- /dev/null
+++ b/stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,35 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**Describe the bug**
+A clear and concise description of what the bug is.
+
+**To Reproduce**
+Steps to reproduce the behavior:
+1. In '...' directory, run command '...'
+2. See error (copy&paste full log, including exceptions and **stacktraces**).
+
+Please copy&paste text instead of screenshots for better searchability.
+
+**Expected behavior**
+A clear and concise description of what you expected to happen.
+
+**Screenshots**
+If applicable, add screenshots to help explain your problem.
+
+**Desktop (please complete the following information):**
+ - OS: [e.g. Linux Ubuntu 20.04, Windows 10]
+ - PyTorch version (e.g., pytorch 1.9.0)
+ - CUDA toolkit version (e.g., CUDA 11.4)
+ - NVIDIA driver version
+ - GPU [e.g., Titan V, RTX 3090]
+ - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:21.08-py3)
+
+**Additional context**
+Add any other context about the problem here.
diff --git a/stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md b/stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..bbcbbe7d61558adde3cbfd0c7a63a67c27ed6d30
--- /dev/null
+++ b/stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,20 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**Is your feature request related to a problem? Please describe.**
+A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
+
+**Describe the solution you'd like**
+A clear and concise description of what you want to happen.
+
+**Describe alternatives you've considered**
+A clear and concise description of any alternative solutions or features you've considered.
+
+**Additional context**
+Add any other context or screenshots about the feature request here.
diff --git a/stylegan3-fun/.gitignore b/stylegan3-fun/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e9a26b31ed5cdb914211fbbb6c61a5a1b3a241f6
--- /dev/null
+++ b/stylegan3-fun/.gitignore
@@ -0,0 +1,173 @@
+# Byte-compiled / optimized / DLL files
+*/**/__pycache__/
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+
+# Conda temp
+.condatmp/
+
+# SGAN specific folders
+datasets/
+dlatents/
+out/
+training-runs/
+pretrained/
+video/
+_screenshots/
\ No newline at end of file
diff --git a/stylegan3-fun/Dockerfile b/stylegan3-fun/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..b00b17dfa711867f685be2004ee1905042a8cbd1
--- /dev/null
+++ b/stylegan3-fun/Dockerfile
@@ -0,0 +1,19 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+FROM nvcr.io/nvidia/pytorch:21.08-py3
+
+ENV PYTHONDONTWRITEBYTECODE 1
+ENV PYTHONUNBUFFERED 1
+
+RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
+
+WORKDIR /workspace
+
+RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
+ENTRYPOINT ["/entry.sh"]
diff --git a/stylegan3-fun/LICENSE.txt b/stylegan3-fun/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6b5ee9bf994cc9441cb659c3527160b4ee5bcb33
--- /dev/null
+++ b/stylegan3-fun/LICENSE.txt
@@ -0,0 +1,97 @@
+Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
+
+
+NVIDIA Source Code License for StyleGAN3
+
+
+=======================================================================
+
+1. Definitions
+
+"Licensor" means any person or entity that distributes its Work.
+
+"Software" means the original work of authorship made available under
+this License.
+
+"Work" means the Software and any additions to or derivative works of
+the Software that are made available under this License.
+
+The terms "reproduce," "reproduction," "derivative works," and
+"distribution" have the meaning as provided under U.S. copyright law;
+provided, however, that for the purposes of this License, derivative
+works shall not include works that remain separable from, or merely
+link (or bind by name) to the interfaces of, the Work.
+
+Works, including the Software, are "made available" under this License
+by including in or with the Work either (a) a copyright notice
+referencing the applicability of this License to the Work, or (b) a
+copy of this License.
+
+2. License Grants
+
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
+ License, each Licensor grants to you a perpetual, worldwide,
+ non-exclusive, royalty-free, copyright license to reproduce,
+ prepare derivative works of, publicly display, publicly perform,
+ sublicense and distribute its Work and any resulting derivative
+ works in any form.
+
+3. Limitations
+
+ 3.1 Redistribution. You may reproduce or distribute the Work only
+ if (a) you do so under this License, (b) you include a complete
+ copy of this License with your distribution, and (c) you retain
+ without modification any copyright, patent, trademark, or
+ attribution notices that are present in the Work.
+
+ 3.2 Derivative Works. You may specify that additional or different
+ terms apply to the use, reproduction, and distribution of your
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
+ provide that the use limitation in Section 3.3 applies to your
+ derivative works, and (b) you identify the specific derivative
+ works that are subject to Your Terms. Notwithstanding Your Terms,
+ this License (including the redistribution requirements in Section
+ 3.1) will continue to apply to the Work itself.
+
+ 3.3 Use Limitation. The Work and any derivative works thereof only
+ may be used or intended for use non-commercially. Notwithstanding
+ the foregoing, NVIDIA and its affiliates may use the Work and any
+ derivative works commercially. As used herein, "non-commercially"
+ means for research or evaluation purposes only.
+
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+ against any Licensor (including any claim, cross-claim or
+ counterclaim in a lawsuit) to enforce any patents that you allege
+ are infringed by any Work, then your rights under this License from
+ such Licensor (including the grant in Section 2.1) will terminate
+ immediately.
+
+ 3.5 Trademarks. This License does not grant any rights to use any
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
+ as necessary to reproduce the notices described in this License.
+
+ 3.6 Termination. If you violate any term of this License, then your
+ rights under this License (including the grant in Section 2.1) will
+ terminate immediately.
+
+4. Disclaimer of Warranty.
+
+THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+THIS LICENSE.
+
+5. Limitation of Liability.
+
+EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+THE POSSIBILITY OF SUCH DAMAGES.
+
+=======================================================================
diff --git a/stylegan3-fun/README.md b/stylegan3-fun/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8a4a2789052b6469e9fd939e0a974797810062cb
--- /dev/null
+++ b/stylegan3-fun/README.md
@@ -0,0 +1,507 @@
+# StyleGAN3-Fun
Let's have fun with StyleGAN2/ADA/3!
+
+SOTA GANs are hard to train and to explore, and StyleGAN2/ADA/3 are no different. The point of this repository is to allow
+the user to both easily train and explore the trained models without unnecessary headaches.
+
+As [before](https://github.com/PDillis/stylegan2-fun), we will build upon the official repository, which has the advantage
+of being backwards-compatible. As such, we can use our previously-trained models from StyleGAN2 and StyleGAN2-ADA. Please
+get acquainted with the official repository and its codebase, as we will be building upon it and as such, increase its
+capabilities (but hopefully not its complexity!).
+
+## Additions
+
+This repository adds/has the following changes (not yet the complete list):
+
+* ***Dataset Setup (`dataset_tool.py`)***
+ * **RGBA support**, so revert saving images to `.png` ([Issue #156](https://github.com/NVlabs/stylegan3/issues/156) by @1378dm). Training can use RGBA and images can be generated.
+ * ***TODO:*** ~~Check that training code is correct for normalizing the alpha channel~~, as well as making the
+ interpolation code work with this new format (look into [`moviepy.editor.VideoClip`](https://zulko.github.io/moviepy/getting_started/videoclips.html?highlight=mask#mask-clips))
+ * For now, interpolation videos will only be saved in RGB format, e.g., discarding the alpha channel.
+ * **`--center-crop-tall`**: add vertical black bars to the sides of each image in the dataset (rectangular images, with height > width),
+ and you wish to train a square model, in the same vein as the horizontal bars added when using `--center-crop-wide` (where width > height).
+ * This is useful when you don't want to lose information from the left and right side of the image by only using the center
+ crop (ibidem for `--center-crop-wide`, but for the top and bottom of the image)
+ * Note that each image doesn't have to be of the same size, and the added bars will only ensure you get a square image, which will then be
+ resized to the model's desired resolution (set by `--resolution`).
+ * Grayscale images in the dataset are converted to `RGB`
+ * If you want to turn this off, remove the respective line in `dataset_tool.py`, e.g., if your dataset is made of images in a folder, then the function to be used is
+ `open_image_folder` in `dataset_tool.py`, and the line to be removed is `img = img.convert('RGB')` in the `iterate_images` inner function.
+ * The dataset can be forced to be of a specific number of channels, that is, grayscale, RGB or RGBA.
+ * To use this, set `--force-channels=1` for grayscale, `--force-channels=3` for RGB, and `--force-channels=4` for RGBA.
+ * If the dataset tool encounters an error, print it along the offending image, but continue with the rest of the dataset
+ ([PR #39](https://github.com/NVlabs/stylegan3/pull/39) from [Andreas Jansson](https://github.com/andreasjansson)).
+ * For conditional models, we can use the subdirectories as the classes by adding `--subfolders-as-labels`. This will
+ generate the `dataset.json` file automatically as done by @pbaylies [here](https://github.com/pbaylies/stylegan2-ada/blob/a8f0b1c891312631f870c94f996bcd65e0b8aeef/dataset_tool.py#L772)
+ * Additionally, in the `--source` folder, we will save a `class_labels.txt` file, to further know which classes correspond to each subdirectory.
+
+* ***Training***
+ * Add `--cfg=stylegan2-ext`, which uses @aydao's extended modifications for handling large and diverse datasets.
+ * A good explanation is found in Gwern's blog [here](https://gwern.net/face#extended-stylegan2-danbooru2019-aydao)
+ * If you wish to fine-tune from @aydao's Anime model, use `--cfg=stylegan2-ext --resume=anime512` when running `train.py`
+ * Note: ***This is an extremely experimental configuration!*** The `.pkl` files will be ~1.1Gb each and training will slow down
+ significantly. Use at your own risk!
+ * `--blur-percent`: Blur both real and generated images before passing them to the Discriminator.
+ * The blur (`blur_init_sigma=10.0`) will completely fade after the selected percentage of the training is completed (using a linear ramp).
+ * Another experimental feature, should help with datasets that have a lot of variation, and you wish the model to slowly
+ learn to generate the objects and then its details.
+ * `--mirrory`: Added vertical mirroring for doubling the dataset size (quadrupling if `--mirror` is used; make sure your dataset has either or both
+ of these symmetries in order for it to make sense to use them)
+ * `--gamma`: If no R1 regularization is provided, the heuristic formula from [StyleGAN](https://github.com/NVlabs/stylegan2) will be used.
+ * Specifically, we will set `gamma=0.0002 * resolution ** 2 / batch_size`
+ * `--aug`: ***TODO:*** add [Deceive-D/APA](https://github.com/EndlessSora/DeceiveD) as an option.
+ * `--augpipe`: Now available to use is [StyleGAN2-ADA's](https://github.com/NVlabs/stylegan2-ada-pytorch) full list of augpipe, i.e., individual augmentations (`blit`, `geom`, `color`, `filter`, `noise`, `cutout`) or their combinations (`bg`, `bgc`, `bgcf`, `bgcfn`, `bgcfnc`).
+ * `--img-snap`: Set when to save snapshot images, so now it's independent of when the model is saved (e.g., save image snapshots more often to know how the model is training without saving the model itself, to save space).
+ * `--snap-res`: The resolution of the snapshots, depending on how many images you wish to see per snapshot. Available resolutions: `1080p`, `4k`, and `8k`.
+ * `--resume-kimg`: Starting number of `kimg`, useful when continuing training a previous run
+ * `--outdir`: Automatically set as `training-runs`, so no need to set beforehand (in general this is true throughout the repository)
+ * `--metrics`: Now set by default to `None`, so there's no need to worry about this one
+ * `--freezeD`: Renamed `--freezed` for better readability
+ * `--freezeM`: Freeze the first layers of the Mapping Network Gm (`G.mapping`)
+ * `--freezeE`: Freeze the embedding layer of the Generator (for class-conditional models)
+ * `--freezeG`: ***TODO:*** Freeze the first layers of the Synthesis Network (`G.synthesis`; less cost to transfer learn, focus on high layers?)
+ * `--resume`: All available pre-trained models from NVIDIA (and more) can be used with a simple dictionary, depending on the `--cfg` used.
+ For example, if you wish to use StyleGAN3's `config-r`, then set `--cfg=stylegan3-r`. In addition, if you wish to transfer learn from FFHQU at 1024 resolution, set `--resume=ffhqu1024`.
+ * The full list of currently available models to transfer learn from (or synthesize new images with) is the following (***TODO:*** add small description of each model,
+ so the user can better know which to use for their particular use-case; proper citation to original authors as well):
+
+
+ StyleGAN2 models
+
+ 1. Majority, if not all, are `config-f`: set `--cfg=stylegan2`
+ * `ffhq256`
+ * `ffhqu256`
+ * `ffhq512`
+ * `ffhq1024`
+ * `ffhqu1024`
+ * `celebahq256`
+ * `lsundog256`
+ * `afhqcat512`
+ * `afhqdog512`
+ * `afhqwild512`
+ * `afhq512`
+ * `brecahad512`
+ * `cifar10` (conditional, 10 classes)
+ * `metfaces1024`
+ * `metfacesu1024`
+ * `lsuncar512` (config-f)
+ * `lsuncat256` (config-f)
+ * `lsunchurch256` (config-f)
+ * `lsunhorse256` (config-f)
+ * `minecraft1024` (thanks to @jeffheaton)
+ * `imagenet512` (thanks to @shawwn)
+ * `wikiart1024-C` (conditional, 167 classes; thanks to @pbaylies)
+ * `wikiart1024-U` (thanks to @pbaylies)
+ * `maps1024` (thanks to @tjukanov)
+ * `fursona512` (thanks to @arfafax)
+ * `mlpony512` (thanks to @arfafax)
+ * `lhq1024` (thanks to @justinpinkney)
+ * `afhqcat256` (Deceive-D/APA models)
+ * `anime256` (Deceive-D/APA models)
+ * `cub256` (Deceive-D/APA models)
+ * `sddogs1024` (Self-Distilled StyleGAN models)
+ * `sdelephant512` (Self-Distilled StyleGAN models)
+ * `sdhorses512` (Self-Distilled StyleGAN models)
+ * `sdbicycles256` (Self-Distilled StyleGAN models)
+ * `sdlions512` (Self-Distilled StyleGAN models)
+ * `sdgiraffes512` (Self-Distilled StyleGAN models)
+ * `sdparrots512` (Self-Distilled StyleGAN models)
+ 2. Extended StyleGAN2 config from @aydao: set `--cfg=stylegan2-ext`
+ * `anime512` (thanks to @aydao; writeup by @gwern: https://gwern.net/Faces#extended-stylegan2-danbooru2019-aydao)
+
+
+
+ StyleGAN3 models
+
+ 1. `config-t`: set `--cfg=stylegan3-t`
+ * `afhq512`
+ * `ffhqu256`
+ * `ffhq1024`
+ * `ffhqu1024`
+ * `metfaces1024`
+ * `metfacesu1024`
+ * `landscapes256` (thanks to @justinpinkney)
+ * `wikiart1024` (thanks to @justinpinkney)
+ * `mechfuture256` (thanks to @edstoica; 29 kimg tick)
+ * `vivflowers256` (thanks to @edstoica; 68 kimg tick)
+ * `alienglass256` (thanks to @edstoica; 38 kimg tick)
+ * `scificity256` (thanks to @edstoica; 210 kimg tick)
+ * `scifiship256` (thanks to @edstoica; 168 kimg tick)
+ 2. `config-r`: set `--cfg=stylegan3-r`
+ * `afhq512`
+ * `ffhq1024`
+ * `ffhqu1024`
+ * `ffhqu256`
+ * `metfaces1024`
+ * `metfacesu1024`
+
+
+ * The main sources of these pretrained models are both the [official NVIDIA repository](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/stylegan3),
+ as well as other community repositories, such as [Justin Pinkney](https://github.com/justinpinkney) 's [Awesome Pretrained StyleGAN2](https://github.com/justinpinkney/awesome-pretrained-stylegan2)
+ and [Awesome Pretrained StyleGAN3](https://github.com/justinpinkney/awesome-pretrained-stylegan3), [Deceive-D/APA](https://github.com/EndlessSora/DeceiveD),
+ [Self-Distilled StyleGAN/Internet Photos](https://github.com/self-distilled-stylegan/self-distilled-internet-photos), and [edstoica](https://github.com/edstoica) 's
+ [Wombo Dream](https://www.wombo.art/) [-based models](https://github.com/edstoica/lucid_stylegan3_datasets_models). Others can be found around the net and are properly credited in this repository,
+ so long as they can be easily downloaded with [`dnnlib.util.open_url`](https://github.com/PDillis/stylegan3-fun/blob/4ce9d6f7601641ba1e2906ed97f2739a63fb96e2/dnnlib/util.py#L396).
+
+* ***Interpolation videos***
+ * [Random interpolation](https://youtu.be/DNfocO1IOUE)
+ * [Generate images/interpolations with the internal representations of the model](https://nvlabs-fi-cdn.nvidia.com/_web/stylegan3/videos/video_8_internal_activations.mp4)
+ * Usage: Add `--layer=` to specify which layer to use for interpolation.
+ * If you don't know the names of the layers available for your model, add the flag `--available-layers` and the
+ layers will be printed to the console, along their names, number of channels, and sizes.
+ * Use one of `--grayscale` or `--rgb` to specify whether to save the images as grayscale or RGB during the interpolation.
+ * For `--rgb`, three consecutive channels (starting at `--starting-channel=0`) will be used to create the RGB image. For `--grayscale`, only the first channel will be used.
+ * Style-mixing
+ * [Sightseeding](https://twitter.com/PDillis/status/1270341433401249793?s=20&t=yLueNkagqsidZFqZ2jNPAw) (jumpiness has been fixed)
+ * [Circular interpolation](https://youtu.be/4nktYGjSVHg)
+ * [Visual-reactive interpolation](https://youtu.be/KoEAkPnE-zA) (Beta)
+ * Audiovisual-reactive interpolation (TODO)
+ * ***TODO:*** Give support to RGBA models!
+* ***Projection into the latent space***
+ * [Project into $\mathcal{W}+$](https://arxiv.org/abs/1904.03189)
+ * Additional losses to use for better projection (e.g., using VGG16 or [CLIP](https://github.com/openai/CLIP))
+* ***[Discriminator Synthesis](https://arxiv.org/abs/2111.02175)*** (official code)
+ * Generate a static image (`python discriminator_synthesis.py dream --help`) or a [video](https://youtu.be/hEJKWL2VQTE) with a feedback loop (`python discriminator_synthesis.py dream-zoom --help`,
+ `python discriminator_synthesis.py channel-zoom --help`, or `python discriminator_synthesis.py interp --help`)
+ * Start from a random image (`random` for noise or `perlin` for 2D fractal Perlin noise, using
+ [Mathieu Duchesneau's implementation](https://github.com/duchesneaumathieu/pyperlin)) or from an existing one
+* ***Expansion on GUI/`visualizer.py`***
+ * Added the rest of the affine transformations
+ * Added widget for class-conditional models (***TODO:*** mix classes with continuous values for `cls`!)
+* ***General model and code additions***
+ * [Multi-modal truncation trick](https://arxiv.org/abs/2202.12211): find the different clusters in your model and use the closest one to your dlatent, in order to increase the fidelity
+ * Usage: Run `python multimodal_truncation.py get-centroids --network=` to use default values; for extra options, run `python multimodal_truncation.py get-centroids --help`
+ * StyleGAN3: anchor the latent space for easier to follow interpolations (thanks to [Rivers Have Wings](https://github.com/crowsonkb) and [nshepperd](https://github.com/nshepperd)).
+ * Use CPU instead of GPU if desired (not recommended, but perfectly fine for generating images, whenever the custom CUDA kernels fail to compile).
+ * Add missing dependencies and channels so that the [`conda`](https://docs.conda.io/en/latest/) environment is correctly setup in Windows
+ (PR's [#111](https://github.com/NVlabs/stylegan3/pull/111)/[#125](https://github.com/NVlabs/stylegan3/pull/125) and [#80](https://github.com/NVlabs/stylegan3/pull/80) /[#143](https://github.com/NVlabs/stylegan3/pull/143) from the base, respectively)
+ * Use [StyleGAN-NADA](https://github.com/rinongal/StyleGAN-nada) models with any part of the code (Issue [#9](https://github.com/PDillis/stylegan3-fun/issues/9))
+ * The StyleGAN-NADA models must first be converted via [Vadim Epstein](https://github.com/eps696) 's conversion code found [here](https://github.com/eps696/stylegan2ada#tweaking-models).
+ * Add PR [#173](https://github.com/NVlabs/stylegan3/pull/173) for adding the last remaining unknown kwarg for using StyleGAN2 models using TF 1.15.
+* ***TODO*** list (this is a long one with more to come, so any help is appreciated):
+ * Add `--device={cuda, ref}` option to pass to each of the custom operations in order to (theoretically) be able to use AMD GPUs, as explained in
+ @l4rz's post [here](https://github.com/l4rz/practical-aspects-of-stylegan2-training#hardware)
+ * Define a [custom Generator](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/59e05bb115c1c7d0de56be0523754076c2b7ee83/legacy.py#L131) in `legacy.py` to modify the output size
+ * Related: the [multi-latent](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/main/training/stylegan2_multi.py), i.e., the one from [@eps696](https://github.com/eps696/stylegan2)
+ * Add [Top-K training](https://arxiv.org/abs/2002.06224) as done [here](https://github.com/dvschultz/stylegan2-ada/blob/8f4ab24f494483542d31bf10f4fdb0005dc62739/train.py#L272) and [here](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/59e05bb115c1c7d0de56be0523754076c2b7ee83/training/loss.py#L79)
+ * Add panorama/SinGAN/feature interpolation from [StyleGAN of All Trades](https://arxiv.org/abs/2111.01619)
+ * [PTI](https://github.com/danielroich/PTI) for better inversion
+ * [Better sampling](https://arxiv.org/abs/2110.08009)
+ * [Add cross-model interpolation](https://twitter.com/arfafax/status/1297681537337446400?s=20&t=xspnTaLFTvd7y4krg8tkxA)
+ * Blend different models (average checkpoints, copy weights, create initial network), as in @aydao's [StyleGAN2-Surgery](https://github.com/aydao/stylegan2-surgery)
+ * Add multi-crop for the dataset creation, as used in [Earth View](https://github.com/PDillis/earthview#multi-crop---data_augmentpy).
+ * Make it easy to download pretrained models from Drive, otherwise a lot of models can't be used with `dnnlib.util.open_url`
+ (e.g., [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human) models)
+ * Finish documentation for better user experience, add videos/images, code samples, visuals...
+ * Add [Ensembling Off-the-shelf Models for GAN Training](https://arxiv.org/abs/2112.09130) and [Any-resolution Training for High-resolution Image Synthesis](https://chail.github.io/anyres-gan/)
+
+## Notebooks (Coming soon!)
+
+## Sponsors ![GitHub Sponsor](https://img.shields.io/github/sponsors/PDillis?label=Sponsor&logo=GitHub)
+
+This repository has been sponsored by:
+
+[isosceles](https://www.jasonfletcher.info/vjloops/)
+
+Thank you so much!
+
+If you wish to sponsor me, click here: [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/PDillis)
+
+
+---
+
+## Alias-Free Generative Adversarial Networks (StyleGAN3)
Official PyTorch implementation of the NeurIPS 2021 paper
+
+![Teaser image](./docs/stylegan3-teaser-1920x1006.png)
+
+**Alias-Free Generative Adversarial Networks**
+Tero Karras, Miika Aittala, Samuli Laine, Erik Härkönen, Janne Hellsten, Jaakko Lehtinen, Timo Aila
+https://nvlabs.github.io/stylegan3
+
+Abstract: *We observe that despite their hierarchical convolutional nature, the synthesis process of typical generative adversarial networks depends on absolute pixel coordinates in an unhealthy manner. This manifests itself as, e.g., detail appearing to be glued to image coordinates instead of the surfaces of depicted objects. We trace the root cause to careless signal processing that causes aliasing in the generator network. Interpreting all signals in the network as continuous, we derive generally applicable, small architectural changes that guarantee that unwanted information cannot leak into the hierarchical synthesis process. The resulting networks match the FID of StyleGAN2 but differ dramatically in their internal representations, and they are fully equivariant to translation and rotation even at subpixel scales. Our results pave the way for generative models better suited for video and animation.*
+
+For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
+
+## Release notes
+
+This repository is an updated version of [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch), with several new features:
+- Alias-free generator architecture and training configurations (`stylegan3-t`, `stylegan3-r`).
+- Tools for interactive visualization (`visualizer.py`), spectral analysis (`avg_spectra.py`), and video generation (`gen_video.py`).
+- Equivariance metrics (`eqt50k_int`, `eqt50k_frac`, `eqr50k`).
+- General improvements: reduced memory usage, slightly faster training, bug fixes.
+
+Compatibility:
+- Compatible with old network pickles created using [stylegan2-ada](https://github.com/NVlabs/stylegan2-ada) and [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). (Note: running old StyleGAN2 models on StyleGAN3 code will produce the same results as running them on stylegan2-ada/stylegan2-ada-pytorch. To benefit from the StyleGAN3 architecture, you need to retrain.)
+- Supports old StyleGAN2 training configurations, including ADA and transfer learning. See [Training configurations](./docs/configs.md) for details.
+- Improved compatibility with Ampere GPUs and newer versions of PyTorch, CuDNN, etc.
+
+## Synthetic image detection
+
+While new generator approaches enable new media synthesis capabilities, they may also present a new challenge for AI forensics algorithms for detection and attribution of synthetic media. In collaboration with digital forensic researchers participating in DARPA's SemaFor program, we curated a synthetic image dataset that allowed the researchers to test and validate the performance of their image detectors in advance of the public release. Please see [here](https://github.com/NVlabs/stylegan3-detector) for more details.
+
+## Additional material
+
+- [Result videos](https://nvlabs-fi-cdn.nvidia.com/stylegan3/videos/)
+- [Curated example images](https://nvlabs-fi-cdn.nvidia.com/stylegan3/images/)
+- [StyleGAN3 pre-trained models](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan3) for config T (translation equiv.) and config R (translation and rotation equiv.)
+ > Access individual networks via `https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/`, where `` is one of:
+ > `stylegan3-t-ffhq-1024x1024.pkl`, `stylegan3-t-ffhqu-1024x1024.pkl`, `stylegan3-t-ffhqu-256x256.pkl`
+ > `stylegan3-r-ffhq-1024x1024.pkl`, `stylegan3-r-ffhqu-1024x1024.pkl`, `stylegan3-r-ffhqu-256x256.pkl`
+ > `stylegan3-t-metfaces-1024x1024.pkl`, `stylegan3-t-metfacesu-1024x1024.pkl`
+ > `stylegan3-r-metfaces-1024x1024.pkl`, `stylegan3-r-metfacesu-1024x1024.pkl`
+ > `stylegan3-t-afhqv2-512x512.pkl`
+ > `stylegan3-r-afhqv2-512x512.pkl`
+- [StyleGAN2 pre-trained models](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan2) compatible with this codebase
+ > Access individual networks via `https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/`, where `` is one of:
+ > `stylegan2-ffhq-1024x1024.pkl`, `stylegan2-ffhq-512x512.pkl`, `stylegan2-ffhq-256x256.pkl`
+ > `stylegan2-ffhqu-1024x1024.pkl`, `stylegan2-ffhqu-256x256.pkl`
+ > `stylegan2-metfaces-1024x1024.pkl`, `stylegan2-metfacesu-1024x1024.pkl`
+ > `stylegan2-afhqv2-512x512.pkl`
+ > `stylegan2-afhqcat-512x512.pkl`, `stylegan2-afhqdog-512x512.pkl`, `stylegan2-afhqwild-512x512.pkl`
+ > `stylegan2-brecahad-512x512.pkl`, `stylegan2-cifar10-32x32.pkl`
+ > `stylegan2-celebahq-256x256.pkl`, `stylegan2-lsundog-256x256.pkl`
+
+## Requirements
+
+* Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
+* 1–8 high-end NVIDIA GPUs with at least 12 GB of memory. We have done all testing and development using Tesla V100 and A100 GPUs.
+* 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
+* CUDA toolkit 11.1 or later. (Why is a separate CUDA toolkit installation required? See [Troubleshooting](./docs/troubleshooting.md#why-is-cuda-toolkit-installation-necessary)).
+* GCC 7 or later (Linux) or Visual Studio (Windows) compilers. Recommended GCC version depends on CUDA version, see for example [CUDA 11.4 system requirements](https://docs.nvidia.com/cuda/archive/11.4.1/cuda-installation-guide-linux/index.html#system-requirements).
+* Python libraries: see [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your StyleGAN3 Python environment:
+ - `conda env create -f environment.yml`
+ - `conda activate stylegan3`
+* Docker users:
+ - Ensure you have correctly installed the [NVIDIA container runtime](https://docs.docker.com/config/containers/resource_constraints/#gpu).
+ - Use the [provided Dockerfile](./Dockerfile) to build an image with the required library dependencies.
+
+The code relies heavily on custom PyTorch extensions that are compiled on the fly using NVCC. On Windows, the compilation requires Microsoft Visual Studio. We recommend installing [Visual Studio Community Edition](https://visualstudio.microsoft.com/vs/) and adding it into `PATH` using `"C:\Program Files (x86)\Microsoft Visual Studio\\Community\VC\Auxiliary\Build\vcvars64.bat"`.
+
+See [Troubleshooting](./docs/troubleshooting.md) for help on common installation and run-time problems.
+
+## Getting started
+
+Pre-trained networks are stored as `*.pkl` files that can be referenced using local filenames or URLs:
+
+```bash
+# Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
+python gen_images.py --outdir=out --trunc=1 --seeds=2 \
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+
+# Render a 4x2 grid of interpolations for seeds 0 through 31.
+python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+```
+
+Outputs from the above commands are placed under `out/*.png`, controlled by `--outdir`. Downloaded network pickles are cached under `$HOME/.cache/dnnlib`, which can be overridden by setting the `DNNLIB_CACHE_DIR` environment variable. The default PyTorch extension build directory is `$HOME/.cache/torch_extensions`, which can be overridden by setting `TORCH_EXTENSIONS_DIR`.
+
+**Docker**: You can run the above curated image example using Docker as follows:
+
+```bash
+# Build the stylegan3:latest image
+docker build --tag stylegan3 .
+
+# Run the gen_images.py script using Docker:
+docker run --gpus all -it --rm --user $(id -u):$(id -g) \
+ -v `pwd`:/scratch --workdir /scratch -e HOME=/scratch \
+ stylegan3 \
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+```
+
+Note: The Docker image requires NVIDIA driver release `r470` or later.
+
+The `docker run` invocation may look daunting, so let's unpack its contents here:
+
+- `--gpus all -it --rm --user $(id -u):$(id -g)`: with all GPUs enabled, run an interactive session with current user's UID/GID to avoid Docker writing files as root.
+- ``-v `pwd`:/scratch --workdir /scratch``: mount current running dir (e.g., the top of this git repo on your host machine) to `/scratch` in the container and use that as the current working dir.
+- `-e HOME=/scratch`: let PyTorch and StyleGAN3 code know where to cache temporary files such as pre-trained models and custom PyTorch extension build results. Note: if you want more fine-grained control, you can instead set `TORCH_EXTENSIONS_DIR` (for custom extensions build dir) and `DNNLIB_CACHE_DIR` (for pre-trained model download cache). You want these cache dirs to reside on persistent volumes so that their contents are retained across multiple `docker run` invocations.
+
+## Interactive visualization
+
+This release contains an interactive model visualization tool that can be used to explore various characteristics of a trained model. To start it, run:
+
+```bash
+python visualizer.py
+```
+
+
+
+## Using networks from Python
+
+You can use pre-trained networks in your own Python code as follows:
+
+```python
+with open('ffhq.pkl', 'rb') as f:
+ G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
+z = torch.randn([1, G.z_dim]).cuda() # latent codes
+c = None # class labels (not used in this example)
+img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation
+```
+
+The above code requires `torch_utils` and `dnnlib` to be accessible via `PYTHONPATH`. It does not need source code for the networks themselves — their class definitions are loaded from the pickle via `torch_utils.persistence`.
+
+The pickle contains three networks. `'G'` and `'D'` are instantaneous snapshots taken during training, and `'G_ema'` represents a moving average of the generator weights over several training steps. The networks are regular instances of `torch.nn.Module`, with all of their parameters and buffers placed on the CPU at import and gradient computation disabled by default.
+
+The generator consists of two submodules, `G.mapping` and `G.synthesis`, that can be executed separately. They also support various additional options:
+
+```python
+w = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8)
+img = G.synthesis(w, noise_mode='const', force_fp32=True)
+```
+
+Please refer to [`gen_images.py`](./gen_images.py) for complete code example.
+
+## Preparing datasets
+
+Datasets are stored as uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information. Alternatively, the folder can also be used directly as a dataset, without running it through `dataset_tool.py` first, but doing so may lead to suboptimal performance.
+
+**FFHQ**: Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and create a zip archive using `dataset_tool.py`:
+
+```bash
+# Original 1024x1024 resolution.
+python dataset_tool.py --source=/tmp/images1024x1024 --dest=~/datasets/ffhq-1024x1024.zip
+
+# Scaled down 256x256 resolution.
+python dataset_tool.py --source=/tmp/images1024x1024 --dest=~/datasets/ffhq-256x256.zip \
+ --resolution=256x256
+```
+
+See the [FFHQ README](https://github.com/NVlabs/ffhq-dataset) for information on how to obtain the unaligned FFHQ dataset images. Use the same steps as above to create a ZIP archive for training and validation.
+
+**MetFaces**: Download the [MetFaces dataset](https://github.com/NVlabs/metfaces-dataset) and create a ZIP archive:
+
+```bash
+python dataset_tool.py --source=~/downloads/metfaces/images --dest=~/datasets/metfaces-1024x1024.zip
+```
+
+See the [MetFaces README](https://github.com/NVlabs/metfaces-dataset) for information on how to obtain the unaligned MetFaces dataset images. Use the same steps as above to create a ZIP archive for training and validation.
+
+**AFHQv2**: Download the [AFHQv2 dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) and create a ZIP archive:
+
+```bash
+python dataset_tool.py --source=~/downloads/afhqv2 --dest=~/datasets/afhqv2-512x512.zip
+```
+
+Note that the above command creates a single combined dataset using all images of all three classes (cats, dogs, and wild animals), matching the setup used in the StyleGAN3 paper. Alternatively, you can also create a separate dataset for each class:
+
+```bash
+python dataset_tool.py --source=~/downloads/afhqv2/train/cat --dest=~/datasets/afhqv2cat-512x512.zip
+python dataset_tool.py --source=~/downloads/afhqv2/train/dog --dest=~/datasets/afhqv2dog-512x512.zip
+python dataset_tool.py --source=~/downloads/afhqv2/train/wild --dest=~/datasets/afhqv2wild-512x512.zip
+```
+
+## Training
+
+You can train new networks using `train.py`. For example:
+
+```bash
+# Train StyleGAN3-T for AFHQv2 using 8 GPUs.
+python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
+ --gpus=8 --batch=32 --gamma=8.2 --mirror=1
+
+# Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
+python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
+ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
+ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
+
+# Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs.
+python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \
+ --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug
+```
+
+Note that the result quality and training time depend heavily on the exact set of options. The most important ones (`--gpus`, `--batch`, and `--gamma`) must be specified explicitly, and they should be selected with care. See [`python train.py --help`](./docs/train-help.txt) for the full list of options and [Training configurations](./docs/configs.md) for general guidelines & recommendations, along with the expected training speed & memory usage in different scenarios.
+
+The results of each training run are saved to a newly created directory, for example `~/training-runs/00000-stylegan3-t-afhqv2-512x512-gpus8-batch32-gamma8.2`. The training loop exports network pickles (`network-snapshot-.pkl`) and random image grids (`fakes.png`) at regular intervals (controlled by `--snap`). For each exported pickle, it evaluates FID (controlled by `--metrics`) and logs the result in `metric-fid50k_full.jsonl`. It also records various statistics in `training_stats.jsonl`, as well as `*.tfevents` if TensorBoard is installed.
+
+## Quality metrics
+
+By default, `train.py` automatically computes FID for each network pickle exported during training. We recommend inspecting `metric-fid50k_full.jsonl` (or TensorBoard) at regular intervals to monitor the training progress. When desired, the automatic computation can be disabled with `--metrics=none` to speed up the training slightly.
+
+Additional quality metrics can also be computed after the training:
+
+```bash
+# Previous training run: look up options automatically, save result to JSONL file.
+python calc_metrics.py --metrics=eqt50k_int,eqr50k \
+ --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
+
+# Pre-trained network pickle: specify dataset explicitly, print result to stdout.
+python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
+```
+
+The first example looks up the training configuration and performs the same operation as if `--metrics=eqt50k_int,eqr50k` had been specified during training. The second example downloads a pre-trained network pickle, in which case the values of `--data` and `--mirror` must be specified explicitly.
+
+Note that the metrics can be quite expensive to compute (up to 1h), and many of them have an additional one-off cost for each new dataset (up to 30min). Also note that the evaluation is done using a different random seed each time, so the results will vary if the same metric is computed multiple times.
+
+Recommended metrics:
+* `fid50k_full`: Fréchet inception distance[1] against the full dataset.
+* `kid50k_full`: Kernel inception distance[2] against the full dataset.
+* `pr50k3_full`: Precision and recall[3] againt the full dataset.
+* `ppl2_wend`: Perceptual path length[4] in W, endpoints, full image.
+* `eqt50k_int`: Equivariance[5] w.r.t. integer translation (EQ-T).
+* `eqt50k_frac`: Equivariance w.r.t. fractional translation (EQ-Tfrac).
+* `eqr50k`: Equivariance w.r.t. rotation (EQ-R).
+
+Legacy metrics:
+* `fid50k`: Fréchet inception distance against 50k real images.
+* `kid50k`: Kernel inception distance against 50k real images.
+* `pr50k3`: Precision and recall against 50k real images.
+* `is50k`: Inception score[6] for CIFAR-10.
+
+References:
+1. [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500), Heusel et al. 2017
+2. [Demystifying MMD GANs](https://arxiv.org/abs/1801.01401), Bińkowski et al. 2018
+3. [Improved Precision and Recall Metric for Assessing Generative Models](https://arxiv.org/abs/1904.06991), Kynkäänniemi et al. 2019
+4. [A Style-Based Generator Architecture for Generative Adversarial Networks](https://arxiv.org/abs/1812.04948), Karras et al. 2018
+5. [Alias-Free Generative Adversarial Networks](https://nvlabs.github.io/stylegan3), Karras et al. 2021
+6. [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498), Salimans et al. 2016
+
+## Spectral analysis
+
+The easiest way to inspect the spectral properties of a given generator is to use the built-in FFT mode in `visualizer.py`. In addition, you can visualize average 2D power spectra (Appendix A, Figure 15) as follows:
+
+```bash
+# Calculate dataset mean and std, needed in subsequent steps.
+python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip
+
+# Calculate average spectrum for the training data.
+python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \
+ --dest=tmp/training-data.npz --mean=112.684 --std=69.509
+
+# Calculate average spectrum for a pre-trained generator.
+python avg_spectra.py calc \
+ --source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \
+ --dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000
+
+# Display results.
+python avg_spectra.py heatmap tmp/training-data.npz
+python avg_spectra.py heatmap tmp/stylegan3-r.npz
+python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz
+```
+
+
+
+## License
+
+Copyright © 2021, NVIDIA Corporation & affiliates. All rights reserved.
+
+This work is made available under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt).
+
+## Citation
+
+```
+@inproceedings{Karras2021,
+ author = {Tero Karras and Miika Aittala and Samuli Laine and Erik H\"ark\"onen and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
+ title = {Alias-Free Generative Adversarial Networks},
+ booktitle = {Proc. NeurIPS},
+ year = {2021}
+}
+```
+
+## Development
+
+This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.
+
+## Acknowledgements
+
+We thank David Luebke, Ming-Yu Liu, Koki Nagano, Tuomas Kynkäänniemi, and Timo Viitanen for reviewing early drafts and helpful suggestions. Frédo Durand for early discussions. Tero Kuosmanen for maintaining our compute infrastructure. AFHQ authors for an updated version of their dataset. Getty Images for the training images in the Beaches dataset. We did not receive external funding or additional revenues for this project.
diff --git a/stylegan3-fun/__pycache__/legacy.cpython-311.pyc b/stylegan3-fun/__pycache__/legacy.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02d5137b9d3c86b5e908ef76cf0d2cd54e0f656a
Binary files /dev/null and b/stylegan3-fun/__pycache__/legacy.cpython-311.pyc differ
diff --git a/stylegan3-fun/__pycache__/legacy.cpython-38.pyc b/stylegan3-fun/__pycache__/legacy.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1f75fe0e1ac3d1668e6786e3daac1b052037f97
Binary files /dev/null and b/stylegan3-fun/__pycache__/legacy.cpython-38.pyc differ
diff --git a/stylegan3-fun/__pycache__/legacy.cpython-39.pyc b/stylegan3-fun/__pycache__/legacy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..533685985f92f0c4783d784faae03c01cf996f2a
Binary files /dev/null and b/stylegan3-fun/__pycache__/legacy.cpython-39.pyc differ
diff --git a/stylegan3-fun/avg_spectra.py b/stylegan3-fun/avg_spectra.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53a7b3b7be5345477e82b154eb535f75da59b78
--- /dev/null
+++ b/stylegan3-fun/avg_spectra.py
@@ -0,0 +1,276 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Compare average power spectra between real and generated images,
+or between multiple generators."""
+
+import os
+import numpy as np
+import torch
+import torch.fft
+import scipy.ndimage
+import matplotlib.pyplot as plt
+import click
+import tqdm
+import dnnlib
+
+import legacy
+from training import dataset
+
+#----------------------------------------------------------------------------
+# Setup an iterator for streaming images, in uint8 NCHW format, based on the
+# respective command line options.
+
+def stream_source_images(source, num, seed, device, data_loader_kwargs=None): # => num_images, image_size, image_iter
+ ext = source.split('.')[-1].lower()
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ if ext == 'pkl':
+ if num is None:
+ raise click.ClickException('--num is required when --source points to network pickle')
+ with dnnlib.util.open_url(source) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
+ def generate_image(seed):
+ rnd = np.random.RandomState(seed)
+ z = torch.from_numpy(rnd.randn(1, G.z_dim)).to(device)
+ c = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim > 0:
+ c[:, rnd.randint(G.c_dim)] = 1
+ return (G(z=z, c=c) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ _ = generate_image(seed) # warm up
+ image_iter = (generate_image(seed + idx) for idx in range(num))
+ return num, G.img_resolution, image_iter
+
+ elif ext == 'zip' or os.path.isdir(source):
+ dataset_obj = dataset.ImageFolderDataset(path=source, max_size=num, random_seed=seed)
+ if num is not None and num != len(dataset_obj):
+ raise click.ClickException(f'--source contains fewer than {num} images')
+ data_loader = torch.utils.data.DataLoader(dataset_obj, batch_size=1, **data_loader_kwargs)
+ image_iter = (image.to(device) for image, _label in data_loader)
+ return len(dataset_obj), dataset_obj.resolution, image_iter
+
+ else:
+ raise click.ClickException('--source must point to network pickle, dataset zip, or directory')
+
+#----------------------------------------------------------------------------
+# Load average power spectrum from the specified .npz file and construct
+# the corresponding heatmap for visualization.
+
+def construct_heatmap(npz_file, smooth):
+ npz_data = np.load(npz_file)
+ spectrum = npz_data['spectrum']
+ image_size = npz_data['image_size']
+ hmap = np.log10(spectrum) * 10 # dB
+ hmap = np.fft.fftshift(hmap)
+ hmap = np.concatenate([hmap, hmap[:1, :]], axis=0)
+ hmap = np.concatenate([hmap, hmap[:, :1]], axis=1)
+ if smooth > 0:
+ sigma = spectrum.shape[0] / image_size * smooth
+ hmap = scipy.ndimage.gaussian_filter(hmap, sigma=sigma, mode='nearest')
+ return hmap, image_size
+
+#----------------------------------------------------------------------------
+
+@click.group()
+def main():
+ """Compare average power spectra between real and generated images,
+ or between multiple generators.
+
+ Example:
+
+ \b
+ # Calculate dataset mean and std, needed in subsequent steps.
+ python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip
+
+ \b
+ # Calculate average spectrum for the training data.
+ python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \\
+ --dest=tmp/training-data.npz --mean=112.684 --std=69.509
+
+ \b
+ # Calculate average spectrum for a pre-trained generator.
+ python avg_spectra.py calc \\
+ --source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \\
+ --dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000
+
+ \b
+ # Display results.
+ python avg_spectra.py heatmap tmp/training-data.npz
+ python avg_spectra.py heatmap tmp/stylegan3-r.npz
+ python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz
+
+ \b
+ # Save as PNG.
+ python avg_spectra.py heatmap tmp/training-data.npz --save=tmp/training-data.png --dpi=300
+ python avg_spectra.py heatmap tmp/stylegan3-r.npz --save=tmp/stylegan3-r.png --dpi=300
+ python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz --save=tmp/slices.png --dpi=300
+ """
+
+#----------------------------------------------------------------------------
+
+@main.command()
+@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
+@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
+@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
+def stats(source, num, seed, device=torch.device('cuda')):
+ """Calculate dataset mean and standard deviation needed by 'calc'."""
+ torch.multiprocessing.set_start_method('spawn')
+ num_images, _image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
+
+ # Accumulate moments.
+ moments = torch.zeros([3], dtype=torch.float64, device=device)
+ for image in tqdm.tqdm(image_iter, total=num_images):
+ image = image.to(torch.float64)
+ moments += torch.stack([torch.ones_like(image).sum(), image.sum(), image.square().sum()])
+ moments = moments / moments[0]
+
+ # Compute mean and standard deviation.
+ mean = moments[1]
+ std = (moments[2] - moments[1].square()).sqrt()
+ print(f'--mean={mean:g} --std={std:g}')
+
+#----------------------------------------------------------------------------
+
+@main.command()
+@click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
+@click.option('--dest', help='Where to store the result', metavar='NPZ', required=True)
+@click.option('--mean', help='Dataset mean for whitening', metavar='FLOAT', type=float, required=True)
+@click.option('--std', help='Dataset standard deviation for whitening', metavar='FLOAT', type=click.FloatRange(min=0), required=True)
+@click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
+@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
+@click.option('--beta', help='Shape parameter for the Kaiser window', metavar='FLOAT', type=click.FloatRange(min=0), default=8, show_default=True)
+@click.option('--interp', help='Frequency-domain interpolation factor', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True)
+def calc(source, dest, mean, std, num, seed, beta, interp, device=torch.device('cuda')):
+ """Calculate average power spectrum and store it in .npz file."""
+ torch.multiprocessing.set_start_method('spawn')
+ num_images, image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
+ spectrum_size = image_size * interp
+ padding = spectrum_size - image_size
+
+ # Setup window function.
+ window = torch.kaiser_window(image_size, periodic=False, beta=beta, device=device)
+ window *= window.square().sum().rsqrt()
+ window = window.ger(window).unsqueeze(0).unsqueeze(1)
+
+ # Accumulate power spectrum.
+ spectrum = torch.zeros([spectrum_size, spectrum_size], dtype=torch.float64, device=device)
+ for image in tqdm.tqdm(image_iter, total=num_images):
+ image = (image.to(torch.float64) - mean) / std
+ image = torch.nn.functional.pad(image * window, [0, padding, 0, padding])
+ spectrum += torch.fft.fftn(image, dim=[2,3]).abs().square().mean(dim=[0,1])
+ spectrum /= num_images
+
+ # Save result.
+ if os.path.dirname(dest):
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
+ np.savez(dest, spectrum=spectrum.cpu().numpy(), image_size=image_size)
+
+#----------------------------------------------------------------------------
+
+@main.command()
+@click.argument('npz-file', nargs=1)
+@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
+@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
+@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=1.25, show_default=True)
+def heatmap(npz_file, save, smooth, dpi):
+ """Visualize 2D heatmap based on the given .npz file."""
+ hmap, image_size = construct_heatmap(npz_file=npz_file, smooth=smooth)
+
+ # Setup plot.
+ plt.figure(figsize=[6, 4.8], dpi=dpi, tight_layout=True)
+ freqs = np.linspace(-0.5, 0.5, num=hmap.shape[0], endpoint=True) * image_size
+ ticks = np.linspace(freqs[0], freqs[-1], num=5, endpoint=True)
+ levels = np.linspace(-40, 20, num=13, endpoint=True)
+
+ # Draw heatmap.
+ plt.xlim(ticks[0], ticks[-1])
+ plt.ylim(ticks[0], ticks[-1])
+ plt.xticks(ticks)
+ plt.yticks(ticks)
+ plt.contourf(freqs, freqs, hmap, levels=levels, extend='both', cmap='Blues')
+ plt.gca().set_aspect('equal')
+ plt.colorbar(ticks=levels)
+ plt.contour(freqs, freqs, hmap, levels=levels, extend='both', linestyles='solid', linewidths=1, colors='midnightblue', alpha=0.2)
+
+ # Display or save.
+ if save is None:
+ plt.show()
+ else:
+ if os.path.dirname(save):
+ os.makedirs(os.path.dirname(save), exist_ok=True)
+ plt.savefig(save)
+
+#----------------------------------------------------------------------------
+
+@main.command()
+@click.argument('npz-files', nargs=-1, required=True)
+@click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
+@click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
+@click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
+def slices(npz_files, save, dpi, smooth):
+ """Visualize 1D slices based on the given .npz files."""
+ cases = [dnnlib.EasyDict(npz_file=npz_file) for npz_file in npz_files]
+ for c in cases:
+ c.hmap, c.image_size = construct_heatmap(npz_file=c.npz_file, smooth=smooth)
+ c.label = os.path.splitext(os.path.basename(c.npz_file))[0]
+
+ # Check consistency.
+ image_size = cases[0].image_size
+ hmap_size = cases[0].hmap.shape[0]
+ if any(c.image_size != image_size or c.hmap.shape[0] != hmap_size for c in cases):
+ raise click.ClickException('All .npz must have the same resolution')
+
+ # Setup plot.
+ plt.figure(figsize=[12, 4.6], dpi=dpi, tight_layout=True)
+ hmap_center = hmap_size // 2
+ hmap_range = np.arange(hmap_center, hmap_size)
+ freqs0 = np.linspace(0, image_size / 2, num=(hmap_size // 2 + 1), endpoint=True)
+ freqs45 = np.linspace(0, image_size / np.sqrt(2), num=(hmap_size // 2 + 1), endpoint=True)
+ xticks0 = np.linspace(freqs0[0], freqs0[-1], num=9, endpoint=True)
+ xticks45 = np.round(np.linspace(freqs45[0], freqs45[-1], num=9, endpoint=True))
+ yticks = np.linspace(-50, 30, num=9, endpoint=True)
+
+ # Draw 0 degree slice.
+ plt.subplot(1, 2, 1)
+ plt.title('0\u00b0 slice')
+ plt.xlim(xticks0[0], xticks0[-1])
+ plt.ylim(yticks[0], yticks[-1])
+ plt.xticks(xticks0)
+ plt.yticks(yticks)
+ for c in cases:
+ plt.plot(freqs0, c.hmap[hmap_center, hmap_range], label=c.label)
+ plt.grid()
+ plt.legend(loc='upper right')
+
+ # Draw 45 degree slice.
+ plt.subplot(1, 2, 2)
+ plt.title('45\u00b0 slice')
+ plt.xlim(xticks45[0], xticks45[-1])
+ plt.ylim(yticks[0], yticks[-1])
+ plt.xticks(xticks45)
+ plt.yticks(yticks)
+ for c in cases:
+ plt.plot(freqs45, c.hmap[hmap_range, hmap_range], label=c.label)
+ plt.grid()
+ plt.legend(loc='upper right')
+
+ # Display or save.
+ if save is None:
+ plt.show()
+ else:
+ if os.path.dirname(save):
+ os.makedirs(os.path.dirname(save), exist_ok=True)
+ plt.savefig(save)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/calc_metrics.py b/stylegan3-fun/calc_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e1e9404dbaa8901352fc74475e6052e103f760
--- /dev/null
+++ b/stylegan3-fun/calc_metrics.py
@@ -0,0 +1,188 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Calculate quality metrics for previous training run or pretrained network pickle."""
+
+import os
+import click
+import json
+import tempfile
+import copy
+import torch
+
+import dnnlib
+import legacy
+from metrics import metric_main
+from metrics import metric_utils
+from torch_utils import training_stats
+from torch_utils import custom_ops
+from torch_utils import misc
+from torch_utils.ops import conv2d_gradfix
+
+#----------------------------------------------------------------------------
+
+def subprocess_fn(rank, args, temp_dir):
+ dnnlib.util.Logger(should_flush=True)
+
+ # Init torch.distributed.
+ if args.num_gpus > 1:
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
+ if os.name == 'nt':
+ init_method = 'file:///' + init_file.replace('\\', '/')
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
+ else:
+ init_method = f'file://{init_file}'
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
+
+ # Init torch_utils.
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
+ if rank != 0 or not args.verbose:
+ custom_ops.verbosity = 'none'
+
+ # Configure torch.
+ device = torch.device('cuda', rank)
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+ conv2d_gradfix.enabled = True
+
+ # Print network summary.
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
+ if rank == 0 and args.verbose:
+ z = torch.empty([1, G.z_dim], device=device)
+ c = torch.empty([1, G.c_dim], device=device)
+ misc.print_module_summary(G, [z, c])
+
+ # Calculate each metric.
+ for metric in args.metrics:
+ if rank == 0 and args.verbose:
+ print(f'Calculating {metric}...')
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
+ result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
+ num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
+ if rank == 0:
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
+ if rank == 0 and args.verbose:
+ print()
+
+ # Done.
+ if rank == 0 and args.verbose:
+ print('Exiting...')
+
+#----------------------------------------------------------------------------
+
+def parse_comma_separated_list(s):
+ if isinstance(s, list):
+ return s
+ if s is None or s.lower() == 'none' or s == '':
+ return []
+ return s.split(',')
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.pass_context
+@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
+@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
+@click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
+@click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
+@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
+@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
+
+def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
+ """Calculate quality metrics for previous training run or pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Previous training run: look up options automatically, save result to JSONL file.
+ python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
+ --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
+
+ \b
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
+
+ \b
+ Recommended metrics:
+ fid50k_full Frechet inception distance against the full dataset.
+ kid50k_full Kernel inception distance against the full dataset.
+ pr50k3_full Precision and recall againt the full dataset.
+ ppl2_wend Perceptual path length in W, endpoints, full image.
+ eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
+ eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
+ eqr50k Equivariance w.r.t. rotation (EQ-R).
+
+ \b
+ Legacy metrics:
+ fid50k Frechet inception distance against 50k real images.
+ kid50k Kernel inception distance against 50k real images.
+ pr50k3 Precision and recall against 50k real images.
+ is50k Inception score for CIFAR-10.
+ """
+ dnnlib.util.Logger(should_flush=True)
+
+ # Validate arguments.
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
+ if not args.num_gpus >= 1:
+ ctx.fail('--gpus must be at least 1')
+
+ # Load network.
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
+ ctx.fail('--network must point to a file or URL')
+ if args.verbose:
+ print(f'Loading network from "{network_pkl}"...')
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
+ network_dict = legacy.load_network_pkl(f)
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
+
+ # Initialize dataset options.
+ if data is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
+ elif network_dict['training_set_kwargs'] is not None:
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
+ else:
+ ctx.fail('Could not look up dataset options; please specify --data')
+
+ # Finalize dataset options.
+ args.dataset_kwargs.resolution = args.G.img_resolution
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
+ if mirror is not None:
+ args.dataset_kwargs.xflip = mirror
+
+ # Print dataset options.
+ if args.verbose:
+ print('Dataset options:')
+ print(json.dumps(args.dataset_kwargs, indent=2))
+
+ # Locate run dir.
+ args.run_dir = None
+ if os.path.isfile(network_pkl):
+ pkl_dir = os.path.dirname(network_pkl)
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
+ args.run_dir = pkl_dir
+
+ # Launch processes.
+ if args.verbose:
+ print('Launching processes...')
+ torch.multiprocessing.set_start_method('spawn')
+ with tempfile.TemporaryDirectory() as temp_dir:
+ if args.num_gpus == 1:
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
+ else:
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ calc_metrics() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/dataset_tool.py b/stylegan3-fun/dataset_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..3474c5271162d212872ab1298af2f242bf0227b7
--- /dev/null
+++ b/stylegan3-fun/dataset_tool.py
@@ -0,0 +1,547 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Tool for creating ZIP/PNG based datasets."""
+
+import functools
+import gzip
+import io
+import json
+import os
+import pickle
+import re
+import sys
+import tarfile
+import zipfile
+from pathlib import Path
+from typing import Callable, Optional, Tuple, Union
+
+import click
+import numpy as np
+import PIL.Image
+from tqdm import tqdm
+from torch_utils import gen_utils
+
+
+# ----------------------------------------------------------------------------
+
+
+def error(msg):
+ print('Error: ' + msg)
+ sys.exit(1)
+
+
+# ----------------------------------------------------------------------------
+
+
+def parse_tuple(s: str) -> Tuple[int, int]:
+ """Parse a 'M,N' or 'MxN' integer tuple.
+
+ Example:
+ '4x2' returns (4,2)
+ '0,1' returns (0,1)
+ """
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
+ if m:
+ return int(m.group(1)), int(m.group(2))
+ raise ValueError(f'cannot parse tuple {s}')
+
+
+# ----------------------------------------------------------------------------
+
+
+def maybe_min(a: int, b: Optional[int]) -> int:
+ if b is not None:
+ return min(a, b)
+ return a
+
+
+# ----------------------------------------------------------------------------
+
+
+def file_ext(name: Union[str, Path]) -> str:
+ return str(name).split('.')[-1]
+
+
+# ----------------------------------------------------------------------------
+
+
+def is_image_ext(fname: Union[str, Path]) -> bool:
+ ext = file_ext(fname).lower()
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
+
+
+# ----------------------------------------------------------------------------
+
+
+def open_image_folder(source_dir, force_channels: int = None, *, max_images: Optional[int], subfolders_as_labels: Optional[bool] = False):
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
+
+ # Load labels.
+ labels = {}
+ meta_fname = os.path.join(source_dir, 'dataset.json')
+ if os.path.isfile(meta_fname) and not subfolders_as_labels:
+ # The `dataset.json` file exists and will be used (whether empty or not)
+ with open(meta_fname, 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = {x[0]: x[1] for x in labels}
+ else:
+ labels = {}
+ elif subfolders_as_labels:
+ # Use the folders in the directory as the labels themselves
+ # Get the subfolder names from the input_images names
+ labels = {os.path.relpath(fname, source_dir).replace('\\', '/'): os.path.basename(os.path.dirname(fname)) for fname in input_images}
+ # Change folder name (value) to a number (from 0 to n-1)
+ label_names = list(set(labels.values()))
+ label_names.sort()
+ labels = {fname: label_names.index(label) for fname, label in labels.items()}
+ print(f'Conditional dataset has {len(label_names)} labels! Saving to `class_labels.txt` in the source directory...')
+ with open(os.path.join(source_dir, 'class_labels.txt'), 'w') as f:
+ # Write, one per line, the index and the label name
+ for i, label in enumerate(label_names):
+ f.write(f'{i}: {label}\n')
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ for idx, fname in enumerate(input_images):
+ arch_fname = os.path.relpath(fname, source_dir)
+ arch_fname = arch_fname.replace('\\', '/')
+ # Adding Pull #39 from Andreas Jansson: https://github.com/NVlabs/stylegan3/pull/39
+ try:
+ img = PIL.Image.open(fname) # Let PIL handle the mode
+ # Convert grayscale image to RGB
+ if img.mode == 'L':
+ img = img.convert('RGB')
+ # Force the number of channels if so requested
+ if force_channels is not None:
+ img = img.convert(gen_utils.channels_dict[int(force_channels)])
+ img = np.array(img)
+ except Exception as e:
+ sys.stderr.write(f'Failed to read {fname}: {e}')
+ continue
+ yield dict(img=img, label=labels.get(arch_fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+
+# ----------------------------------------------------------------------------
+
+
+def open_image_zip(source, force_channels: int = None, *, max_images: Optional[int]):
+ with zipfile.ZipFile(source, mode='r') as z:
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
+
+ # Load labels.
+ labels = {}
+ if 'dataset.json' in z.namelist():
+ with z.open('dataset.json', 'r') as file:
+ labels = json.load(file)['labels']
+ if labels is not None:
+ labels = {x[0]: x[1] for x in labels}
+ else:
+ labels = {}
+
+ max_idx = maybe_min(len(input_images), max_images)
+
+ def iterate_images():
+ with zipfile.ZipFile(source, mode='r') as z:
+ for idx, fname in enumerate(input_images):
+ with z.open(fname, 'r') as file:
+ # Same as above: PR #39 by Andreas Jansson and turn Grayscale to RGB
+ try:
+ img = PIL.Image.open(file) # type: ignore
+ if img.mode == 'L':
+ img = img.convert('RGB')
+ # Force the number of channels if so requested
+ if force_channels is not None:
+ img = img.convert(gen_utils.channels_dict[int(force_channels)])
+ img = np.array(img)
+ except Exception as e:
+ sys.stderr.write(f'Failed to read {fname}: {e}')
+ continue
+ yield dict(img=img, label=labels.get(fname))
+ if idx >= max_idx-1:
+ break
+ return max_idx, iterate_images()
+
+
+# ----------------------------------------------------------------------------
+
+
+def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
+ import cv2 # pip install opencv-python # pylint: disable=import-error
+ import lmdb # pip install lmdb # pylint: disable=import-error
+
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
+
+ def iterate_images():
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
+ for idx, (_key, value) in enumerate(txn.cursor()):
+ try:
+ try:
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
+ if img is None:
+ raise IOError('cv2.imdecode failed')
+ img = img[:, :, ::-1] # BGR => RGB
+ except IOError:
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
+ yield dict(img=img, label=None)
+ if idx >= max_idx-1:
+ break
+ except:
+ print(sys.exc_info()[1])
+
+ return max_idx, iterate_images()
+
+
+# ----------------------------------------------------------------------------
+
+
+def open_cifar10(tarball: str, *, max_images: Optional[int]):
+ images = []
+ labels = []
+
+ with tarfile.open(tarball, 'r:gz') as tar:
+ for batch in range(1, 6):
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
+ with tar.extractfile(member) as file:
+ data = pickle.load(file, encoding='latin1')
+ images.append(data['data'].reshape(-1, 3, 32, 32))
+ labels.append(data['labels'])
+
+ images = np.concatenate(images)
+ labels = np.concatenate(labels)
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+
+# ----------------------------------------------------------------------------
+
+
+def open_mnist(images_gz: str, *, max_images: Optional[int]):
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
+ assert labels_gz != images_gz
+ images = []
+ labels = []
+
+ with gzip.open(images_gz, 'rb') as f:
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
+ with gzip.open(labels_gz, 'rb') as f:
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
+
+ images = images.reshape(-1, 28, 28)
+ images = np.pad(images, [(0, 0), (2, 2), (2, 2)], 'constant', constant_values=0)
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
+ assert np.min(images) == 0 and np.max(images) == 255
+ assert np.min(labels) == 0 and np.max(labels) == 9
+
+ max_idx = maybe_min(len(images), max_images)
+
+ def iterate_images():
+ for idx, img in enumerate(images):
+ yield dict(img=img, label=int(labels[idx]))
+ if idx >= max_idx-1:
+ break
+
+ return max_idx, iterate_images()
+
+
+# ----------------------------------------------------------------------------
+
+
+def make_transform(
+ transform: Optional[str],
+ output_width: Optional[int],
+ output_height: Optional[int]
+) -> Callable[[np.ndarray], Optional[np.ndarray]]:
+ def scale(width, height, img):
+ w = img.shape[1]
+ h = img.shape[0]
+ if width == w and height == h:
+ return img
+ img = PIL.Image.fromarray(img)
+ ww = width if width is not None else w
+ hh = height if height is not None else h
+ img = img.resize((ww, hh), PIL.Image.LANCZOS)
+ return np.array(img)
+
+ def center_crop(width, height, img):
+ crop = np.min(img.shape[:2])
+ img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
+ (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[img.shape[2]])
+ img = img.resize((width, height), PIL.Image.LANCZOS)
+ return np.array(img)
+
+ def center_crop_wide(width, height, img):
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
+ if img.shape[1] < width or ch < height:
+ return None
+
+ img = img[(img.shape[0] - ch) // 2: (img.shape[0] + ch) // 2]
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[img.shape[2]])
+ img = img.resize((width, height), PIL.Image.LANCZOS)
+ img = np.array(img)
+
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
+ return canvas
+
+ def center_crop_tall(width, height, img):
+ ch = int(np.round(height * img.shape[1] / img.shape[0]))
+ if img.shape[0] < height or ch < width:
+ return None
+
+ img = img[:, (img.shape[1] - ch) // 2: (img.shape[1] + ch) // 2] # center-crop: [width0, height0, C] -> [width0, height, C]
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[img.shape[2]])
+ img = img.resize((width, height), PIL.Image.LANCZOS) # resize: [width0, height, 3] -> [width, height, 3]
+ img = np.array(img)
+
+ canvas = np.zeros([height, height, 3], dtype=np.uint8) # square canvas
+ canvas[:, (height - width) // 2: (height + width) // 2] = img # replace the middle with img
+ return canvas
+
+ if transform is None:
+ return functools.partial(scale, output_width, output_height)
+ if transform == 'center-crop':
+ if (output_width is None) or (output_height is None):
+ error(f'must specify --resolution=WxH when using {transform} transform')
+ return functools.partial(center_crop, output_width, output_height)
+ if transform == 'center-crop-wide':
+ if (output_width is None) or (output_height is None):
+ error(f'must specify --resolution=WxH when using {transform} transform')
+ return functools.partial(center_crop_wide, output_width, output_height)
+ if transform == 'center-crop-tall':
+ if (output_width is None) or (output_height is None):
+ error(f'must specify --resolution=WxH when using {transform} transform')
+ return functools.partial(center_crop_tall, output_width, output_height)
+ assert False, 'unknown transform'
+
+
+# ----------------------------------------------------------------------------
+
+
+def open_dataset(source, force_channels, *, max_images: Optional[int], subfolders_as_labels: Optional[bool] = False):
+ if os.path.isdir(source):
+ if source.rstrip('/').endswith('_lmdb'):
+ return open_lmdb(source, max_images=max_images)
+ else:
+ return open_image_folder(source, force_channels, max_images=max_images, subfolders_as_labels=subfolders_as_labels)
+ elif os.path.isfile(source):
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
+ return open_cifar10(source, max_images=max_images)
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
+ return open_mnist(source, max_images=max_images)
+ elif file_ext(source) == 'zip':
+ return open_image_zip(source, force_channels, max_images=max_images)
+ else:
+ assert False, 'unknown archive type'
+ else:
+ error(f'Missing input file or directory: {source}')
+
+
+# ----------------------------------------------------------------------------
+
+
+def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
+ dest_ext = file_ext(dest)
+
+ if dest_ext == 'zip':
+ if os.path.dirname(dest) != '':
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
+ zf.writestr(fname, data)
+ return '', zip_write_bytes, zf.close
+ else:
+ # If the output folder already exists, check that it is
+ # empty.
+ #
+ # Note: creating the output directory is not strictly
+ # necessary as folder_write_bytes() also mkdirs, but it's better
+ # to give an error message earlier in case the dest folder
+ # somehow cannot be created.
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
+ error('--dest folder must be empty')
+ os.makedirs(dest, exist_ok=True)
+
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
+ with open(fname, 'wb') as fout:
+ if isinstance(data, str):
+ data = data.encode('utf8')
+ fout.write(data)
+ return dest, folder_write_bytes, lambda: None
+
+
+# ----------------------------------------------------------------------------
+
+
+@click.command()
+@click.pass_context
+@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
+@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
+@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
+@click.option('--force-channels', help='Force the number of channels in the image (1: grayscale, 3: RGB, 4: RGBA)', type=click.Choice(['1', '3', '4']), default=None)
+@click.option('--subfolders-as-labels', help='Use the folder names as the labels, to avoid setting up `dataset.json`', is_flag=True)
+@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-tall']))
+@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
+def convert_dataset(
+ ctx: click.Context,
+ source: str,
+ dest: str,
+ max_images: Optional[int],
+ force_channels: Optional[int],
+ subfolders_as_labels: Optional[bool],
+ transform: Optional[str],
+ resolution: Optional[Tuple[int, int]]
+):
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
+
+ The input dataset format is guessed from the --source argument:
+
+ \b
+ --source *_lmdb/ Load LSUN dataset
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
+ --source path/ Recursively load all images from path/
+ --source dataset.zip Recursively load all images from dataset.zip
+
+ Specifying the output format and path:
+
+ \b
+ --dest /path/to/dir Save output files under /path/to/dir
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
+
+ The output dataset format can be either an image folder or an uncompressed zip archive.
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
+ offer better training performance on network file systems.
+
+ Images within the dataset archive will be stored as uncompressed PNG.
+ Uncompresed PNGs can be efficiently decoded in the training loop.
+
+ Class labels are stored in a file called 'dataset.json' that is stored at the
+ dataset root folder. This file has the following structure:
+
+ \b
+ {
+ "labels": [
+ ["00000/img00000000.png",6],
+ ["00000/img00000001.png",9],
+ ... repeated for every image in the dataset
+ ["00049/img00049999.png",1]
+ ]
+ }
+
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
+ not containing class labels.
+
+ Image scale/crop and resolution requirements:
+
+ Output images must be square-shaped and they must all have the same power-of-two
+ dimensions.
+
+ To scale arbitrary input image size to a specific width and height, use the
+ --resolution option. Output resolution will be either the original
+ input resolution (if resolution was not specified) or the one specified with
+ --resolution option.
+
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
+ center crop transform on the input image. These options should be used with the
+ --resolution option. For example:
+
+ \b
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
+ --transform=center-crop-wide --resolution=512x384
+ """
+
+ PIL.Image.init() # type: ignore
+
+ if dest == '':
+ ctx.fail('--dest output filename or directory must not be an empty string')
+
+ num_files, input_iter = open_dataset(source, force_channels, max_images=max_images, subfolders_as_labels=subfolders_as_labels)
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
+
+ if resolution is None: resolution = (None, None)
+ transform_image = make_transform(transform, *resolution)
+
+ dataset_attrs = None
+
+ labels = []
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
+ idx_str = f'{idx:08d}'
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
+
+ # Apply crop and resize.
+ img = transform_image(image['img'])
+
+ # Transform may drop images.
+ if img is None:
+ continue
+
+ # Error check to require uniform image attributes across
+ # the whole dataset.
+ channels = img.shape[2] if img.ndim == 3 else 1
+ cur_image_attrs = {
+ 'width': img.shape[1],
+ 'height': img.shape[0],
+ 'channels': channels
+ }
+ if dataset_attrs is None:
+ dataset_attrs = cur_image_attrs
+ width = dataset_attrs['width']
+ height = dataset_attrs['height']
+ if width != height:
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
+ if dataset_attrs['channels'] not in [1, 3, 4]:
+ error('Input images must be stored as grayscale, RGB, or RGBA')
+ if width != 2 ** int(np.floor(np.log2(width))):
+ error('Image width/height after scale and crop are required to be power-of-two')
+ elif dataset_attrs != cur_image_attrs:
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
+
+ # Save the image as an uncompressed PNG.
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[channels])
+ image_bits = io.BytesIO()
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
+
+ metadata = {
+ 'labels': labels if all(x is not None for x in labels) else None
+ }
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
+ close_dest()
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == "__main__":
+ convert_dataset() # pylint: disable=no-value-for-parameter
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/discriminator_synthesis.py b/stylegan3-fun/discriminator_synthesis.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a3cbeb9008436b1640f5b33d7ddb6acc29fb76e
--- /dev/null
+++ b/stylegan3-fun/discriminator_synthesis.py
@@ -0,0 +1,1007 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from torchvision import transforms
+
+import PIL
+from PIL import Image
+
+try:
+ import ffmpeg
+except ImportError:
+ raise ImportError('ffmpeg-python not found! Install it via "pip install ffmpeg-python"')
+
+import scipy.ndimage as nd
+import numpy as np
+
+import os
+import click
+from typing import Union, Tuple, Optional, List, Type
+from tqdm import tqdm
+import re
+
+from torch_utils import gen_utils
+from network_features import DiscriminatorFeatures
+
+
+# ----------------------------------------------------------------------------
+
+
+@click.group()
+def main():
+ pass
+
+
+# ----------------------------------------------------------------------------
+
+
+def get_available_layers(max_resolution: int) -> List[str]:
+ """Helper function to get the available layers given a max resolution (first block in the Discriminator)"""
+ max_res_log2 = int(np.log2(max_resolution))
+ block_resolutions = [2**i for i in range(max_res_log2, 2, -1)]
+
+ available_layers = ['from_rgb']
+ for block_res in block_resolutions:
+ # We don't add the skip layer, as it's the same as conv1 (due to in-place addition; could be changed)
+ available_layers.extend([f'b{block_res}_conv0', f'b{block_res}_conv1'])
+ # We also skip 'b4_mbstd', as it doesn't add any new information compared to b8_conv1
+ available_layers.extend(['b4_conv', 'fc', 'out'])
+ return available_layers
+
+
+# ----------------------------------------------------------------------------
+# DeepDream code; modified from Erik Linder-Norén's repository: https://github.com/eriklindernoren/PyTorch-Deep-Dream
+
+def get_image(seed: int = 0,
+ image_noise: str = 'random',
+ starting_image: Union[str, os.PathLike] = None,
+ image_size: int = 1024,
+ convert_to_grayscale: bool = False,
+ device: torch.device = torch.device('cpu')) -> Tuple[PIL.Image.Image, str]:
+ """Set the random seed (NumPy + PyTorch), as well as get an image from a path or generate a random one with the seed"""
+ torch.manual_seed(seed)
+ rnd = np.random.RandomState(seed)
+
+ # Load image or generate a random one if none is provided
+ if starting_image is not None:
+ image = Image.open(starting_image).convert('RGB').resize((image_size, image_size), Image.LANCZOS)
+ else:
+ if image_noise == 'random':
+ starting_image = f'random_image-seed_{seed:08d}.jpg'
+ image = Image.fromarray(rnd.randint(0, 255, (image_size, image_size, 3), dtype='uint8'))
+ elif image_noise == 'perlin':
+ try:
+ # Graciously using Mathieu Duchesneau's implementation: https://github.com/duchesneaumathieu/pyperlin
+ from pyperlin import FractalPerlin2D
+ starting_image = f'perlin_image-seed_{seed:08d}.jpg'
+ shape = (3, image_size, image_size)
+ resolutions = [(2**i, 2**i) for i in range(1, 6+1)] # for lacunarity = 2.0 # TODO: set as cli variable
+ factors = [0.5**i for i in range(6)] # for persistence = 0.5 TODO: set as cli variables
+ g_cuda = torch.Generator(device=device).manual_seed(seed)
+ rgb = FractalPerlin2D(shape, resolutions, factors, generator=g_cuda)().cpu().numpy()
+ rgb = (255 * (rgb + 1) / 2).astype(np.uint8) # [-1.0, 1.0] => [0, 255]
+ image = Image.fromarray(rgb.transpose(1, 2, 0), 'RGB') # Reshape leads us to weird tiling
+
+ except ImportError:
+ raise ImportError('pyperlin not found! Install it via "pip install pyperlin"')
+
+ if convert_to_grayscale:
+ image = image.convert('L').convert('RGB') # We do a little trolling to Pillow (so we have a 3-channel image)
+
+ return image, starting_image
+
+
+def crop_resize_rotate(img: PIL.Image.Image,
+ crop_size: int = None,
+ new_size: int = None,
+ rotation_deg: float = None,
+ translate_x: float = 0.0,
+ translate_y: float = 0.0) -> PIL.Image.Image:
+ """Center-crop the input image into a square of sides crop_size; can be resized to new_size; rotated rotation_deg counter-clockwise"""
+ # Center-crop the input image
+ if crop_size is not None:
+ w, h = img.size # Input image width and height
+ img = img.crop(box=((w - crop_size) // 2, # Left pixel coordinate
+ (h - crop_size) // 2, # Upper pixel coordinate
+ (w + crop_size) // 2, # Right pixel coordinate
+ (h + crop_size) // 2)) # Lower pixel coordinate
+ # Resize
+ if new_size is not None:
+ img = img.resize(size=(new_size, new_size), # Requested size of the image in pixels; (width, height)
+ resample=Image.LANCZOS) # Resampling filter
+ # Rotation and translation
+ if rotation_deg is not None:
+ img = img.rotate(angle=rotation_deg, # Angle to rotate image, counter-clockwise
+ resample=Image.BICUBIC, # Resampling filter; options: Image.Resampling.{NEAREST, BILINEAR, BICUBIC}
+ expand=False, # If True, the whole rotated image will be shown
+ translate=(translate_x, translate_y), # Translate the image, from top-left corner (post-rotation)
+ fillcolor=(0, 0, 0)) # Black background
+ # TODO: tile the background
+ return img
+
+
+mean = np.array([0.485, 0.456, 0.406])
+std = np.array([0.229, 0.224, 0.225])
+
+preprocess = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
+
+
+def deprocess(image_np: torch.Tensor) -> np.ndarray:
+ image_np = image_np.squeeze().transpose(1, 2, 0)
+ image_np = image_np * std.reshape((1, 1, 3)) + mean.reshape((1, 1, 3))
+ # image_np = (image_np + 1.0) / 2.0
+ image_np = np.clip(image_np, 0.0, 1.0)
+ image_np = (255 * image_np).astype('uint8')
+ return image_np
+
+
+def clip(image_tensor: torch.Tensor) -> torch.Tensor:
+ """Clamp per channel"""
+ for c in range(3):
+ m, s = mean[c], std[c]
+ image_tensor[0, c] = torch.clamp(image_tensor[0, c], -m / s, (1 - m) / s)
+ return image_tensor
+
+
+def dream(image: PIL.Image.Image,
+ model: torch.nn.Module,
+ layers: List[str],
+ channels: List[int] = None,
+ normed: bool = False,
+ sqrt_normed: bool = False,
+ iterations: int = 20,
+ lr: float = 1e-2) -> np.ndarray:
+ """ Updates the image to maximize outputs for n iterations """
+ Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
+ image = Variable(Tensor(image), requires_grad=True)
+ for i in range(iterations):
+ model.zero_grad()
+ out = model.get_layers_features(image, layers=layers, channels=channels, normed=normed, sqrt_normed=sqrt_normed)
+ loss = sum(layer.norm() for layer in out) # More than one layer may be used
+ loss.backward()
+ avg_grad = np.abs(image.grad.data.cpu().numpy()).mean()
+ norm_lr = lr / avg_grad
+ image.data += norm_lr * image.grad.data
+ image.data = clip(image.data)
+ # image.data = torch.clamp(image.data, -1.0, 1.0)
+ image.grad.data.zero_()
+ return image.cpu().data.numpy()
+
+
+def deep_dream(image: PIL.Image.Image,
+ model: torch.nn.Module,
+ model_resolution: int,
+ layers: List[str],
+ channels: List[int],
+ seed: Union[int, Type[None]],
+ normed: bool,
+ sqrt_normed: bool,
+ iterations: int,
+ lr: float,
+ octave_scale: float,
+ num_octaves: int,
+ unzoom_octave: bool = False,
+ disable_inner_tqdm: bool = False,
+ ignore_initial_transform: bool = False) -> np.ndarray:
+ """ Main deep dream method """
+ # Center-crop and resize
+ if not ignore_initial_transform:
+ image = crop_resize_rotate(img=image, crop_size=min(image.size), new_size=model_resolution)
+ # Preprocess image
+ image = preprocess(image)
+ # image = torch.from_numpy(np.array(image)).permute(-1, 0, 1) / 127.5 - 1.0 # alternative
+ image = image.unsqueeze(0).cpu().data.numpy()
+ # Extract image representations for each octave
+ octaves = [image]
+ for _ in range(num_octaves - 1):
+ # Alternatively, see if we get better results with: https://www.tensorflow.org/tutorials/generative/deepdream#taking_it_up_an_octave
+ octave = nd.zoom(octaves[-1], (1, 1, 1 / octave_scale, 1 / octave_scale), order=1)
+ # Necessary for StyleGAN's Discriminator, as it cannot handle any image size
+ if unzoom_octave:
+ octave = nd.zoom(octave, np.array(octaves[-1].shape) / np.array(octave.shape), order=1)
+ octaves.append(octave)
+
+ detail = np.zeros_like(octaves[-1])
+ tqdm_desc = f'Dreaming w/layers {"|".join(x for x in layers)}'
+ tqdm_desc = f'Seed: {seed} - {tqdm_desc}' if seed is not None else tqdm_desc
+ for octave, octave_base in enumerate(tqdm(octaves[::-1], desc=tqdm_desc, disable=disable_inner_tqdm)):
+ if octave > 0:
+ # Upsample detail to new octave dimension
+ detail = nd.zoom(detail, np.array(octave_base.shape) / np.array(detail.shape), order=1)
+ # Add deep dream detail from previous octave to new base
+ input_image = octave_base + detail
+ # Get new deep dream image
+ dreamed_image = dream(input_image, model, layers, channels, normed, sqrt_normed, iterations, lr)
+ # Extract deep dream details
+ detail = dreamed_image - octave_base
+
+ return deprocess(dreamed_image)
+
+
+# ----------------------------------------------------------------------------
+
+# Helper functions (all base code taken from: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html)
+
+
+class ContentLoss(nn.Module):
+
+ def __init__(self, target,):
+ super(ContentLoss, self).__init__()
+ # we 'detach' the target content from the tree used
+ # to dynamically compute the gradient: this is a stated value,
+ # not a variable. Otherwise the forward method of the criterion
+ # will throw an error.
+ self.target = target.detach()
+
+ def forward(self, input):
+ self.loss = F.mse_loss(input, self.target)
+ return input
+
+
+def gram_matrix(input):
+ a, b, c, d = input.size() # (batch_size, no. feature maps, dims of a f. map (N=c*d))
+
+ features = input.view(a * b, c * d) # resize F_XL into \hat F_XL
+
+ G = torch.mm(features, features.t()) # compute the gram product
+
+ # 'Normalize' the values of the gram matrix by dividing by the number of element in each feature maps.
+ return G.div(a * b * c * d) # can also do torch.numel(input) to get the number of elements
+
+
+class StyleLoss(nn.Module):
+ def __init__(self, target_feature):
+ super(StyleLoss, self).__init__()
+ self.target = gram_matrix(target_feature).detach()
+
+ def forward(self, input):
+ G = gram_matrix(input)
+ self.loss = F.mse_loss(G, self.target)
+ return input
+
+
+@main.command(name='style-transfer', help='Use the StyleGAN2/3 Discriminator to perform style transfer')
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
+@click.option('--content', type=str, help='Content image filename (url or local path)', required=True)
+@click.option('--style', type=str, help='Style image filename (url or local path)', required=True)
+def style_transfer_discriminator(
+ ctx: click.Context,
+ network_pkl: str,
+ cfg: str,
+ content: str,
+ style: str,
+):
+ print('Coming soon!')
+ # Reference: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html
+
+ # Set up device
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+ imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
+
+ loader = transforms.Compose([transforms.Resize(imsize), # scale imported image
+ transforms.ToTensor()]) # transform it into a torch tensor
+
+ # Helper function
+ def image_loader(image_name):
+ image = Image.open(image_name)
+ # fake batch dimension required to fit network's input dimensions
+ image = loader(image).unsqueeze(0)
+ return image.to(device, torch.float)
+
+ style_img = image_loader(style)
+ content_img = image_loader(content)
+
+ # This shouldn't really happen, but just in case
+ assert style_img.size() == content_img.size(), 'Style and content images must be the same size'
+
+ unloader = transforms.ToPILImage() # reconvert into PIL image
+
+ # Load Discriminator
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
+ # TODO: finish this!
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='dream', help='Discriminator Dreaming with the StyleGAN2/3 Discriminator and the chosen layers')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
+# Synthesis options
+@click.option('--seeds', type=gen_utils.num_range, help='Random seeds to use. Accepted comma-separated values, ranges, or combinations: "a,b,c", "a-c", "a,b-d,e".', default='0')
+@click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='perlin', show_default=True)
+@click.option('--starting-image', type=str, help='Path to image to start from', default=None)
+@click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
+@click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=1e-2, show_default=True)
+@click.option('--iterations', '-it', type=int, help='Number of gradient ascent steps per octave', default=20, show_default=True)
+# Layer options
+@click.option('--layers', type=str, help='Layers of the Discriminator to use as the features. If "all", will generate a dream image per available layer in the loaded model. If "use_all", will use all available layers.', default='b16_conv1', show_default=True)
+@click.option('--channels', type=gen_utils.num_range, help='Comma-separated list and/or range of the channels of the Discriminator to use as the features. If "None", will use all channels in each specified layer.', default=None, show_default=True)
+@click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
+@click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
+# Octaves options
+@click.option('--num-octaves', type=int, help='Number of octaves', default=5, show_default=True)
+@click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
+@click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=True, show_default=True)
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
+def discriminator_dream(
+ ctx: click.Context,
+ network_pkl: Union[str, os.PathLike],
+ cfg: Optional[str],
+ seeds: List[int],
+ image_noise: str,
+ starting_image: Union[str, os.PathLike],
+ convert_to_grayscale: bool,
+ class_idx: Optional[int], # TODO: conditional model
+ learning_rate: float,
+ iterations: int,
+ layers: str,
+ channels: Optional[List[int]],
+ norm_model_layers: bool,
+ sqrt_norm_model_layers: bool,
+ num_octaves: int,
+ octave_scale: float,
+ unzoom_octave: bool,
+ outdir: Union[str, os.PathLike],
+ description: str,
+):
+ # Set up device
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+ # Load Discriminator
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
+
+ # Get the model resolution (image resizing and getting available layers)
+ model_resolution = D.img_resolution
+
+ # TODO: do this better, as we can combine these conditions later
+ layers = layers.split(',')
+
+ # We will use the features of the Discriminator, on the layer specified by the user
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
+
+ if 'all' in layers:
+ # Get all the available layers in a list
+ layers = get_available_layers(max_resolution=model_resolution)
+
+ for seed in seeds:
+ # Get the image and image name
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
+ starting_image=starting_image,
+ image_size=model_resolution,
+ convert_to_grayscale=convert_to_grayscale)
+
+ # Make the run dir in the specified output directory
+ desc = f'discriminator-dream-all_layers-seed_{seed}'
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Save starting image
+ image.save(os.path.join(run_dir, f'{os.path.basename(starting_image).split(".")[0]}.jpg'))
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'synthesis_options': {
+ 'seed': seed,
+ 'random_image_noise': image_noise,
+ 'starting_image': starting_image,
+ 'class_idx': class_idx,
+ 'learning_rate': learning_rate,
+ 'iterations': iterations},
+ 'layer_options': {
+ 'layer': layers,
+ 'channels': channels,
+ 'norm_model_layers': norm_model_layers,
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers},
+ 'octaves_options': {
+ 'num_octaves': num_octaves,
+ 'octave_scale': octave_scale,
+ 'unzoom_octave': unzoom_octave},
+ 'extra_parameters': {
+ 'outdir': run_dir,
+ 'description': description}
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ # For each layer:
+ for layer in layers:
+ # Extract deep dream image
+ dreamed_image = deep_dream(image, model, model_resolution, layers=[layer], channels=channels, seed=seed, normed=norm_model_layers,
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, lr=learning_rate,
+ octave_scale=octave_scale, num_octaves=num_octaves, unzoom_octave=unzoom_octave)
+
+ # Save the resulting dreamed image
+ filename = f'layer-{layer}_dreamed_{os.path.basename(starting_image).split(".")[0]}.jpg'
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
+
+ else:
+ if 'use_all' in layers:
+ # Get all available layers
+ layers = get_available_layers(max_resolution=model_resolution)
+ else:
+ # Parse the layers given by the user and leave only those available by the model
+ available_layers = get_available_layers(max_resolution=model_resolution)
+ layers = [layer for layer in layers if layer in available_layers]
+
+ # Make the run dir in the specified output directory
+ desc = f'discriminator-dream-layers_{"-".join(x for x in layers)}'
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ starting_images, used_seeds = [], []
+ for seed in seeds:
+ # Get the image and image name
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
+ starting_image=starting_image,
+ image_size=model_resolution,
+ convert_to_grayscale=convert_to_grayscale)
+
+ # Extract deep dream image
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, channels=channels, seed=seed, normed=norm_model_layers,
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, lr=learning_rate,
+ octave_scale=octave_scale, num_octaves=num_octaves, unzoom_octave=unzoom_octave)
+
+ # For logging later
+ starting_images.append(starting_image)
+ used_seeds.append(seed)
+
+ # Save the resulting image and initial image
+ filename = f'dreamed_{os.path.basename(starting_image)}'
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
+ image.save(os.path.join(run_dir, os.path.basename(starting_image)))
+ starting_image = None
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'synthesis_options': {
+ 'seeds': used_seeds,
+ 'starting_image': starting_images,
+ 'class_idx': class_idx,
+ 'learning_rate': learning_rate,
+ 'iterations': iterations},
+ 'layer_options': {
+ 'layer': layers,
+ 'channels': channels,
+ 'norm_model_layers': norm_model_layers,
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers},
+ 'octaves_options': {
+ 'octave_scale': octave_scale,
+ 'num_octaves': num_octaves,
+ 'unzoom_octave': unzoom_octave},
+ 'extra_parameters': {
+ 'outdir': run_dir,
+ 'description': description}
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='dream-zoom',
+ help='Zoom/rotate/translate after each Discriminator Dreaming iteration. A video will be saved.')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
+# Synthesis options
+@click.option('--seed', type=int, help='Random seed to use', default=0, show_default=True)
+@click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
+@click.option('--starting-image', type=str, help='Path to image to start from', default=None)
+@click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
+@click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=5e-3, show_default=True)
+@click.option('--iterations', '-it', type=click.IntRange(min=1), help='Number of gradient ascent steps per octave', default=10, show_default=True)
+# Layer options
+@click.option('--layers', type=str, help='Comma-separated list of the layers of the Discriminator to use as the features. If "use_all", will use all available layers.', default='b16_conv0', show_default=True)
+@click.option('--channels', type=gen_utils.num_range, help='Comma-separated list and/or range of the channels of the Discriminator to use as the features. If "None", will use all channels in each specified layer.', default=None, show_default=True)
+@click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
+@click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
+# Octaves options
+@click.option('--num-octaves', type=click.IntRange(min=1), help='Number of octaves', default=5, show_default=True)
+@click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
+@click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=False, show_default=True)
+# Individual frame manipulation options
+@click.option('--pixel-zoom', '-zoom', type=int, help='How many pixels to zoom per step (positive for zoom in, negative for zoom out, padded with black)', default=2, show_default=True)
+@click.option('--rotation-deg', '-rot', type=float, help='Rotate image counter-clockwise per frame (padded with black)', default=0.0, show_default=True)
+@click.option('--translate-x', '-tx', type=float, help='Translate the image in the horizontal axis per frame (from left to right, padded with black)', default=0.0, show_default=True)
+@click.option('--translate-y', '-ty', type=float, help='Translate the image in the vertical axis per frame (from top to bottom, padded with black)', default=0.0, show_default=True)
+# Video options
+@click.option('--fps', type=gen_utils.parse_fps, help='FPS for the mp4 video of optimization progress (if saved)', default=25, show_default=True)
+@click.option('--duration-sec', type=float, help='Duration length of the video', default=15.0, show_default=True)
+@click.option('--reverse-video', is_flag=True, help='Add flag to reverse the generated video')
+@click.option('--include-starting-image', type=bool, help='Include the starting image in the final video', default=True, show_default=True)
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
+def discriminator_dream_zoom(
+ ctx: click.Context,
+ network_pkl: Union[str, os.PathLike],
+ cfg: Optional[str],
+ seed: int,
+ image_noise: Optional[str],
+ starting_image: Optional[Union[str, os.PathLike]],
+ convert_to_grayscale: bool,
+ class_idx: Optional[int], # TODO: conditional model
+ learning_rate: float,
+ iterations: int,
+ layers: str,
+ channels: List[int],
+ norm_model_layers: Optional[bool],
+ sqrt_norm_model_layers: Optional[bool],
+ num_octaves: int,
+ octave_scale: float,
+ unzoom_octave: Optional[bool],
+ pixel_zoom: int,
+ rotation_deg: float,
+ translate_x: int,
+ translate_y: int,
+ fps: int,
+ duration_sec: float,
+ reverse_video: bool,
+ include_starting_image: bool,
+ outdir: Union[str, os.PathLike],
+ description: str,
+):
+ # Set up device
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+ # Load Discriminator
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
+
+ # Get the model resolution (for resizing the starting image if needed)
+ model_resolution = D.img_resolution
+ zoom_size = model_resolution - 2 * pixel_zoom
+
+ layers = layers.split(',')
+ if 'use_all' in layers:
+ # Get all available layers
+ layers = get_available_layers(max_resolution=model_resolution)
+ else:
+ # Parse the layers given by the user and leave only those available by the model
+ available_layers = get_available_layers(max_resolution=model_resolution)
+ layers = [layer for layer in layers if layer in available_layers]
+
+ # We will use the features of the Discriminator, on the layer specified by the user
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
+
+ # Get the image and image name
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
+ starting_image=starting_image,
+ image_size=model_resolution,
+ convert_to_grayscale=convert_to_grayscale)
+
+ # Make the run dir in the specified output directory
+ desc = 'discriminator-dream-zoom'
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'synthesis_options': {
+ 'seed': seed,
+ 'random_image_noise': image_noise,
+ 'starting_image': starting_image,
+ 'class_idx': class_idx,
+ 'learning_rate': learning_rate,
+ 'iterations': iterations
+ },
+ 'layer_options': {
+ 'layers': layers,
+ 'channels': channels,
+ 'norm_model_layers': norm_model_layers,
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers
+ },
+ 'octaves_options': {
+ 'num_octaves': num_octaves,
+ 'octave_scale': octave_scale,
+ 'unzoom_octave': unzoom_octave
+ },
+ 'frame_manipulation_options': {
+ 'pixel_zoom': pixel_zoom,
+ 'rotation_deg': rotation_deg,
+ 'translate_x': translate_x,
+ 'translate_y': translate_y,
+ },
+ 'video_options': {
+ 'fps': fps,
+ 'duration_sec': duration_sec,
+ 'reverse_video': reverse_video,
+ 'include_starting_image': include_starting_image,
+ },
+ 'extra_parameters': {
+ 'outdir': run_dir,
+ 'description': description
+ }
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ num_frames = int(np.rint(duration_sec * fps)) # Number of frames for the video
+ n_digits = int(np.log10(num_frames)) + 1 # Number of digits for naming each frame
+
+ # Save the starting image
+ starting_image_name = f'dreamed_{0:0{n_digits}d}.jpg' if include_starting_image else 'starting_image.jpg'
+ image.save(os.path.join(run_dir, starting_image_name))
+
+ for idx, frame in enumerate(tqdm(range(num_frames), desc='Dreaming...', unit='frame')):
+ # Zoom in after the first frame
+ if idx > 0:
+ image = crop_resize_rotate(image, crop_size=zoom_size, new_size=model_resolution,
+ rotation_deg=rotation_deg, translate_x=translate_x, translate_y=translate_y)
+ # Extract deep dream image
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, seed=seed, normed=norm_model_layers,
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, channels=channels,
+ lr=learning_rate, octave_scale=octave_scale, num_octaves=num_octaves,
+ unzoom_octave=unzoom_octave, disable_inner_tqdm=True)
+
+ # Save the resulting image and initial image
+ filename = f'dreamed_{idx + 1:0{n_digits}d}.jpg'
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
+
+ # Now, the dreamed image is the starting image
+ image = Image.fromarray(dreamed_image, 'RGB')
+
+ # Save the final video
+ gen_utils.save_video_from_images(run_dir=run_dir, image_names=f'dreamed_%0{n_digits}d.jpg',
+ video_name='dream-zoom', fps=fps, reverse_video=reverse_video)
+
+
+# ----------------------------------------------------------------------------
+
+@main.command(name='channel-zoom', help='Dream zoom using only the specified channels in the selected layer')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
+# Synthesis options
+@click.option('--seed', type=int, help='Random seed to use', default=0, show_default=True)
+@click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
+@click.option('--starting-image', type=str, help='Path to image to start from', default=None)
+@click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
+@click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=5e-3, show_default=True)
+@click.option('--iterations', '-it', type=click.IntRange(min=1), help='Number of gradient ascent steps per octave', default=10, show_default=True)
+# Layer options
+@click.option('--layer', type=str, help='Layers of the Discriminator to use as the features.', default='b8_conv0', show_default=True)
+@click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
+@click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
+# Octaves options
+@click.option('--num-octaves', type=click.IntRange(min=1), help='Number of octaves', default=5, show_default=True)
+@click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
+@click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=False, show_default=True)
+# Individual frame manipulation options
+@click.option('--pixel-zoom', '-zoom', type=int, help='How many pixels to zoom per step (positive for zoom in, negative for zoom out, padded with black)', default=2, show_default=True)
+@click.option('--rotation-deg', '-rot', type=float, help='Rotate image counter-clockwise per frame (padded with black)', default=0.0, show_default=True)
+@click.option('--translate-x', '-tx', type=float, help='Translate the image in the horizontal axis per frame (from left to right, padded with black)', default=0.0, show_default=True)
+@click.option('--translate-y', '-ty', type=float, help='Translate the image in the vertical axis per frame (from top to bottom, padded with black)', default=0.0, show_default=True)
+# Video options
+@click.option('--frames-per-channel', type=click.IntRange(min=1), help='Number of frames per channel', default=1, show_default=True)
+@click.option('--fps', type=gen_utils.parse_fps, help='FPS for the mp4 video of optimization progress (if saved)', default=25, show_default=True)
+@click.option('--reverse-video', is_flag=True, help='Add flag to reverse the generated video')
+@click.option('--include-starting-image', type=bool, help='Include the starting image in the final video', default=True, show_default=True)
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
+def channel_zoom(
+ ctx: click.Context,
+ network_pkl: Union[str, os.PathLike],
+ cfg: Optional[str],
+ seed: int,
+ image_noise: Optional[str],
+ starting_image: Optional[Union[str, os.PathLike]],
+ convert_to_grayscale: bool,
+ class_idx: Optional[int], # TODO: conditional model
+ learning_rate: float,
+ iterations: int,
+ layer: str,
+ norm_model_layers: Optional[bool],
+ sqrt_norm_model_layers: Optional[bool],
+ num_octaves: int,
+ octave_scale: float,
+ unzoom_octave: Optional[bool],
+ pixel_zoom: int,
+ rotation_deg: float,
+ translate_x: int,
+ translate_y: int,
+ frames_per_channel: int,
+ fps: int,
+ reverse_video: bool,
+ include_starting_image: bool,
+ outdir: Union[str, os.PathLike],
+ description: str,
+):
+ """Zoom in using all the channels of a network (or a specified layer)"""
+ # Set up device
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+ # Load Discriminator
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
+
+ # Get the model resolution (for resizing the starting image if needed)
+ model_resolution = D.img_resolution
+ zoom_size = model_resolution - 2 * pixel_zoom
+
+ if 'use_all' in layer:
+ ctx.fail('Cannot use "use_all" with this command. Please specify the layers you want to use.')
+ else:
+ # Parse the layers given by the user and leave only those available by the model
+ available_layers = get_available_layers(max_resolution=model_resolution)
+ assert layer in available_layers, f'Layer {layer} not available. Available layers: {available_layers}'
+ layers = [layer]
+
+ # We will use the features of the Discriminator, on the layer specified by the user
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
+
+ # Get the image and image name
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
+ starting_image=starting_image,
+ image_size=model_resolution,
+ convert_to_grayscale=convert_to_grayscale)
+
+ # Make the run dir in the specified output directory
+ desc = 'discriminator-channel-zoom'
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Finally, let's get the number of channels in the selected layer
+ channels_dict = {res: D.get_submodule(f'b{res}.conv0').out_channels for res in D.block_resolutions}
+ channels_dict[4] = D.get_submodule('b4.conv').out_channels # Last block has a different name
+ # Get the dimension of the block from the selected layer (e.g., from 'b128_conv0' get '128')
+ block_resolution = re.search(r'b(\d+)_', layer).group(1)
+ total_channels = channels_dict[int(block_resolution)]
+ # Make a list of all the channels, each repeated frames_per_channel
+ channels = np.repeat(np.arange(total_channels), frames_per_channel)
+
+ num_frames = int(np.rint(total_channels * frames_per_channel)) # Number of frames for the video
+ n_digits = int(np.log10(num_frames)) + 1 # Number of digits for naming each frame
+
+ # Save the starting image
+ starting_image_name = f'dreamed_{0:0{n_digits}d}.jpg' if include_starting_image else 'starting_image.jpg'
+ image.save(os.path.join(run_dir, starting_image_name))
+
+ for idx, frame in enumerate(tqdm(range(num_frames), desc='Dreaming...', unit='frame')):
+ # Zoom in after the first frame
+ if idx > 0:
+ image = crop_resize_rotate(image, crop_size=zoom_size, new_size=model_resolution,
+ rotation_deg=rotation_deg, translate_x=translate_x, translate_y=translate_y)
+ # Extract deep dream image
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, seed=seed, normed=norm_model_layers,
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, channels=channels[idx:idx + 1],
+ lr=learning_rate, octave_scale=octave_scale, num_octaves=num_octaves,
+ unzoom_octave=unzoom_octave, disable_inner_tqdm=True)
+
+ # Save the resulting image and initial image
+ filename = f'dreamed_{idx + 1:0{n_digits}d}.jpg'
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
+
+ # Now, the dreamed image is the starting image
+ image = Image.fromarray(dreamed_image, 'RGB')
+
+ # Save the final video
+ gen_utils.save_video_from_images(run_dir=run_dir, image_names=f'dreamed_%0{n_digits}d.jpg', video_name='channel-zoom',
+ fps=fps, reverse_video=reverse_video)
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'synthesis_options': {
+ 'seed': seed,
+ 'random_image_noise': image_noise,
+ 'starting_image': starting_image,
+ 'class_idx': class_idx,
+ 'learning_rate': learning_rate,
+ 'iterations': iterations
+ },
+ 'layer_options': {
+ 'layer': layer,
+ 'channels': 'all',
+ 'total_channels': total_channels,
+ 'norm_model_layers': norm_model_layers,
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers
+ },
+ 'octaves_options': {
+ 'num_octaves': num_octaves,
+ 'octave_scale': octave_scale,
+ 'unzoom_octave': unzoom_octave
+ },
+ 'frame_manipulation_options': {
+ 'pixel_zoom': pixel_zoom,
+ 'rotation_deg': rotation_deg,
+ 'translate_x': translate_x,
+ 'translate_y': translate_y,
+ },
+ 'video_options': {
+ 'fps': fps,
+ 'frames_per_channel': frames_per_channel,
+ 'reverse_video': reverse_video,
+ 'include_starting_image': include_starting_image,
+ },
+ 'extra_parameters': {
+ 'outdir': run_dir,
+ 'description': description
+ }
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='interp', help='Interpolate between two or more seeds')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
+# Synthesis options
+@click.option('--seeds', type=gen_utils.num_range, help='Random seeds to generate the Perlin noise from', required=True)
+@click.option('--interp-type', '-interp', type=click.Choice(['linear', 'spherical']), help='Type of interpolation in Z or W', default='spherical', show_default=True)
+@click.option('--smooth', is_flag=True, help='Add flag to smooth the interpolation between the seeds')
+@click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
+@click.option('--starting-image', type=str, help='Path to image to start from', default=None)
+@click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
+@click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=5e-3, show_default=True)
+@click.option('--iterations', '-it', type=click.IntRange(min=1), help='Number of gradient ascent steps per octave', default=10, show_default=True)
+# Layer options
+@click.option('--layers', type=str, help='Comma-separated list of the layers of the Discriminator to use as the features. If "use_all", will use all available layers.', default='b16_conv0', show_default=True)
+@click.option('--channels', type=gen_utils.num_range, help='Comma-separated list and/or range of the channels of the Discriminator to use as the features. If "None", will use all channels in each specified layer.', default=None, show_default=True)
+@click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
+@click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
+# Octaves options
+@click.option('--num-octaves', type=click.IntRange(min=1), help='Number of octaves', default=5, show_default=True)
+@click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
+@click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=False, show_default=True)
+# TODO: Individual frame manipulation options
+# Video options
+@click.option('--seed-sec', '-sec', type=float, help='Number of seconds between each seed transition', default=5.0, show_default=True)
+@click.option('--fps', type=gen_utils.parse_fps, help='FPS for the mp4 video of optimization progress (if saved)', default=25, show_default=True)
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
+def random_interpolation(
+ ctx: click.Context,
+ network_pkl: Union[str, os.PathLike],
+ cfg: Optional[str],
+ seeds: List[int],
+ interp_type: Optional[str],
+ smooth: Optional[bool],
+ image_noise: Optional[str],
+ starting_image: Optional[Union[str, os.PathLike]],
+ convert_to_grayscale: bool,
+ class_idx: Optional[int], # TODO: conditional model
+ learning_rate: float,
+ iterations: int,
+ layers: str,
+ channels: List[int],
+ norm_model_layers: Optional[bool],
+ sqrt_norm_model_layers: Optional[bool],
+ num_octaves: int,
+ octave_scale: float,
+ unzoom_octave: Optional[bool],
+ seed_sec: float,
+ fps: int,
+ outdir: Union[str, os.PathLike],
+ description: str,
+):
+ """Do a latent walk between random Perlin images (given the seeds) and generate a video with these frames."""
+ # TODO: To make this better and more stable, we generate Perlin noise animations, not interpolations
+ # Set up device
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+ # Load Discriminator
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
+
+ # Get model resolution
+ model_resolution = D.img_resolution
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
+
+ layers = layers.split(',')
+ # Get all available layers
+ if 'use_all' in layers:
+ layers = get_available_layers(max_resolution=model_resolution)
+ else:
+ # Parse the layers given by the user and leave only those available by the model
+ available_layers = get_available_layers(max_resolution=model_resolution)
+ layers = [layer for layer in layers if layer in available_layers]
+
+ # Make the run dir in the specified output directory
+ desc = f'random-interp-layers_{"-".join(x for x in layers)}'
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Number of steps to take between each random image
+ n_steps = int(np.rint(seed_sec * fps))
+ # Total number of frames
+ num_frames = int(n_steps * (len(seeds) - 1))
+ # Total video length in seconds
+ duration_sec = num_frames / fps
+
+ # Number of digits for naming purposes
+ n_digits = int(np.log10(num_frames)) + 1
+
+ # Create interpolation of noises
+ random_images = []
+ for seed in seeds:
+ # Get the starting seed and image
+ image, _ = get_image(seed=seed, image_noise=image_noise, starting_image=starting_image,
+ image_size=model_resolution, convert_to_grayscale=convert_to_grayscale)
+ image = np.array(image) / 255.0
+ random_images.append(image)
+ random_images = np.stack(random_images)
+
+ all_images = np.empty([0] + list(random_images.shape[1:]), dtype=np.float32)
+ # Do interpolation
+ for i in range(len(random_images) - 1):
+ # Interpolate between each pair of images
+ interp = gen_utils.interpolate(random_images[i], random_images[i + 1], n_steps, interp_type, smooth)
+ # Append it to the list of all images
+ all_images = np.append(all_images, interp, axis=0)
+
+ # DeepDream expects a list of PIL.Image objects
+ pil_images = []
+ for idx in range(len(all_images)):
+ im = (255 * all_images[idx]).astype(dtype=np.uint8)
+ pil_images.append(Image.fromarray(im))
+
+ for idx, image in enumerate(tqdm(pil_images, desc='Interpolating...', unit='frame', total=num_frames)):
+ # Extract deep dream image
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, channels=channels, seed=None,
+ normed=norm_model_layers, disable_inner_tqdm=True, ignore_initial_transform=True,
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, lr=learning_rate,
+ octave_scale=octave_scale, num_octaves=num_octaves, unzoom_octave=unzoom_octave)
+
+ # Save the resulting image and initial image
+ filename = f'{image_noise}-interpolation_frame_{idx:0{n_digits}d}.jpg'
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'synthesis_options': {
+ 'seeds': seeds,
+ 'starting_image': starting_image,
+ 'class_idx': class_idx,
+ 'learning_rate': learning_rate,
+ 'iterations': iterations},
+ 'layer_options': {
+ 'layer': layers,
+ 'channels': channels,
+ 'norm_model_layers': norm_model_layers,
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers},
+ 'octaves_options': {
+ 'octave_scale': octave_scale,
+ 'num_octaves': num_octaves,
+ 'unzoom_octave': unzoom_octave},
+ 'extra_parameters': {
+ 'outdir': run_dir,
+ 'description': description}
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ # Generate video
+ print('Saving video...')
+ ffmpeg_command = r'/usr/bin/ffmpeg' if os.name != 'nt' else r'C:\\Ffmpeg\\bin\\ffmpeg.exe'
+ stream = ffmpeg.input(os.path.join(run_dir, f'{image_noise}-interpolation_frame_%0{n_digits}d.jpg'), framerate=fps)
+ stream = ffmpeg.output(stream, os.path.join(run_dir, f'{image_noise}-interpolation.mp4'), crf=20, pix_fmt='yuv420p')
+ ffmpeg.run(stream, capture_stdout=True, capture_stderr=True, cmd=ffmpeg_command)
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == '__main__':
+ main()
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/dnnlib/__init__.py b/stylegan3-fun/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7423bffe245d0ff3f32e8658aa67daae454e64e
--- /dev/null
+++ b/stylegan3-fun/dnnlib/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc b/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea4b58fe3709e1f3438b930f419ff6b5cce8853e
Binary files /dev/null and b/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc b/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32ca03c24c7a7f0484ac7948e754be8c76e34841
Binary files /dev/null and b/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc differ
diff --git a/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc b/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8905d9cfae18562108b3a5998940b25d32b28d1e
Binary files /dev/null and b/stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc b/stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7500bf0ce951aae371afb7116a1e4eaf9b750435
Binary files /dev/null and b/stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc differ
diff --git a/stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc b/stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b032ba7057ce253a7d314a48b079cd5288d4b491
Binary files /dev/null and b/stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc differ
diff --git a/stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc b/stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5995fdccafd5ea57a606dec89703755353350b8
Binary files /dev/null and b/stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc differ
diff --git a/stylegan3-fun/dnnlib/util.py b/stylegan3-fun/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bbdf3bd8fe1c138cd969d37dcc52190b45c4c16
--- /dev/null
+++ b/stylegan3-fun/dnnlib/util.py
@@ -0,0 +1,491 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def format_time_brief(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
+ else:
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/stylegan3-fun/docs/avg_spectra_screen0.png b/stylegan3-fun/docs/avg_spectra_screen0.png
new file mode 100644
index 0000000000000000000000000000000000000000..94339f820ebf7bf8acbcdbfb547ec646481554da
Binary files /dev/null and b/stylegan3-fun/docs/avg_spectra_screen0.png differ
diff --git a/stylegan3-fun/docs/avg_spectra_screen0_half.png b/stylegan3-fun/docs/avg_spectra_screen0_half.png
new file mode 100644
index 0000000000000000000000000000000000000000..0768901c8cc004aa9dce401eb198ade4627720ab
Binary files /dev/null and b/stylegan3-fun/docs/avg_spectra_screen0_half.png differ
diff --git a/stylegan3-fun/docs/configs.md b/stylegan3-fun/docs/configs.md
new file mode 100644
index 0000000000000000000000000000000000000000..7469c87f78556cf0ed3faa3753e31da7b3afb7f0
--- /dev/null
+++ b/stylegan3-fun/docs/configs.md
@@ -0,0 +1,201 @@
+# Training configurations
+
+This document provides guidelines for selecting appropriate training options for various scenarios, as well as an extensive list of recommended configurations.
+
+#### Example
+
+In the remainder of this document, we summarize each configuration as follows:
+
+| Config
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :--------------: | :--------------: | :------------: | :--
+| StyleGAN3‑T | 18.47 | 12.29 | 4.3 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=8.2 --mirror=1`
+
+This corresponds to the following command line:
+
+```.bash
+# Train StyleGAN3-T for AFHQv2 using 8 GPUs.
+python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
+ --gpus=8 --batch=32 --gamma=8.2 --mirror=1
+```
+
+Explanation of the columns:
+- **Config**: StyleGAN3-T (translation equiv.), StyleGAN3-R (translation and rotation equiv.), or StyleGAN2. Reflects the value of `--cfg`.
+- **s/kimg**: Raw training speed, measured separately on Tesla V100 and A100 using our recommended Docker image. The number indicates how many seconds, on average, it takes to process 1000 images from the training set. The number tends to vary slightly over the course of training; typically by no more than ±20%.
+- **GPU mem**: Maximum GPU memory usage observed during training, reported in gigabytes per GPU. The above example uses 8 GPUs, which means that the total GPU memory usage is around 34.4 GB.
+- **Options**: Command line options for `train.py`, excluding `--outdir` and `--data`.
+
+#### Total training time
+
+In addition the raw s/kimg number, the training time also depends on the `--kimg` and `--metric` options. `--kimg` controls the total number of training iterations and is set to 25000 by default. This is long enough to reach convergence in typical cases, but in practice the results should already look quite reasonable around 5000 kimg. `--metrics` determines which quality metrics are computed periodically during training. The default is `fid50k_full`, which increases the training time slightly; typically by no more than 5%. The automatic computation can be disabled by specifying `--metrics=none`.
+
+In the above example, the total training time on V100 is approximately 18.47 s/kimg * 25000 kimg * 1.05 ≈ 485,000 seconds ≈ 5 days and 14 hours. Disabling metric computation (`--metrics=none`) reduces this to approximately 5 days and 8 hours.
+
+## General guidelines
+
+The most important hyperparameter that needs to be tuned on a per-dataset basis is the R1 regularization weight, `--gamma`, that must be specified explicitly for `train.py`. As a rule of thumb, the value of `--gamma` scales quadratically with respect to the training set resolution: doubling the resolution (e.g., 256x256 → 512x512) means that `--gamma` should be multiplied by 4 (e.g., 2 → 8). The optimal value is usually the same for `--cfg=stylegan3-t` and `--cfg=stylegan3-r`, but considerably lower for `--cfg=stylegan2`.
+
+In practice, we recommend selecting the value of `--gamma` as follows:
+- Find the closest match for your specific case in this document (config, resolution, and GPU count).
+- Try training with the same `--gamma` first.
+- Then, try increasing the value by 2x and 4x, and also decreasing it by 2x and 4x.
+- Pick the value that yields the lowest FID.
+
+The results may also be improved by adjusting `--mirror` and `--aug`, depending on the training data. Specifying `--mirror=1` augments the dataset with random *x*-flips, which effectively doubles the number of images. This is generally beneficial with datasets that are horizontally symmetric (e.g., FFHQ), but it can be harmful if the images contain noticeable asymmetric features (e.g., text or letters). Specifying `--aug=noaug` disables adaptive discriminator augmentation (ADA), which may improve the results slightly if the training set is large enough (at least 100k images when accounting for *x*-flips). With small datasets (less than 30k images), it is generally a good idea to leave the augmentations enabled.
+
+It is possible to speed up the training by decreasing network capacity, i.e., `--cbase=16384`. This typically leads to lower quality results, but the difference is less pronounced with low-resolution datasets (e.g., 256x256).
+
+#### Scaling to different number of GPUs
+
+You can select the number of GPUs by changing the value of `--gpu`; this does not affect the convergence curves or training dynamics in any way. By default, the total batch size (`--batch`) is divided evenly among the GPUs, which means that decreasing the number of GPUs yields higher per-GPU memory usage. To avoid running out of memory, you can decrease the per-GPU batch size by specifying `--batch-gpu`, which performs the same computation in multiple passes using gradient accumulation.
+
+By default, `train.py` exports network snapshots once every 200 kimg, i.e., the product of `--snap=50` and `--tick=4`. When using few GPUs (e.g., 1–2), this means that it may take a very long time for the first snapshot to appear. We recommend increasing the snapshot frequency in such cases by specifying `--snap=20`, `--snap=10`, or `--snap=5`.
+
+Note that the configurations listed in this document have been specifically tuned for 8 GPUs. The safest way to scale them to different GPU counts is to adjust `--gpu`, `--batch-gpu`, and `--snap` as described above, but it may be possible to reach faster convergence by adjusting some of the other hyperparameters as well. Note, however, that adjusting the total batch size (`--batch`) requires some experimentation; decreasing `--batch` usually necessitates increasing regularization (`--gamma`) and/or decreasing the learning rates (most importantly `--dlr`).
+
+#### Transfer learning
+
+Transfer learning makes it possible to reach very good results very quickly, especially when the training set is small and/or the images resemble the ones produced by a pre-trained model. To enable transfer learning, you can point `--resume` to one of the pre-trained models that we provide for [StyleGAN3](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan3) and [StyleGAN2](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan2). For example:
+
+```.bash
+# Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
+python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
+ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
+ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
+```
+
+The pre-trained model should be selected to match the specified config, resolution, and architecture-related hyperparameters (e.g., `--cbase`, `--map-depth`, and `--mbstd-group`). You check this by looking at the `fakes_init.png` exported by `train.py` at the beginning; if the configuration is correct, the images should look reasonable.
+
+With transfer learning, the results may be improved slightly by adjusting `--freezed`, in addition to the above guidelines for `--gamma`, `--mirror`, and `--aug`. In our experience, `--freezed=10` and `--freezed=13` tend to work reasonably well.
+
+## Recommended configurations
+
+This section lists recommended settings for StyleGAN3-T and StyleGAN3-R for different resolutions and GPU counts, selected according to the above guidelines. These are intended to provide a good starting point when experimenting with a new dataset. Please note that many of the options (e.g., `--gamma`, `--mirror`, and `--aug`) are still worth adjusting on a case-by-case basis.
+
+#### 128x128 resolution
+
+| Config
| GPUs
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :----------: | :--------------: | :--------------: | :------------: | :--
+| StyleGAN3‑T | 1 | 73.68 | 27.20 | 7.2 | `--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=0.5 --batch-gpu=16 --snap=10`
+| StyleGAN3‑T | 2 | 37.30 | 13.74 | 7.1 | `--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=0.5 --snap=20`
+| StyleGAN3‑T | 4 | 20.66 | 7.52 | 4.1 | `--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=0.5`
+| StyleGAN3‑T | 8 | 11.31 | 4.40 | 2.6 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=0.5`
+| StyleGAN3‑R | 1 | 58.44 | 34.23 | 8.3 | `--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=0.5 --batch-gpu=16 --snap=10`
+| StyleGAN3‑R | 2 | 29.92 | 17.29 | 8.2 | `--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=0.5 --snap=20`
+| StyleGAN3‑R | 4 | 15.49 | 9.53 | 4.5 | `--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=0.5`
+| StyleGAN3‑R | 8 | 8.43 | 5.69 | 2.7 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=0.5`
+
+#### 256x256 resolution
+
+| Config
| GPUs
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :----------: | :--------------: | :--------------: | :------------: | :--
+| StyleGAN3‑T | 1 | 89.15 | 49.81 | 9.5 | `--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=2 --batch-gpu=16 --snap=10`
+| StyleGAN3‑T | 2 | 45.45 | 25.05 | 9.3 | `--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=2 --snap=20`
+| StyleGAN3‑T | 4 | 23.94 | 13.26 | 5.2 | `--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=2`
+| StyleGAN3‑T | 8 | 13.04 | 7.32 | 3.1 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=2`
+| StyleGAN3‑R | 1 | 87.37 | 56.73 | 6.7 | `--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=2 --batch-gpu=8 --snap=10`
+| StyleGAN3‑R | 2 | 44.12 | 28.60 | 6.7 | `--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=2 --batch-gpu=8 --snap=20`
+| StyleGAN3‑R | 4 | 22.42 | 14.39 | 6.6 | `--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=2`
+| StyleGAN3‑R | 8 | 11.88 | 8.03 | 3.7 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=2`
+
+#### 512x512 resolution
+
+| Config
| GPUs
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :----------: | :---------------: | :---------------: | :------------: | :--
+| StyleGAN3‑T | 1 | 137.33 | 90.25 | 7.8 | `--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=8 --batch-gpu=8 --snap=10`
+| StyleGAN3‑T | 2 | 69.65 | 45.42 | 7.7 | `--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=8 --batch-gpu=8 --snap=20`
+| StyleGAN3‑T | 4 | 34.88 | 22.81 | 7.6 | `--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=8`
+| StyleGAN3‑T | 8 | 18.47 | 12.29 | 4.3 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=8`
+| StyleGAN3‑R | 1 | 158.91 | 110.13 | 6.0 | `--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=8 --batch-gpu=4 --snap=10`
+| StyleGAN3‑R | 2 | 79.96 | 55.18 | 6.0 | `--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=8 --batch-gpu=4 --snap=20`
+| StyleGAN3‑R | 4 | 40.86 | 27.99 | 5.9 | `--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=8 --batch-gpu=4`
+| StyleGAN3‑R | 8 | 20.44 | 14.04 | 5.9 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=8`
+
+#### 1024x1024 resolution
+
+| Config
| GPUs
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :----------: | :---------------: | :---------------: | :-------------: | :--
+| StyleGAN3‑T | 1 | 221.85 | 156.91 | 7.0 | `--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=32 --batch-gpu=4 --snap=5`
+| StyleGAN3‑T | 2 | 113.44 | 79.16 | 6.8 | `--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=32 --batch-gpu=4 --snap=10`
+| StyleGAN3‑T | 4 | 57.04 | 39.62 | 6.7 | `--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=32 --batch-gpu=4 --snap=20`
+| StyleGAN3‑T | 8 | 28.71 | 20.01 | 6.6 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=32`
+| StyleGAN3‑R | 1 | 263.44 | 184.81 | 10.2 | `--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=32 --batch-gpu=4 --snap=5`
+| StyleGAN3‑R | 2 | 134.22 | 92.58 | 10.1 | `--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=32 --batch-gpu=4 --snap=10`
+| StyleGAN3‑R | 4 | 67.33 | 46.53 | 10.0 | `--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=32 --batch-gpu=4 --snap=20`
+| StyleGAN3‑R | 8 | 34.12 | 23.42 | 9.9 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=32`
+
+## Configurations used in StyleGAN3 paper
+
+This section lists the exact settings that we used in the "Alias-Free Generative Adversarial Networks" paper.
+
+#### FFHQ-U and FFHQ at 1024x1024 resolution
+
+| Config
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :--------------: | :--------------: | :------------: | :--
+| StyleGAN2 | 17.55 | 14.57 | 6.2 | `--cfg=stylegan2 --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug`
+| StyleGAN3‑T | 28.71 | 20.01 | 6.6 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=32.8 --mirror=1 --aug=noaug`
+| StyleGAN3‑R | 34.12 | 23.42 | 9.9 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=32.8 --mirror=1 --aug=noaug`
+
+#### MetFaces-U at 1024x1024 resolution
+
+| Config
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :--------------: | :--------------: | :-------------: | :--
+| StyleGAN2 | 18.74 | 11.80 | 7.4 | `--cfg=stylegan2 --gpus=8 --batch=32 --gamma=10 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-1024x1024.pkl`
+| StyleGAN3‑T | 29.84 | 21.06 | 7.7 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=16.4 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl`
+| StyleGAN3‑R | 35.10 | 24.32 | 10.9 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl`
+
+#### MetFaces at 1024x1024 resolution
+
+| Config
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :--------------: | :--------------: | :-------------: | :--
+| StyleGAN2 | 18.74 | 11.80 | 7.4 | `--cfg=stylegan2 --gpus=8 --batch=32 --gamma=5 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl`
+| StyleGAN3‑T | 29.84 | 21.06 | 7.7 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl`
+| StyleGAN3‑R | 35.10 | 24.32 | 10.9 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=3.3 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl`
+
+#### AFHQv2 at 512x512 resolution
+
+| Config
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :--------------: | :--------------: | :------------: | :--
+| StyleGAN2 | 10.90 | 6.60 | 3.9 | `--cfg=stylegan2 --gpus=8 --batch=32 --gamma=5 --mirror=1`
+| StyleGAN3‑T | 18.47 | 12.29 | 4.3 | `--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=8.2 --mirror=1`
+| StyleGAN3‑R | 20.44 | 14.04 | 5.9 | `--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=16.4 --mirror=1`
+
+#### FFHQ-U ablations at 256x256 resolution
+
+| Config
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :--------------------------- | :-------------: | :-------------: | :------------: | :--
+| StyleGAN2 | 3.61 | 2.19 | 2.7 | `--cfg=stylegan2 --gpus=8 --batch=64 --gamma=1 --mirror=1 --aug=noaug --cbase=16384 --glr=0.0025 --dlr=0.0025 --mbstd-group=8`
+| StyleGAN3‑T | 7.40 | 3.74 | 3.5 | `--cfg=stylegan3-t --gpus=8 --batch=64 --gamma=1 --mirror=1 --aug=noaug --cbase=16384 --dlr=0.0025`
+| StyleGAN3‑R | 6.71 | 4.81 | 4.2 | `--cfg=stylegan3-r --gpus=8 --batch=64 --gamma=1 --mirror=1 --aug=noaug --cbase=16384 --dlr=0.0025`
+
+## Old StyleGAN2-ADA configurations
+
+This section lists command lines that can be used to match the configurations provided by our previous [StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada-pytorch) codebase. The first table corresponds to `--cfg=auto` (default) for different resolutions and GPU counts, while the second table lists the remaining alternatives.
+
+#### Default configuration
+
+| Res.
| GPUs
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :---------------------- | :----------: | :---------------: | :--------------: | :------------: | :--
+| 128² | 1 | 12.51 | 6.79 | 6.2 | `--cfg=stylegan2 --gpus=1 --batch=32 --gamma=0.1024 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`
+| 128² | 2 | 6.43 | 3.45 | 6.2 | `--cfg=stylegan2 --gpus=2 --batch=64 --gamma=0.0512 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`
+| 128² | 4 | 3.82 | 2.23 | 3.5 | `--cfg=stylegan2 --gpus=4 --batch=64 --gamma=0.0512 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`
+| 256² | 1 | 20.84 | 12.53 | 4.5 | `--cfg=stylegan2 --gpus=1 --batch=16 --gamma=0.8192 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`
+| 256² | 2 | 10.93 | 6.36 | 4.5 | `--cfg=stylegan2 --gpus=2 --batch=32 --gamma=0.4096 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`
+| 256² | 4 | 5.39 | 3.20 | 4.5 | `--cfg=stylegan2 --gpus=4 --batch=64 --gamma=0.2048 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`
+| 256² | 8 | 3.89 | 2.38 | 2.6 | `--cfg=stylegan2 --gpus=8 --batch=64 --gamma=0.2048 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`
+| 512² | 1 | 71.59 | 41.06 | 6.8 | `--cfg=stylegan2 --gpus=1 --batch=8 --gamma=6.5536 --map-depth=2 --glr=0.0025 --dlr=0.0025`
+| 512² | 2 | 36.79 | 20.83 | 6.8 | `--cfg=stylegan2 --gpus=2 --batch=16 --gamma=3.2768 --map-depth=2 --glr=0.0025 --dlr=0.0025`
+| 512² | 4 | 18.12 | 10.45 | 6.7 | `--cfg=stylegan2 --gpus=4 --batch=32 --gamma=1.6384 --map-depth=2 --glr=0.0025 --dlr=0.0025`
+| 512² | 8 | 9.09 | 5.24 | 6.8 | `--cfg=stylegan2 --gpus=8 --batch=64 --gamma=0.8192 --map-depth=2 --glr=0.0025 --dlr=0.0025`
+| 1024² | 1 | 141.83 | 90.39 | 7.2 | `--cfg=stylegan2 --gpus=1 --batch=4 --gamma=52.4288 --map-depth=2`
+| 1024² | 2 | 73.13 | 46.04 | 7.2 | `--cfg=stylegan2 --gpus=2 --batch=8 --gamma=26.2144 --map-depth=2`
+| 1024² | 4 | 36.95 | 23.15 | 7.0 | `--cfg=stylegan2 --gpus=4 --batch=16 --gamma=13.1072 --map-depth=2`
+| 1024² | 8 | 18.47 | 11.66 | 7.3 | `--cfg=stylegan2 --gpus=8 --batch=32 --gamma=6.5536 --map-depth=2`
+
+#### Repro configurations
+
+| Name
| s/kimg
(V100) | s/kimg
(A100) | GPU
mem | Options
+| :---------------------- | :--------------: | :--------------: | :------------: | :--
+| `stylegan2` | 17.55 | 14.57 | 6.2 | `--cfg=stylegan2 --gpus=8 --batch=32 --gamma=10`
+| `paper256` | 4.01 | 2.47 | 2.7 | `--cfg=stylegan2 --gpus=8 --batch=64 --gamma=1 --cbase=16384 --glr=0.0025 --dlr=0.0025 --mbstd-group=8`
+| `paper512` | 9.11 | 5.28 | 6.7 | `--cfg=stylegan2 --gpus=8 --batch=64 --gamma=0.5 --glr=0.0025 --dlr=0.0025 --mbstd-group=8`
+| `paper1024` | 18.56 | 11.75 | 6.9 | `--cfg=stylegan2 --gpus=8 --batch=32 --gamma=2`
diff --git a/stylegan3-fun/docs/dataset-tool-help.txt b/stylegan3-fun/docs/dataset-tool-help.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b8fe51ded07dbc24c0261ba40722ade632e5c625
--- /dev/null
+++ b/stylegan3-fun/docs/dataset-tool-help.txt
@@ -0,0 +1,70 @@
+Usage: dataset_tool.py [OPTIONS]
+
+ Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
+ PyTorch.
+
+ The input dataset format is guessed from the --source argument:
+
+ --source *_lmdb/ Load LSUN dataset
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
+ --source path/ Recursively load all images from path/
+ --source dataset.zip Recursively load all images from dataset.zip
+
+ Specifying the output format and path:
+
+ --dest /path/to/dir Save output files under /path/to/dir
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
+
+ The output dataset format can be either an image folder or an uncompressed
+ zip archive. Zip archives makes it easier to move datasets around file
+ servers and clusters, and may offer better training performance on network
+ file systems.
+
+ Images within the dataset archive will be stored as uncompressed PNG.
+ Uncompresed PNGs can be efficiently decoded in the training loop.
+
+ Class labels are stored in a file called 'dataset.json' that is stored at
+ the dataset root folder. This file has the following structure:
+
+ {
+ "labels": [
+ ["00000/img00000000.png",6],
+ ["00000/img00000001.png",9],
+ ... repeated for every image in the datase
+ ["00049/img00049999.png",1]
+ ]
+ }
+
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
+ not containing class labels.
+
+ Image scale/crop and resolution requirements:
+
+ Output images must be square-shaped and they must all have the same power-
+ of-two dimensions.
+
+ To scale arbitrary input image size to a specific width and height, use
+ the --resolution option. Output resolution will be either the original
+ input resolution (if resolution was not specified) or the one specified
+ with --resolution option.
+
+ Use the --transform=center-crop or --transform=center-crop-wide options to
+ apply a center crop transform on the input image. These options should be
+ used with the --resolution option. For example:
+
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
+ --transform=center-crop-wide --resolution=512x384
+
+Options:
+ --source PATH Directory or archive name for input dataset
+ [required]
+
+ --dest PATH Output directory or archive name for output
+ dataset [required]
+
+ --max-images INTEGER Output only up to `max-images` images
+ --transform [center-crop|center-crop-wide]
+ Input crop/resize mode
+ --resolution WxH Output resolution (e.g., '512x512')
+ --help Show this message and exit.
diff --git a/stylegan3-fun/docs/stylegan3-teaser-1920x1006.png b/stylegan3-fun/docs/stylegan3-teaser-1920x1006.png
new file mode 100644
index 0000000000000000000000000000000000000000..943beebfc6c647078edb5d28bad17ed11ff2600e
--- /dev/null
+++ b/stylegan3-fun/docs/stylegan3-teaser-1920x1006.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6d71fd93e8f9456e6a4d751649effb3096290f2a115d4d62b91de35458918ac2
+size 1819731
diff --git a/stylegan3-fun/docs/train-help.txt b/stylegan3-fun/docs/train-help.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5839bf11123d1838ac61c773da5899d7e9199361
--- /dev/null
+++ b/stylegan3-fun/docs/train-help.txt
@@ -0,0 +1,53 @@
+Usage: train.py [OPTIONS]
+
+ Train a GAN using the techniques described in the paper "Alias-Free
+ Generative Adversarial Networks".
+
+ Examples:
+
+ # Train StyleGAN3-T for AFHQv2 using 8 GPUs.
+ python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
+ --gpus=8 --batch=32 --gamma=8.2 --mirror=1
+
+ # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
+ python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
+ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
+ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
+
+ # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs.
+ python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \
+ --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug
+
+Options:
+ --outdir DIR Where to save the results [required]
+ --cfg [stylegan3-t|stylegan3-r|stylegan2]
+ Base configuration [required]
+ --data [ZIP|DIR] Training data [required]
+ --gpus INT Number of GPUs to use [required]
+ --batch INT Total batch size [required]
+ --gamma FLOAT R1 regularization weight [required]
+ --cond BOOL Train conditional model [default: False]
+ --mirror BOOL Enable dataset x-flips [default: False]
+ --aug [noaug|ada|fixed] Augmentation mode [default: ada]
+ --resume [PATH|URL] Resume from given network pickle
+ --freezed INT Freeze first layers of D [default: 0]
+ --p FLOAT Probability for --aug=fixed [default: 0.2]
+ --target FLOAT Target value for --aug=ada [default: 0.6]
+ --batch-gpu INT Limit batch size per GPU
+ --cbase INT Capacity multiplier [default: 32768]
+ --cmax INT Max. feature maps [default: 512]
+ --glr FLOAT G learning rate [default: varies]
+ --dlr FLOAT D learning rate [default: 0.002]
+ --map-depth INT Mapping network depth [default: varies]
+ --mbstd-group INT Minibatch std group size [default: 4]
+ --desc STR String to include in result dir name
+ --metrics [NAME|A,B,C|none] Quality metrics [default: fid50k_full]
+ --kimg KIMG Total training duration [default: 25000]
+ --tick KIMG How often to print progress [default: 4]
+ --snap TICKS How often to save snapshots [default: 50]
+ --seed INT Random seed [default: 0]
+ --fp32 BOOL Disable mixed-precision [default: False]
+ --nobench BOOL Disable cuDNN benchmarking [default: False]
+ --workers INT DataLoader worker processes [default: 3]
+ -n, --dry-run Print training options and exit
+ --help Show this message and exit.
diff --git a/stylegan3-fun/docs/troubleshooting.md b/stylegan3-fun/docs/troubleshooting.md
new file mode 100644
index 0000000000000000000000000000000000000000..a96c143374338a647e1cefcc773da5ba83e5925b
--- /dev/null
+++ b/stylegan3-fun/docs/troubleshooting.md
@@ -0,0 +1,31 @@
+# Troubleshooting
+
+Our PyTorch code uses custom [CUDA extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) to speed up some of the network layers. Getting these to run can sometimes be a hassle.
+
+This page aims to give guidance on how to diagnose and fix run-time problems related to these extensions.
+
+## Before you start
+
+1. Try Docker first! Ensure you can successfully run our models using the recommended Docker image. Follow the instructions in [README.md](/README.md) to get it running.
+2. Can't use Docker? Read on..
+
+## Installing dependencies
+
+Make sure you've installed everything listed on the requirements section in the [README.md](/README.md). The key components w.r.t. custom extensions are:
+
+- **[CUDA toolkit 11.1](https://developer.nvidia.com/cuda-toolkit)** or later (this is not the same as `cudatoolkit` from Conda).
+ - PyTorch invokes `nvcc` to compile our CUDA kernels.
+- **ninja**
+ - PyTorch uses [Ninja](https://ninja-build.org/) as its build system.
+- **GCC** (Linux) or **Visual Studio** (Windows)
+ - GCC 7.x or later is required. Earlier versions such as GCC 6.3 [are known not to work](https://github.com/NVlabs/stylegan3/issues/2).
+
+#### Why is CUDA toolkit installation necessary?
+
+The PyTorch package contains the required CUDA toolkit libraries needed to run PyTorch, so why is a separate CUDA toolkit installation required? Our models use custom CUDA kernels to implement operations such as efficient resampling of 2D images. PyTorch code invokes the CUDA compiler at run-time to compile these kernels on first-use. The tools and libraries required for this compilation are not bundled in PyTorch and thus a host CUDA toolkit installation is required.
+
+## Things to try
+
+- Completely remove: `$HOME/.cache/torch_extensions` (Linux) or `C:\Users\\AppData\Local\torch_extensions\torch_extensions\Cache` (Windows) and re-run StyleGAN3 python code.
+- Run ninja in `$HOME/.cache/torch_extensions` to see that it builds.
+- Inspect the `build.ninja` in the build directories under `$HOME/.cache/torch_extensions` and check CUDA tools and versions are consistent with what you intended to use.
diff --git a/stylegan3-fun/docs/visualizer_screen0.png b/stylegan3-fun/docs/visualizer_screen0.png
new file mode 100644
index 0000000000000000000000000000000000000000..17c8cffa30f3f096006e6671d44d0de955008a9d
--- /dev/null
+++ b/stylegan3-fun/docs/visualizer_screen0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fafce2c734f6eb2bed6467d3a1e83f9910e0bc4b374414c79da95c7c52d68ab9
+size 1257068
diff --git a/stylegan3-fun/docs/visualizer_screen0_half.png b/stylegan3-fun/docs/visualizer_screen0_half.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab76f7f2aeaec8aa479c899852e1408fdd45e417
Binary files /dev/null and b/stylegan3-fun/docs/visualizer_screen0_half.png differ
diff --git a/stylegan3-fun/environment.yml b/stylegan3-fun/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..993278329e31456e24a7167e4556e880527ccf62
--- /dev/null
+++ b/stylegan3-fun/environment.yml
@@ -0,0 +1,35 @@
+name: stylegan3
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge # PR #80 by @SetZero / #143 by @coldwaterq
+dependencies:
+ - python >= 3.8
+ - pip
+ - numpy>=1.20
+ - click>=8.0
+ - pillow=8.3.1
+ - scipy=1.7.1
+ - pytorch::pytorch=1.11 # We can now use Pytorch 1.11!
+ - nvidia::cudatoolkit=11.3 # Necessary to remove previous PR by @edstoica
+ - requests=2.26.0
+ - tqdm=4.62.2
+ - ninja=1.10.2
+ - matplotlib=3.4.2
+ - imageio=2.9.0
+ - pip:
+ - imgui==1.3.0
+ - glfw==2.2.0
+ - pyopengl==3.1.5
+ - imageio-ffmpeg==0.4.3
+ - pyspng
+ - pyperlin # for Discriminator Synthesis
+ - psutil # PR #125 by @fastflair / #111 by @siddharthksah
+ - tensorboard # PR #125 by @fastflair
+ - torchvision==0.12.0 # For "Discriminator Synthesis" / discriminator_synthesis.py
+ - pyperlin # For "Discriminator Synthesis" / discriminator_synthesis.py
+ - scikit-learn # For "Self-Distilled StyleGAN" / multimodal_truncation.py
+ - moviepy==1.0.3
+ - ffmpeg-python==0.2.0
+ - scikit-video==1.1.11
+ - setuptools==59.5.0 # PR #8 by @ZibbeZabbe
\ No newline at end of file
diff --git a/stylegan3-fun/gen_images.py b/stylegan3-fun/gen_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..66fdef44fd163d5207911a2b4fd7853aff73855c
--- /dev/null
+++ b/stylegan3-fun/gen_images.py
@@ -0,0 +1,145 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate images using pretrained network pickle."""
+
+import os
+import re
+from typing import List, Optional, Tuple, Union
+
+import click
+import dnnlib
+import numpy as np
+import PIL.Image
+import torch
+
+import legacy
+
+#----------------------------------------------------------------------------
+
+def parse_range(s: Union[str, List]) -> List[int]:
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
+
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
+ '''
+ if isinstance(s, list): return s
+ ranges = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for p in s.split(','):
+ m = range_re.match(p)
+ if m:
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
+ else:
+ ranges.append(int(p))
+ return ranges
+
+#----------------------------------------------------------------------------
+
+def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
+ '''Parse a floating point 2-vector of syntax 'a,b'.
+
+ Example:
+ '0,1' returns (0,1)
+ '''
+ if isinstance(s, tuple): return s
+ parts = s.split(',')
+ if len(parts) == 2:
+ return (float(parts[0]), float(parts[1]))
+ raise ValueError(f'cannot parse 2-vector {s}')
+
+#----------------------------------------------------------------------------
+
+def make_transform(translate: Tuple[float,float], angle: float):
+ m = np.eye(3)
+ s = np.sin(angle/360.0*np.pi*2)
+ c = np.cos(angle/360.0*np.pi*2)
+ m[0][0] = c
+ m[0][1] = s
+ m[0][2] = translate[0]
+ m[1][0] = -s
+ m[1][1] = c
+ m[1][2] = translate[1]
+ return m
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
+@click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
+@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
+def generate_images(
+ network_pkl: str,
+ seeds: List[int],
+ truncation_psi: float,
+ noise_mode: str,
+ outdir: str,
+ translate: Tuple[float,float],
+ rotate: float,
+ class_idx: Optional[int]
+):
+ """Generate images using pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+
+ \b
+ # Generate uncurated images with truncation using the MetFaces-U dataset
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
+ """
+
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ os.makedirs(outdir, exist_ok=True)
+
+ # Labels.
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ if class_idx is None:
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print ('warn: --class=lbl ignored when running on an unconditional network')
+
+ # Generate images.
+ for seed_idx, seed in enumerate(seeds):
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
+
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
+ # operations in the network.
+ if hasattr(G.synthesis, 'input'):
+ m = make_transform(translate, rotate)
+ m = np.linalg.inv(m)
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
+
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
+
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ generate_images() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/gen_video.py b/stylegan3-fun/gen_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..51926d6a284704491f3a2b25cfcb4e07dfe56b20
--- /dev/null
+++ b/stylegan3-fun/gen_video.py
@@ -0,0 +1,281 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Generate lerp videos using pretrained network pickle."""
+
+import copy
+import os
+import re
+from typing import List, Optional, Tuple, Union
+
+import click
+import dnnlib
+import imageio
+import numpy as np
+import scipy.interpolate
+import torch
+from tqdm import tqdm
+
+import legacy
+from torch_utils import gen_utils
+
+
+# ----------------------------------------------------------------------------
+
+
+def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
+ batch_size, channels, img_h, img_w = img.shape
+ if grid_w is None:
+ grid_w = batch_size // grid_h
+ assert batch_size == grid_w * grid_h
+ if float_to_uint8:
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
+ img = img.permute(2, 0, 3, 1, 4)
+ img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
+ if chw_to_hwc:
+ img = img.permute(1, 2, 0)
+ if to_numpy:
+ img = img.cpu().numpy()
+ return img
+
+
+# ----------------------------------------------------------------------------
+
+
+def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
+ """Parse a floating point 2-vector of syntax 'a,b'.
+
+ Example:
+ '0,1' returns (0,1)
+ """
+ if isinstance(s, tuple): return s
+ parts = s.split(',')
+ if len(parts) == 2:
+ return (float(parts[0]), float(parts[1]))
+ raise ValueError(f'cannot parse 2-vector {s}')
+
+
+# ----------------------------------------------------------------------------
+
+
+def make_transform(translate: Tuple[float,float], angle: float):
+ m = np.eye(3)
+ s = np.sin(angle/360.0*np.pi*2)
+ c = np.cos(angle/360.0*np.pi*2)
+ m[0][0] = c
+ m[0][1] = s
+ m[0][2] = translate[0]
+ m[1][0] = -s
+ m[1][1] = c
+ m[1][2] = translate[1]
+ return m
+
+
+# ----------------------------------------------------------------------------
+
+
+def gen_interp_video(G,
+ mp4: str,
+ seeds: List[int],
+ shuffle_seed: int = None,
+ w_frames: int = 60*4,
+ kind: str = 'cubic',
+ grid_dims: Tuple[int] = (1,1),
+ num_keyframes: int = None,
+ wraps: int = 2,
+ psi: float = 1.0,
+ device: torch.device = torch.device('cuda'),
+ stabilize_video: bool = True,
+ **video_kwargs):
+ grid_w = grid_dims[0]
+ grid_h = grid_dims[1]
+
+ if stabilize_video:
+ # Thanks to @RiversHaveWings and @nshepperd1
+ if hasattr(G.synthesis, 'input'):
+ shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))
+ G.synthesis.input.affine.bias.data.add_(shift.squeeze(0))
+ G.synthesis.input.affine.weight.data.zero_()
+
+ # Get the Generator's transform
+ m = G.synthesis.input.transform if hasattr(G.synthesis, 'input') else None
+
+ if num_keyframes is None:
+ if len(seeds) % (grid_w*grid_h) != 0:
+ raise ValueError('Number of input seeds must be divisible by grid W*H')
+ num_keyframes = len(seeds) // (grid_w*grid_h)
+
+ all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
+ for idx in range(num_keyframes*grid_h*grid_w):
+ all_seeds[idx] = seeds[idx % len(seeds)]
+
+ if shuffle_seed is not None:
+ rng = np.random.RandomState(seed=shuffle_seed)
+ rng.shuffle(all_seeds)
+
+ zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
+ ws = G.mapping(z=zs, c=None, truncation_psi=psi)
+ _ = G.synthesis(ws[:1]) # warm up
+ ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
+
+ # Interpolation.
+ grid = []
+ for yi in range(grid_h):
+ row = []
+ for xi in range(grid_w):
+ x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
+ y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
+ interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
+ row.append(interp)
+ grid.append(row)
+
+ # Render video.
+ video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
+ for frame_idx in tqdm(range(num_keyframes * w_frames)):
+ imgs = []
+ # Construct an inverse affine matrix and pass to the generator. The generator expects
+ # this matrix as an inverse to avoid potentially failing numerical operations in the network.
+ if hasattr(G.synthesis, 'input'):
+ # Set default values for each affine transformation
+ total_rotation = 0.0 # If >= 0.0, will rotate the pixels counter-clockwise w.r.t. the center; in radians
+ total_translation_x = 0.0 # If >= 0.0, will translate all pixels to the right; if <= 0.0, to the left
+ total_translation_y = 0.0 # If >= 0.0, will translate all pixels upwards; if <= 0.0, downwards
+ total_scale_x = 1.0 # If <= 1.0, will zoom in; else, will zoom out (x-axis)
+ total_scale_y = 1.0 # If <= 1.0, will zoom in; else, will zoom out (y-axis)
+ total_shear_x = 0.0 # If >= 0.0, will shear pixels to the right, keeping y fixed; if <= 0.0, to the left
+ total_shear_y = 0.0 # If >= 0.0, will shear pixels upwards, keeping x fixed; if <= 0.0, downwards
+ mirror_x = False # Mirror along the x-axis; if True, will flip the image horizontally (can't be a function of frame_idx)
+ mirror_y = False # Mirror along the y-axis; if True, will flip the image vertically (can't be a function of frame_idx)
+
+ # Go nuts with these. They can be constants as above to fix centering/rotation in your video,
+ # or you can make them functions of frame_idx to animate them, such as (uncomment as many as you want to try):
+ # total_scale_x = 1 + np.sin(np.pi*frame_idx/(num_keyframes * w_frames))/2 # will oscillate between 0.5 and 1.5
+ # total_rotation = 4*np.pi*frame_idx/(num_keyframes * w_frames) # 4 will dictate the number of rotations, so 1 full rotation
+ # total_shear_y = 2*np.sin(2*np.pi*frame_idx/(num_keyframes * w_frames)) # will oscillate between -2 and 2
+
+ # We then use these values to construct the affine matrix
+ m = gen_utils.make_affine_transform(m, angle=total_rotation, translate_x=total_translation_x,
+ translate_y=total_translation_y, scale_x=total_scale_x,
+ scale_y=total_scale_y, shear_x=total_shear_x, shear_y=total_shear_y,
+ mirror_x=mirror_x, mirror_y=mirror_y)
+ m = np.linalg.inv(m)
+ # Finally, we pass the matrix to the generator
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
+
+ # The rest stays the same, for all you gen_video.py lovers out there
+ for yi in range(grid_h):
+ for xi in range(grid_w):
+ interp = grid[yi][xi]
+ w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
+ img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
+ imgs.append(img)
+ video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
+ video_out.close()
+
+
+# ----------------------------------------------------------------------------
+
+
+def parse_range(s: Union[str, List[int]]) -> List[int]:
+ """Parse a comma separated list of numbers or ranges and return a list of ints.
+
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
+ """
+ if isinstance(s, list): return s
+ ranges = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for p in s.split(','):
+ m = range_re.match(p)
+ if m:
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
+ else:
+ ranges.append(int(p))
+ return ranges
+
+
+# ----------------------------------------------------------------------------
+
+
+def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
+ """Parse a 'M,N' or 'MxN' integer tuple.
+
+ Example:
+ '4x2' returns (4,2)
+ '0,1' returns (0,1)
+ """
+ if isinstance(s, tuple): return s
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
+ if m:
+ return (int(m.group(1)), int(m.group(2)))
+ raise ValueError(f'cannot parse tuple {s}')
+
+
+# ----------------------------------------------------------------------------
+
+
+@click.command()
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
+@click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
+@click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
+@click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
+@click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--stabilize-video', is_flag=True, help='Stabilize the video by anchoring the mapping to w_avg')
+@click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
+def generate_images(
+ network_pkl: str,
+ seeds: List[int],
+ shuffle_seed: Optional[int],
+ truncation_psi: float,
+ grid: Tuple[int,int],
+ num_keyframes: Optional[int],
+ stabilize_video: bool,
+ w_frames: int,
+ output: str
+):
+ """Render a latent vector interpolation video.
+
+ Examples:
+
+ \b
+ # Render a 4x2 grid of interpolations for seeds 0 through 31.
+ python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
+
+ Animation length and seed keyframes:
+
+ The animation length is either determined based on the --seeds value or explicitly
+ specified using the --num-keyframes option.
+
+ When num keyframes is specified with --num-keyframes, the output video length
+ will be 'num_keyframes*w_frames' frames.
+
+ If --num-keyframes is not specified, the number of seeds given with
+ --seeds must be divisible by grid size W*H (--grid). In this case the
+ output video length will be '# seeds/(w*h)*w_frames' frames.
+ """
+
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
+
+ gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames,
+ seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, stabilize_video=stabilize_video)
+
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == "__main__":
+ generate_images() # pylint: disable=no-value-for-parameter
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/generate.py b/stylegan3-fun/generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e93267ff334b61fcad34297da728cfe18484cda0
--- /dev/null
+++ b/stylegan3-fun/generate.py
@@ -0,0 +1,838 @@
+import os
+import sys
+from typing import List, Optional, Union, Tuple
+import click
+
+import dnnlib
+from torch_utils import gen_utils
+import copy
+
+import scipy
+import numpy as np
+import PIL.Image
+import torch
+
+import legacy
+from viz.renderer import Renderer
+
+os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide'
+import moviepy.editor
+
+
+# ----------------------------------------------------------------------------
+
+
+# We group the different types of generation (images, grid, video, wacky stuff) into a main function
+@click.group()
+def main():
+ pass
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='images')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename: can be URL, local file, or the name of the model in torch_utils.gen_utils.resume_specs', required=True)
+@click.option('--device', help='Device to use for image generation; using the CPU is slower than the GPU', type=click.Choice(['cpu', 'cuda']), default='cuda', show_default=True)
+@click.option('--cfg', type=click.Choice(gen_utils.available_cfgs), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+# Synthesis options (feed a list of seeds or give the projected w to synthesize)
+@click.option('--seeds', type=gen_utils.num_range, help='List of random seeds')
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None, show_default=True)
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
+@click.option('--projected-w', help='Projection result file; can be either .npy or .npz files', type=click.Path(exists=True, dir_okay=False), metavar='FILE')
+@click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
+# Save the output of the intermediate layers
+@click.option('--layer', 'layer_name', type=str, help='Layer name to extract; if unsure, use `--available-layers`', default=None, show_default=True)
+@click.option('--available-layers', is_flag=True, help='List the available layers in the trained model and exit')
+@click.option('--starting-channel', 'starting_channel', type=int, help='Starting channel for the layer extraction', default=0, show_default=True)
+@click.option('--grayscale', 'save_grayscale', type=bool, help='Use the first channel starting from `--starting-channel` to generate a grayscale image.', default=False, show_default=True)
+@click.option('--rgb', 'save_rgb', type=bool, help='Use 3 consecutive channels (if they exist) to generate a RGB image, starting from `--starting-channel`.', default=False, show_default=True)
+@click.option('--rgba', 'save_rgba', type=bool, help='Use 4 consecutive channels (if they exist) to generate a RGBA image, starting from `--starting-channel`.', default=False, show_default=True)
+@click.option('--img-scale-db', 'img_scale_db', type=click.FloatRange(min=-40, max=40), help='Scale the image pixel values, akin to "exposure" (lower, the image is grayer/, higher the more white/burnt regions)', default=0, show_default=True)
+@click.option('--img-normalize', 'img_normalize', type=bool, help='Normalize images of the selected layer and channel', default=False, show_default=True)
+# Grid options
+@click.option('--save-grid', is_flag=True, help='Use flag to save image grid')
+@click.option('--grid-width', '-gw', type=click.IntRange(min=1), help='Grid width (number of columns)', default=None)
+@click.option('--grid-height', '-gh', type=click.IntRange(min=1), help='Grid height (number of rows)', default=None)
+# Extra parameters for saving the results
+@click.option('--save-dlatents', is_flag=True, help='Use flag to save individual dlatents (W) for each individual resulting image')
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'images'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='generate-images', show_default=True)
+def generate_images(
+ ctx: click.Context,
+ network_pkl: str,
+ device: Optional[str],
+ cfg: Optional[str],
+ seeds: Optional[List[int]],
+ truncation_psi: Optional[float],
+ class_idx: Optional[int],
+ noise_mode: Optional[str],
+ anchor_latent_space: Optional[bool],
+ projected_w: Optional[Union[str, os.PathLike]],
+ new_center: Tuple[str, Union[int, np.ndarray]],
+ layer_name: Optional[str],
+ available_layers: Optional[bool],
+ starting_channel: Optional[int],
+ save_grayscale: Optional[bool],
+ save_rgb: Optional[bool],
+ save_rgba: Optional[bool],
+ img_scale_db: Optional[float],
+ img_normalize: Optional[bool],
+ save_grid: Optional[bool],
+ grid_width: int,
+ grid_height: int,
+ save_dlatents: Optional[bool],
+ outdir: Union[str, os.PathLike],
+ description: str,
+):
+ """Generate images using pretrained network pickle.
+
+ Examples:
+
+ \b
+ # Generate curated MetFaces images without truncation (Fig.10 left)
+ python generate.py images --trunc=1 --seeds=85,265,297,849 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ \b
+ # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
+ python generate.py images --trunc=0.7 --seeds=600-605 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ \b
+ # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
+ python generate.py images --seeds=0-35 --class=1 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
+
+ \b
+ # Render an image from projected W
+ python generate.py images --projected_w=projected_w.npz \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ \b
+ Generate class conditional StyleGAN2 WikiArt images, save each individually, and save all of them as a grid
+ python generate.py images --cfg=stylegan2 --network=wikiart1024-C --class=155 \\
+ --trunc=0.7 --seeds=10-50 --save-grid
+ """
+ # Sanity check
+ if len(seeds) < 1:
+ ctx.fail('Use `--seeds` to specify at least one seed.')
+
+ device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu')
+
+ # Load the network
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
+
+ if available_layers:
+ click.secho(f'Printing available layers (name, channels and size) for "{network_pkl}"...', fg='blue')
+ _ = Renderer().render(G=G, available_layers=available_layers)
+ sys.exit(1)
+
+ # Setup for using CPU
+ if device.type == 'cpu':
+ gen_utils.use_cpu(G)
+
+ # Stabilize/anchor the latent space
+ if anchor_latent_space:
+ gen_utils.anchor_latent_space(G)
+
+ description = 'generate-images' if len(description) == 0 else description
+ description = f'{description}-{layer_name}_layer' if layer_name is not None else description
+ # Create the run dir with the given name description
+ run_dir = gen_utils.make_run_dir(outdir, description)
+
+ # Synthesize the result of a W projection.
+ if projected_w is not None:
+ if seeds is not None:
+ print('warn: --seeds is ignored when using --projected-w')
+ print(f'Generating images from projected W "{projected_w}"')
+ ws, ext = gen_utils.get_latent_from_file(projected_w, return_ext=True)
+ ws = torch.tensor(ws, device=device)
+ assert ws.shape[1:] == (G.num_ws, G.w_dim)
+ n_digits = int(np.log10(len(ws))) + 1 # number of digits for naming the images
+ if ext == '.npy':
+ img = gen_utils.w_to_img(G, ws, noise_mode)[0]
+ PIL.Image.fromarray(img, gen_utils.channels_dict[G.synthesis.img_channels]).save(f'{run_dir}/proj.png')
+ else:
+ for idx, w in enumerate(ws):
+ img = gen_utils.w_to_img(G, w, noise_mode)[0]
+ PIL.Image.fromarray(img,
+ gen_utils.channels_dict[G.synthesis.img_channels]).save(f'{run_dir}/proj{idx:0{n_digits}d}.png')
+ return
+
+ # Labels.
+ class_idx = gen_utils.parse_class(G, class_idx, ctx)
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print('warn: --class=lbl ignored when running on an unconditional network')
+
+ if seeds is None:
+ ctx.fail('--seeds option is required when not using --projected-w')
+
+ # Recenter the latent space, if specified
+ if new_center is None:
+ w_avg = G.mapping.w_avg
+ else:
+ new_center, new_center_value = new_center
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
+ if isinstance(new_center_value, int):
+ w_avg = gen_utils.get_w_from_seed(G, device, new_center_value,
+ truncation_psi=1.0) # We want the pure dlatent
+ elif isinstance(new_center_value, np.ndarray):
+ w_avg = torch.from_numpy(new_center_value).to(device)
+ else:
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
+
+ # Generate images.
+ images = []
+ for seed_idx, seed in enumerate(seeds):
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
+ dlatent = gen_utils.get_w_from_seed(G, device, seed, truncation_psi=1.0)
+ # Do truncation trick with center (new or global)
+ w = w_avg + (dlatent - w_avg) * truncation_psi
+
+ # TODO: this is starting to look like an auxiliary function!
+ # Save the intermediate layer output.
+ if layer_name is not None:
+ # Sanity check (meh, could be done better)
+ submodule_names = {name: mod for name, mod in G.synthesis.named_modules()}
+ assert layer_name in submodule_names, f'Layer "{layer_name}" not found in the network! Available layers: {", ".join(submodule_names)}'
+ assert True in (save_grayscale, save_rgb, save_rgba), 'You must select to save the image in at least one of the three possible formats! (L, RGB, RGBA)'
+
+ sel_channels = 3 if save_rgb else (1 if save_grayscale else 4)
+ res = Renderer().render(G=G, layer_name=layer_name, dlatent=w, sel_channels=sel_channels,
+ base_channel=starting_channel, img_scale_db=img_scale_db, img_normalize=img_normalize)
+ img = res.image
+ else:
+ img = gen_utils.w_to_img(G, w, noise_mode)[0]
+
+ if save_grid:
+ images.append(img)
+
+ # Get the image format, whether user-specified or the one from the model
+ try:
+ img_format = gen_utils.channels_dict[sel_channels]
+ except NameError:
+ img_format = gen_utils.channels_dict[G.synthesis.img_channels]
+
+ # Save image, avoiding grayscale errors in PIL
+ PIL.Image.fromarray(img[:, :, 0] if img.shape[-1] == 1 else img,
+ img_format).save(os.path.join(run_dir, f'seed{seed}.png'))
+ if save_dlatents:
+ np.save(os.path.join(run_dir, f'seed{seed}.npy'), w.unsqueeze(0).cpu().numpy())
+
+ if save_grid:
+ print('Saving image grid...')
+ images = np.array(images)
+
+ # We let the function infer the shape of the grid
+ if (grid_width, grid_height) == (None, None):
+ grid = gen_utils.create_image_grid(images)
+ # The user tells the specific shape of the grid, but one value may be None
+ else:
+ grid = gen_utils.create_image_grid(images, (grid_width, grid_height))
+
+ grid = grid[:, :, 0] if grid.shape[-1] == 1 else grid
+ PIL.Image.fromarray(grid, img_format).save(os.path.join(run_dir, 'grid.png'))
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'device': device.type,
+ 'config': cfg,
+ 'synthesis': {
+ 'seeds': seeds,
+ 'truncation_psi': truncation_psi,
+ 'class_idx': class_idx,
+ 'noise_mode': noise_mode,
+ 'anchor_latent_space': anchor_latent_space,
+ 'projected_w': projected_w,
+ 'new_center': new_center
+ },
+ 'intermediate_representations': {
+ 'layer': layer_name,
+ 'starting_channel': starting_channel,
+ 'grayscale': save_grayscale,
+ 'rgb': save_rgb,
+ 'rgba': save_rgba,
+ 'img_scale_db': img_scale_db,
+ 'img_normalize': img_normalize
+ },
+ 'grid_options': {
+ 'save_grid': save_grid,
+ 'grid_width': grid_width,
+ 'grid_height': grid_height,
+ },
+ 'extra_parameters': {
+ 'save_dlatents': save_dlatents,
+ 'run_dir': run_dir,
+ 'description': description,
+ }
+ }
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='random-video')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(gen_utils.available_cfgs), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+# Synthesis options
+@click.option('--seeds', type=gen_utils.num_range, help='List of random seeds', required=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
+# Save the output of the intermediate layers
+@click.option('--layer', 'layer_name', type=str, help='Layer name to extract; if unsure, use `--available-layers`', default=None, show_default=True)
+@click.option('--available-layers', is_flag=True, help='List the available layers in the trained model and exit')
+@click.option('--starting-channel', 'starting_channel', type=int, help='Starting channel for the layer extraction', default=0, show_default=True)
+@click.option('--grayscale', 'save_grayscale', type=bool, help='Use the first channel starting from `--starting-channel` to generate a grayscale image.', default=False, show_default=True)
+@click.option('--rgb', 'save_rgb', type=bool, help='Use 3 consecutive channels (if they exist) to generate a RGB image, starting from `--starting-channel`.', default=False, show_default=True)
+@click.option('--img-scale-db', 'img_scale_db', type=click.FloatRange(min=-40, max=40), help='Scale the image pixel values, akin to "exposure" (lower, the image is grayer/, higher the more white/burnt regions)', default=0, show_default=True)
+@click.option('--img-normalize', 'img_normalize', type=bool, help='Normalize images of the selected layer and channel', default=False, show_default=True)
+# Video options
+@click.option('--grid-width', '-gw', type=click.IntRange(min=1), help='Video grid width / number of columns', default=None, show_default=True)
+@click.option('--grid-height', '-gh', type=click.IntRange(min=1), help='Video grid height / number of rows', default=None, show_default=True)
+@click.option('--slowdown', type=gen_utils.parse_slowdown, help='Slow down the video by this amount; will be approximated to the nearest power of 2', default='1', show_default=True)
+@click.option('--duration-sec', '-sec', type=float, help='Duration length of the video', default=30.0, show_default=True)
+@click.option('--fps', type=click.IntRange(min=1), help='Video FPS.', default=30, show_default=True)
+@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'video'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results')
+def random_interpolation_video(
+ ctx: click.Context,
+ network_pkl: Union[str, os.PathLike],
+ cfg: Optional[str],
+ seeds: List[int],
+ truncation_psi: Optional[float],
+ new_center: Tuple[str, Union[int, np.ndarray]],
+ class_idx: Optional[int],
+ noise_mode: Optional[str],
+ anchor_latent_space: Optional[bool],
+ layer_name: Optional[str],
+ available_layers: Optional[bool],
+ starting_channel: Optional[int],
+ save_grayscale: Optional[bool],
+ save_rgb: Optional[bool],
+ img_scale_db: Optional[float],
+ img_normalize: Optional[bool],
+ grid_width: int,
+ grid_height: int,
+ slowdown: Optional[int],
+ duration_sec: Optional[float],
+ fps: int,
+ outdir: Union[str, os.PathLike],
+ description: str,
+ compress: bool,
+ smoothing_sec: Optional[float] = 3.0 # for Gaussian blur; won't be a command-line parameter, change at own risk
+):
+ """
+ Generate a random interpolation video using a pretrained network.
+
+ Examples:
+
+ \b
+ # Generate a 30-second long, untruncated MetFaces video at 30 FPS (3 rows and 2 columns; horizontal):
+ python generate.py random-video --seeds=0-5 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ \b
+ # Generate a 60-second long, truncated 1x2 MetFaces video at 60 FPS (2 rows and 1 column; vertical):
+ python generate.py random-video --trunc=0.7 --seeds=10,20 --grid-width=1 --grid-height=2 \\
+ --fps=60 -sec=60 --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ """
+ # Sanity check
+ if len(seeds) < 1:
+ ctx.fail('Use `--seeds` to specify at least one seed.')
+
+ device = torch.device('cuda')
+
+ # Load the network
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
+
+ # Print the available layers in the model
+ if available_layers:
+ click.secho(f'Printing available layers (name, channels and size) for "{network_pkl}"...', fg='blue')
+ _ = Renderer().render(G=G, available_layers=available_layers)
+ sys.exit(1)
+
+ # Sadly, render can only generate one image at a time, so for now we'll just use the first seed
+ if layer_name is not None and len(seeds) > 1:
+ print(f'Note: Only one seed is supported for layer extraction, using seed "{seeds[0]}"...')
+ seeds = seeds[:1]
+
+ # Stabilize/anchor the latent space
+ if anchor_latent_space:
+ gen_utils.anchor_latent_space(G)
+
+ # Create the run dir with the given name description; add slowdown if different from the default (1)
+ desc = 'random-video'
+ desc = f'random-video-{description}' if description is not None else desc
+ desc = f'{desc}-{slowdown}xslowdown' if slowdown != 1 else desc
+ desc = f'{desc}-{layer_name}_layer' if layer_name is not None else desc
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Number of frames in the video and its total duration in seconds
+ num_frames = int(np.rint(duration_sec * fps))
+ total_duration = duration_sec * slowdown
+
+ print('Generating latent vectors...')
+ # TODO: let another helper function handle each case, we will use it for the grid
+ # If there's more than one seed provided and the shape isn't specified by the user
+ if (grid_width is None and grid_height is None) and len(seeds) >= 1:
+ # TODO: this can be done by another function
+ # Number of images in the grid video according to the seeds provided
+ num_seeds = len(seeds)
+ # Get the grid width and height according to num, giving priority to the number of columns
+ grid_width = max(int(np.ceil(np.sqrt(num_seeds))), 1)
+ grid_height = max((num_seeds - 1) // grid_width + 1, 1)
+ grid_size = (grid_width, grid_height)
+ shape = [num_frames, G.z_dim] # This is per seed
+ # Get the z latents
+ all_latents = np.stack([np.random.RandomState(seed).randn(*shape).astype(np.float32) for seed in seeds], axis=1)
+
+ # If only one seed is provided, but the user specifies the grid shape:
+ elif None not in (grid_width, grid_height) and len(seeds) == 1:
+ grid_size = (grid_width, grid_height)
+ shape = [num_frames, np.prod(grid_size), G.z_dim]
+ # Since we have one seed, we use it to generate all latents
+ all_latents = np.random.RandomState(*seeds).randn(*shape).astype(np.float32)
+
+ # If one or more seeds are provided, and the user also specifies the grid shape:
+ elif None not in (grid_width, grid_height) and len(seeds) >= 1:
+ # Case is similar to the first one
+ num_seeds = len(seeds)
+ grid_size = (grid_width, grid_height)
+ available_slots = np.prod(grid_size)
+ if available_slots < num_seeds:
+ diff = num_seeds - available_slots
+ click.secho(f'More seeds were provided ({num_seeds}) than available spaces in the grid ({available_slots})',
+ fg='red')
+ click.secho(f'Removing the last {diff} seeds: {seeds[-diff:]}', fg='blue')
+ seeds = seeds[:available_slots]
+ shape = [num_frames, G.z_dim]
+ all_latents = np.stack([np.random.RandomState(seed).randn(*shape).astype(np.float32) for seed in seeds], axis=1)
+
+ else:
+ ctx.fail('Error: wrong combination of arguments! Please provide either a list of seeds, one seed and the grid '
+ 'width and height, or more than one seed and the grid width and height')
+
+ # Let's smooth out the random latents so that now they form a loop (and are correctly generated in a 512-dim space)
+ all_latents = scipy.ndimage.gaussian_filter(all_latents, sigma=[smoothing_sec * fps, 0, 0], mode='wrap')
+ all_latents /= np.sqrt(np.mean(np.square(all_latents)))
+
+ # Name of the video
+ mp4_name = f'{grid_width}x{grid_height}-slerp-{slowdown}xslowdown'
+
+ # Labels.
+ class_idx = gen_utils.parse_class(G, class_idx, ctx)
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print('warn: --class=lbl ignored when running on an unconditional network')
+
+ # Let's slowdown the video, if so desired
+ while slowdown > 1:
+ all_latents, duration_sec, num_frames = gen_utils.double_slowdown(latents=all_latents,
+ duration=duration_sec,
+ frames=num_frames)
+ slowdown //= 2
+
+ if new_center is None:
+ w_avg = G.mapping.w_avg
+ else:
+ new_center, new_center_value = new_center
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
+ if isinstance(new_center_value, int):
+ w_avg = gen_utils.get_w_from_seed(G, device, new_center_value,
+ truncation_psi=1.0) # We want the pure dlatent
+ elif isinstance(new_center_value, np.ndarray):
+ w_avg = torch.from_numpy(new_center_value).to(device)
+ else:
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
+
+ # Auxiliary function for moviepy
+ def make_frame(t):
+ frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
+ latents = torch.from_numpy(all_latents[frame_idx]).to(device)
+ # Do the truncation trick (with the global centroid or the new center provided by the user)
+ w = G.mapping(latents, None)
+ w = w_avg + (w - w_avg) * truncation_psi
+
+ # Get the images
+
+ # Save the intermediate layer output.
+ if layer_name is not None:
+ # Sanity check (again, could be done better)
+ submodule_names = {name: mod for name, mod in G.synthesis.named_modules()}
+ assert layer_name in submodule_names, f'Layer "{layer_name}" not found in the network! Available layers: {", ".join(submodule_names)}'
+ assert True in (save_grayscale, save_rgb), 'You must select to save the video in at least one of the two possible formats! (L, RGB)'
+
+ sel_channels = 3 if save_rgb else 1
+ res = Renderer().render(G=G, layer_name=layer_name, dlatent=w, sel_channels=sel_channels,
+ base_channel=starting_channel, img_scale_db=img_scale_db, img_normalize=img_normalize)
+ images = res.image
+ images = np.expand_dims(np.array(images), axis=0)
+ else:
+ images = gen_utils.w_to_img(G, w, noise_mode) # Remember, it can only be a single image
+ # RGBA -> RGB, if necessary
+ images = images[:, :, :, :3]
+
+ # Generate the grid for this timestamp
+ grid = gen_utils.create_image_grid(images, grid_size)
+ # moviepy.editor.VideoClip expects 3 channels
+ if grid.shape[2] == 1:
+ grid = grid.repeat(3, 2)
+ return grid
+
+ # Generate video using the respective make_frame function
+ videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
+ videoclip.set_duration(total_duration)
+
+ mp4_name = f'{mp4_name}_{layer_name}' if layer_name is not None else mp4_name
+
+ # Change the video parameters (codec, bitrate) if you so desire
+ final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
+ videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')
+
+ # Save the configuration used
+ new_center = 'w_avg' if new_center is None else new_center
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'config': cfg,
+ 'synthesis_options': {
+ 'seeds': seeds,
+ 'truncation_psi': truncation_psi,
+ 'new_center': new_center,
+ 'class_idx': class_idx,
+ 'noise_mode': noise_mode,
+ 'anchor_latent_space': anchor_latent_space
+ },
+ 'intermediate_representations': {
+ 'layer': layer_name,
+ 'starting_channel': starting_channel,
+ 'grayscale': save_grayscale,
+ 'rgb': save_rgb,
+ 'img_scale_db': img_scale_db,
+ 'img_normalize': img_normalize
+ },
+ 'video_options': {
+ 'grid_width': grid_width,
+ 'grid_height': grid_height,
+ 'slowdown': slowdown,
+ 'duration_sec': duration_sec,
+ 'video_fps': fps,
+ 'compress': compress,
+ 'smoothing_sec': smoothing_sec
+ },
+ 'extra_parameters': {
+ 'run_dir': run_dir,
+ 'description': desc
+ }
+ }
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ # Compress the video (lower file size, same resolution)
+ if compress:
+ gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command('circular-video')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(gen_utils.available_cfgs), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+# Synthesis options
+@click.option('--seed', type=int, help='Random seed', required=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--trunc-start', 'truncation_psi_start', type=float, help='Initial value of pulsating truncation psi', default=None, show_default=True)
+@click.option('--trunc-end', 'truncation_psi_end', type=float, help='Maximum/minimum value of pulsating truncation psi', default=None, show_default=True)
+@click.option('--global-pulse', 'global_pulsation_trick', is_flag=True, help='If set, the truncation psi will pulsate globally (on all grid cells)')
+@click.option('--wave-pulse', 'wave_pulsation_trick', is_flag=True, help='If set, the truncation psi will pulsate in a wave-like fashion from the upper left to the lower right in the grid')
+@click.option('--frequency', 'pulsation_frequency', type=int, help='Frequency of the pulsation', default=1, show_default=True)
+@click.option('--new-center', type=str, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
+@click.option('--new-w-avg', 'new_w_avg', type=gen_utils.parse_new_center, help='Path to a new "global" w_avg (seed or .npy/.npz file) to be used in the truncation trick', default=None)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
+@click.option('--flesh', 'aydao_flesh_digression', is_flag=True, help='If set, we will slowly modify the constant input to the network (based on @aydao\'s work')
+# Video options
+@click.option('--grid-width', '-gw', type=click.IntRange(min=1), help='Video grid width / number of columns', required=True)
+@click.option('--grid-height', '-gh', type=click.IntRange(min=1), help='Video grid height / number of rows', required=True)
+@click.option('--duration-sec', '-sec', type=float, help='Duration length of the video', default=10.0, show_default=True)
+@click.option('--fps', type=click.IntRange(min=1), help='Video FPS.', default=30, show_default=True)
+@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'video'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results')
+def circular_video(
+ ctx: click.Context,
+ network_pkl: Union[str, os.PathLike],
+ cfg: Optional[str],
+ seed: int,
+ truncation_psi: Optional[float],
+ truncation_psi_start: Optional[float],
+ truncation_psi_end: Optional[float],
+ global_pulsation_trick: Optional[bool],
+ wave_pulsation_trick: Optional[bool],
+ pulsation_frequency: Optional[int],
+ new_center: Tuple[str, Union[int, np.ndarray]],
+ new_w_avg: Optional[Union[str, os.PathLike]],
+ class_idx: Optional[int],
+ noise_mode: Optional[str],
+ anchor_latent_space: Optional[bool],
+ aydao_flesh_digression: Optional[bool],
+ grid_width: int,
+ grid_height: int,
+ duration_sec: float,
+ fps: int,
+ compress: Optional[bool],
+ outdir: Union[str, os.PathLike],
+ description: str
+):
+ """
+ Generate a circular interpolation video in two random axes of Z, given a seed
+ """
+
+ device = torch.device('cuda')
+
+ # Load the network
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
+
+ # Get the constant input
+ if aydao_flesh_digression:
+ if hasattr(G.synthesis, 'b4'):
+ model_type = 'stylegan2'
+ const_input = copy.deepcopy(G.synthesis.b4.const).cpu().numpy()
+ elif hasattr(G.synthesis, 'input'):
+ model_type = 'stylegan3'
+ input_frequencies = copy.deepcopy(G.synthesis.input.freqs).cpu().numpy()
+ input_phases = copy.deepcopy(G.synthesis.input.phases).cpu().numpy()
+ else:
+ ctx.fail('Error: This option is only available for StyleGAN2 and StyleGAN3 models!')
+
+ # Get the labels, if the model is conditional
+ class_idx = gen_utils.parse_class(G, class_idx, ctx)
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print('warn: --class=lbl ignored when running on an unconditional network')
+
+ # Get center of the latent space (global or user-indicated)
+ if new_center is None:
+ w_avg = G.mapping.w_avg
+ w_avg = w_avg.view(1, 1, -1) # [w_dim] => [1, 1, w_dim]
+ else:
+ # It's an int, so use as a seed
+ if new_center.isdigit():
+ w_avg = gen_utils.get_w_from_seed(G, device, int(new_center), truncation_psi=1.0).to(device)
+ # It's a file, so load it
+ elif os.path.isfile(new_center):
+ w_avg = gen_utils.get_latent_from_file(new_center, return_ext=False)
+ w_avg = torch.from_numpy(w_avg).to(device)
+ # It's a directory, so get all latents inside it (including subdirectories, so be careful)
+ elif os.path.isdir(new_center):
+ w_avg = gen_utils.parse_all_projected_dlatents(new_center)
+ w_avg = torch.tensor(w_avg).squeeze(1).to(device)
+ else:
+ message = 'Only seeds (int) or paths to latent files (.npy/.npz) or directories containing these are allowed for "--new-center"'
+ raise ctx.fail(message)
+
+ # Some sanity checks
+ num_centers = len(w_avg)
+ if num_centers == 0:
+ raise ctx.fail('No centers were found! If files, makes sure they are .npy or .npz files.')
+ # Just one is provided, so this will be a sort of 'global' center
+ elif num_centers == 1:
+ print(f'Using only one center (if more than one is desired, provide a directory with all of them)')
+ elif num_centers != grid_height * grid_width:
+ message = f"Number of centers ({num_centers}) doesn't match the grid size ({grid_height}x{grid_width})"
+ raise ctx.fail(message)
+
+ print('Using wave pulsation trick' if wave_pulsation_trick else 'Using global pulsation trick' if global_pulsation_trick else 'Using standard truncation trick...')
+ # Stabilize/anchor the latent space
+ if anchor_latent_space:
+ gen_utils.anchor_latent_space(G)
+
+ # Create the run dir with the given name description; add slowdown if different from the default (1)
+ desc = 'circular-video'
+ desc = f'circular-video-{description}' if description is not None else desc
+ desc = f'{desc}-aydao-flesh-digression' if aydao_flesh_digression else desc
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Calculate the total number of frames in the video
+ num_frames = int(np.rint(duration_sec * fps))
+
+ grid_size = (grid_width, grid_height)
+ # Get the latents with the random state
+ random_state = np.random.RandomState(seed)
+ # Choose two random dims on which to plot the circles (from 0 to G.z_dim-1),
+ # one pair for each element of the grid (2*grid_width*grid_height in total)
+ try:
+ z1, z2 = np.split(random_state.choice(G.z_dim, 2 * np.prod(grid_size), replace=False), 2)
+ except ValueError:
+ # Extreme case: G.z_dim < 2 * grid_width * grid_height (low G.z_dim most likely)
+ z1, z2 = np.split(random_state.choice(G.z_dim, 2 * np.prod(grid_size), replace=True), 2)
+
+ # We partition the circle in equal strides w.r.t. num_frames
+ get_angles = lambda num_frames: np.linspace(0, 2*np.pi, num_frames)
+ angles = get_angles(num_frames=num_frames)
+
+ # Basic Polar to Cartesian transformation
+ polar_to_cartesian = lambda radius, theta: (radius * np.cos(theta), radius * np.sin(theta))
+ # Using a fixed radius (this value is irrelevant), we generate the circles in each chosen grid
+ Z1, Z2 = polar_to_cartesian(radius=5.0, theta=angles)
+
+ # Our latents will be comprising mostly of zeros
+ all_latents = np.zeros([num_frames, np.prod(grid_size), G.z_dim]).astype(np.float32)
+ # Obtain all the frames belonging to the specific box in the grid,
+ # replacing the zero values with the circle perimeter values
+ for box in range(np.prod(grid_size)):
+ box_frames = all_latents[:, box]
+ box_frames[:, [z1[box], z2[box]]] = np.vstack((Z1, Z2)).T
+
+ if aydao_flesh_digression:
+ # We will modify the constant input to the network (for --cfg=stylegan2)
+ if model_type == 'stylegan2':
+ const_input_interpolation = np.random.randn(num_frames, *const_input.shape).astype(np.float32) / 4 # [num_frames, G.w_dim, 4, 4] ; "/ 4" is arbitrary
+ const_input_interpolation = scipy.ndimage.gaussian_filter(const_input_interpolation, sigma=[fps, 0, 0, 0], mode='wrap')
+ const_input_interpolation /= np.sqrt(np.mean(np.square(const_input_interpolation)))
+ elif model_type == 'stylegan3':
+ const_freq_interpolation = np.random.randn(num_frames, *input_frequencies.shape).astype(np.float32) / 32 # [num_frames, G.w_dim, 2]
+ const_freq_interpolation = scipy.ndimage.gaussian_filter(const_freq_interpolation, sigma=[5.0*fps, 0, 0], mode='wrap')
+ const_freq_interpolation /= np.sqrt(np.mean(np.square(const_freq_interpolation)))
+
+ const_phase_interpolation = np.random.randn(num_frames, *input_phases.shape).astype(np.float32) / 8 # [num_frames, G.w_dim, 2]
+ const_phase_interpolation = scipy.ndimage.gaussian_filter(const_phase_interpolation, sigma=[5.0*fps, 0], mode='wrap')
+ const_phase_interpolation /= np.sqrt(np.mean(np.square(const_phase_interpolation)))
+
+ # Convert to torch tensor
+ if new_w_avg is not None:
+ print("Moving all the latent space towards the new center...")
+ _, new_w_avg = new_w_avg
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
+ if isinstance(new_w_avg, int):
+ new_w_avg = gen_utils.get_w_from_seed(G, device, new_w_avg,
+ truncation_psi=1.0) # We want the pure dlatent
+ elif isinstance(new_w_avg, np.ndarray):
+ new_w_avg = torch.from_numpy(new_w_avg).to(device) # [1, num_ws, w_dim]
+ else:
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
+
+ # Auxiliary function for moviepy
+ def make_frame(t):
+ frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
+ latents = torch.from_numpy(all_latents[frame_idx]).to(device)
+ # Get the images with the respective label
+ dlatents = gen_utils.z_to_dlatent(G, latents, label, truncation_psi=1.0) # Get the pure dlatent
+ # Do truncation trick
+ # For the truncation trick (supersedes any value chosen for truncation_psi)
+ if None not in (truncation_psi_start, truncation_psi_end):
+ # For both, truncation psi will have the general form of a sinusoid: psi = (cos(t) + alpha) / beta
+ if global_pulsation_trick:
+ tr = gen_utils.global_pulsate_psi(psi_start=truncation_psi_start,
+ psi_end=truncation_psi_end,
+ n_steps=num_frames)
+ elif wave_pulsation_trick:
+ tr = gen_utils.wave_pulse_truncation_psi(psi_start=truncation_psi_start,
+ psi_end=truncation_psi_end,
+ n_steps=num_frames,
+ grid_shape=grid_size,
+ frequency=pulsation_frequency,
+ time=frame_idx)
+ # Define how to use the truncation psi
+ if global_pulsation_trick:
+ tr = tr[frame_idx].to(device)
+ elif wave_pulsation_trick:
+ tr = tr.to(device)
+ else:
+ # It's a float, so we can just use it
+ tr = truncation_psi
+
+ w = w_avg + (dlatents - w_avg) * tr
+ # Modify the constant input
+ if aydao_flesh_digression:
+ if model_type == 'stylegan2':
+ G.synthesis.b4.const.copy_(torch.from_numpy(const_input_interpolation[frame_idx]))
+ elif model_type == 'stylegan3':
+ pass
+ # G.synthesis.input.freqs.copy_(torch.from_numpy(const_freq_interpolation[frame_idx]))
+ # G.synthesis.input.phases.copy_(torch.from_numpy(const_phase_interpolation[frame_idx]))
+ # G.synthesis.input.phases.copy_(torch.from_numpy(
+ # input_phases * np.cos(np.pi * frame_idx / num_frames) ** 2
+ # ))
+ # Get the images
+ images = gen_utils.w_to_img(G, w, noise_mode, new_w_avg, tr)
+ # RGBA -> RGB
+ images = images[:, :, :, :3]
+ # Generate the grid for this timestep
+ grid = gen_utils.create_image_grid(images, grid_size)
+ # Grayscale => RGB
+ if grid.shape[2] == 1:
+ grid = grid.repeat(3, 2)
+ return grid
+
+ # Generate video using the respective make_frame function
+ videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
+ videoclip.set_duration(duration_sec)
+
+ # Name of the video
+ mp4_name = f'{grid_width}x{grid_height}-circular'
+
+ # Change the video parameters (codec, bitrate) if you so desire
+ final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
+ videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')
+
+ # Save the configuration used
+ new_center = 'w_avg' if new_center is None else new_center
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'config': cfg,
+ 'seed': seed,
+ 'z1, z2': [[int(i), int(j)] for i, j in zip(z1, z2)],
+ 'truncation_psi': truncation_psi if isinstance(truncation_psi, float) else 'pulsating',
+ 'truncation_psi_start': truncation_psi_start,
+ 'truncation_psi_end': truncation_psi_end,
+ 'new_center': new_center,
+ 'class_idx': class_idx,
+ 'noise_mode': noise_mode,
+ 'grid_width': grid_width,
+ 'grid_height': grid_height,
+ 'duration_sec': duration_sec,
+ 'video_fps': fps,
+ 'run_dir': run_dir,
+ 'description': desc,
+ 'compress': compress
+ }
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ # Compress the video (lower file size, same resolution)
+ if compress:
+ gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/gui_utils/__init__.py b/stylegan3-fun/gui_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..939e7c6c8f94c4ea1141885c3c3295fe083b06aa
--- /dev/null
+++ b/stylegan3-fun/gui_utils/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc b/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b0d539a74dbddf790db7b463e00fc80168963bc
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc b/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4701a573b48329466e0dc579765107b3b17b1c5e
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc b/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0730f16823c204ca018797744d89fab8799a78f
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc b/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfd5ce0b793a6f45d40d5325b41b4757c7c78a2a
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc b/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8689b3933467e41a2a8eb7a5e24c3d83bf3fb1e0
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc b/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19751c61dc5f6b1129e847ce8a35da24dcc95de3
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc b/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8882d30f8678e033f44091d4e19b47e1780a3f65
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc b/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c86b18529713392e7541cb68dcbcd2f32006f9d
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc b/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12578ba57343bcc5a863e186cb48ea4e20d4c264
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc b/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72d27ff3bcf5247658599d11f012b3023a2dfbb3
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc b/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74826e2d6738f1208708e6d9b87534ba194fdb3a
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc b/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..905340a52d53378cb2da0ac85bace8283f203a18
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-311.pyc b/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25f4389cc59908964c556132f4619a276746619e
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-311.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-38.pyc b/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43d3415a7dfa5d2e17e39f54ca92e125f573ce91
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-38.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-39.pyc b/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..888f7e302d56a3e2075e933c1a3fa1ef48b9c15e
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-39.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-311.pyc b/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fef3b61532e0372734175170e55282c45de818af
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-311.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-38.pyc b/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a6eb96d3e3bf73ee632aaa40ae3605040e443b6
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-38.pyc differ
diff --git a/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-39.pyc b/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..228d8174c1b5754e37eb587a9ebcfa91bccfde5f
Binary files /dev/null and b/stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-39.pyc differ
diff --git a/stylegan3-fun/gui_utils/gl_utils.py b/stylegan3-fun/gui_utils/gl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..64b6cb6510ab2f5075effe8684d824f50bd38272
--- /dev/null
+++ b/stylegan3-fun/gui_utils/gl_utils.py
@@ -0,0 +1,374 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import functools
+import contextlib
+import numpy as np
+import OpenGL.GL as gl
+import OpenGL.GL.ARB.texture_float
+import dnnlib
+
+#----------------------------------------------------------------------------
+
+def init_egl():
+ assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
+ import OpenGL.EGL as egl
+ import ctypes
+
+ # Initialize EGL.
+ display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
+ assert display != egl.EGL_NO_DISPLAY
+ major = ctypes.c_int32()
+ minor = ctypes.c_int32()
+ ok = egl.eglInitialize(display, major, minor)
+ assert ok
+ assert major.value * 10 + minor.value >= 14
+
+ # Choose config.
+ config_attribs = [
+ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
+ egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
+ egl.EGL_NONE
+ ]
+ configs = (ctypes.c_int32 * 1)()
+ num_configs = ctypes.c_int32()
+ ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
+ assert ok
+ assert num_configs.value == 1
+ config = configs[0]
+
+ # Create dummy pbuffer surface.
+ surface_attribs = [
+ egl.EGL_WIDTH, 1,
+ egl.EGL_HEIGHT, 1,
+ egl.EGL_NONE
+ ]
+ surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
+ assert surface != egl.EGL_NO_SURFACE
+
+ # Setup GL context.
+ ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
+ assert ok
+ context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
+ assert context != egl.EGL_NO_CONTEXT
+ ok = egl.eglMakeCurrent(display, surface, surface, context)
+ assert ok
+
+#----------------------------------------------------------------------------
+
+_texture_formats = {
+ ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
+ ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
+ ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
+ ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
+ ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
+ ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
+ ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
+ ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
+}
+
+def get_texture_format(dtype, channels):
+ return _texture_formats[(np.dtype(dtype).name, int(channels))]
+
+#----------------------------------------------------------------------------
+
+def prepare_texture_data(image):
+ image = np.asarray(image)
+ if image.ndim == 2:
+ image = image[:, :, np.newaxis]
+ if image.dtype.name == 'float64':
+ image = image.astype('float32')
+ return image
+
+#----------------------------------------------------------------------------
+
+def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
+ align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
+ image = prepare_texture_data(image)
+ height, width, channels = image.shape
+ size = zoom * [width, height]
+ pos = pos - size * align
+ if rint:
+ pos = np.rint(pos)
+ fmt = get_texture_format(image.dtype, channels)
+
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
+ gl.glRasterPos2f(pos[0], pos[1])
+ gl.glPixelZoom(zoom[0], -zoom[1])
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+ gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
+ gl.glPopClientAttrib()
+ gl.glPopAttrib()
+
+#----------------------------------------------------------------------------
+
+def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ dtype = np.dtype(dtype)
+ fmt = get_texture_format(dtype, channels)
+ image = np.empty([height, width, channels], dtype=dtype)
+
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
+ gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
+ gl.glPopClientAttrib()
+ return np.flipud(image)
+
+#----------------------------------------------------------------------------
+
+class Texture:
+ def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
+ self.gl_id = None
+ self.bilinear = bilinear
+ self.mipmap = mipmap
+
+ # Determine size and dtype.
+ if image is not None:
+ image = prepare_texture_data(image)
+ self.height, self.width, self.channels = image.shape
+ self.dtype = image.dtype
+ else:
+ assert width is not None and height is not None
+ self.width = width
+ self.height = height
+ self.channels = channels if channels is not None else 3
+ self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
+
+ # Validate size and dtype.
+ assert isinstance(self.width, int) and self.width >= 0
+ assert isinstance(self.height, int) and self.height >= 0
+ assert isinstance(self.channels, int) and self.channels >= 1
+ assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
+
+ # Create texture object.
+ self.gl_id = gl.glGenTextures(1)
+ with self.bind():
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
+ self.update(image)
+
+ def delete(self):
+ if self.gl_id is not None:
+ gl.glDeleteTextures([self.gl_id])
+ self.gl_id = None
+
+ def __del__(self):
+ try:
+ self.delete()
+ except:
+ pass
+
+ @contextlib.contextmanager
+ def bind(self):
+ prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
+ gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
+ yield
+ gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
+
+ def update(self, image):
+ if image is not None:
+ image = prepare_texture_data(image)
+ assert self.is_compatible(image=image)
+ with self.bind():
+ fmt = get_texture_format(self.dtype, self.channels)
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
+ if self.mipmap:
+ gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
+ gl.glPopClientAttrib()
+
+ def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
+ size = zoom * [self.width, self.height]
+ with self.bind():
+ gl.glPushAttrib(gl.GL_ENABLE_BIT)
+ gl.glEnable(gl.GL_TEXTURE_2D)
+ draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
+ gl.glPopAttrib()
+
+ def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
+ if image is not None:
+ if image.ndim != 3:
+ return False
+ ih, iw, ic = image.shape
+ if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
+ return False
+ if width is not None and self.width != width:
+ return False
+ if height is not None and self.height != height:
+ return False
+ if channels is not None and self.channels != channels:
+ return False
+ if dtype is not None and self.dtype != dtype:
+ return False
+ return True
+
+#----------------------------------------------------------------------------
+
+class Framebuffer:
+ def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
+ self.texture = texture
+ self.gl_id = None
+ self.gl_color = None
+ self.gl_depth_stencil = None
+ self.msaa = msaa
+
+ # Determine size and dtype.
+ if texture is not None:
+ assert isinstance(self.texture, Texture)
+ self.width = texture.width
+ self.height = texture.height
+ self.channels = texture.channels
+ self.dtype = texture.dtype
+ else:
+ assert width is not None and height is not None
+ self.width = width
+ self.height = height
+ self.channels = channels if channels is not None else 4
+ self.dtype = np.dtype(dtype) if dtype is not None else np.float32
+
+ # Validate size and dtype.
+ assert isinstance(self.width, int) and self.width >= 0
+ assert isinstance(self.height, int) and self.height >= 0
+ assert isinstance(self.channels, int) and self.channels >= 1
+ assert width is None or width == self.width
+ assert height is None or height == self.height
+ assert channels is None or channels == self.channels
+ assert dtype is None or dtype == self.dtype
+
+ # Create framebuffer object.
+ self.gl_id = gl.glGenFramebuffers(1)
+ with self.bind():
+
+ # Setup color buffer.
+ if self.texture is not None:
+ assert self.msaa == 0
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
+ else:
+ fmt = get_texture_format(self.dtype, self.channels)
+ self.gl_color = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
+
+ # Setup depth/stencil buffer.
+ self.gl_depth_stencil = gl.glGenRenderbuffers(1)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
+
+ def delete(self):
+ if self.gl_id is not None:
+ gl.glDeleteFramebuffers([self.gl_id])
+ self.gl_id = None
+ if self.gl_color is not None:
+ gl.glDeleteRenderbuffers(1, [self.gl_color])
+ self.gl_color = None
+ if self.gl_depth_stencil is not None:
+ gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
+ self.gl_depth_stencil = None
+
+ def __del__(self):
+ try:
+ self.delete()
+ except:
+ pass
+
+ @contextlib.contextmanager
+ def bind(self):
+ prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
+ prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
+ if self.width is not None and self.height is not None:
+ gl.glViewport(0, 0, self.width, self.height)
+ yield
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
+
+ def blit(self, dst=None):
+ assert dst is None or isinstance(dst, Framebuffer)
+ with self.bind():
+ gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
+ gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
+
+#----------------------------------------------------------------------------
+
+def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
+ assert vertices.ndim == 2 and vertices.shape[1] == 2
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
+ color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
+ alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
+
+ gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
+ gl.glMatrixMode(gl.GL_MODELVIEW)
+ gl.glPushMatrix()
+
+ gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
+ gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
+ gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
+ gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
+ gl.glTranslate(pos[0], pos[1], 0)
+ gl.glScale(size[0], size[1], 1)
+ gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
+ gl.glDrawArrays(mode, 0, vertices.shape[0])
+
+ gl.glPopMatrix()
+ gl.glPopAttrib()
+ gl.glPopClientAttrib()
+
+#----------------------------------------------------------------------------
+
+def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
+ assert pos2 is None or size is None
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
+ pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
+ size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
+ pos = pos - size * align
+ if rint:
+ pos = np.rint(pos)
+ rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
+ rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
+ if np.min(rounding) == 0:
+ rounding *= 0
+ vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
+
+@functools.lru_cache(maxsize=10000)
+def _setup_rect(rx, ry):
+ t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
+ s = 1 - np.sin(t); c = 1 - np.cos(t)
+ x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
+ y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
+ v = np.stack([x, y], axis=-1).reshape(-1, 2)
+ return v.astype('float32')
+
+#----------------------------------------------------------------------------
+
+def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
+ hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
+ vertices = _setup_circle(float(hole))
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
+
+@functools.lru_cache(maxsize=10000)
+def _setup_circle(hole):
+ t = np.linspace(0, np.pi * 2, 128)
+ s = np.sin(t); c = np.cos(t)
+ v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
+ return v.astype('float32')
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/gui_utils/glfw_window.py b/stylegan3-fun/gui_utils/glfw_window.py
new file mode 100644
index 0000000000000000000000000000000000000000..83264eb89a855ec5038cf255994ee2b4b3ddb5ee
--- /dev/null
+++ b/stylegan3-fun/gui_utils/glfw_window.py
@@ -0,0 +1,229 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import time
+import glfw
+import OpenGL.GL as gl
+from . import gl_utils
+
+#----------------------------------------------------------------------------
+
+class GlfwWindow: # pylint: disable=too-many-public-methods
+ def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
+ self._glfw_window = None
+ self._drawing_frame = False
+ self._frame_start_time = None
+ self._frame_delta = 0
+ self._fps_limit = None
+ self._vsync = None
+ self._skip_frames = 0
+ self._deferred_show = deferred_show
+ self._close_on_esc = close_on_esc
+ self._esc_pressed = False
+ self._drag_and_drop_paths = None
+ self._capture_next_frame = False
+ self._captured_frame = None
+
+ # Create window.
+ glfw.init()
+ glfw.window_hint(glfw.VISIBLE, False)
+ self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
+ self._attach_glfw_callbacks()
+ self.make_context_current()
+
+ # Adjust window.
+ self.set_vsync(False)
+ self.set_window_size(window_width, window_height)
+ if not self._deferred_show:
+ glfw.show_window(self._glfw_window)
+
+ def close(self):
+ if self._drawing_frame:
+ self.end_frame()
+ if self._glfw_window is not None:
+ glfw.destroy_window(self._glfw_window)
+ self._glfw_window = None
+ #glfw.terminate() # Commented out to play it nice with other glfw clients.
+
+ def __del__(self):
+ try:
+ self.close()
+ except:
+ pass
+
+ @property
+ def window_width(self):
+ return self.content_width
+
+ @property
+ def window_height(self):
+ return self.content_height + self.title_bar_height
+
+ @property
+ def content_width(self):
+ width, _height = glfw.get_window_size(self._glfw_window)
+ return width
+
+ @property
+ def content_height(self):
+ _width, height = glfw.get_window_size(self._glfw_window)
+ return height
+
+ @property
+ def title_bar_height(self):
+ _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
+ return top
+
+ @property
+ def monitor_width(self):
+ _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
+ return width
+
+ @property
+ def monitor_height(self):
+ _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
+ return height
+
+ @property
+ def frame_delta(self):
+ return self._frame_delta
+
+ def set_title(self, title):
+ glfw.set_window_title(self._glfw_window, title)
+
+ def set_window_size(self, width, height):
+ width = min(width, self.monitor_width)
+ height = min(height, self.monitor_height)
+ glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
+ if width == self.monitor_width and height == self.monitor_height:
+ self.maximize()
+
+ def set_content_size(self, width, height):
+ self.set_window_size(width, height + self.title_bar_height)
+
+ def maximize(self):
+ glfw.maximize_window(self._glfw_window)
+
+ def set_position(self, x, y):
+ glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
+
+ def center(self):
+ self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
+
+ def set_vsync(self, vsync):
+ vsync = bool(vsync)
+ if vsync != self._vsync:
+ glfw.swap_interval(1 if vsync else 0)
+ self._vsync = vsync
+
+ def set_fps_limit(self, fps_limit):
+ self._fps_limit = int(fps_limit)
+
+ def should_close(self):
+ return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
+
+ def skip_frame(self):
+ self.skip_frames(1)
+
+ def skip_frames(self, num): # Do not update window for the next N frames.
+ self._skip_frames = max(self._skip_frames, int(num))
+
+ def is_skipping_frames(self):
+ return self._skip_frames > 0
+
+ def capture_next_frame(self):
+ self._capture_next_frame = True
+
+ def pop_captured_frame(self):
+ frame = self._captured_frame
+ self._captured_frame = None
+ return frame
+
+ def pop_drag_and_drop_paths(self):
+ paths = self._drag_and_drop_paths
+ self._drag_and_drop_paths = None
+ return paths
+
+ def draw_frame(self): # To be overridden by subclass.
+ self.begin_frame()
+ # Rendering code goes here.
+ self.end_frame()
+
+ def make_context_current(self):
+ if self._glfw_window is not None:
+ glfw.make_context_current(self._glfw_window)
+
+ def begin_frame(self):
+ # End previous frame.
+ if self._drawing_frame:
+ self.end_frame()
+
+ # Apply FPS limit.
+ if self._frame_start_time is not None and self._fps_limit is not None:
+ delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
+ if delay > 0:
+ time.sleep(delay)
+ cur_time = time.perf_counter()
+ if self._frame_start_time is not None:
+ self._frame_delta = cur_time - self._frame_start_time
+ self._frame_start_time = cur_time
+
+ # Process events.
+ glfw.poll_events()
+
+ # Begin frame.
+ self._drawing_frame = True
+ self.make_context_current()
+
+ # Initialize GL state.
+ gl.glViewport(0, 0, self.content_width, self.content_height)
+ gl.glMatrixMode(gl.GL_PROJECTION)
+ gl.glLoadIdentity()
+ gl.glTranslate(-1, 1, 0)
+ gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
+ gl.glMatrixMode(gl.GL_MODELVIEW)
+ gl.glLoadIdentity()
+ gl.glEnable(gl.GL_BLEND)
+ gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
+
+ # Clear.
+ gl.glClearColor(0, 0, 0, 1)
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
+
+ def end_frame(self):
+ assert self._drawing_frame
+ self._drawing_frame = False
+
+ # Skip frames if requested.
+ if self._skip_frames > 0:
+ self._skip_frames -= 1
+ return
+
+ # Capture frame if requested.
+ if self._capture_next_frame:
+ self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
+ self._capture_next_frame = False
+
+ # Update window.
+ if self._deferred_show:
+ glfw.show_window(self._glfw_window)
+ self._deferred_show = False
+ glfw.swap_buffers(self._glfw_window)
+
+ def _attach_glfw_callbacks(self):
+ glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
+ glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
+
+ def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
+ if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
+ self._esc_pressed = True
+
+ def _glfw_drop_callback(self, _window, paths):
+ self._drag_and_drop_paths = paths
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/gui_utils/imgui_utils.py b/stylegan3-fun/gui_utils/imgui_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..333024bd6999bf2b18a5cb96766c4da3798666a2
--- /dev/null
+++ b/stylegan3-fun/gui_utils/imgui_utils.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import contextlib
+import imgui
+
+#----------------------------------------------------------------------------
+
+def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
+ s = imgui.get_style()
+ s.window_padding = [spacing, spacing]
+ s.item_spacing = [spacing, spacing]
+ s.item_inner_spacing = [spacing, spacing]
+ s.columns_min_spacing = spacing
+ s.indent_spacing = indent
+ s.scrollbar_size = scrollbar
+ s.frame_padding = [4, 3]
+ s.window_border_size = 1
+ s.child_border_size = 1
+ s.popup_border_size = 1
+ s.frame_border_size = 1
+ s.window_rounding = 0
+ s.child_rounding = 0
+ s.popup_rounding = 3
+ s.frame_rounding = 3
+ s.scrollbar_rounding = 3
+ s.grab_rounding = 3
+
+ getattr(imgui, f'style_colors_{color_scheme}')(s)
+ c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
+ c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
+ s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
+
+#----------------------------------------------------------------------------
+
+@contextlib.contextmanager
+def grayed_out(cond=True):
+ if cond:
+ s = imgui.get_style()
+ text = s.colors[imgui.COLOR_TEXT_DISABLED]
+ grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
+ back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
+ imgui.push_style_color(imgui.COLOR_TEXT, *text)
+ imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
+ imgui.push_style_color(imgui.COLOR_BUTTON, *back)
+ imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
+ imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
+ imgui.push_style_color(imgui.COLOR_HEADER, *back)
+ imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
+ imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
+ imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
+ yield
+ imgui.pop_style_color(14)
+ else:
+ yield
+
+#----------------------------------------------------------------------------
+
+@contextlib.contextmanager
+def item_width(width=None):
+ if width is not None:
+ imgui.push_item_width(width)
+ yield
+ imgui.pop_item_width()
+ else:
+ yield
+
+#----------------------------------------------------------------------------
+
+def scoped_by_object_id(method):
+ def decorator(self, *args, **kwargs):
+ imgui.push_id(str(id(self)))
+ res = method(self, *args, **kwargs)
+ imgui.pop_id()
+ return res
+ return decorator
+
+#----------------------------------------------------------------------------
+
+def button(label, width=0, enabled=True):
+ with grayed_out(not enabled):
+ clicked = imgui.button(label, width=width)
+ clicked = clicked and enabled
+ return clicked
+
+#----------------------------------------------------------------------------
+
+def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
+ expanded = False
+ if show:
+ if default:
+ flags |= imgui.TREE_NODE_DEFAULT_OPEN
+ if not enabled:
+ flags |= imgui.TREE_NODE_LEAF
+ with grayed_out(not enabled):
+ expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
+ expanded = expanded and enabled
+ return expanded, visible
+
+#----------------------------------------------------------------------------
+
+def popup_button(label, width=0, enabled=True):
+ if button(label, width, enabled):
+ imgui.open_popup(label)
+ opened = imgui.begin_popup(label)
+ return opened
+
+#----------------------------------------------------------------------------
+
+def input_text(label, value, buffer_length, flags, width=None, help_text=''):
+ old_value = value
+ color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
+ if value == '':
+ color[-1] *= 0.5
+ with item_width(width):
+ imgui.push_style_color(imgui.COLOR_TEXT, *color)
+ value = value if value != '' else help_text
+ changed, value = imgui.input_text(label, value, buffer_length, flags)
+ value = value if value != help_text else ''
+ imgui.pop_style_color(1)
+ if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
+ changed = (value != old_value)
+ return changed, value
+
+#----------------------------------------------------------------------------
+
+def drag_previous_control(enabled=True):
+ dragging = False
+ dx = 0
+ dy = 0
+ if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
+ if enabled:
+ dragging = True
+ dx, dy = imgui.get_mouse_drag_delta()
+ imgui.reset_mouse_drag_delta()
+ imgui.end_drag_drop_source()
+ return dragging, dx, dy
+
+#----------------------------------------------------------------------------
+
+def drag_button(label, width=0, enabled=True):
+ clicked = button(label, width=width, enabled=enabled)
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
+ return clicked, dragging, dx, dy
+
+#----------------------------------------------------------------------------
+
+def drag_hidden_window(label, x, y, width, height, enabled=True):
+ imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
+ imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
+ imgui.set_next_window_position(x, y)
+ imgui.set_next_window_size(width, height)
+ imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
+ imgui.end()
+ imgui.pop_style_color(2)
+ return dragging, dx, dy
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/gui_utils/imgui_window.py b/stylegan3-fun/gui_utils/imgui_window.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d539a1382def526050c83978d1118348ac77ad
--- /dev/null
+++ b/stylegan3-fun/gui_utils/imgui_window.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import imgui
+import imgui.integrations.glfw
+
+from . import glfw_window
+from . import imgui_utils
+from . import text_utils
+
+#----------------------------------------------------------------------------
+
+class ImguiWindow(glfw_window.GlfwWindow):
+ def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
+ if font is None:
+ font = text_utils.get_default_font()
+ font_sizes = {int(size) for size in font_sizes}
+ super().__init__(title=title, **glfw_kwargs)
+
+ # Init fields.
+ self._imgui_context = None
+ self._imgui_renderer = None
+ self._imgui_fonts = None
+ self._cur_font_size = max(font_sizes)
+
+ # Delete leftover imgui.ini to avoid unexpected behavior.
+ if os.path.isfile('imgui.ini'):
+ os.remove('imgui.ini')
+
+ # Init ImGui.
+ self._imgui_context = imgui.create_context()
+ self._imgui_renderer = _GlfwRenderer(self._glfw_window)
+ self._attach_glfw_callbacks()
+ imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
+ imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
+ self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
+ self._imgui_renderer.refresh_font_texture()
+
+ def close(self):
+ self.make_context_current()
+ self._imgui_fonts = None
+ if self._imgui_renderer is not None:
+ self._imgui_renderer.shutdown()
+ self._imgui_renderer = None
+ if self._imgui_context is not None:
+ #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
+ self._imgui_context = None
+ super().close()
+
+ def _glfw_key_callback(self, *args):
+ super()._glfw_key_callback(*args)
+ self._imgui_renderer.keyboard_callback(*args)
+
+ @property
+ def font_size(self):
+ return self._cur_font_size
+
+ @property
+ def spacing(self):
+ return round(self._cur_font_size * 0.4)
+
+ def set_font_size(self, target): # Applied on next frame.
+ self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
+
+ def begin_frame(self):
+ # Begin glfw frame.
+ super().begin_frame()
+
+ # Process imgui events.
+ self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
+ if self.content_width > 0 and self.content_height > 0:
+ self._imgui_renderer.process_inputs()
+
+ # Begin imgui frame.
+ imgui.new_frame()
+ imgui.push_font(self._imgui_fonts[self._cur_font_size])
+ imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
+
+ def end_frame(self):
+ imgui.pop_font()
+ imgui.render()
+ imgui.end_frame()
+ self._imgui_renderer.render(imgui.get_draw_data())
+ super().end_frame()
+
+#----------------------------------------------------------------------------
+# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
+
+class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.mouse_wheel_multiplier = 1
+
+ def scroll_callback(self, window, x_offset, y_offset):
+ self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/gui_utils/text_utils.py b/stylegan3-fun/gui_utils/text_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..35e5e4a16dc62c4be80df5432208bce5d386bf16
--- /dev/null
+++ b/stylegan3-fun/gui_utils/text_utils.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import functools
+from typing import Optional
+
+import dnnlib
+import numpy as np
+import PIL.Image
+import PIL.ImageFont
+import scipy.ndimage
+
+from . import gl_utils
+
+#----------------------------------------------------------------------------
+
+def get_default_font():
+ url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
+ return dnnlib.util.open_url(url, return_filename=True)
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=None)
+def get_pil_font(font=None, size=32):
+ if font is None:
+ font = get_default_font()
+ return PIL.ImageFont.truetype(font=font, size=size)
+
+#----------------------------------------------------------------------------
+
+def get_array(string, *, dropshadow_radius: int=None, **kwargs):
+ if dropshadow_radius is not None:
+ offset_x = int(np.ceil(dropshadow_radius*2/3))
+ offset_y = int(np.ceil(dropshadow_radius*2/3))
+ return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
+ else:
+ return _get_array_priv(string, **kwargs)
+
+@functools.lru_cache(maxsize=10000)
+def _get_array_priv(
+ string: str, *,
+ size: int = 32,
+ max_width: Optional[int]=None,
+ max_height: Optional[int]=None,
+ min_size=10,
+ shrink_coef=0.8,
+ dropshadow_radius: int=None,
+ offset_x: int=None,
+ offset_y: int=None,
+ **kwargs
+):
+ cur_size = size
+ array = None
+ while True:
+ if dropshadow_radius is not None:
+ # separate implementation for dropshadow text rendering
+ array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
+ else:
+ array = _get_array_impl(string, size=cur_size, **kwargs)
+ height, width, _ = array.shape
+ if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
+ break
+ cur_size = max(int(cur_size * shrink_coef), min_size)
+ return array
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=10000)
+def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
+ pil_font = get_pil_font(font=font, size=size)
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
+ width = max(line.shape[1] for line in lines)
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
+ line_spacing = line_pad if line_pad is not None else size // 2
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
+ mask = np.concatenate(lines, axis=0)
+ alpha = mask
+ if outline > 0:
+ mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
+ alpha = mask.astype(np.float32) / 255
+ alpha = scipy.ndimage.gaussian_filter(alpha, outline)
+ alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
+ alpha = np.maximum(alpha, mask)
+ return np.stack([mask, alpha], axis=-1)
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=10000)
+def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
+ assert (offset_x > 0) and (offset_y > 0)
+ pil_font = get_pil_font(font=font, size=size)
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
+ width = max(line.shape[1] for line in lines)
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
+ line_spacing = line_pad if line_pad is not None else size // 2
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
+ mask = np.concatenate(lines, axis=0)
+ alpha = mask
+
+ mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
+ alpha = mask.astype(np.float32) / 255
+ alpha = scipy.ndimage.gaussian_filter(alpha, radius)
+ alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
+ alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
+ alpha = np.maximum(alpha, mask)
+ return np.stack([mask, alpha], axis=-1)
+
+#----------------------------------------------------------------------------
+
+@functools.lru_cache(maxsize=10000)
+def get_texture(string, bilinear=True, mipmap=True, **kwargs):
+ return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/legacy.py b/stylegan3-fun/legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ff07230da3321626a54f7411bbeb59bada634c4
--- /dev/null
+++ b/stylegan3-fun/legacy.py
@@ -0,0 +1,346 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Converting legacy network pickle into the new format."""
+
+import click
+import pickle
+import re
+import copy
+import numpy as np
+import torch
+import dnnlib
+from torch_utils import misc
+
+#----------------------------------------------------------------------------
+
+def load_network_pkl(f, force_fp16=False):
+ data = _LegacyUnpickler(f).load()
+ try:
+ # Legacy TensorFlow pickle => convert.
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
+ tf_G, tf_D, tf_Gs = data
+ G = convert_tf_generator(tf_G)
+ D = convert_tf_discriminator(tf_D)
+ G_ema = convert_tf_generator(tf_Gs)
+ data = dict(G=G, D=D, G_ema=G_ema)
+
+ # Add missing fields.
+ if 'training_set_kwargs' not in data:
+ data['training_set_kwargs'] = None
+ if 'augment_pipe' not in data:
+ data['augment_pipe'] = None
+
+ # Validate contents.
+ if 'G' in data:
+ assert isinstance(data['G'], torch.nn.Module)
+ if 'D' in data:
+ assert isinstance(data['D'], torch.nn.Module)
+ assert isinstance(data['G_ema'], torch.nn.Module)
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
+
+ # Force FP16.
+ if force_fp16:
+ for key in ['G', 'D', 'G_ema']:
+ old = data[key]
+ kwargs = copy.deepcopy(old.init_kwargs)
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
+ fp16_kwargs.num_fp16_res = 4
+ fp16_kwargs.conv_clamp = 256
+ if kwargs != old.init_kwargs:
+ new = type(old)(**kwargs).eval().requires_grad_(False)
+ misc.copy_params_and_buffers(old, new, require_all=True)
+ data[key] = new
+ except KeyError:
+ # Most likely a StyleGAN-NADA pkl, so pass and return data
+ pass
+ return data
+
+#----------------------------------------------------------------------------
+
+class _TFNetworkStub(dnnlib.EasyDict):
+ pass
+
+class _LegacyUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if module == 'dnnlib.tflib.network' and name == 'Network':
+ return _TFNetworkStub
+ return super().find_class(module, name)
+
+#----------------------------------------------------------------------------
+
+def _collect_tf_params(tf_net):
+ # pylint: disable=protected-access
+ tf_params = dict()
+ def recurse(prefix, tf_net):
+ for name, value in tf_net.variables:
+ tf_params[prefix + name] = value
+ for name, comp in tf_net.components.items():
+ recurse(prefix + name + '/', comp)
+ recurse('', tf_net)
+ return tf_params
+
+#----------------------------------------------------------------------------
+
+def _populate_module_params(module, *patterns):
+ for name, tensor in misc.named_params_and_buffers(module):
+ found = False
+ value = None
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
+ match = re.fullmatch(pattern, name)
+ if match:
+ found = True
+ if value_fn is not None:
+ value = value_fn(*match.groups())
+ break
+ try:
+ assert found
+ if value is not None:
+ tensor.copy_(torch.from_numpy(np.array(value)))
+ except:
+ print(name, list(tensor.shape))
+ raise
+
+#----------------------------------------------------------------------------
+
+def convert_tf_generator(tf_G):
+ if tf_G.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_G.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None, none=None):
+ known_kwargs.add(tf_name)
+ val = tf_kwargs.get(tf_name, default)
+ return val if val is not None else none
+
+ # Convert kwargs.
+ from training import networks_stylegan2
+ network_class = networks_stylegan2.Generator
+ kwargs = dnnlib.EasyDict(
+ z_dim = kwarg('latent_size', 512),
+ c_dim = kwarg('label_size', 0),
+ w_dim = kwarg('dlatent_size', 512),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ architecture = kwarg('architecture', 'skip'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ use_noise = kwarg('use_noise', True),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 8),
+ embed_features = kwarg('label_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
+ ),
+ )
+ # Check for unknown kwargs.
+ kwarg('truncation_psi')
+ kwarg('truncation_cutoff')
+ kwarg('style_mixing_prob')
+ kwarg('structure')
+ kwarg('conditioning')
+ kwarg('fused_modconv')
+ kwarg('randomize_noise') # PR #173 by @cobanov, for StyleGAN2 models/TF 1.15
+ kwarg('resolution_h') # For transferring from --network=anime1024 --cfg=stylegan2-ext
+ kwarg('resolution_w') # For transferring from --network=anime1024 --cfg=stylegan2-ext
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_G)
+
+ # Add changes from --cfg=stylegan2-ext; currently can't think of another way to do this
+ if 'resolution_h' and 'resolution_w' in tf_kwargs:
+ kwargs.channel_max = 1024
+ kwargs.w_dim = 1024
+ kwargs.z_dim = 1024
+ kwargs.mapping_kwargs.layer_features = 1024
+ kwargs.mapping_kwargs.num_layers = 4
+ kwargs.channel_base *= 2
+
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
+ kwargs.synthesis.kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ G = network_class(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(G,
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
+ r'.*\.resample_filter', None,
+ r'.*\.act_filter', None,
+ )
+ return G
+
+#----------------------------------------------------------------------------
+
+def convert_tf_discriminator(tf_D):
+ if tf_D.version < 4:
+ raise ValueError('TensorFlow pickle version too low')
+
+ # Collect kwargs.
+ tf_kwargs = tf_D.static_kwargs
+ known_kwargs = set()
+ def kwarg(tf_name, default=None):
+ known_kwargs.add(tf_name)
+ return tf_kwargs.get(tf_name, default)
+
+ # Convert kwargs.
+ kwargs = dnnlib.EasyDict(
+ c_dim = kwarg('label_size', 0),
+ img_resolution = kwarg('resolution', 1024),
+ img_channels = kwarg('num_channels', 3),
+ architecture = kwarg('architecture', 'resnet'),
+ channel_base = kwarg('fmap_base', 16384) * 2,
+ channel_max = kwarg('fmap_max', 512),
+ num_fp16_res = kwarg('num_fp16_res', 0),
+ conv_clamp = kwarg('conv_clamp', None),
+ cmap_dim = kwarg('mapping_fmaps', None),
+ block_kwargs = dnnlib.EasyDict(
+ activation = kwarg('nonlinearity', 'lrelu'),
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
+ freeze_layers = kwarg('freeze_layers', 0),
+ ),
+ mapping_kwargs = dnnlib.EasyDict(
+ num_layers = kwarg('mapping_layers', 0),
+ embed_features = kwarg('mapping_fmaps', None),
+ layer_features = kwarg('mapping_fmaps', None),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
+ ),
+ epilogue_kwargs = dnnlib.EasyDict(
+ mbstd_group_size = kwarg('mbstd_group_size', None),
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
+ activation = kwarg('nonlinearity', 'lrelu'),
+ ),
+ )
+
+ # Check for unknown kwargs.
+ kwarg('structure')
+ kwarg('conditioning')
+ kwarg('resolution_h') # For transferring from --network=anime1024 --cfg=stylegan2-ext
+ kwarg('resolution_w') # For transferring from --network=anime1024 --cfg=stylegan2-ext
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
+ if len(unknown_kwargs) > 0:
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
+
+ # Add changes from --cfg=stylegan2-ext. (I really can't think of another way tbh)
+ if 'resolution_h' and 'resolution_w' in tf_kwargs:
+ kwargs.epilogue_kwargs.mbstd_num_channels = 4
+
+ # Collect params.
+ tf_params = _collect_tf_params(tf_D)
+ for name, value in list(tf_params.items()):
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
+ if match:
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
+ kwargs.architecture = 'orig'
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
+
+ # Convert params.
+ from training import networks_stylegan2
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
+ # pylint: disable=unnecessary-lambda
+ # pylint: disable=f-string-without-interpolation
+ _populate_module_params(D,
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
+ r'.*\.resample_filter', None,
+ )
+ return D
+
+#----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--source', help='Input pickle', required=True, metavar='PATH')
+@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
+@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
+def convert_network_pickle(source, dest, force_fp16):
+ """Convert legacy network pickle into the native PyTorch format.
+
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
+
+ Example:
+
+ \b
+ python legacy.py \\
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
+ --dest=stylegan2-cat-config-f.pkl
+ """
+ print(f'Loading "{source}"...')
+ with dnnlib.util.open_url(source) as f:
+ data = load_network_pkl(f, force_fp16=force_fp16)
+ print(f'Saving "{dest}"...')
+ with open(dest, 'wb') as f:
+ pickle.dump(data, f)
+ print('Done.')
+
+#----------------------------------------------------------------------------
+
+if __name__ == "__main__":
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/live_visualreactive.py b/stylegan3-fun/live_visualreactive.py
new file mode 100644
index 0000000000000000000000000000000000000000..09e774d6ea5b187e6c2e80e272571a622cd1b8c0
--- /dev/null
+++ b/stylegan3-fun/live_visualreactive.py
@@ -0,0 +1,581 @@
+import os
+from typing import List, Union, Optional, Tuple, Type
+import click
+import time
+import copy
+
+import dnnlib
+from torch_utils import gen_utils
+
+import numpy as np
+
+import cv2
+import imutils
+import PIL.Image
+
+import scipy
+from sklearn.preprocessing import StandardScaler
+from sklearn.decomposition import PCA
+
+import torch
+from torchvision import transforms
+
+import legacy
+
+from network_features import VGG16FeaturesNVIDIA
+
+import mediapipe as mp
+
+
+# ----------------------------------------------------------------------------
+
+
+def parse_height(s: str = None) -> Union[int, Type[None]]:
+ """Parse height argument."""
+ if s is not None:
+ if s == 'max':
+ return s
+ else:
+ return int(s)
+ return None
+
+
+# ----------------------------------------------------------------------------
+
+
+# TODO: Analyze latent space/variant to the proposed PCA https://openreview.net/pdf?id=SlzEll3EsKv
+# TODO: Add hand tracking/normalization here: https://github.com/caillonantoine/hand_osc/blob/master/detect.py
+
+@click.command()
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename: can be URL, local file, or the name of the model in torch_utils.gen_utils.resume_specs', required=True)
+@click.option('--device', help='Device to use for image generation; using the CPU is slower than the GPU', type=click.Choice(['cpu', 'cuda']), default='cuda', show_default=True)
+@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+# Synthesis options (feed a list of seeds or give the projected w to synthesize)
+@click.option('--seed', type=click.INT, help='Random seed to use for static synthesized image', default=0, show_default=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.6, show_default=True)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None, show_default=True)
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
+@click.option('--mirror', is_flag=True, help='Mirror the synthesized image')
+@click.option('--demo-height', type=int, help='Height of the demo window', default=360, show_default=True)
+@click.option('--demo-width', type=int, help='Width of the demo window', default=None, show_default=True)
+@click.option('--only-synth', is_flag=True, help='Only synthesize the image and save it to disk')
+@click.option('--layer', type=str, help='Layer to use for the feature extractor', default='conv4_1', show_default=True)
+# How to set the fake dlatent
+@click.option('--v0', is_flag=True, help='Average the features of VGG and use a static dlatent to do style-mixing')
+@click.option('--v1', is_flag=True, help='Separate the input image into regions for coarse, middle, and fine layers for style-mixing')
+@click.option('--v2', is_flag=True, help='Manipulate the input to the Generator (StyleGAN2 and 3)')
+@click.option('--v3', is_flag=True, help='Latent mirror. Warning, should be used with low-resolution models (e.g., 16x16)')
+# TODO: intermediate layers?
+# Video options
+@click.option('--display-height', type=parse_height, help="Height of the display window; if 'max', will use G.img_resolution", default=None, show_default=True)
+@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
+@click.option('--fps', type=click.IntRange(min=1), help='Save the video with this framerate.', default=30, show_default=True)
+@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with `ffmpeg-python` (same resolution, lower file size)')
+# Extra parameters
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'videos'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='live_visual-reactive', show_default=True)
+@click.option('--verbose', is_flag=True, help='Print FPS of the live interpolation ever second; plot the detected hands for `--v2`')
+def live_visual_reactive(
+ ctx,
+ network_pkl: str,
+ device: Optional[str],
+ cfg: str,
+ seed: int,
+ truncation_psi: float,
+ class_idx: int,
+ noise_mode: str,
+ new_center: Union[int, str],
+ mirror: bool,
+ demo_height: int,
+ demo_width: int,
+ only_synth: bool,
+ layer: str,
+ v0: bool,
+ v1: bool,
+ v2: bool,
+ v3: bool,
+ display_height: Optional[int],
+ anchor_latent_space: bool,
+ fps: int,
+ compress: bool,
+ outdir: str,
+ description: str,
+ verbose: Optional[bool]):
+ """Live Visual-Reactive interpolation. A camera/webcamera is needed to be accessed by OpenCV."""
+ # Set device; GPU is recommended
+ device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu')
+
+ if v0 or v1:
+ # Load the feature extractor; here, VGG16
+ print('Loading VGG16 and its features...')
+ url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
+ with dnnlib.util.open_url(url) as f:
+ vgg16 = torch.jit.load(f).eval().to(device)
+
+ vgg16_features = VGG16FeaturesNVIDIA(vgg16).requires_grad_(False).to(device)
+ del vgg16
+
+ # If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url
+ try:
+ network_pkl = gen_utils.resume_specs[cfg][network_pkl]
+ except KeyError:
+ # Otherwise, it's a local file or an url
+ pass
+
+ print('Loading Generator...')
+ with dnnlib.util.open_url(network_pkl) as f:
+ G = legacy.load_network_pkl(f)['G_ema'].eval().requires_grad_(False).to(device) # type: ignore
+
+ # Stabilize/anchor the latent space
+ if anchor_latent_space:
+ gen_utils.anchor_latent_space(G)
+
+ # Warm up the Generator
+ ws = G.mapping(z=torch.randn(1, 512, device=device), c=None, truncation_psi=1.0)
+ _ = G.synthesis(ws[:1])
+
+ # Label, in case it's a class-conditional model
+ class_idx = gen_utils.parse_class(G, class_idx, ctx)
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print('warn: --class=lbl ignored when running on an unconditional network')
+
+ # Recenter the latent space, if specified
+ if new_center is None:
+ w_avg = G.mapping.w_avg
+ else:
+ new_center, new_center_value = new_center
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
+ if isinstance(new_center_value, int):
+ w_avg = gen_utils.get_w_from_seed(G, device, new_center_value,
+ truncation_psi=1.0) # We want the pure dlatent
+ elif isinstance(new_center_value, np.ndarray):
+ w_avg = torch.from_numpy(new_center_value).to(device)
+ else:
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
+
+ # Set up the video capture dimensions
+ height = demo_height
+ width = int(4.0/3*demo_height) if demo_width is None else demo_width
+ sheight = int(height)
+ swidth = sheight
+
+ # Set display size. If none specified or user selects 'max', will use image resolution of the network
+ display_height = G.img_resolution if (display_height is None or display_height == 'max') else display_height
+
+ # Fix a dlatent to do style mixing
+ static_w = G.mapping(torch.randn(1, G.z_dim, device=device, generator=torch.Generator(device=device).manual_seed(seed)), c=None)
+
+ cam = cv2.VideoCapture(0)
+ idx = 0
+
+ start_time = time.time()
+ x = 1 # displays the frame rate every 1 second if verbose is True
+ counter = 0
+ starting = True # Initialize some default values only one time
+ recording_flag = False
+
+ # Preprocess each image for VGG16
+ preprocess = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])])
+
+ if v0 or v1:
+ while cam.isOpened():
+ # read frame
+ idx += 1
+ ret_val, img = cam.read()
+ img = imutils.resize(img, height=height)
+ if mirror:
+ img = cv2.flip(img, 1)
+ img = np.array(img).transpose(2, 0, 1)
+ img = torch.from_numpy(img).unsqueeze(0).float().to(device)
+
+ frame = preprocess(img).to(device)
+ fake_z = vgg16_features.get_layers_features(frame, layers=[layer])[0]
+
+ # v0
+ if v0:
+ fake_z = fake_z.view(1, 512, -1).mean(2)
+
+ # Perform EMA with previous fake_z
+ if counter == 0:
+ prev_fake_z = fake_z
+ # Do EMA
+ fake_z = 0.2 * prev_fake_z + 0.8 * fake_z
+ prev_fake_z = fake_z
+
+ fake_w = gen_utils.z_to_dlatent(G, fake_z, label, truncation_psi)
+
+ # Do style mixing
+ fake_w[:, 4:] = static_w[:, 4:]
+
+ # v1
+ elif v1:
+ _n, _c, h, w = fake_z.shape
+
+ # Separate into coarse/middle/fine according to areas
+ coarse_fake_z = fake_z[:, :, :h//2, :]
+ middle_fake_z = fake_z[:, :, h//2:, :w//2]
+ fine_fake_z = fake_z[:, :, h//2:, w//2:]
+
+ # Convert them to the expected shape (each region will be their own latent)
+ coarse_fake_z = coarse_fake_z.reshape(1, G.z_dim, -1).mean(2)
+ middle_fake_z = middle_fake_z.reshape(1, G.z_dim, -1).mean(2)
+ fine_fake_z = fine_fake_z.reshape(1, G.z_dim, -1).mean(2)
+
+ # Get the respective dlatents
+ coarse_fake_w = gen_utils.z_to_dlatent(G, coarse_fake_z, label, 1.0)
+ middle_fake_w = gen_utils.z_to_dlatent(G, middle_fake_z, label, 1.0)
+ fine_fake_w = gen_utils.z_to_dlatent(G, fine_fake_z, label, 1.0)
+ fake_w = torch.cat([coarse_fake_w[:, :4], middle_fake_w[:, 4:8], fine_fake_w[:, 8:]], dim=1) # [1, G.num_ws, G.z_dim]
+
+ # Perform EMA with previous fake_w
+ if counter == 0 and starting:
+ prev_fake_w = fake_w
+ starting = False
+ # Do EMA
+ fake_w = 0.4 * prev_fake_w + 0.6 * fake_w
+ prev_fake_w = fake_w
+
+ # Set images to expected data type
+ img = img.clamp(0, 255).data[0].cpu().numpy()
+ img = img.transpose(1, 2, 0).astype('uint8')
+
+ simg = gen_utils.w_to_img(G, fake_w, noise_mode, w_avg, truncation_psi)[0]
+ simg = cv2.cvtColor(simg, cv2.COLOR_BGR2RGB)
+
+ # display
+ if not only_synth:
+ display_width = int(4/3*display_height)
+ # Resize input image from the camera
+ img = cv2.resize(img, (display_width, display_height))
+ # Resize accordingly the synthesized image
+ simg = cv2.resize(simg, (display_height, display_height), interpolation=cv2.INTER_CUBIC)
+ img = np.concatenate((img, simg), axis=1)
+ cv2.imshow('Visuorreactive Demo', img)
+ else:
+ # Resize the synthesized image to the desired display height/width
+ simg = cv2.resize(simg, (display_height, display_height))
+ cv2.imshow('Visuorreactive Demo - Only Synth Image', simg)
+
+ counter += 1
+
+ # FPS counter
+ if (time.time() - start_time) > x and verbose:
+ print(f"FPS: {counter / (time.time() - start_time):0.2f}")
+ counter = 0
+ start_time = time.time()
+
+ # ESC to quit; SPACE to start recording
+ key = cv2.waitKey(1)
+
+ if key == 27:
+ break
+ elif key == 32:
+ # Transition from not recording to recording
+ if not recording_flag:
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ w = width + swidth if not only_synth else G.img_resolution
+ out = cv2.VideoWriter('output.mp4', fourcc, fps, (w, height))
+ recording_flag = True
+ else:
+ recording_flag = False
+ out.release()
+
+ if recording_flag:
+ out.write(img)
+
+ cam.release()
+ cv2.destroyAllWindows()
+
+ elif v2:
+ # TODO: Clean this, this is a mess
+ mp_drawing = mp.solutions.drawing_utils
+ mp_drawing_styles = mp.solutions.drawing_styles
+ mp_hands = mp.solutions.hands
+
+ # Generate a loop of images
+ num_frames = 900
+ shape = [num_frames, 1, G.z_dim]
+ # Generate a loop of images
+ all_latents = np.random.RandomState(seed).randn(*shape).astype(np.float32)
+ all_latents = scipy.ndimage.gaussian_filter(all_latents, sigma=[3.0 * 30, 0, 0], mode='wrap')
+ all_latents /= np.sqrt(np.mean(np.square(all_latents)))
+ all_latents = torch.from_numpy(all_latents).to(device)
+
+ c = 0
+
+ if hasattr(G.synthesis, 'b4'):
+ model_type = 'stylegan2'
+ const_input = copy.deepcopy(G.synthesis.b4.const).cpu().numpy()
+ const_input_interpolation = np.random.randn(num_frames, *const_input.shape).astype(np.float32) # [num_frames, G.w_dim, 4, 4]
+ const_input_interpolation = scipy.ndimage.gaussian_filter(const_input_interpolation, sigma=[fps, 0, 0, 0], mode='wrap')
+ const_input_interpolation /= np.sqrt(np.mean(np.square(const_input_interpolation))) / 2
+ elif hasattr(G.synthesis, 'input'):
+ model_type = 'stylegan3'
+
+ with mp_hands.Hands(
+ model_complexity=0,
+ min_detection_confidence=0.5,
+ min_tracking_confidence=0.5) as hands:
+ counter = 0
+
+ while cam.isOpened():
+ success, image = cam.read()
+ if not success:
+ print("Ignoring empty camera frame.")
+ continue
+
+ # To improve performance, optionally mark the image as not writeable to
+ # pass by reference.
+ image.flags.writeable = False
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ results = hands.process(image)
+
+ # Get the hand rotation w.r.t. the landmarks (concretely, the wrist [0] and middle finger [9]
+ # will define the Y axis; the x-axis will be 90 degrees counter-clockwise from it)
+ # Set the vertical direction as 0.0
+ if results.multi_hand_landmarks:
+ # Just the first hand
+ hand = results.multi_hand_landmarks[-1]
+
+ base = hand.landmark[0]
+ middle = hand.landmark[9]
+ dx = middle.x - base.x
+ dy = middle.y - base.y
+ angle = - np.pi / 2 - np.arctan2(dy, dx) # Set s.t. line(palm, middle finger) is the vertical axis
+
+ # Let's get the position in x and y of the center of the whole hand
+ # Calculate the center of the hand (average of all landmarks)
+ # Set the center of the image as the origin
+ x = 0.0
+ y = 0.0
+ z = 0.0
+ area_points = []
+ for idx, landmark in enumerate(hand.landmark):
+ x += landmark.x
+ y += landmark.y
+ z += landmark.z
+ if idx in range(0, 21, 4):
+ area_points.append([landmark.x, landmark.y])
+ x /= len(hand.landmark)
+ y /= len(hand.landmark)
+ z /= len(hand.landmark)
+
+ x -= 0.5
+ y -= 0.5
+
+ # Calculate the distance to the origin
+ dist = np.sqrt(x ** 2 + y ** 2)
+ # Normalize it
+ dist *= 4*2 ** 0.5 # Max distance is 1/sqrt(2)
+
+ # Get the area of the hand enclosed between the 5 fingers and the wrist
+ # We will use the trapezoidal rule to approximate the area
+ hand_area = 0.0
+ for i in range(len(area_points) - 1):
+ hand_area += (area_points[i][0] - area_points[i + 1][0]) * (area_points[i][1] + area_points[i + 1][1])
+ hand_area += (area_points[-1][0] - area_points[0][0]) * (area_points[-1][1] + area_points[0][1])
+ hand_area = abs(hand_area) / 2
+
+
+ else:
+ # EMAs toward zero
+ angle = 0.0 if starting else prev_angle * 0.9
+ x = 0.0 if starting else prev_x * 0.9
+ y = 0.0 if starting else prev_y * 0.9
+ z = 0.0 if starting else prev_z * 0.9
+ dist = 0.0 if starting else prev_dist * 0.9
+ hand_area = 0.0 if starting else prev_hand_area * 0.9
+
+ if counter == 0 and starting:
+ prev_angle = angle
+ prev_x = x
+ prev_y = y
+ prev_z = z
+ prev_dist = dist
+ prev_hand_area = hand_area
+ starting = False
+
+ # ema these values
+ angle = 0.2 * prev_angle + 0.8 * angle
+ prev_angle = angle
+
+ x = 0.2 * prev_x + 0.8 * x
+ prev_x = x
+ y = 0.2 * prev_y + 0.8 * y
+ prev_y = y
+ z = 0.2 * prev_z + 0.8 * z
+ prev_z = z
+
+ dist = 0.2 * prev_dist + 0.8 * dist
+ prev_dist = dist
+
+ hand_area = 0.2 * prev_hand_area + 0.8 * hand_area
+ prev_hand_area = hand_area
+
+ # FPS and angle
+ if (time.time() - start_time) > x and verbose:
+ print(f'[{c % num_frames} / {num_frames}] FPS: {counter / (time.time() - start_time):0.2f}, '
+ f'Angle (rad): {angle:.3f}, Hand Center: ({x:.3f}, {y:.3f}, {z:.3f}), Distance: {dist:.3f}, Area: {hand_area:.3f}')
+ counter = 0
+ start_time = time.time()
+
+ if hasattr(G.synthesis, 'input'):
+ # Rotate and translate the image
+ m = gen_utils.make_affine_transform(None, angle=angle, translate_x=x, translate_y=-y,
+ scale_x=1+2*z, scale_y=1+2*z)
+ m = np.linalg.inv(m)
+ # Finally, we pass the matrix to the generator
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
+
+ elif hasattr(G.synthesis, 'b4'):
+ G.synthesis.b4.const.copy_(torch.from_numpy((1 - dist) * const_input + const_input_interpolation[c % num_frames] * dist))
+
+ # Draw the hand annotations on the image.
+ image.flags.writeable = True
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ # Replace image with a white background
+ image[:] = (255, 255, 255)
+ if results.multi_hand_landmarks and verbose:
+ for hand_landmarks in results.multi_hand_landmarks:
+ mp_drawing.draw_landmarks(
+ image,
+ hand_landmarks,
+ mp_hands.HAND_CONNECTIONS,
+ mp_drawing_styles.get_default_hand_landmarks_style(),
+ mp_drawing_styles.get_default_hand_connections_style())
+
+ # Get the latent vectors
+ latent = all_latents[c % num_frames]
+ c += 1
+
+ # Synthesize the image
+ simg = gen_utils.z_to_img(G=G, latents=latent, label=label,
+ truncation_psi=truncation_psi, noise_mode=noise_mode)[0]
+ simg = cv2.cvtColor(simg, cv2.COLOR_BGR2RGB)
+
+ # display
+ if not only_synth:
+ # Let's vertically concatenate the input image and the synthesized image
+ # Resize input image from the camera
+ img_width = display_height
+ img_height = int(img_width * height / width)
+ img = cv2.resize(image, (img_width, img_height))
+ w, h = img_width, img_height
+
+ # Resize accordingly the synthesized image
+ simg = cv2.resize(simg, (display_height, display_height), interpolation=cv2.INTER_CUBIC)
+ h += display_height
+
+ # Concatenate and show the images
+ img = np.concatenate((simg, img), axis=0)
+ cv2.imshow('Visuorreactive Demo', img)
+
+ # display_width = int(4/3*display_height)
+ # Resize input image from the camera
+ # img = cv2.resize(image, (display_width, display_height))
+ # w, h = display_width, display_height
+ # Resize accordingly the synthesized image
+ # simg = cv2.resize(simg, (display_height, display_height), interpolation=cv2.INTER_CUBIC)
+ # w += h
+ # img = np.concatenate((img, simg), axis=1)
+ # cv2.imshow('Visuorreactive Demo', img)
+ else:
+ # Resize the synthesized image to the desired display height/width
+ simg = cv2.resize(simg, (display_height, display_height))
+ w, h = display_height, display_height
+ cv2.imshow('Visuorreactive Demo - Only Synth Image', simg)
+
+ counter += 1
+
+ key = cv2.waitKey(1)
+ # User presses 'ESC' to exit
+ if key == 27:
+ break
+ elif key == 32:
+ # Transition from not recording to recording
+ if not recording_flag:
+ print('Recording started')
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ out = cv2.VideoWriter(f'visualreactive_input_v2_{model_type}.mp4', fourcc, fps, (w, h))
+ recording_flag = True
+ else:
+ print('Recording stopped')
+ recording_flag = False
+ out.release()
+
+ if recording_flag:
+ out.write(img)
+ cam.release()
+
+ elif v3:
+ # Set number of rows and columns for the generated "mirror"
+ num_frames = 900
+ nrows = 16
+ ncols = 20
+ shape = [num_frames, nrows * ncols, G.z_dim]
+ # Generate a loop of images
+ all_latents = np.random.RandomState(seed).randn(*shape).astype(np.float32)
+ all_latents = scipy.ndimage.gaussian_filter(all_latents, sigma=[3.0 * 30, 0, 0], mode='wrap')
+ all_latents /= np.sqrt(np.mean(np.square(all_latents)))
+ all_latents = torch.from_numpy(all_latents).to(device)
+
+ c = 0
+ while cam.isOpened():
+ ret, img = cam.read()
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ blur = cv2.GaussianBlur(gray, (5, 5), 0)
+ canny = cv2.Canny(blur, 10, 70)
+ ret, mask = cv2.threshold(canny, 70, 255, cv2.THRESH_BINARY)
+
+ # Truncation trick
+ # Reshape the mask to the same size as the latent vector
+ mask = cv2.resize(mask, (nrows, ncols), interpolation=cv2.INTER_AREA)
+ mask = mask.astype(np.float32) / float(mask.max())
+ trunc = torch.from_numpy(mask).view(-1, 1, 1).to(device)
+ trunc = 1.0 - trunc
+
+ # Get the latent vectors
+ z = all_latents[c % num_frames]
+ w = G.mapping(z, None)
+ w = w * trunc + G.mapping.w_avg * (1 - trunc)
+
+ c += 1
+
+ simg = gen_utils.w_to_img(G, w, truncation_psi=truncation_psi)
+ simg = gen_utils.create_image_grid(simg, (ncols, nrows))
+ simg = cv2.cvtColor(simg, cv2.COLOR_BGR2RGB)
+
+ # Resize the synthesized image to the desired display height/width
+ simg = cv2.resize(simg, (int(display_height * ncols / nrows), display_height))
+
+ cv2.imshow('Video feed', simg)
+
+ counter += 1
+ # FPS counter
+ if (time.time() - start_time) > x and verbose:
+ print(f"FPS: {counter / (time.time() - start_time):0.2f}")
+ counter = 0
+ start_time = time.time()
+
+ key = cv2.waitKey(1)
+ # User presses 'ESC' to exit
+ if key == 27:
+ break
+ cam.release()
+
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == '__main__':
+ live_visual_reactive()
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/__init__.py b/stylegan3-fun/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..939e7c6c8f94c4ea1141885c3c3295fe083b06aa
--- /dev/null
+++ b/stylegan3-fun/metrics/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/stylegan3-fun/metrics/equivariance.py b/stylegan3-fun/metrics/equivariance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5559aca37b45e246a2ed279193cf4d59011f0b9
--- /dev/null
+++ b/stylegan3-fun/metrics/equivariance.py
@@ -0,0 +1,267 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import copy
+import numpy as np
+import torch
+import torch.fft
+from torch_utils.ops import upfirdn2d
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+# Utilities.
+
+def sinc(x):
+ y = (x * np.pi).abs()
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
+
+def lanczos_window(x, a):
+ x = x.abs() / a
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
+
+def rotation_matrix(angle):
+ angle = torch.as_tensor(angle).to(torch.float32)
+ mat = torch.eye(3, device=angle.device)
+ mat[0, 0] = angle.cos()
+ mat[0, 1] = angle.sin()
+ mat[1, 0] = -angle.sin()
+ mat[1, 1] = angle.cos()
+ return mat
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.1.
+
+def apply_integer_translation(x, tx, ty):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.round().to(torch.int64)
+ iy = ty.round().to(torch.int64)
+
+ z = torch.zeros_like(x)
+ m = torch.zeros_like(x)
+ if abs(ix) < W and abs(iy) < H:
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.2.
+
+def apply_fractional_translation(x, tx, ty, a=3):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.floor().to(torch.int64)
+ iy = ty.floor().to(torch.int64)
+ fx = tx - ix
+ fy = ty - iy
+ b = a - 1
+
+ z = torch.zeros_like(x)
+ zx0 = max(ix - b, 0)
+ zy0 = max(iy - b, 0)
+ zx1 = min(ix + a, 0) + W
+ zy1 = min(iy + a, 0) + H
+ if zx0 < zx1 and zy0 < zy1:
+ taps = torch.arange(a * 2, device=x.device) - b
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
+ y = x
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
+ z[:, :, zy0:zy1, zx0:zx1] = y
+
+ m = torch.zeros_like(x)
+ mx0 = max(ix + a, 0)
+ my0 = max(iy + a, 0)
+ mx1 = min(ix - b, 0) + W
+ my1 = min(iy - b, 0) + H
+ if mx0 < mx1 and my0 < my1:
+ m[:, :, my0:my1, mx0:mx1] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Construct an oriented low-pass filter that applies the appropriate
+# bandlimit with respect to the input and output of the given affine 2D
+# image transformation.
+
+def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
+ assert a <= amax < aflt
+ mat = torch.as_tensor(mat).to(torch.float32)
+
+ # Construct 2D filter taps in input & output coordinate spaces.
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
+ yi, xi = torch.meshgrid(taps, taps)
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
+
+ # Convolution of two oriented 2D sinc filters.
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
+
+ # Convolution of two oriented 2D Lanczos windows.
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
+
+ # Construct windowed FIR filter.
+ f = f * w
+
+ # Finalize.
+ c = (aflt - amax) * up
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
+ return f
+
+#----------------------------------------------------------------------------
+# Apply the given affine transformation to a batch of 2D images.
+
+def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
+ _N, _C, H, W = x.shape
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
+
+ # Construct filter.
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
+ p = f.shape[0] // 2
+
+ # Construct sampling grid.
+ theta = mat.inverse()
+ theta[:2, 2] *= 2
+ theta[0, 2] += 1 / up / W
+ theta[1, 2] += 1 / up / H
+ theta[0, :] *= W / (W + p / up * 2)
+ theta[1, :] *= H / (H + p / up * 2)
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
+
+ # Resample image.
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+ # Form mask.
+ m = torch.zeros_like(y)
+ c = p * 2 + 1
+ m[:, :, c:-c, c:-c] = 1
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply fractional rotation to a batch of 2D images. Corresponds to the
+# operator R_\alpha in Appendix E.3.
+
+def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(angle)
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
+
+#----------------------------------------------------------------------------
+# Modify the frequency content of a batch of 2D images as if they had undergo
+# fractional rotation -- but without actually rotating them. Corresponds to
+# the operator R^*_\alpha in Appendix E.3.
+
+def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(-angle)
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
+ y = upfirdn2d.filter2d(x=x, f=f)
+ m = torch.zeros_like(y)
+ c = f.shape[0] // 2
+ m[:, :, c:-c, c:-c] = 1
+ return y, m
+
+#----------------------------------------------------------------------------
+# Compute the selected equivariance metrics for the given generator.
+
+def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ I = torch.eye(3, device=opts.device)
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
+ if M is None:
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ sums = None
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ s = []
+
+ # Randomize noise buffers, if any.
+ for name, buf in G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Run mapping network.
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
+ c = next(c_iter)
+ ws = G.mapping(z=z, c=c)
+
+ # Generate reference image.
+ M[:] = I
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+
+ # Integer translation (EQ-T).
+ if compute_eqt_int:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ t = (t * G.img_resolution).round() / G.img_resolution
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Fractional translation (EQ-T_frac).
+ if compute_eqt_frac:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Rotation (EQ-R).
+ if compute_eqr:
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
+ M[:] = rotation_matrix(-angle)
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
+ mask = ref_mask * pseudo_mask
+ s += [(ref - pseudo).square() * mask, mask]
+
+ # Accumulate results.
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
+ sums = sums + s if sums is not None else s
+ progress.update(num_samples)
+
+ # Compute PSNRs.
+ if opts.num_gpus > 1:
+ torch.distributed.all_reduce(sums)
+ sums = sums.cpu()
+ mses = sums[0::2] / sums[1::2]
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
+ psnrs = tuple(psnrs.numpy())
+ return psnrs[0] if len(psnrs) == 1 else psnrs
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/frechet_inception_distance.py b/stylegan3-fun/metrics/frechet_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d5231cd70cf178ae3664a5974240b102ce3b0cf
--- /dev/null
+++ b/stylegan3-fun/metrics/frechet_inception_distance.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash
+equilibrium". Matches the original implementation by Heusel et al. at
+https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
+
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_fid(opts, max_real, num_gen):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
+
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/inception_score.py b/stylegan3-fun/metrics/inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a3a442b4cd7f993773ee77c5394796c28c2ef8
--- /dev/null
+++ b/stylegan3-fun/metrics/inception_score.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_is(opts, num_gen, num_splits):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
+
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/kernel_inception_distance.py b/stylegan3-fun/metrics/kernel_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69325c1ef4e2894817ef6003e9335c4de657199
--- /dev/null
+++ b/stylegan3-fun/metrics/kernel_inception_distance.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
+GANs". Matches the original implementation by Binkowski et al. at
+https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ n = real_features.shape[1]
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ kid = t / num_subsets / m
+ return float(kid)
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/metric_main.py b/stylegan3-fun/metrics/metric_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..1179712c5105d9c905b772cc9f1c989812a783ce
--- /dev/null
+++ b/stylegan3-fun/metrics/metric_main.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Main API for computing and reporting quality metrics."""
+
+import os
+import time
+import json
+import torch
+import dnnlib
+
+from . import metric_utils
+from . import frechet_inception_distance
+from . import kernel_inception_distance
+from . import precision_recall
+from . import perceptual_path_length
+from . import inception_score
+from . import equivariance
+
+#----------------------------------------------------------------------------
+
+_metric_dict = dict() # name => fn
+
+def register_metric(fn):
+ assert callable(fn)
+ _metric_dict[fn.__name__] = fn
+ return fn
+
+def is_valid_metric(metric):
+ return metric in _metric_dict
+
+def list_valid_metrics():
+ return list(_metric_dict.keys())
+
+#----------------------------------------------------------------------------
+
+def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
+ assert is_valid_metric(metric)
+ opts = metric_utils.MetricOptions(**kwargs)
+
+ # Calculate.
+ start_time = time.time()
+ results = _metric_dict[metric](opts)
+ total_time = time.time() - start_time
+
+ # Broadcast results.
+ for key, value in list(results.items()):
+ if opts.num_gpus > 1:
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
+ torch.distributed.broadcast(tensor=value, src=0)
+ value = float(value.cpu())
+ results[key] = value
+
+ # Decorate with metadata.
+ return dnnlib.EasyDict(
+ results = dnnlib.EasyDict(results),
+ metric = metric,
+ total_time = total_time,
+ total_time_str = dnnlib.util.format_time(total_time),
+ num_gpus = opts.num_gpus,
+ )
+
+#----------------------------------------------------------------------------
+
+def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
+ metric = result_dict['metric']
+ assert is_valid_metric(metric)
+ if run_dir is not None and snapshot_pkl is not None:
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
+
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
+ print(jsonl_line)
+ if run_dir is not None and os.path.isdir(run_dir):
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
+ f.write(jsonl_line + '\n')
+
+#----------------------------------------------------------------------------
+# Recommended metrics.
+
+@register_metric
+def fid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
+ return dict(fid50k_full=fid)
+
+@register_metric
+def kid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k_full=kid)
+
+@register_metric
+def pr50k3_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
+
+@register_metric
+def ppl2_wend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
+ return dict(ppl2_wend=ppl)
+
+@register_metric
+def eqt50k_int(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
+ return dict(eqt50k_int=psnr)
+
+@register_metric
+def eqt50k_frac(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
+ return dict(eqt50k_frac=psnr)
+
+@register_metric
+def eqr50k(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
+ return dict(eqr50k=psnr)
+
+#----------------------------------------------------------------------------
+# Legacy metrics.
+
+@register_metric
+def fid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
+ return dict(fid50k=fid)
+
+@register_metric
+def kid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k=kid)
+
+@register_metric
+def pr50k3(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
+
+@register_metric
+def is50k(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
+ return dict(is50k_mean=mean, is50k_std=std)
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/metric_utils.py b/stylegan3-fun/metrics/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af122b21b5a7874d63b79ee40c2cb36d4ab4e5a2
--- /dev/null
+++ b/stylegan3-fun/metrics/metric_utils.py
@@ -0,0 +1,279 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Miscellaneous utilities used internally by the quality metrics."""
+
+import os
+import time
+import hashlib
+import pickle
+import copy
+import uuid
+import numpy as np
+import torch
+import dnnlib
+
+#----------------------------------------------------------------------------
+
+class MetricOptions:
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
+ assert 0 <= rank < num_gpus
+ self.G = G
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
+ self.num_gpus = num_gpus
+ self.rank = rank
+ self.device = device if device is not None else torch.device('cuda', rank)
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.cache = cache
+
+#----------------------------------------------------------------------------
+
+_feature_detector_cache = dict()
+
+def get_feature_detector_name(url):
+ return os.path.splitext(url.split('/')[-1])[0]
+
+def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
+ assert 0 <= rank < num_gpus
+ key = (url, device)
+ if key not in _feature_detector_cache:
+ is_leader = (rank == 0)
+ if not is_leader and num_gpus > 1:
+ torch.distributed.barrier() # leader goes first
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
+ _feature_detector_cache[key] = pickle.load(f).to(device)
+ if is_leader and num_gpus > 1:
+ torch.distributed.barrier() # others follow
+ return _feature_detector_cache[key]
+
+#----------------------------------------------------------------------------
+
+def iterate_random_labels(opts, batch_size):
+ if opts.G.c_dim == 0:
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
+ while True:
+ yield c
+ else:
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ while True:
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ yield c
+
+#----------------------------------------------------------------------------
+
+class FeatureStats:
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
+ self.capture_all = capture_all
+ self.capture_mean_cov = capture_mean_cov
+ self.max_items = max_items
+ self.num_items = 0
+ self.num_features = None
+ self.all_features = None
+ self.raw_mean = None
+ self.raw_cov = None
+
+ def set_num_features(self, num_features):
+ if self.num_features is not None:
+ assert num_features == self.num_features
+ else:
+ self.num_features = num_features
+ self.all_features = []
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
+
+ def is_full(self):
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
+
+ def append(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ assert x.ndim == 2
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
+ if self.num_items >= self.max_items:
+ return
+ x = x[:self.max_items - self.num_items]
+
+ self.set_num_features(x.shape[1])
+ self.num_items += x.shape[0]
+ if self.capture_all:
+ self.all_features.append(x)
+ if self.capture_mean_cov:
+ x64 = x.astype(np.float64)
+ self.raw_mean += x64.sum(axis=0)
+ self.raw_cov += x64.T @ x64
+
+ def append_torch(self, x, num_gpus=1, rank=0):
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
+ assert 0 <= rank < num_gpus
+ if num_gpus > 1:
+ ys = []
+ for src in range(num_gpus):
+ y = x.clone()
+ torch.distributed.broadcast(y, src=src)
+ ys.append(y)
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
+ self.append(x.cpu().numpy())
+
+ def get_all(self):
+ assert self.capture_all
+ return np.concatenate(self.all_features, axis=0)
+
+ def get_all_torch(self):
+ return torch.from_numpy(self.get_all())
+
+ def get_mean_cov(self):
+ assert self.capture_mean_cov
+ mean = self.raw_mean / self.num_items
+ cov = self.raw_cov / self.num_items
+ cov = cov - np.outer(mean, mean)
+ return mean, cov
+
+ def save(self, pkl_file):
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(self.__dict__, f)
+
+ @staticmethod
+ def load(pkl_file):
+ with open(pkl_file, 'rb') as f:
+ s = dnnlib.EasyDict(pickle.load(f))
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
+ obj.__dict__.update(s)
+ return obj
+
+#----------------------------------------------------------------------------
+
+class ProgressMonitor:
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
+ self.tag = tag
+ self.num_items = num_items
+ self.verbose = verbose
+ self.flush_interval = flush_interval
+ self.progress_fn = progress_fn
+ self.pfn_lo = pfn_lo
+ self.pfn_hi = pfn_hi
+ self.pfn_total = pfn_total
+ self.start_time = time.time()
+ self.batch_time = self.start_time
+ self.batch_items = 0
+ if self.progress_fn is not None:
+ self.progress_fn(self.pfn_lo, self.pfn_total)
+
+ def update(self, cur_items):
+ assert (self.num_items is None) or (cur_items <= self.num_items)
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
+ return
+ cur_time = time.time()
+ total_time = cur_time - self.start_time
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
+ if (self.verbose) and (self.tag is not None):
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
+ self.batch_time = cur_time
+ self.batch_items = cur_items
+
+ if (self.progress_fn is not None) and (self.num_items is not None):
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
+
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
+ return ProgressMonitor(
+ tag = tag,
+ num_items = num_items,
+ flush_interval = flush_interval,
+ verbose = self.verbose,
+ progress_fn = self.progress_fn,
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
+ pfn_total = self.pfn_total,
+ )
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ # Try to lookup from cache.
+ cache_file = None
+ if opts.cache:
+ # Choose cache file name.
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
+
+ # Check if the file exists (all processes must agree).
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
+ if opts.num_gpus > 1:
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
+ torch.distributed.broadcast(tensor=flag, src=0)
+ flag = (float(flag.cpu()) != 0)
+
+ # Load.
+ if flag:
+ return FeatureStats.load(cache_file)
+
+ # Initialize.
+ num_items = len(dataset)
+ if max_items is not None:
+ num_items = min(num_items, max_items)
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
+ for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images.to(opts.device), **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+
+ # Save to cache.
+ if cache_file is not None and opts.rank == 0:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ temp_file = cache_file + '.' + uuid.uuid4().hex
+ stats.save(temp_file)
+ os.replace(temp_file, cache_file) # atomic
+ return stats
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
+ if batch_gen is None:
+ batch_gen = min(batch_size, 4)
+ assert batch_size % batch_gen == 0
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
+
+ # Initialize.
+ stats = FeatureStats(**stats_kwargs)
+ assert stats.max_items is not None
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ while not stats.is_full():
+ images = []
+ for _i in range(batch_size // batch_gen):
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+ img = G(z=z, c=next(c_iter), **opts.G_kwargs)
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ images.append(img)
+ images = torch.cat(images)
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images, **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+ return stats
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/perceptual_path_length.py b/stylegan3-fun/metrics/perceptual_path_length.py
new file mode 100644
index 0000000000000000000000000000000000000000..c68519fea298b076ef317b5ea75e22a77225baaf
--- /dev/null
+++ b/stylegan3-fun/metrics/perceptual_path_length.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
+Architecture for Generative Adversarial Networks". Matches the original
+implementation by Karras et al. at
+https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
+
+import copy
+import numpy as np
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+# Spherical interpolation of a batch of vectors.
+def slerp(a, b, t):
+ a = a / a.norm(dim=-1, keepdim=True)
+ b = b / b.norm(dim=-1, keepdim=True)
+ d = (a * b).sum(dim=-1, keepdim=True)
+ p = t * torch.acos(d)
+ c = b - d * a
+ c = c / c.norm(dim=-1, keepdim=True)
+ d = a * torch.cos(p) + c * torch.sin(p)
+ d = d / d.norm(dim=-1, keepdim=True)
+ return d
+
+#----------------------------------------------------------------------------
+
+class PPLSampler(torch.nn.Module):
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
+ assert space in ['z', 'w']
+ assert sampling in ['full', 'end']
+ super().__init__()
+ self.G = copy.deepcopy(G)
+ self.G_kwargs = G_kwargs
+ self.epsilon = epsilon
+ self.space = space
+ self.sampling = sampling
+ self.crop = crop
+ self.vgg16 = copy.deepcopy(vgg16)
+
+ def forward(self, c):
+ # Generate random latents and interpolation t-values.
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
+
+ # Interpolate in W or Z.
+ if self.space == 'w':
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
+ else: # space == 'z'
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
+
+ # Randomize noise buffers.
+ for name, buf in self.G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Generate images.
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
+
+ # Center crop.
+ if self.crop:
+ assert img.shape[2] == img.shape[3]
+ c = img.shape[2] // 8
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
+
+ # Downsample to 256x256.
+ factor = self.G.img_resolution // 256
+ if factor > 1:
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
+
+ # Scale dynamic range from [-1,1] to [0,255].
+ img = (img + 1) * (255 / 2)
+ if self.G.img_channels == 1:
+ img = img.repeat([1, 3, 1, 1])
+
+ # Evaluate differential LPIPS.
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
+ return dist
+
+#----------------------------------------------------------------------------
+
+def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
+
+ # Setup sampler and labels.
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
+ sampler.eval().requires_grad_(False).to(opts.device)
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ dist = []
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ x = sampler(next(c_iter))
+ for src in range(opts.num_gpus):
+ y = x.clone()
+ if opts.num_gpus > 1:
+ torch.distributed.broadcast(y, src=src)
+ dist.append(y)
+ progress.update(num_samples)
+
+ # Compute PPL.
+ if opts.rank != 0:
+ return float('nan')
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
+ lo = np.percentile(dist, 1, interpolation='lower')
+ hi = np.percentile(dist, 99, interpolation='higher')
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
+ return float(ppl)
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/metrics/precision_recall.py b/stylegan3-fun/metrics/precision_recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..120ef801ba488ae5288d14cbee49b566492b6695
--- /dev/null
+++ b/stylegan3-fun/metrics/precision_recall.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Precision/Recall (PR) from the paper "Improved Precision and Recall
+Metric for Assessing Generative Models". Matches the original implementation
+by Kynkaanniemi et al. at
+https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
+
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
+ assert 0 <= rank < num_gpus
+ num_cols = col_features.shape[0]
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
+ dist_batches = []
+ for col_batch in col_batches[rank :: num_gpus]:
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
+ for src in range(num_gpus):
+ dist_broadcast = dist_batch.clone()
+ if num_gpus > 1:
+ torch.distributed.broadcast(dist_broadcast, src=src)
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
+
+#----------------------------------------------------------------------------
+
+def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ detector_kwargs = dict(return_features=True)
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
+
+ results = dict()
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
+ kth = []
+ for manifold_batch in manifold.split(row_batch_size):
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
+ kth = torch.cat(kth) if opts.rank == 0 else None
+ pred = []
+ for probes_batch in probes.split(row_batch_size):
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
+ return results['precision'], results['recall']
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/multimodal_truncation.py b/stylegan3-fun/multimodal_truncation.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbb5cb2d1a2e65a9395f8df6017c6c25195104af
--- /dev/null
+++ b/stylegan3-fun/multimodal_truncation.py
@@ -0,0 +1,150 @@
+import os
+from typing import List, Optional, Union, Tuple
+import click
+
+import dnnlib
+from torch_utils import gen_utils
+
+import scipy
+import numpy as np
+import PIL.Image
+import torch
+
+import legacy
+
+from sklearn.cluster import KMeans
+from sklearn.preprocessing import StandardScaler
+
+
+# ----------------------------------------------------------------------------
+
+# TODO/hax: Use this for generation: https://huggingface.co/spaces/SIGGRAPH2022/Self-Distilled-StyleGAN/blob/main/model.py
+# SGANXL uses it for generation w/L2 norm: https://github.com/autonomousvision/stylegan_xl/blob/4241ff9cfeb69d617427107a75d69e9d1c2d92f2/torch_utils/gen_utils.py#L428
+@click.group()
+def main():
+ pass
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='get-centroids')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename: can be URL, local file, or the name of the model in torch_utils.gen_utils.resume_specs', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+@click.option('--device', help='Device to use for image generation; using the CPU is slower than the GPU', type=click.Choice(['cpu', 'cuda']), default='cuda', show_default=True)
+# Centroids options
+@click.option('--seed', type=int, help='Random seed to use', default=0, show_default=True)
+@click.option('--num-latents', type=int, help='Number of latents to use for clustering; not recommended to change', default=60000, show_default=True)
+@click.option('--num-clusters', type=click.Choice(['32', '64', '128']), help='Number of cluster centroids to find', default='64', show_default=True)
+# Extra parameters
+@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
+@click.option('--plot-pca', '-pca', is_flag=True, help='Plot and save the PCA of the disentangled latent space W')
+@click.option('--dim-pca', '-dim', type=click.IntRange(min=2, max=3), help='Number of dimensions to use for the PCA', default=3, show_default=True)
+@click.option('--verbose', type=bool, help='Verbose mode for KMeans (during centroids calculation)', show_default=True, default=False)
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'clusters'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='pure_centroids', show_default=True)
+def get_centroids(
+ ctx: click.Context,
+ network_pkl: str,
+ cfg: Optional[str],
+ device: Optional[str],
+ seed: Optional[int],
+ num_latents: Optional[int],
+ num_clusters: Optional[str],
+ anchor_latent_space: Optional[bool],
+ plot_pca: Optional[bool],
+ dim_pca: Optional[int],
+ verbose: Optional[bool],
+ outdir: Union[str, os.PathLike],
+ description: Optional[str]
+):
+ """Find the cluster centers in the latent space of the selected model"""
+ device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu')
+
+ # Load the network
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
+
+ # Setup for using CPU
+ if device.type == 'cpu':
+ gen_utils.use_cpu(G)
+
+ # Stabilize/anchor the latent space
+ if anchor_latent_space:
+ gen_utils.anchor_latent_space(G)
+
+ desc = f'multimodal-truncation-{num_clusters}clusters'
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
+ # Create the run dir with the given name description
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ print('Generating all the latents...')
+ z = torch.from_numpy(np.random.RandomState(seed).randn(num_latents, G.z_dim)).to(device)
+ w = G.mapping(z, None)[:, 0, :]
+
+ # Get the centroids
+ print('Finding the cluster centroids. Patience...')
+ scaler = StandardScaler()
+ scaler.fit(w.cpu())
+
+ # Scale the dlatents and perform KMeans with the selected number of clusters
+ w_scaled = scaler.transform(w.cpu())
+ kmeans = KMeans(n_clusters=int(num_clusters), random_state=0, init='random', verbose=int(verbose)).fit(w_scaled)
+
+ # Get the centroids and inverse transform them to the original space
+ w_avg_multi = torch.from_numpy(scaler.inverse_transform(kmeans.cluster_centers_)).to(device)
+
+ print('Success! Saving the centroids...')
+ for idx, w_avg in enumerate(w_avg_multi):
+ w_avg = torch.tile(w_avg, (1, G.mapping.num_ws, 1))
+ img = gen_utils.w_to_img(G, w_avg)[0]
+ # Save image and dlatent/new centroid
+ PIL.Image.fromarray(img, 'RGB').save(os.path.join(run_dir, f'pure_centroid_no{idx+1:03d}-{num_clusters}clusters.jpg'))
+ np.save(os.path.join(run_dir, f'centroid_{idx+1:03d}-{num_clusters}clusters.npy'), w_avg.unsqueeze(0).cpu().numpy())
+
+ # Save the configuration used
+ ctx.obj = {
+ 'model_options': {
+ 'network_pkl': network_pkl,
+ 'model_configuration': cfg},
+ 'centroids_options': {
+ 'seed': seed,
+ 'num_latents': num_latents,
+ 'num_clusters': num_clusters},
+ 'extra_parameters': {
+ 'anchor_latent_space': anchor_latent_space,
+ 'outdir': run_dir,
+ 'description': description}
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ if plot_pca:
+ print('Plotting the PCA of the disentangled latent space...')
+ import matplotlib.pyplot as plt
+ from sklearn.decomposition import PCA
+
+ pca = PCA(n_components=dim_pca)
+ fit_pca = pca.fit(w_scaled)
+ fit_pca = pca.fit_transform(w_scaled)
+ kmeans_pca = KMeans(n_clusters=int(num_clusters), random_state=0, verbose=0, init='random').fit_predict(fit_pca)
+
+ fig = plt.figure(figsize=(20, 10))
+ ax = fig.add_subplot(111, projection='3d' if dim_pca == 3 else None)
+ axes = fit_pca[:, 0], fit_pca[:, 1], fit_pca[:, 2] if dim_pca == 3 else fit_pca[:, 0], fit_pca[:, 1]
+ ax.scatter(*axes, c=kmeans_pca, cmap='inferno', edgecolor='k', s=40, alpha=0.5)
+ ax.set_title(r"$| \mathcal{W} | \rightarrow $" + f'{dim_pca}')
+ ax.axis('off')
+ plt.savefig(os.path.join(run_dir, f'pca_{dim_pca}dim_{num_clusters}clusters.png'))
+
+ print('Done!')
+
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == '__main__':
+ main()
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/network_features.py b/stylegan3-fun/network_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..aced06dba58c747e430afc459c20f1211eff51e5
--- /dev/null
+++ b/stylegan3-fun/network_features.py
@@ -0,0 +1,335 @@
+import torch
+import torch.nn as nn
+from torchvision import models
+
+import numpy as np
+
+from typing import List, Tuple
+from collections import OrderedDict
+import operator
+
+
+# ----------------------------------------------------------------------------
+
+
+class VGG16Features(torch.nn.Module):
+ """
+ Use pre-trained VGG16 provided by PyTorch. Code modified from lainwired/pacifinapacific
+ https://github.com/pacifinapacific/StyleGAN_LatentEditor. My modification is that we can use
+ the ReLU activation if we want, or the pure conv1_1, conv1_2, conv3_2, and conv4_2 activations.
+
+ My conclusions are that it's best to have one model of VGG, so I will use the one provided by NVIDIA
+ as it is both easier to slice and it can return LPIPS if so desired.
+ """
+ # Image2StyleGAN: How to Embed Images into the StyleGAN latent space? https://arxiv.org/abs/1904.03189,
+ # layers = [0, 2, 12, 19]
+ # Image2StyleGAN++: How to Edit the Embedded Images? https://arxiv.org/abs/1911.11544,
+ # layers = [0, 2, 7, 14], but make sure to return conv3_3 twice for the Style Loss
+ def __init__(self, device, use_relu=False):
+ super(VGG16Features, self).__init__()
+ # Load and partition the model
+ vgg16 = models.vgg16(pretrained=True).to(device)
+ self.vgg16_features = vgg16.features
+ self.avgpool = vgg16.avgpool # TODO: more work can be done to partition any part of the model, but not my jam
+ self.classifier = vgg16.classifier
+
+ self.conv1_1 = torch.nn.Sequential()
+ self.conv1_2 = torch.nn.Sequential()
+ self.conv3_2 = torch.nn.Sequential()
+ self.conv4_2 = torch.nn.Sequential()
+
+ layers = [0, 2, 12, 19]
+ if use_relu:
+ layers = [layer + 1 for layer in layers]
+
+ for i in range(layers[0] + 1):
+ self.conv1_1.add_module(str(i), self.vgg16_features[i])
+
+ for i in range(layers[0] + 1, layers[1] + 1):
+ self.conv1_2.add_module(str(i), self.vgg16_features[i])
+
+ for i in range(layers[1] + 1, layers[2] + 1):
+ self.conv3_2.add_module(str(i), self.vgg16_features[i])
+
+ for i in range(layers[2] + 1, layers[3] + 1):
+ self.conv4_2.add_module(str(i), self.vgg16_features[i])
+
+ # We're not optimizing VGG16
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def get_feature_layers(self, x):
+ conv1_1 = self.conv1_1(x)
+ conv1_2 = self.conv1_2(conv1_1)
+ conv3_2 = self.conv3_2(conv1_2)
+ conv4_2 = self.conv4_2(conv3_2)
+
+ conv1_1 = conv1_1 / torch.numel(conv1_1)
+ conv1_2 = conv1_2 / torch.numel(conv1_2)
+ conv3_2 = conv3_2 / torch.numel(conv3_2)
+ conv4_2 = conv4_2 / torch.numel(conv4_2)
+
+ return conv1_1, conv1_2, conv3_2, conv4_2
+
+
+class VGG16FeaturesNVIDIA(torch.nn.Module):
+ def __init__(self, vgg16):
+ super(VGG16FeaturesNVIDIA, self).__init__()
+ # NOTE: ReLU is already included in the output of every conv output
+ self.conv1_1 = vgg16.layers.conv1
+ self.conv1_2 = vgg16.layers.conv2
+ self.pool1 = vgg16.layers.pool1
+
+ self.conv2_1 = vgg16.layers.conv3
+ self.conv2_2 = vgg16.layers.conv4
+ self.pool2 = vgg16.layers.pool2
+
+ self.conv3_1 = vgg16.layers.conv5
+ self.conv3_2 = vgg16.layers.conv6
+ self.conv3_3 = vgg16.layers.conv7
+ self.pool3 = vgg16.layers.pool3
+
+ self.conv4_1 = vgg16.layers.conv8
+ self.conv4_2 = vgg16.layers.conv9
+ self.conv4_3 = vgg16.layers.conv10
+ self.pool4 = vgg16.layers.pool4
+
+ self.conv5_1 = vgg16.layers.conv11
+ self.conv5_2 = vgg16.layers.conv12
+ self.conv5_3 = vgg16.layers.conv13
+ self.pool5 = vgg16.layers.pool5
+ self.adavgpool = torch.nn.AdaptiveAvgPool2d(output_size=(7, 7)) # We need this for 256x256 images (> 224x224)
+
+ self.fc1 = vgg16.layers.fc1
+ self.fc2 = vgg16.layers.fc2
+ self.fc3 = vgg16.layers.fc3
+ self.softmax = vgg16.layers.softmax
+
+ def get_layers_features(self, x: torch.Tensor, layers: List[str], normed: bool = False, sqrt_normed: bool = False):
+ """
+ x is an image/tensor of shape [1, 3, 256, 256], and layers is a list of the names of the layers you wish
+ to return in order to compare the activations/features with another image.
+
+ Example:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ img1 = torch.randn(1, 3, 256, 256, device=device)
+ img2 = torch.randn(1, 3, 256, 256, device=device)
+ layers = ['conv1_1', 'conv1_2', 'conv3_3', 'conv3_3', 'fc3'] # Indeed, return twice conv3_3
+
+ # Load the VGG16 feature detector.
+ url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
+ with dnnlib.util.open_url(url) as f:
+ vgg16 = torch.jit.load(f).eval().to(device)
+
+ vgg16 = VGG16FeaturesNVIDIA(vgg16)
+
+ # Get the desired features from the layers list
+ features1 = vgg16.get_layers_features(img1, layers)
+ features2 = vgg16.get_layers_features(img2, layers)
+
+ # Get, e.g., the MSE loss between the two features
+ mse = torch.nn.MSELoss(reduction='mean')
+ loss = sum(map(lambda x, y: mse(x, y), features1, features2))
+ """
+ # Legend: => conv2d, -> max pool 2d, ~> adaptive average pool 2d, ->> fc layer; shapes of input/output are shown
+ assert layers is not None
+
+ features_dict = OrderedDict()
+ features_dict['conv1_1'] = self.conv1_1(x) # [1, 3, 256, 256] => [1, 64, 256, 256]
+ features_dict['conv1_2'] = self.conv1_2(features_dict['conv1_1']) # [1, 64, 256, 256] => [1, 64, 256, 256]
+ features_dict['pool1'] = self.pool1(features_dict['conv1_2']) # [1, 64, 256, 256] -> [1, 64, 128, 128]
+
+ features_dict['conv2_1'] = self.conv2_1(features_dict['pool1']) # [1, 64, 128, 128] => [1, 128, 128, 128]
+ features_dict['conv2_2'] = self.conv2_2(features_dict['conv2_1']) # [1, 128, 128, 128] => [1, 128, 128, 128]
+ features_dict['pool2'] = self.pool2(features_dict['conv2_2']) # [1, 128, 128, 128] -> [1, 128, 64, 64]
+
+ features_dict['conv3_1'] = self.conv3_1(features_dict['pool2']) # [1, 128, 64, 64] => [1, 256, 64, 64]
+ features_dict['conv3_2'] = self.conv3_2(features_dict['conv3_1']) # [1, 256, 64, 64] => [1, 256, 64, 64]
+ features_dict['conv3_3'] = self.conv3_3(features_dict['conv3_2']) # [1, 256, 64, 64] => [1, 256, 64, 64]
+ features_dict['pool3'] = self.pool3(features_dict['conv3_3']) # [1, 256, 64, 64] -> [1, 256, 32, 32]
+
+ features_dict['conv4_1'] = self.conv4_1(features_dict['pool3']) # [1, 256, 32, 32] => [1, 512, 32, 32]
+ features_dict['conv4_2'] = self.conv4_2(features_dict['conv4_1']) # [1, 512, 32, 32] => [1, 512, 32, 32]
+ features_dict['conv4_3'] = self.conv4_3(features_dict['conv4_2']) # [1, 512, 32, 32] => [1, 512, 32, 32]
+ features_dict['pool4'] = self.pool4(features_dict['conv4_3']) # [1, 512, 32, 32] -> [1, 512, 16, 16]
+
+ features_dict['conv5_1'] = self.conv5_1(features_dict['pool4']) # [1, 512, 16, 16] => [1, 512, 16, 16]
+ features_dict['conv5_2'] = self.conv5_2(features_dict['conv5_1']) # [1, 512, 16, 16] => [1, 512, 16, 16]
+ features_dict['conv5_3'] = self.conv5_3(features_dict['conv5_2']) # [1, 512, 16, 16] => [1, 512, 16, 16]
+ features_dict['pool5'] = self.pool5(features_dict['conv5_3']) # [1, 512, 16, 16] -> [1, 512, 8, 8]
+
+ features_dict['adavgpool'] = self.adavgpool(features_dict['pool5']) # [1, 512, 8, 8] ~> [1, 512, 7, 7]
+ features_dict['fc1'] = self.fc1(features_dict['adavgpool']) # [1, 512, 7, 7] ->> [1, 4096]; w/ReLU
+ features_dict['fc2'] = self.fc2(features_dict['fc1']) # [1, 4096] ->> [1, 4096]; w/ReLU
+ features_dict['fc3'] = self.softmax(self.fc3(features_dict['fc2'])) # [1, 4096] ->> [1, 1000]; w/o ReLU; apply softmax
+
+ result_list = list()
+ for layer in layers:
+ if normed:
+ # Divide each layer by the number of elements in it
+ result_list.append(features_dict[layer] / torch.numel(features_dict[layer]))
+ elif sqrt_normed:
+ # Divide each layer by the square root of the number of elements in it
+ result_list.append(features_dict[layer] / torch.tensor(torch.numel(features_dict[layer]),
+ dtype=torch.float).sqrt())
+ else:
+ result_list.append(features_dict[layer])
+ return result_list
+
+
+# ----------------------------------------------------------------------------
+
+
+class DiscriminatorFeatures(torch.nn.Module):
+ def __init__(self, D):
+ super(DiscriminatorFeatures, self).__init__()
+
+ # assert D.init_kwargs.architecture == 'resnet' # removed as some resnet models don't have this attribute
+ self.block_resolutions = D.block_resolutions
+
+ # For loop to get all the inner features of the trained Discriminator with a resnet architecture
+ for res in self.block_resolutions:
+ if res == D.img_resolution:
+ setattr(self, 'from_rgb', operator.attrgetter(f'b{res}.fromrgb')(D))
+ setattr(self, f'b{res}_skip', operator.attrgetter(f'b{res}.skip')(D))
+ setattr(self, f'b{res}_conv0', operator.attrgetter(f'b{res}.conv0')(D))
+ setattr(self, f'b{res}_conv1', operator.attrgetter(f'b{res}.conv1')(D))
+
+ # Unique, last block with a fc/out, so we can extract features in a regular fashion
+ setattr(self, 'b4_mbstd', D.b4.mbstd)
+ setattr(self, 'b4_conv', D.b4.conv)
+ setattr(self, 'adavgpool', nn.AdaptiveAvgPool2d(4)) # Necessary if images are of different resolution than D.img_resolution
+ setattr(self, 'fc', D.b4.fc)
+ setattr(self, 'out', D.b4.out)
+
+ def get_block_resolutions(self):
+ """Get the block resolutions available for the current Discriminator. Remove?"""
+ return self.block_resolutions
+
+ def get_layers_features(self,
+ x: torch.Tensor, # Input image
+ layers: List[str] = None,
+ channels: List[int] = None,
+ normed: bool = False,
+ sqrt_normed: bool = False) -> Tuple[torch.Tensor, ...]:
+ """
+ Get the feature of a specific layer of the Discriminator (with resnet architecture). The following shows the
+ shapes of an image, x, as it flows through the different blocks that compose the Discriminator.
+
+ *** Legend: => conv2d, -> flatten, ->> fc layer, ~> mbstd layer, +> adaptive average pool ***
+
+ # First block / DiscriminatorBlock
+ from_rgb = self.from_rgb(x) # [1, 3, 1024, 1024] => [1, 32, 1024, 1024]
+ b1024_skip = self.b1024_skip(from_rgb, gain=np.sqrt(0.5)) # [1, 32, 1024, 1024] => [1, 64, 512, 512]
+ b1024_conv0 = self.b1024_conv0(from_rgb) # [1, 32, 1024, 1024] => [1, 32, 1024, 1024]
+ b1024_conv1 = self.b1024_conv1(b1024_conv0, gain=np.sqrt(0.5)) # [1, 32, 1024, 1024] => [1, 64, 512, 512]
+ b1024_conv1 = b1024_skip.add_(b1024_conv1) # [1, 64, 512, 512]
+
+ # Second block / DiscriminatorBlock
+ b512_skip = self.b512_skip(b1024_conv1, gain=np.sqrt(0.5)) # [1, 64, 512, 512] => [1, 128, 256, 256]
+ b512_conv0 = self.b512_conv0(b1024_conv1) # [1, 64, 512, 512] => [1, 64, 512, 512]
+ b512_conv1 = self.b512_conv1(b512_conv0, gain=np.sqrt(0.5)) # [1, 64, 512, 512] => [1, 128, 256, 256]
+ b512_conv1 = b512_skip.add_(b512_conv1) # [1, 128, 256, 256]
+
+ # Third block / DiscriminatorBlock
+ b256_skip = self.b256_skip(b512_conv1, gain=np.sqrt(0.5)) # [1, 128, 256, 256] => [1, 256, 128, 128]
+ b256_conv0 = self.b256_conv0(b512_conv1) # [1, 128, 256, 256] => [1, 128, 256, 256]
+ b256_conv1 = self.b256_conv1(b256_conv0, gain=np.sqrt(0.5)) # [1, 128, 256, 256] => [1, 256, 128, 128]
+ b256_conv1 = b256_skip.add_(b256_conv1) # [1, 256, 128, 128]
+
+ # Fourth block / DiscriminatorBlock
+ b128_skip = self.b128_skip(b256_conv1, gain=np.sqrt(0.5)) # [1, 256, 128, 128] => [1, 512, 64 ,64]
+ b128_conv0 = self.b128_conv0(b256_conv1) # [1, 256, 128, 128] => [1, 256, 128, 128]
+ b128_conv1 = self.b128_conv1(b128_conv0, gain=np.sqrt(0.5)) # [1, 256, 128, 128] => [1, 512, 64, 64]
+ b128_conv1 = b128_skip.add_(b128_conv1) # [1, 512, 64, 64]
+
+ # Fifth block / DiscriminatorBlock
+ b64_skip = self.b64_skip(b128_conv1, gain=np.sqrt(0.5)) # [1, 512, 64, 64] => [1, 512, 32, 32]
+ b64_conv0 = self.b64_conv0(b128_conv1) # [1, 512, 64, 64] => [1, 512, 64, 64]
+ b64_conv1 = self.b64_conv1(b64_conv0, gain=np.sqrt(0.5)) # [1, 512, 64, 64] => [1, 512, 32, 32]
+ b64_conv1 = b64_skip.add_(b64_conv1) # [1, 512, 32, 32]
+
+ # Sixth block / DiscriminatorBlock
+ b32_skip = self.b32_skip(b64_conv1, gain=np.sqrt(0.5)) # [1, 512, 32, 32] => [1, 512, 16, 16]
+ b32_conv0 = self.b32_conv0(b64_conv1) # [1, 512, 32, 32] => [1, 512, 32, 32]
+ b32_conv1 = self.b32_conv1(b32_conv0, gain=np.sqrt(0.5)) # [1, 512, 32, 32] => [1, 512, 16, 16]
+ b32_conv1 = b32_skip.add_(b32_conv1) # [1, 512, 16, 16]
+
+ # Seventh block / DiscriminatorBlock
+ b16_skip = self.b16_skip(b32_conv1, gain=np.sqrt(0.5)) # [1, 512, 16, 16] => [1, 512, 8, 8]
+ b16_conv0 = self.b16_conv0(b32_conv1) # [1, 512, 16, 16] => [1, 512, 16, 16]
+ b16_conv1 = self.b16_conv1(b16_conv0, gain=np.sqrt(0.5)) # [1, 512, 16, 16] => [1, 512, 8, 8]
+ b16_conv1 = b16_skip.add_(b16_conv1) # [1, 512, 8, 8]
+
+ # Eighth block / DiscriminatorBlock
+ b8_skip = self.b8_skip(b16_conv1, gain=np.sqrt(0.5)) # [1, 512, 8, 8] => [1, 512, 4, 4]
+ b8_conv0 = self.b8_conv0(b16_conv1) # [1, 512, 8, 8] => [1, 512, 8, 8]
+ b8_conv1 = self.b8_conv1(b8_conv0, gain=np.sqrt(0.5)) # [1, 512, 8, 8] => [1, 512, 4, 4]
+ b8_conv1 = b8_skip.add_(b8_conv1) # [1, 512, 4, 4]
+
+ # Ninth block / DiscriminatorEpilogue
+ b4_mbstd = self.b4_mbstd(b8_conv1) # [1, 512, 4, 4] ~> [1, 513, 4, 4]
+ b4_conv = self.adavgpool(self.b4_conv(b4_mbstd)) # [1, 513, 4, 4] => [1, 512, 4, 4] +> [1, 512, 4, 4]
+ fc = self.fc(b4_conv.flatten(1)) # [1, 512, 4, 4] -> [1, 8192] ->> [1, 512]
+ out = self.out(fc) # [1, 512] ->> [1, 1]
+ """
+ assert not (normed and sqrt_normed), 'Choose one of the normalizations!'
+
+ # Return the full output if no layers are indicated
+ if layers is None:
+ layers = ['out']
+
+ features_dict = OrderedDict() # Can just be a dictionary, but I plan to use the order of the features later on
+ features_dict['from_rgb'] = getattr(self, 'from_rgb')(x) # [1, 3, D.img_resolution, D.img_resolution] =>
+ # => [1, 32, D.img_resolution, D.img_resolution]
+
+ for idx, res in enumerate(self.block_resolutions):
+
+ # conv0 and skip from the first block use from_rgb
+ if idx == 0:
+ features_dict[f'b{res}_skip'] = getattr(self, f'b{res}_skip')(
+ features_dict['from_rgb'], gain=np.sqrt(0.5))
+ features_dict[f'b{res}_conv0'] = getattr(self, f'b{res}_conv0')(features_dict['from_rgb'])
+
+ # The rest use the previous block's conv1
+ else:
+ features_dict[f'b{res}_skip'] = getattr(self, f'b{res}_skip')(
+ features_dict[f'b{self.block_resolutions[idx - 1]}_conv1'], gain=np.sqrt(0.5)
+ )
+ features_dict[f'b{res}_conv0'] = getattr(self, f'b{res}_conv0')(
+ features_dict[f'b{self.block_resolutions[idx - 1]}_conv1']
+ )
+ # Finally, pass the current block's conv0 and do the skip connection addition
+ features_dict[f'b{res}_conv1'] = getattr(self, f'b{res}_conv1')(features_dict[f'b{res}_conv0'],
+ gain=np.sqrt(0.5))
+ features_dict[f'b{res}_conv1'] = features_dict[f'b{res}_skip'].add_(features_dict[f'b{res}_conv1'])
+
+ # Irrespective of the image size/model size, the last block will be the same:
+ features_dict['b4_mbstd'] = getattr(self, 'b4_mbstd')(features_dict['b8_conv1']) # [1, 512, 4, 4] ~> [1, 513, 4, 4]
+ features_dict['b4_conv'] = getattr(self, 'b4_conv')(features_dict['b4_mbstd']) # [1, 513, 4, 4] => [1, 512, 4, 4]
+ features_dict['b4_conv'] = getattr(self, 'adavgpool')(features_dict['b4_conv']) # [1, 512, 4, 4] +> [1, 512, 4, 4] (Needed if x's resolution is not D.img_resolution)
+ features_dict['fc'] = getattr(self, 'fc')(features_dict['b4_conv'].flatten(1)) # [1, 512, 4, 4] -> [1, 8192] ->> [1, 512]
+ features_dict['out'] = getattr(self, 'out')(features_dict['fc']) # [1, 512] ->> [1, 1]
+
+ result_list = list()
+ for layer in layers:
+ if channels is not None:
+ max_channels = features_dict[layer].shape[1] # The number of channels in the layer
+ channels = [c for c in channels if c < max_channels] # Remove channels that are too high
+ channels = [c for c in channels if c >= 0] # Remove channels that are too low
+ channels = list(set(channels)) # Remove duplicates
+ if layer not in ['fc', 'out']:
+ features_dict[layer] = features_dict[layer][:, channels, :, :] # [1, max_channels, size, size] => [1, len(channels), size, size]
+ else:
+ features_dict[layer] = features_dict[layer][:, channels] # [1, max_channels] => [1, len(channels)]
+ # Two options to normalize, otherwise we only add the unmodified output; recommended if using more than one layer
+ if normed:
+ result_list.append(features_dict[layer] / torch.numel(features_dict[layer]))
+ elif sqrt_normed:
+ result_list.append(features_dict[layer] / torch.tensor(torch.numel(features_dict[layer]),
+ dtype=torch.float).sqrt())
+ else:
+ result_list.append(features_dict[layer])
+
+ return tuple(result_list)
diff --git a/stylegan3-fun/out/seed0002.png b/stylegan3-fun/out/seed0002.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b41b60fdbfa4fe0c64d4e40a50ba2f4d41c4011
--- /dev/null
+++ b/stylegan3-fun/out/seed0002.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7929bc0e024c6f49582d68e91806aca60e41b9d114c132e537e79f4d22c96eaa
+size 1142534
diff --git a/stylegan3-fun/projector.py b/stylegan3-fun/projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdbac926319370ee242504b3349e9c32ac4a4728
--- /dev/null
+++ b/stylegan3-fun/projector.py
@@ -0,0 +1,551 @@
+import copy
+import os
+from time import perf_counter
+
+import click
+from typing import List, Tuple
+import imageio
+import numpy as np
+import PIL.Image
+
+import torch
+import torch.nn.functional as F
+
+import dnnlib
+from dnnlib.util import format_time
+import legacy
+
+from torch_utils import gen_utils
+from tqdm import tqdm
+from pytorch_ssim import SSIM # from https://github.com/Po-Hsun-Su/pytorch-ssim
+
+from network_features import VGG16FeaturesNVIDIA, DiscriminatorFeatures
+
+from metrics import metric_utils
+
+# ----------------------------------------------------------------------------
+
+
+def project(
+ G,
+ target: PIL.Image.Image, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
+ *,
+ projection_seed: int,
+ truncation_psi: float,
+ num_steps: int = 1000,
+ w_avg_samples: int = 10000,
+ initial_learning_rate: float = 0.1,
+ initial_noise_factor: float = 0.05,
+ constant_learning_rate: bool = False,
+ lr_rampdown_length: float = 0.25,
+ lr_rampup_length: float = 0.05,
+ noise_ramp_length: float = 0.75,
+ regularize_noise_weight: float = 1e5,
+ project_in_wplus: bool = False,
+ loss_paper: str = 'sgan2', # ['sgan2' || Experimental: 'im2sgan' | 'clip' | 'discriminator']
+ normed: bool = False,
+ sqrt_normed: bool = False,
+ start_wavg: bool = True,
+ device: torch.device,
+ D = None) -> Tuple[torch.Tensor, dict]: # output shape: [num_steps, C, 512], C depending on resolution of G
+ """
+ Projecting a 'target' image into the W latent space. The user has an option to project into W+, where all elements
+ in the latent vector are different. Likewise, the projection process can start from the W midpoint or from a random
+ point, though results have shown that starting from the midpoint (start_wavg) yields the best results.
+ """
+ assert target.size == (G.img_resolution, G.img_resolution)
+
+ G = copy.deepcopy(G).eval().requires_grad_(False).to(device)
+
+ # Compute w stats.
+ z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
+ w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
+ if project_in_wplus: # Thanks to @pbaylies for a clean way on how to do this
+ print('Projecting in W+ latent space...')
+ if start_wavg:
+ print(f'Starting from W midpoint using {w_avg_samples} samples...')
+ w_avg = torch.mean(w_samples, dim=0, keepdim=True) # [1, L, C]
+ else:
+ print(f'Starting from a random vector (seed: {projection_seed})...')
+ z = np.random.RandomState(projection_seed).randn(1, G.z_dim)
+ w_avg = G.mapping(torch.from_numpy(z).to(device), None) # [1, L, C]
+ w_avg = G.mapping.w_avg + truncation_psi * (w_avg - G.mapping.w_avg)
+ else:
+ print('Projecting in W latent space...')
+ w_samples = w_samples[:, :1, :] # [N, 1, C]
+ if start_wavg:
+ print(f'Starting from W midpoint using {w_avg_samples} samples...')
+ w_avg = torch.mean(w_samples, dim=0, keepdim=True) # [1, 1, C]
+ else:
+ print(f'Starting from a random vector (seed: {projection_seed})...')
+ z = np.random.RandomState(projection_seed).randn(1, G.z_dim)
+ w_avg = G.mapping(torch.from_numpy(z).to(device), None)[:, :1, :] # [1, 1, C]; fake w_avg
+ w_avg = G.mapping.w_avg + truncation_psi * (w_avg - G.mapping.w_avg)
+ w_std = (torch.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
+ # Setup noise inputs (only for StyleGAN2 models)
+ noise_buffs = {name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name}
+
+ # Features for target image. Reshape to 256x256 if it's larger to use with VGG16 (unnecessary for CLIP due to preprocess step)
+ if loss_paper in ['sgan2', 'im2sgan', 'discriminator']:
+ target = np.array(target, dtype=np.uint8)
+ target = torch.tensor(target.transpose([2, 0, 1]), device=device)
+ target = target.unsqueeze(0).to(device).to(torch.float32)
+ if target.shape[2] > 256:
+ target = F.interpolate(target, size=(256, 256), mode='area')
+
+ if loss_paper in ['sgan2', 'im2sgan']:
+ # Load the VGG16 feature detector.
+ url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ vgg16 = metric_utils.get_feature_detector(url, device=device)
+
+ # Define the target features and possible new losses
+ if loss_paper == 'sgan2':
+ target_features = vgg16(target, resize_images=False, return_lpips=True)
+
+ elif loss_paper == 'im2sgan':
+ # Use specific layers
+ vgg16_features = VGG16FeaturesNVIDIA(vgg16)
+ # Too cumbersome to add as command-line arg, so we leave it here; use whatever you need, as many times as needed
+ layers = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2',
+ 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3', 'fc1', 'fc2', 'fc3']
+ target_features = vgg16_features.get_layers_features(target, layers, normed=normed, sqrt_normed=sqrt_normed)
+ # Uncomment the next line if you also want to use LPIPS features
+ # lpips_target_features = vgg16(target_images, resize_images=False, return_lpips=True)
+
+ mse = torch.nn.MSELoss(reduction='mean')
+ ssim_out = SSIM() # can be used as a loss; recommended usage: ssim_loss = 1 - ssim_out(img1, img2)
+
+ elif loss_paper == 'discriminator':
+ disc = DiscriminatorFeatures(D).requires_grad_(False).to(device)
+
+ layers = ['b128_conv0', 'b128_conv1', 'b64_conv0', 'b64_conv1', 'b32_conv0', 'b32_conv1',
+ 'b16_conv0', 'b16_conv1', 'b8_conv0', 'b8_conv1', 'b4_conv']
+
+ target_features = disc.get_layers_features(target, layers, normed=normed, sqrt_normed=sqrt_normed)
+ mse = torch.nn.MSELoss(reduction='mean')
+ ssim_out = SSIM()
+
+ elif loss_paper == 'clip':
+ import clip
+ model, preprocess = clip.load('ViT-B/32', device=device) # TODO: let user decide which model to use (use list given by clip.available_models()
+
+ target = preprocess(target).unsqueeze(0).to(device)
+ # text = either we give a target image or a text as target
+ target_features = model.encode_image(target)
+
+ mse = torch.nn.MSELoss(reduction='mean')
+
+ w_opt = w_avg.clone().detach().requires_grad_(True)
+ w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
+ optimizer = torch.optim.Adam([w_opt] + list(noise_buffs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
+
+ # Init noise.
+ for buf in noise_buffs.values():
+ buf[:] = torch.randn_like(buf)
+ buf.requires_grad = True
+
+ for step in range(num_steps):
+ # Learning rate schedule.
+ t = step / num_steps
+ w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
+
+ if constant_learning_rate:
+ # Turn off the rampup/rampdown of the learning rate
+ lr_ramp = 1.0
+ else:
+ lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
+ lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
+ lr = initial_learning_rate * lr_ramp
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+ # Synth images from opt_w.
+ w_noise = torch.randn_like(w_opt) * w_noise_scale
+ if project_in_wplus:
+ ws = w_opt + w_noise
+ else:
+ ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
+ synth_images = G.synthesis(ws, noise_mode='const')
+
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
+ synth_images = (synth_images + 1) * (255/2)
+ if synth_images.shape[2] > 256:
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
+
+ # Reshape synthetic images if G was trained with grayscale data
+ if synth_images.shape[1] == 1:
+ synth_images = synth_images.repeat(1, 3, 1, 1) # [1, 1, 256, 256] => [1, 3, 256, 256]
+
+ # Features for synth images.
+ if loss_paper == 'sgan2':
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
+ dist = (target_features - synth_features).square().sum()
+
+ # Noise regularization.
+ reg_loss = 0.0
+ for v in noise_buffs.values():
+ noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
+ while True:
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
+ if noise.shape[2] <= 8:
+ break
+ noise = F.avg_pool2d(noise, kernel_size=2)
+ loss = dist + reg_loss * regularize_noise_weight
+ # Print in the same line (avoid cluttering the commandline)
+ n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1
+ message = f'step {step + 1:{n_digits}d}/{num_steps}: dist {dist:.7e} | loss {loss.item():.7e}'
+ print(message, end='\r')
+
+ last_status = {'dist': dist.item(), 'loss': loss.item()}
+
+ elif loss_paper == 'im2sgan':
+ # Uncomment to also use LPIPS features as loss (must be better fine-tuned):
+ # lpips_synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
+
+ synth_features = vgg16_features.get_layers_features(synth_images, layers, normed=normed, sqrt_normed=sqrt_normed)
+ percept_error = sum(map(lambda x, y: mse(x, y), target_features, synth_features))
+
+ # Also uncomment to add the LPIPS loss to the perception error (to-be better fine-tuned)
+ # percept_error += 1e1 * (lpips_target_features - lpips_synth_features).square().sum()
+
+ # Pixel-level MSE
+ mse_error = mse(synth_images, target) / (G.img_channels * G.img_resolution * G.img_resolution)
+ ssim_loss = ssim_out(target, synth_images) # tracking SSIM (can also be added the total loss)
+ loss = percept_error + mse_error # + 1e-2 * (1 - ssim_loss) # needs to be fine-tuned
+
+ # Noise regularization.
+ reg_loss = 0.0
+ for v in noise_buffs.values():
+ noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
+ while True:
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
+ if noise.shape[2] <= 8:
+ break
+ noise = F.avg_pool2d(noise, kernel_size=2)
+ loss += reg_loss * regularize_noise_weight
+ # We print in the same line (avoid cluttering the commandline)
+ n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1
+ message = f'step {step + 1:{n_digits}d}/{num_steps}: percept loss {percept_error.item():.7e} | ' \
+ f'pixel mse {mse_error.item():.7e} | ssim {ssim_loss.item():.7e} | loss {loss.item():.7e}'
+ print(message, end='\r')
+
+ last_status = {'percept_error': percept_error.item(),
+ 'pixel_mse': mse_error.item(),
+ 'ssim': ssim_loss.item(),
+ 'loss': loss.item()}
+
+ elif loss_paper == 'discriminator':
+ synth_features = disc.get_layers_features(synth_images, layers, normed=normed, sqrt_normed=sqrt_normed)
+ percept_error = sum(map(lambda x, y: mse(x, y), target_features, synth_features))
+
+ # Also uncomment to add the LPIPS loss to the perception error (to-be better fine-tuned)
+ # percept_error += 1e1 * (lpips_target_features - lpips_synth_features).square().sum()
+
+ # Pixel-level MSE
+ mse_error = mse(synth_images, target) / (G.img_channels * G.img_resolution * G.img_resolution)
+ ssim_loss = ssim_out(target, synth_images) # tracking SSIM (can also be added the total loss)
+ loss = percept_error + mse_error # + 1e-2 * (1 - ssim_loss) # needs to be fine-tuned
+
+ # Noise regularization.
+ reg_loss = 0.0
+ for v in noise_buffs.values():
+ noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
+ while True:
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
+ if noise.shape[2] <= 8:
+ break
+ noise = F.avg_pool2d(noise, kernel_size=2)
+ loss += reg_loss * regularize_noise_weight
+ # We print in the same line (avoid cluttering the commandline)
+ n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1
+ message = f'step {step + 1:{n_digits}d}/{num_steps}: percept loss {percept_error.item():.7e} | ' \
+ f'pixel mse {mse_error.item():.7e} | ssim {ssim_loss.item():.7e} | loss {loss.item():.7e}'
+ print(message, end='\r')
+
+ last_status = {'percept_error': percept_error.item(),
+ 'pixel_mse': mse_error.item(),
+ 'ssim': ssim_loss.item(),
+ 'loss': loss.item()}
+
+ elif loss_paper == 'clip':
+
+ import torchvision.transforms as T
+ synth_img = F.interpolate(synth_images, size=(224, 224), mode='area')
+ prep = T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
+ synth_img = prep(synth_img)
+ # synth_images = synth_images.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy()[0] # NCWH => WHC
+ # synth_images = preprocess(PIL.Image.fromarray(synth_images, 'RGB')).unsqueeze(0).to(device)
+ synth_features = model.encode_image(synth_img)
+ dist = mse(target_features, synth_features)
+
+ # Noise regularization.
+ reg_loss = 0.0
+ for v in noise_buffs.values():
+ noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
+ while True:
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
+ if noise.shape[2] <= 8:
+ break
+ noise = F.avg_pool2d(noise, kernel_size=2)
+ loss = dist + reg_loss * regularize_noise_weight
+ # Print in the same line (avoid cluttering the commandline)
+ n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1
+ message = f'step {step + 1:{n_digits}d}/{num_steps}: dist {dist:.7e}'
+ print(message, end='\r')
+
+ last_status = {'dist': dist.item(), 'loss': loss.item()}
+
+ # Step
+ optimizer.zero_grad(set_to_none=True)
+ loss.backward()
+ optimizer.step()
+
+ # Save projected W for each optimization step.
+ w_out[step] = w_opt.detach()[0]
+
+ # Normalize noise.
+ with torch.no_grad():
+ for buf in noise_buffs.values():
+ buf -= buf.mean()
+ buf *= buf.square().mean().rsqrt()
+
+ # Save run config
+ run_config = {
+ 'optimization_options': {
+ 'num_steps': num_steps,
+ 'initial_learning_rate': initial_learning_rate,
+ 'constant_learning_rate': constant_learning_rate,
+ 'regularize_noise_weight': regularize_noise_weight,
+ },
+ 'projection_options': {
+ 'w_avg_samples': w_avg_samples,
+ 'initial_noise_factor': initial_noise_factor,
+ 'lr_rampdown_length': lr_rampdown_length,
+ 'lr_rampup_length': lr_rampup_length,
+ 'noise_ramp_length': noise_ramp_length,
+ },
+ 'latent_space_options': {
+ 'project_in_wplus': project_in_wplus,
+ 'start_wavg': start_wavg,
+ 'projection_seed': projection_seed,
+ 'truncation_psi': truncation_psi,
+ },
+ 'loss_options': {
+ 'loss_paper': loss_paper,
+ 'vgg16_normed': normed,
+ 'vgg16_sqrt_normed': sqrt_normed,
+ },
+ 'elapsed_time': '',
+ 'last_commandline_status': last_status
+ }
+
+ if project_in_wplus:
+ return w_out, run_config # [num_steps, L, C]
+ return w_out.repeat([1, G.mapping.num_ws, 1]), run_config # [num_steps, 1, C] => [num_steps, L, C]
+
+
+# ----------------------------------------------------------------------------
+
+
+@click.command()
+@click.pass_context
+@click.option('--network', '-net', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', help='Config of the network, used only if you want to use one of the models that are in torch_utils.gen_utils.resume_specs', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']))
+@click.option('--target', '-t', 'target_fname', type=click.Path(exists=True, dir_okay=False), help='Target image file to project to', required=True, metavar='FILE')
+# Optimization options
+@click.option('--num-steps', '-nsteps', help='Number of optimization steps', type=click.IntRange(min=0), default=1000, show_default=True)
+@click.option('--init-lr', '-lr', 'initial_learning_rate', type=float, help='Initial learning rate of the optimization process', default=0.1, show_default=True)
+@click.option('--constant-lr', 'constant_learning_rate', is_flag=True, help='Add flag to use a constant learning rate throughout the optimization (turn off the rampup/rampdown)')
+@click.option('--reg-noise-weight', '-regw', 'regularize_noise_weight', type=float, help='Noise weight regularization', default=1e5, show_default=True)
+@click.option('--seed', type=int, help='Random seed', default=303, show_default=True)
+@click.option('--stabilize-projection', is_flag=True, help='Add flag to stabilize the latent space/anchor to w_avg, making it easier to project (only for StyleGAN3 config-r/t models)')
+# Video options
+@click.option('--save-video', '-video', is_flag=True, help='Save an mp4 video of optimization progress')
+@click.option('--compress', is_flag=True, help='Compress video with ffmpeg-python; same resolution, lower memory size')
+@click.option('--fps', type=int, help='FPS for the mp4 video of optimization progress (if saved)', default=30, show_default=True)
+# Options on which space to project to (W or W+) and where to start: the middle point of W (w_avg) or a specific seed
+@click.option('--project-in-wplus', '-wplus', is_flag=True, help='Project in the W+ latent space')
+@click.option('--start-wavg', '-wavg', type=bool, help='Start with the average W vector, ootherwise will start from a random seed (provided by user)', default=True, show_default=True)
+@click.option('--projection-seed', type=int, help='Seed to start projection from', default=None, show_default=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi to use in projection when using a projection seed', default=0.7, show_default=True)
+# Decide the loss to use when projecting (all other apart from o.g. StyleGAN2's are experimental, you can select the VGG16 features/layers to use in the im2sgan loss)
+@click.option('--loss-paper', '-loss', type=click.Choice(['sgan2', 'im2sgan', 'discriminator', 'clip']), help='Loss to use (if using "im2sgan", make sure to norm the VGG16 features)', default='sgan2', show_default=True)
+# im2sgan loss options (try with and without them, though I've found --vgg-normed to work best for me)
+@click.option('--vgg-normed', 'normed', is_flag=True, help='Add flag to norm the VGG16 features by the number of elements per layer that was used')
+@click.option('--vgg-sqrt-normed', 'sqrt_normed', is_flag=True, help='Add flag to norm the VGG16 features by the square root of the number of elements per layer that was used')
+# Extra parameters for saving the results
+@click.option('--save-every-step', '-saveall', is_flag=True, help='Save every step taken in the projection (save both the dlatent as a.npy and its respective image).')
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'projection'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Extra description to add to the experiment name', default='')
+def run_projection(
+ ctx: click.Context,
+ network_pkl: str,
+ cfg: str,
+ target_fname: str,
+ num_steps: int,
+ initial_learning_rate: float,
+ constant_learning_rate: bool,
+ regularize_noise_weight: float,
+ seed: int,
+ stabilize_projection: bool,
+ save_video: bool,
+ compress: bool,
+ fps: int,
+ project_in_wplus: bool,
+ start_wavg: bool,
+ projection_seed: int,
+ truncation_psi: float,
+ loss_paper: str,
+ normed: bool,
+ sqrt_normed: bool,
+ save_every_step: bool,
+ outdir: str,
+ description: str,
+):
+ """Project given image to the latent space of pretrained network pickle.
+
+ Examples:
+
+ \b
+ python projector.py --target=~/mytarget.png --project-in-wplus --save-video --num-steps=5000 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
+ """
+ torch.manual_seed(seed)
+
+ # If we're not starting from the W midpoint, assert the user fed a seed to start from
+ if not start_wavg:
+ if projection_seed is None:
+ ctx.fail('Provide a seed to start from if not starting from the midpoint. Use "--projection-seed" to do so')
+
+ # Load networks.
+ # If model name exists in the gen_utils.resume_specs dictionary, use it instead of the full url
+ try:
+ network_pkl = gen_utils.resume_specs[cfg][network_pkl]
+ except KeyError:
+ # Otherwise, it's a local file or an url
+ pass
+ print('Loading networks from "%s"...' % network_pkl)
+ device = torch.device('cuda')
+ with dnnlib.util.open_url(network_pkl) as fp:
+ G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device)
+ if loss_paper == 'discriminator':
+ # We must also load the Discriminator
+ with dnnlib.util.open_url(network_pkl) as fp:
+ D = legacy.load_network_pkl(fp)['D'].requires_grad_(False).to(device)
+
+ # Load target image.
+ target_pil = PIL.Image.open(target_fname).convert('RGB')
+ w, h = target_pil.size
+ s = min(w, h)
+ target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
+ target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
+
+ # Stabilize the latent space to make things easier (for StyleGAN3's config t and r models)
+ if stabilize_projection:
+ gen_utils.anchor_latent_space(G)
+
+ # Optimize projection.
+ start_time = perf_counter()
+ projected_w_steps, run_config = project(
+ G,
+ target=target_pil,
+ num_steps=num_steps,
+ initial_learning_rate=initial_learning_rate,
+ constant_learning_rate=constant_learning_rate,
+ regularize_noise_weight=regularize_noise_weight,
+ project_in_wplus=project_in_wplus,
+ start_wavg=start_wavg,
+ projection_seed=projection_seed,
+ truncation_psi=truncation_psi,
+ loss_paper=loss_paper,
+ normed=normed,
+ sqrt_normed=sqrt_normed,
+ device=device,
+ D=D if loss_paper == 'discriminator' else None
+ )
+ elapsed_time = format_time(perf_counter()-start_time)
+ print(f'\nElapsed time: {elapsed_time}')
+ run_config['elapsed_time'] = elapsed_time
+ # Make the run dir automatically
+ desc = 'projection-wplus' if project_in_wplus else 'projection-w'
+ desc = f'{desc}-wavgstart' if start_wavg else f'{desc}-seed{projection_seed}start'
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
+ desc = f'{desc}-{loss_paper}'
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'description': description,
+ 'target_image': target_fname,
+ 'outdir': run_dir,
+ 'save_video': save_video,
+ 'seed': seed,
+ 'video_fps': fps,
+ 'save_every_step': save_every_step,
+ 'run_config': run_config
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ # Render debug output: optional video and projected image and W vector.
+ result_name = os.path.join(run_dir, 'proj')
+ npy_name = os.path.join(run_dir, 'projected')
+ # If we project in W+, add to the name of the results
+ if project_in_wplus:
+ result_name, npy_name = f'{result_name}_wplus', f'{npy_name}_wplus'
+ # Either in W or W+, we can start from the W midpoint or one given by the projection seed
+ if start_wavg:
+ result_name, npy_name = f'{result_name}_wavg', f'{npy_name}_wavg'
+ else:
+ result_name, npy_name = f'{result_name}_seed-{projection_seed}', f'{npy_name}_seed-{projection_seed}'
+
+ # Save the target image
+ target_pil.save(os.path.join(run_dir, 'target.jpg'))
+
+ if save_every_step:
+ # Save every projected frame and W vector. TODO: This can be optimized to be saved as training progresses
+ n_digits = int(np.log10(num_steps)) + 1 if num_steps > 0 else 1
+ for step in tqdm(range(num_steps), desc='Saving projection results', unit='steps'):
+ w = projected_w_steps[step]
+ synth_image = gen_utils.w_to_img(G, dlatents=w, noise_mode='const')[0]
+ PIL.Image.fromarray(synth_image, 'RGB').save(f'{result_name}_step{step:0{n_digits}d}.jpg')
+ np.save(f'{npy_name}_step{step:0{n_digits}d}.npy', w.unsqueeze(0).cpu().numpy())
+ else:
+ # Save only the final projected frame and W vector.
+ print('Saving projection results...')
+ projected_w = projected_w_steps[-1]
+ synth_image = gen_utils.w_to_img(G, dlatents=projected_w, noise_mode='const')[0]
+ PIL.Image.fromarray(synth_image, 'RGB').save(f'{result_name}_final.jpg')
+ np.save(f'{npy_name}_final.npy', projected_w.unsqueeze(0).cpu().numpy())
+
+ # Save the optimization video and compress it if so desired
+ if save_video:
+ video = imageio.get_writer(f'{result_name}.mp4', mode='I', fps=fps, codec='libx264', bitrate='16M')
+ print(f'Saving optimization progress video "{result_name}.mp4"')
+ for projected_w in projected_w_steps:
+ synth_image = gen_utils.w_to_img(G, dlatents=projected_w, noise_mode='const')[0]
+ video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) # left side target, right projection
+ video.close()
+
+ if save_video and compress:
+ # Compress the video; might fail, and is a basic command that can also be better optimized
+ gen_utils.compress_video(original_video=f'{result_name}.mp4',
+ original_video_name=f'{result_name.split(os.sep)[-1]}',
+ outdir=run_dir,
+ ctx=ctx)
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == "__main__":
+ run_projection() # pylint: disable=no-value-for-parameter
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/pytorch_ssim/__init__.py b/stylegan3-fun/pytorch_ssim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..865ff65754da3efd705fd099371bcebb0044da1e
--- /dev/null
+++ b/stylegan3-fun/pytorch_ssim/__init__.py
@@ -0,0 +1,78 @@
+# Code from Evan Su/Po-Hsun-Su: https://github.com/Po-Hsun-Su/pytorch-ssim
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+from math import exp
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
+ return gauss/gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1*mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size//2, groups=channel) - mu1_mu2
+
+ C1 = 0.01**2
+ C2 = 0.03**2
+
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size = 11, size_average = True):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = 1
+ self.window = create_window(window_size, self.channel)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.window.data.type() == img1.data.type():
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ self.window = window
+ self.channel = channel
+
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
+
+
+def ssim(img1, img2, window_size=11, size_average=True):
+ (_, channel, _, _) = img1.size()
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, size_average)
diff --git a/stylegan3-fun/sightseeding.py b/stylegan3-fun/sightseeding.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd845ab48187f789253a08ed0ccdf89cb69b9afc
--- /dev/null
+++ b/stylegan3-fun/sightseeding.py
@@ -0,0 +1,199 @@
+import os
+from typing import List, Union, Tuple
+import click
+
+import dnnlib
+import legacy
+
+import torch
+
+import numpy as np
+from torch_utils import gen_utils
+
+os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
+import moviepy.editor
+
+
+# ----------------------------------------------------------------------------
+
+
+@click.command()
+@click.pass_context
+@click.option('--network', '-net', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+# Synthesis options
+@click.option('--seeds', '-s', type=gen_utils.num_range, help='List of seeds to visit in order ("a,b,c", "a-b", "a,b-c,d,e-f,a", ...', required=True)
+@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--seed-sec', '-sec', type=float, help='Number of seconds between each seed transition', default=5.0, show_default=True)
+@click.option('--interp-type', '-interp', type=click.Choice(['linear', 'spherical']), help='Type of interpolation in Z or W', default='spherical', show_default=True)
+@click.option('--interp-in-z', is_flag=True, help='Add flag to interpolate in Z instead of in W')
+# Video options
+@click.option('--smooth', is_flag=True, help='Add flag to smooth the transition between the latent vectors')
+@click.option('--fps', type=gen_utils.parse_fps, help='Video FPS.', default=30, show_default=True)
+@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file via ffmpeg-python (same resolution, lower file size)')
+# Run options
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'sightseeding'), show_default=True, metavar='DIR')
+@click.option('--desc', type=str, help='Additional description for the directory name where', default='', show_default=True)
+def sightseeding(
+ ctx: click.Context,
+ network_pkl: Union[str, os.PathLike],
+ cfg: str,
+ seeds: List[int],
+ class_idx: int,
+ truncation_psi: float,
+ new_center: Tuple[str, Union[int, np.ndarray]],
+ noise_mode: str,
+ seed_sec: float,
+ interp_type: str,
+ interp_in_z: bool,
+ smooth: bool,
+ fps: int,
+ compress: bool,
+ outdir: Union[str, os.PathLike],
+ desc: str,
+):
+ """
+ Examples:
+ # Will go from seeds 0 through 5, coming to the starting one in the end; the transition between each pair of seeds
+ taking 7.5 seconds, spherically (and smoothly) interpolating in W, compressing the final video with ffmpeg-python
+ python sightseeding.py --seeds=0-5,0 --seed-sec=7.5 --smooth --compress \
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl
+ """
+ # Sanity check:
+ if len(seeds) < 2:
+ ctx.fail('Please enter more than one seed to interpolate between!')
+
+ device = torch.device('cuda')
+
+ # Load the network
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
+
+ # Get center of the latent space (global or user-indicated)
+ if new_center is None:
+ w_avg = G.mapping.w_avg
+ else:
+ new_center, new_center_value = new_center
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
+ if isinstance(new_center_value, int):
+ w_avg = gen_utils.get_w_from_seed(G, device, new_center_value,
+ truncation_psi=1.0) # We want the pure dlatent
+ elif isinstance(new_center_value, np.ndarray):
+ w_avg = torch.from_numpy(new_center_value).to(device)
+ else:
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
+
+ # Create the run dir with the given name description
+ desc = f'{desc}-sightseeding' if len(desc) != 0 else 'sightseeding'
+ desc = f'{desc}-{interp_type}-smooth' if smooth else f'{desc}-{interp_type}'
+ desc = f'{desc}-in-Z' if interp_in_z else f'{desc}-in-W'
+ run_dir = gen_utils.make_run_dir(outdir, desc)
+
+ # Number of steps to take between each latent vector
+ n_steps = int(np.rint(seed_sec * fps))
+ # Total number of frames
+ num_frames = int(n_steps * (len(seeds) - 1))
+ # Total video length in seconds
+ duration_sec = num_frames / fps
+
+ # Labels
+ label = torch.zeros([1, G.c_dim], device=device)
+ if G.c_dim != 0:
+ if class_idx is None:
+ ctx.fail('Must specify class label with --class when using a conditional network')
+ label[:, class_idx] = 1
+ else:
+ if class_idx is not None:
+ print('warn: --class=lbl ignored when running on an unconditional network')
+
+ # Generate the random vectors from each seed
+ print('Generating Z vectors...')
+ all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim).astype(np.float32) for seed in seeds])
+ # If user wants to interpolate in Z
+ if interp_in_z:
+ print(f'Interpolating in Z...(interpolation type: {interp_type})')
+ src_z = np.empty([0] + list(all_z.shape[1:]), dtype=np.float32)
+ for i in range(len(all_z) - 1):
+ # We interpolate between each pair of latents
+ interp = gen_utils.interpolate(all_z[i], all_z[i + 1], n_steps, interp_type, smooth)
+ # Append it to our source
+ src_z = np.append(src_z, interp, axis=0)
+ # Convert to dlatent vectors
+ print('Generating W vectors...')
+ src_w = G.mapping(torch.from_numpy(src_z).to(device), label)
+
+ # Otherwise, interpolation is done in W
+ else:
+ print(f'Interpolating in W... (interpolation type: {interp_type})')
+ print('Generating W vectors...')
+ all_w = G.mapping(torch.from_numpy(all_z).to(device), label).cpu()
+ src_w = np.empty([0] + list(all_w.shape[1:]), dtype=np.float32)
+ for i in range(len(all_w) - 1):
+ # We interpolate between each pair of dlatents
+ interp = gen_utils.interpolate(all_w[i], all_w[i + 1], n_steps, interp_type, smooth)
+ # Append it to our source
+ src_w = np.append(src_w, interp, axis=0)
+ src_w = torch.from_numpy(src_w).to(device)
+
+ # Do the truncation trick
+ src_w = w_avg + (src_w - w_avg) * truncation_psi
+
+ # Auxiliary function for moviepy
+ def make_frame(t):
+ frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
+ w = src_w[frame_idx].unsqueeze(0) # [18, 512] -> [1, 18, 512]
+ image = gen_utils.w_to_img(G, w, noise_mode)
+ # Generate the grid for this timestamp
+ grid = gen_utils.create_image_grid(image, (1, 1))
+ # grayscale => RGB
+ if grid.shape[2] == 1:
+ grid = grid.repeat(3, 2)
+ return grid
+
+ # Generate video using make_frame
+ print('Generating sightseeding video...')
+ videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
+ videoclip.set_duration(duration_sec)
+ mp4_name = '-'.join(map(str, seeds)) # Make it clear by the file name what is the path taken
+ mp4_name = f'{mp4_name}-sightseeding' if len(mp4_name) < 50 else 'sightseeding' # arbitrary rule of mine
+
+ # Set the video parameters (change if you like)
+ final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
+ videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')
+
+ # Save the configuration used for the experiment
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'config': cfg,
+ 'seeds': seeds,
+ 'class_idx': class_idx,
+ 'truncation_psi': truncation_psi,
+ 'noise_mode': noise_mode,
+ 'seed_sec': seed_sec,
+ 'duration_sec': duration_sec,
+ 'interp_type': interp_type,
+ 'interp_in_z': interp_in_z,
+ 'smooth_video': smooth,
+ 'video_fps': fps,
+ 'compress': compress,
+ 'run_dir': run_dir,
+ 'description': desc,
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ # Compress the video (lower file size, same resolution)
+ if compress:
+ gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)
+
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == '__main__':
+ sightseeding()
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/style_mixing.py b/stylegan3-fun/style_mixing.py
new file mode 100644
index 0000000000000000000000000000000000000000..82e7b6aaea855bd10e427eca2b1cd596fa25d497
--- /dev/null
+++ b/stylegan3-fun/style_mixing.py
@@ -0,0 +1,429 @@
+import os
+from typing import List, Union, Optional, Tuple
+import click
+
+import dnnlib
+from torch_utils import gen_utils
+
+import numpy as np
+import PIL.Image
+import scipy
+import torch
+
+import legacy
+
+os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide'
+import moviepy.editor
+
+# ----------------------------------------------------------------------------
+
+
+# TODO: this is no longer true for StyleGAN3, we have 14 layers irrespective of resolution
+def parse_styles(s: str) -> List[int]:
+ """
+ Helper function for parsing style layers. s will be a comma-separated list of values, and these can be
+ either ranges ('a-b'), ints ('a', 'b', 'c', ...), or the style layer names ('coarse', 'middle', 'fine').
+
+ A combination of these can also be used. For example, if the user wishes to mix the 'coarse' and 'fine'
+ layers, then the input can be: 'coarse,fine'. If just the 'middle' and '14-17' layers are to be used,
+ then 'middle,14-17' or '14-17,middle' can be the used as input.
+
+ The repeated styles will be deleted, as these won't add anything to our final result.
+ """
+ style_layers_dict = {'coarse': list(range(0, 4)), 'middle': list(range(4, 8)), 'fine': list(range(8, 18))}
+ str_list = s.split(',')
+ nums = []
+ for el in str_list:
+ if el in style_layers_dict:
+ nums.extend(style_layers_dict[el])
+ else:
+ nums.extend(gen_utils.num_range(el, remove_repeated=True))
+ # Sanity check: delete repeating numbers and limit values between 0 and 17
+ nums = list(set([max(min(x, 17), 0) for x in nums]))
+ return nums
+
+
+# TODO: For StyleGAN3, there's only 'coarse' and 'fine' groups, though the boundary is not 100% clear
+def style_names(max_style: int, file_name: str, desc: str, col_styles: List[int]) -> Tuple[str, str]:
+ """
+ Add the styles if they are being used (from the StyleGAN paper)
+ to both the file name and the new directory to be created.
+ """
+ if list(range(0, 4)) == col_styles:
+ styles = 'coarse_styles'
+ elif list(range(4, 8)) == col_styles:
+ styles = 'middle_styles'
+ elif list(range(8, max_style)) == col_styles:
+ styles = 'fine_styles'
+ elif list(range(0, 8)) == col_styles:
+ styles = 'coarse+middle_styles'
+ elif list(range(4, max_style)) == col_styles:
+ styles = 'middle+fine_styles'
+ elif list(range(0, 4)) + list(range(8, max_style)) == col_styles:
+ styles = 'coarse+fine_styles'
+ else:
+ styles = 'custom_styles'
+
+ file_name = f'{file_name}-{styles}'
+ desc = f'{desc}-{styles}'
+
+ return file_name, desc
+
+
+def _parse_cols(s: str, G, device: torch.device, truncation_psi: float) -> List[torch.Tensor]:
+ """s can be a path to a npy/npz file or a seed number (int)"""
+ s = s.split(',')
+ w = torch.Tensor().to(device)
+ for el in s:
+ if os.path.isfile(el):
+ w_el = gen_utils.get_latent_from_file(el) # np.ndarray
+ w_el = torch.from_numpy(w_el).to(device) # torch.tensor
+ w = torch.cat((w_el, w))
+ else:
+ nums = gen_utils.num_range(el, remove_repeated=True)
+ for n in nums:
+ w = torch.cat((gen_utils.get_w_from_seed(G, device, n, truncation_psi), w))
+ return w
+
+
+# ----------------------------------------------------------------------------
+
+
+# We group the different types of style-mixing (grid and video) into a main function
+@click.group()
+def main():
+ pass
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='grid')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+@click.option('--device', help='Device to use for image generation; using the CPU is slower than the GPU', type=click.Choice(['cpu', 'cuda']), default='cuda', show_default=True)
+# Synthesis options
+@click.option('--row-seeds', '-rows', 'row_seeds', type=gen_utils.num_range, help='Random seeds to use for image rows', required=True)
+@click.option('--col-seeds', '-cols', 'col_seeds', type=gen_utils.num_range, help='Random seeds to use for image columns', required=True)
+@click.option('--styles', 'col_styles', type=parse_styles, help='Style layers to use; can pass "coarse", "middle", "fine", or a list or range of ints', default='0-6', show_default=True)
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
+@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'images'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='', show_default=True)
+def generate_style_mix(
+ ctx: click.Context,
+ network_pkl: str,
+ cfg: Optional[str],
+ device: Optional[str],
+ row_seeds: List[int],
+ col_seeds: List[int],
+ col_styles: List[int],
+ truncation_psi: float,
+ noise_mode: str,
+ anchor_latent_space: bool,
+ outdir: str,
+ description: str,
+):
+ """Generate style-mixing images using pretrained network pickle.
+
+ Examples:
+
+ \b
+ python style_mixing.py grid --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+ """
+ # TODO: add class_idx
+ device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu')
+
+ # Load the network
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
+
+ # Setup for using CPU
+ if device.type == 'cpu':
+ gen_utils.use_cpu(G)
+
+ # Stabilize/anchor the latent space
+ if anchor_latent_space:
+ gen_utils.anchor_latent_space(G)
+
+ # Sanity check: loaded model and selected styles must be compatible
+ max_style = G.mapping.num_ws
+ if max(col_styles) > max_style:
+ click.secho(f'WARNING: Maximum col-style allowed: {max_style - 1} for loaded network "{network_pkl}" '
+ f'of resolution {G.img_resolution}x{G.img_resolution}', fg='red')
+ click.secho('Removing col-styles exceeding this value...', fg='blue')
+ col_styles[:] = [style for style in col_styles if style < max_style]
+
+ print('Generating W vectors...')
+ all_seeds = list(set(row_seeds + col_seeds)) # TODO: change this in order to use _parse_cols
+ all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
+ all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
+ w_avg = G.mapping.w_avg
+ all_w = w_avg + (all_w - w_avg) * truncation_psi
+ w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
+
+ print('Generating images...')
+ all_images = gen_utils.w_to_img(G, all_w, noise_mode)
+ image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
+
+ print('Generating style-mixed images...')
+ for row_seed in row_seeds:
+ for col_seed in col_seeds:
+ w = w_dict[row_seed].clone()
+ w[col_styles] = w_dict[col_seed][col_styles]
+ image = gen_utils.w_to_img(G, w, noise_mode)[0]
+ image_dict[(row_seed, col_seed)] = image
+
+ # Name of grid and run directory
+ grid_name = 'grid'
+ description = 'stylemix-grid' if len(description) == 0 else description
+ # Add to the name the styles (from the StyleGAN paper) if they are being used
+ grid_name, description = style_names(max_style, grid_name, description, col_styles)
+ # Create the run dir with the given name description
+ run_dir = gen_utils.make_run_dir(outdir, description)
+
+ print('Saving image grid...')
+ W = G.img_resolution
+ H = G.img_resolution
+ canvas = PIL.Image.new(gen_utils.channels_dict[G.synthesis.img_channels], # Handle RGBA case
+ (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
+ for row_idx, row_seed in enumerate([0] + row_seeds):
+ for col_idx, col_seed in enumerate([0] + col_seeds):
+ if row_idx == 0 and col_idx == 0:
+ continue
+ key = (row_seed, col_seed)
+ if row_idx == 0:
+ key = (col_seed, col_seed)
+ if col_idx == 0:
+ key = (row_seed, row_seed)
+ canvas.paste(PIL.Image.fromarray(image_dict[key],
+ gen_utils.channels_dict[G.synthesis.img_channels]),
+ (W * col_idx, H * row_idx))
+ canvas.save(os.path.join(run_dir, f'{grid_name}.png'))
+
+ print('Saving individual images...')
+ for (row_seed, col_seed), image in image_dict.items():
+ PIL.Image.fromarray(image,
+ gen_utils.channels_dict[G.synthesis.img_channels]).save(os.path.join(run_dir, f'{row_seed}-{col_seed}.png'))
+
+ # Save the configuration used
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'row_seeds': row_seeds,
+ 'col_seeds': col_seeds,
+ 'col_styles': col_styles,
+ 'truncation_psi': truncation_psi,
+ 'noise_mode': noise_mode,
+ 'run_dir': run_dir,
+ 'description': description,
+ }
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+
+# ----------------------------------------------------------------------------
+
+
+@main.command(name='video')
+@click.pass_context
+@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
+@click.option('--cfg', type=click.Choice(['stylegan2', 'stylegan3-t', 'stylegan3-r']), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
+# Synthesis options
+@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
+@click.option('--noise-mode', type=click.Choice(['const', 'random', 'none']), help='Noise mode', default='const', show_default=True)
+@click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
+@click.option('--row-seed', '-row', 'row_seed', type=int, help='Random seed to use for video row', required=True)
+@click.option('--columns', '-cols', 'columns', type=str, help='Path to dlatents (.npy/.npz) or seeds to use ("a", "b-c", "e,f-g,h,i", etc.), or a combination of both', required=True)
+@click.option('--styles', 'col_styles', type=parse_styles, help='Style layers to use; can pass "coarse", "middle", "fine", or a list or range of ints', default='0-6', show_default=True)
+@click.option('--only-stylemix', is_flag=True, help='Add flag to only show the style-mixed images in the video')
+# Video options
+@click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file via ffmpeg-python (same resolution, lower file size)')
+@click.option('--duration-sec', type=float, help='Duration of the video in seconds', default=30, show_default=True)
+@click.option('--fps', type=click.IntRange(min=1), help='Video FPS.', default=30, show_default=True)
+# Extra parameters for saving the results
+@click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'video'), show_default=True, metavar='DIR')
+@click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='', show_default=True)
+def random_stylemix_video(
+ ctx: click.Context,
+ network_pkl: str,
+ cfg: Optional[str],
+ row_seed: int,
+ columns: str,
+ col_styles: List[int],
+ only_stylemix: bool,
+ compress: bool,
+ truncation_psi: float,
+ noise_mode: str,
+ anchor_latent_space: bool,
+ fps: int,
+ duration_sec: float,
+ outdir: Union[str, os.PathLike],
+ description: str,
+ smoothing_sec: Optional[float] = 3.0 # for Gaussian blur; won't be a parameter, change at own risk
+):
+ """Generate random style-mixing video using pretrained network pickle.
+
+ Examples:
+
+ \b
+ python style_mixing.py video --row=85 --cols=55,821,1789 --fps=60 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+
+ \b
+ python style_mixing.py video --row=0 --cols=7-10 --styles=fine,1,3,5-7 --duration-sec=60 \\
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
+ """
+ # TODO: add class_idx
+ # Calculate number of frames
+ num_frames = int(np.rint(duration_sec * fps))
+
+ device = torch.device('cuda')
+
+ # Load the network
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
+
+ # Stabilize/anchor the latent space
+ if anchor_latent_space:
+ gen_utils.anchor_latent_space(G)
+
+ # Get the average dlatent
+ w_avg = G.mapping.w_avg
+
+ # Sanity check: loaded model and selected styles must be compatible
+ max_style = G.mapping.num_ws
+ if max(col_styles) > max_style:
+ click.secho(f'WARNING: Maximum col-style allowed: {max_style - 1} for loaded network "{network_pkl}" '
+ f'of resolution {G.img_resolution}x{G.img_resolution}', fg='red')
+ click.secho('Removing col-styles exceeding this value...', fg='blue')
+ col_styles[:] = [style for style in col_styles if style < max_style]
+
+ # First column (video) latents
+ print('Generating source W vectors...')
+ src_shape = [num_frames, G.z_dim]
+ src_z = np.random.RandomState(row_seed).randn(*src_shape).astype(np.float32)
+ src_z = scipy.ndimage.gaussian_filter(src_z, sigma=[smoothing_sec * fps, 0], mode='wrap') # wrap to form a loop
+ src_z /= np.sqrt(np.mean(np.square(src_z))) # normalize
+
+ # Map to W and do truncation trick
+ src_w = G.mapping(torch.from_numpy(src_z).to(device), None)
+ src_w = w_avg + (src_w - w_avg) * truncation_psi
+
+ # First row (images) latents
+ dst_w = _parse_cols(columns, G, device, truncation_psi)
+ # dst_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in col_seeds])
+ # dst_w = G.mapping(torch.from_numpy(dst_z).to(device), None)
+ # dst_w = w_avg + (dst_w - w_avg) * truncation_psi
+
+ # Width and height of the generated image
+ W = G.img_resolution
+ H = G.img_resolution
+
+ # Video name
+ mp4_name = f'{len(dst_w)}x1'
+ # Run dir name
+ description = 'stylemix-video' if len(description) == 0 else description
+ # Add to the name the styles (from the StyleGAN paper) if they are being used to both file and run dir
+ mp4_name, description = style_names(max_style, mp4_name, description, col_styles)
+ # Create the run dir with the description
+ run_dir = gen_utils.make_run_dir(outdir, description)
+
+ # If user wishes to only show the style-transferred images (nice for 1x1 case)
+ if only_stylemix:
+ print('Generating style-mixing video (saving only the style-transferred images)...')
+ # We generate a canvas where we will paste all the generated images
+ canvas = PIL.Image.new('RGB', (W * len(dst_w), H * len([row_seed])), 'black') # use any color you want
+
+ def make_frame(t):
+ # Get the frame number according to time t
+ frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
+ # For each of the column images
+ for col, _ in enumerate(dst_w):
+ # Select the pertinent latent w column
+ w_col = dst_w[col].unsqueeze(0) # [18, 512] -> [1, 18, 512]
+ # Replace the values defined by col_styles
+ w_col[:, col_styles] = src_w[frame_idx, col_styles]
+ # Generate the style-mixed images
+ col_images = gen_utils.w_to_img(G, w_col, noise_mode)
+ # Paste them in their respective spot in the grid
+ for row, image in enumerate(list(col_images)):
+ canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * H, row * W))
+
+ return np.array(canvas)
+
+ mp4_name = f'{mp4_name}-only-stylemix'
+ else:
+ print('Generating style-mixing video (saving the whole grid)...')
+ # Generate an empty canvas where we will paste all the generated images
+ canvas = PIL.Image.new('RGB', (W * (len(dst_w) + 1), H * (len([row_seed]) + 1)), 'black')
+
+ # Generate all destination images (first row; static images)
+ dst_images = gen_utils.w_to_img(G, dst_w, noise_mode)
+ # Paste them in the canvas
+ for col, dst_image in enumerate(list(dst_images)):
+ canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), ((col + 1) * H, 0))
+
+ def make_frame(t):
+ # Get the frame number according to time t
+ frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
+ # Get the image at this frame (first column; video)
+ src_image = gen_utils.w_to_img(G, src_w[frame_idx], noise_mode)[0]
+ # Paste it to the lower left
+ canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), (0, H))
+
+ # For each of the column images (destination images)
+ for col, _ in enumerate(list(dst_images)):
+ # Select pertinent latent w column
+ w_col = dst_w[col].unsqueeze(0) # [18, 512] -> [1, 18, 512]
+ # Replace the values defined by col_styles
+ w_col[:, col_styles] = src_w[frame_idx, col_styles]
+ # Generate these style-mixed images
+ col_images = gen_utils.w_to_img(G, w_col, noise_mode)
+ # Paste them in their respective spot in the grid
+ for row, image in enumerate(list(col_images)):
+ canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * H, (row + 1) * W))
+
+ return np.array(canvas)
+
+ mp4_name = f'{mp4_name}-style-mixing'
+
+ # Generate video using the respective make_frame function
+ videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
+ videoclip.set_duration(duration_sec)
+
+ # Change the video parameters (codec, bitrate) if you so desire
+ final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
+ videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')
+
+ # Save the configuration used for the experiment
+ ctx.obj = {
+ 'network_pkl': network_pkl,
+ 'row_seed': row_seed,
+ 'columns': columns,
+ 'col_styles': col_styles,
+ 'only_stylemix': only_stylemix,
+ 'compress': compress,
+ 'truncation_psi': truncation_psi,
+ 'noise_mode': noise_mode,
+ 'duration_sec': duration_sec,
+ 'video_fps': fps,
+ 'run_dir': run_dir,
+ 'description': description,
+ }
+ # Save the run configuration
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
+
+ # Compress the video (smaller file size, same resolution; not guaranteed though)
+ if compress:
+ gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)
+
+
+# ----------------------------------------------------------------------------
+
+
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
+
+
+# ----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/__init__.py b/stylegan3-fun/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..939e7c6c8f94c4ea1141885c3c3295fe083b06aa
--- /dev/null
+++ b/stylegan3-fun/torch_utils/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-311.pyc b/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e57162ae2ef506cbdd0cefb13d7fd35c12d0d7f
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-38.pyc b/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2abeb3ff166cf2c5895ceb9b235a1dab5a4d55c
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-39.pyc b/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85bed4d85963a4e4657624348b99d9c3cde0ee7c
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-311.pyc b/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea303bb3f7fab9265f8af0864382ad8fbdf16a7d
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-38.pyc b/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4dcb01ff4a635119795b3a47e9305eaeb07b5bfe
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-39.pyc b/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32df4fab25c6933a0b634218226893a9f3bfcf89
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/custom_ops.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-311.pyc b/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf7a3d96ada638a9744ad7c4d48b4a3d0a5d7dc7
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-38.pyc b/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..350925607ed402b945fa1f3488c3e5d67fefef8d
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-39.pyc b/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1d2ef2c7c10229b4ad784adcbae81ec00368bfe
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/gen_utils.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/misc.cpython-311.pyc b/stylegan3-fun/torch_utils/__pycache__/misc.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4696ff22dfb7e5b08db8d10ca729bfd54acae0c
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/misc.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/misc.cpython-38.pyc b/stylegan3-fun/torch_utils/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..687f33cfb5cc5ad57561494dc6bf3310166384cf
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/misc.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/misc.cpython-39.pyc b/stylegan3-fun/torch_utils/__pycache__/misc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f29e425f174973b56ad990acbc9ca4159c32a7bb
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/misc.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-311.pyc b/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a82cbf74e981d77f82b45ed6558ce96debefdad1
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-38.pyc b/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1817d1f5df03758ffc7e5fbb970c7e9ce9eda2c9
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-39.pyc b/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..695c7db07b74c815789e4f657488ff2e8d74e9dd
Binary files /dev/null and b/stylegan3-fun/torch_utils/__pycache__/persistence.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/custom_ops.py b/stylegan3-fun/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..439e445b16da7ac985f7a1f2053e665385d47e87
--- /dev/null
+++ b/stylegan3-fun/torch_utils/custom_ops.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import glob
+import hashlib
+import importlib
+import os
+import re
+import shutil
+import uuid
+
+import torch
+import torch.utils.cpp_extension
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files*/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+
+def _get_mangled_gpu_name():
+ name = torch.cuda.get_device_name().lower()
+ out = []
+ for c in name:
+ if re.match('[a-z0-9_-]+', c):
+ out.append(c)
+ else:
+ out.append('-')
+ return ''.join(out)
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+ if headers is None:
+ headers = []
+ if source_dir is not None:
+ sources = [os.path.join(source_dir, fname) for fname in sources]
+ headers = [os.path.join(source_dir, fname) for fname in headers]
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+ verbose_build = (verbosity == 'full')
+
+ # Compile and load.
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
+ # break the build or unnecessarily restrict what's available to nvcc.
+ # Unset it to let nvcc decide based on what's available on the
+ # machine.
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ #
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
+ # around the *.cu dependency bug in ninja config.
+ #
+ all_source_files = sorted(sources + headers)
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
+
+ # Compute combined hash digest for all source files.
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+
+ # Select cached build directory name.
+ source_digest = hash_md5.hexdigest()
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
+
+ if not os.path.isdir(cached_build_dir):
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
+ os.makedirs(tmpdir)
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
+ try:
+ os.replace(tmpdir, cached_build_dir) # atomic
+ except OSError:
+ # source directory already exists, delete tmpdir and its contents.
+ shutil.rmtree(tmpdir)
+ if not os.path.isdir(cached_build_dir): raise
+
+ # Compile.
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+
+ # Load.
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache dict.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/gen_utils.py b/stylegan3-fun/torch_utils/gen_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..83f8bbaefeae2ac67defacd97401f0d50662395a
--- /dev/null
+++ b/stylegan3-fun/torch_utils/gen_utils.py
@@ -0,0 +1,729 @@
+import os
+import re
+import json
+
+from typing import List, Tuple, Union, Optional, Type
+from collections import OrderedDict
+from locale import atof
+
+import click
+import numpy as np
+import torch
+
+import dnnlib
+import legacy
+
+
+# ----------------------------------------------------------------------------
+
+
+channels_dict = {1: 'L', 3: 'RGB', 4: 'RGBA'}
+
+
+# ----------------------------------------------------------------------------
+
+
+available_cfgs = ['stylegan2', 'stylegan2-ext', 'stylegan3-t', 'stylegan3-r']
+
+
+# ----------------------------------------------------------------------------
+
+
+def create_image_grid(images: np.ndarray, grid_size: Optional[Tuple[int, int]] = None):
+ """
+ Create a grid with the fed images
+ Args:
+ images (np.array): array of images
+ grid_size (tuple(int)): size of grid (grid_width, grid_height)
+ Returns:
+ grid (np.array): image grid of size grid_size
+ """
+ # Sanity check
+ assert images.ndim == 3 or images.ndim == 4, f'Images has {images.ndim} dimensions (shape: {images.shape})!'
+ num, img_h, img_w, c = images.shape
+ # If user specifies the grid shape, use it
+ if grid_size is not None:
+ grid_w, grid_h = tuple(grid_size)
+ # If one of the sides is None, then we must infer it (this was divine inspiration)
+ if grid_w is None:
+ grid_w = num // grid_h + min(num % grid_h, 1)
+ elif grid_h is None:
+ grid_h = num // grid_w + min(num % grid_w, 1)
+
+ # Otherwise, we can infer it by the number of images (priority is given to grid_w)
+ else:
+ grid_w = max(int(np.ceil(np.sqrt(num))), 1)
+ grid_h = max((num - 1) // grid_w + 1, 1)
+
+ # Sanity check
+ assert grid_w * grid_h >= num, 'Number of rows and columns in the grid must be greater than the number of images!'
+ # Get the grid
+ grid = np.zeros([grid_h * img_h, grid_w * img_h] + list(images.shape[-1:]), dtype=images.dtype)
+ # Paste each image in the grid
+ for idx in range(num):
+ x = (idx % grid_w) * img_w
+ y = (idx // grid_w) * img_h
+ grid[y:y + img_h, x:x + img_w, ...] = images[idx]
+ return grid
+
+
+# ----------------------------------------------------------------------------
+
+
+def parse_fps(fps: Union[str, int]) -> int:
+ """Return FPS for the video; at worst, video will be 1 FPS, but no lower.
+ Useful if we don't have click, else simply use click.IntRange(min=1)"""
+ if isinstance(fps, int):
+ return max(fps, 1)
+ try:
+ fps = int(atof(fps))
+ return max(fps, 1)
+ except ValueError:
+ print(f'Typo in "--fps={fps}", will use default value of 30')
+ return 30
+
+
+def num_range(s: str, remove_repeated: bool = False) -> List[int]:
+ """
+ Extended helper function from the original (original is contained here).
+ Accept a comma separated list of numbers 'a,b,c', a range 'a-c', or a combination
+ of both 'a,b-c', 'a-b,c', 'a,b-c,d,e-f,...', and return as a list of ints.
+ """
+ nums = []
+ range_re = re.compile(r'^(\d+)-(\d+)$')
+ for el in s.split(','):
+ match = range_re.match(el)
+ if match:
+ # Sanity check 1: accept ranges 'a-b' or 'b-a', with a<=b
+ lower, upper = int(match.group(1)), int(match.group(2))
+ if lower <= upper:
+ r = list(range(lower, upper + 1))
+ else:
+ r = list(range(upper, lower + 1))
+ # We will extend nums as r is also a list
+ nums.extend(r)
+ else:
+ # It's a single number, so just append it (if it's an int)
+ try:
+ nums.append(int(atof(el)))
+ except ValueError:
+ continue # we ignore bad values
+ # Sanity check 2: delete repeating numbers by default, but keep order given by user
+ if remove_repeated:
+ nums = list(OrderedDict.fromkeys(nums))
+ return nums
+
+
+def float_list(s: str) -> List[float]:
+ """
+ Helper function for parsing a string of comma-separated floats and returning each float
+ """
+ str_list = s.split(',')
+ nums = []
+ float_re = re.compile(r'^(\d+.\d+)$')
+ for el in str_list:
+ match = float_re.match(el)
+ if match:
+ nums.append(float(match.group(1)))
+ else:
+ try:
+ nums.append(float(el))
+ except ValueError:
+ continue # Ignore bad values
+
+ return nums
+
+
+def parse_slowdown(slowdown: Union[str, int]) -> int:
+ """Function to parse the 'slowdown' parameter by the user. Will approximate to the nearest power of 2."""
+ # TODO: slowdown should be any int
+ if not isinstance(slowdown, int):
+ try:
+ slowdown = atof(slowdown)
+ except ValueError:
+ print(f'Typo in "{slowdown}"; will use default value of 1')
+ slowdown = 1
+ assert slowdown > 0, '"slowdown" cannot be negative or 0!'
+ # Let's approximate slowdown to the closest power of 2 (nothing happens if it's already a power of 2)
+ slowdown = 2**int(np.rint(np.log2(slowdown)))
+ return max(slowdown, 1) # Guard against 0.5, 0.25, ... cases
+
+
+def parse_new_center(s: str) -> Tuple[str, Union[int, Tuple[np.ndarray, Optional[str]]]]:
+ """Get a new center for the W latent space (a seed or projected dlatent; to be transformed later)"""
+ try:
+ new_center = int(s) # it's a seed
+ return s, new_center
+ except ValueError:
+ new_center = get_latent_from_file(s, return_ext=False) # it's a projected dlatent
+ return s, new_center
+
+
+def parse_all_projected_dlatents(s: str) -> List[torch.Tensor]:
+ """Get all the dlatents (.npy/.npz files) in a given directory"""
+ # Get all the files in the directory and subdirectories
+ files = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(s)) for f in fn]
+ # Filter only the .npy or .npz files
+ files = [f for f in files if f.endswith('.npy') or f.endswith('.npz')]
+ # Sort them by name, but only according to the last digits in the name (in case there's an error before)
+ files = sorted(files, key=lambda x: int(''.join(filter(str.isdigit, x))))
+ # Get the full path
+ # files = [os.path.join(s, f) for f in files]
+ # Get the dlatents
+ dlatents = [get_latent_from_file(f, return_ext=False) for f in files]
+
+ return dlatents
+
+
+def load_network(name: str, network_pkl: Union[str, os.PathLike], cfg: Optional[str], device: torch.device):
+ """Load and return the discriminator D from a trained network."""
+ # Define the model
+ if cfg is not None:
+ assert network_pkl in resume_specs[cfg], f'This model is not available for config {cfg}!'
+ network_pkl = resume_specs[cfg][network_pkl]
+ print(f'Loading networks from "{network_pkl}"...')
+ with dnnlib.util.open_url(network_pkl) as f:
+ net = legacy.load_network_pkl(f)[name].eval().requires_grad_(False).to(device) # type: ignore
+
+ return net
+
+
+def parse_class(G, class_idx: int, ctx: click.Context) -> Union[int, Type[None]]:
+ """Parse the class_idx and return it, if it's allowed by the conditional model G"""
+ if G.c_dim == 0:
+ # Unconditional model
+ return None
+ # Conditional model, so class must be specified by user
+ if class_idx is None:
+ ctx.fail('Must specify class label with --class when using a conditional network!')
+ if class_idx not in range(G.c_dim):
+ ctx.fail(f'Your class label can be at most {G.c_dim - 1}!')
+ print(f'Using class {class_idx} (available labels: range({G.c_dim - 1})...)')
+ return class_idx
+
+
+# ----------------------------------------------------------------------------
+
+
+def save_video_from_images(run_dir: str,
+ image_names: str,
+ video_name: str,
+ fps: int = 30,
+ reverse_video: bool = True,
+ crf: int = 20,
+ pix_fmt: str = 'yuv420p') -> None:
+ """ Save a .mp4 video from the images in the run_dir directory; the video can also be saved in reverse """
+ print('Saving video...')
+ try:
+ import ffmpeg
+ except ImportError:
+ raise ImportError('ffmpeg-python not found! Install it via "pip install ffmpeg-python"')
+ # Get the ffmpeg command for the current OS (not tested in MacOS!)
+ if os.name == 'nt':
+ ffmpeg_command = r'C:\\Ffmpeg\\bin\\ffmpeg.exe'
+ else:
+ # Get where is the ffmpeg command via `whereis ffmpeg` in the terminal
+ ffmpeg_command = os.popen('whereis ffmpeg').read().split(' ')[1:]
+ # Remove any ffprobe and ffplay commands
+ ffmpeg_command = [c for c in ffmpeg_command if 'ffprobe' not in c and 'ffplay' not in c]
+ # If there are more, just select the first one and remove the newline character
+ ffmpeg_command = ffmpeg_command[0].replace('\n', '')
+
+ stream = ffmpeg.input(os.path.join(run_dir, image_names), framerate=fps)
+ stream = ffmpeg.output(stream, os.path.join(run_dir, f'{video_name}.mp4'), crf=crf, pix_fmt=pix_fmt)
+ ffmpeg.run(stream, capture_stdout=True, capture_stderr=True, cmd=ffmpeg_command)
+
+ # Save the reversed video apart from the original one, so the user can compare both
+ if reverse_video:
+ stream = ffmpeg.input(os.path.join(run_dir, f'{video_name}.mp4'))
+ stream = stream.video.filter('reverse')
+ stream = ffmpeg.output(stream, os.path.join(run_dir, f'{video_name}_reversed.mp4'), crf=crf, pix_fmt=pix_fmt)
+ ffmpeg.run(stream, capture_stdout=True, capture_stderr=True) # ibidem
+
+
+def compress_video(
+ original_video: Union[str, os.PathLike],
+ original_video_name: Union[str, os.PathLike],
+ outdir: Union[str, os.PathLike],
+ ctx: click.Context) -> None:
+ """ Helper function to compress the original_video using ffmpeg-python. moviepy creates huge videos, so use
+ ffmpeg to 'compress' it (won't be perfect, 'compression' will depend on the video dimensions). ffmpeg
+ can also be used to e.g. resize the video, make a GIF, save all frames in the video to the outdir, etc.
+ """
+ try:
+ import ffmpeg
+ except (ModuleNotFoundError, ImportError):
+ ctx.fail('Missing ffmpeg! Install it via "pip install ffmpeg-python"')
+
+ print('Compressing the video...')
+ resized_video_name = os.path.join(outdir, f'{original_video_name}-compressed.mp4')
+ ffmpeg.input(original_video).output(resized_video_name).run(capture_stdout=True, capture_stderr=True)
+ print('Success!')
+
+
+# ----------------------------------------------------------------------------
+
+
+def interpolation_checks(
+ t: Union[float, np.ndarray],
+ v0: np.ndarray,
+ v1: np.ndarray) -> Tuple[Union[float, np.ndarray], np.ndarray, np.ndarray]:
+ """Tests for the interpolation functions"""
+ # Make sure 0.0<=t<=1.0
+ assert np.min(t) >= 0.0 and np.max(t) <= 1.0
+ # Guard against v0 and v1 not being NumPy arrays
+ if not isinstance(v0, np.ndarray):
+ v0 = np.array(v0)
+ if not isinstance(v1, np.ndarray):
+ v1 = np.array(v1)
+ # Both should have the same shape in order to interpolate between them
+ assert v0.shape == v1.shape, f'Incompatible shapes! v0: {v0.shape}, v1: {v1.shape}'
+ return t, v0, v1
+
+
+def lerp(
+ t: Union[float, np.ndarray],
+ v0: Union[float, list, tuple, np.ndarray],
+ v1: Union[float, list, tuple, np.ndarray]) -> np.ndarray:
+ """
+ Linear interpolation between v0 (starting) and v1 (final) vectors; for optimal results,
+ use t as an np.ndarray to return all results at once via broadcasting
+ """
+ t, v0, v1 = interpolation_checks(t, v0, v1)
+ v2 = (1.0 - t) * v0 + t * v1
+ return v2
+
+
+def slerp(
+ t: Union[float, np.ndarray],
+ v0: Union[float, list, tuple, np.ndarray],
+ v1: Union[float, list, tuple, np.ndarray],
+ dot_threshold: float = 0.9995) -> np.ndarray:
+ """
+ Spherical linear interpolation between v0 (starting) and v1 (final) vectors; for optimal
+ results, use t as an np.ndarray to return all results at once via broadcasting.
+
+ dot_threshold is the threshold for considering if the two vectors are collinear (not recommended to alter).
+
+ Adapted from the Python code at: https://en.wikipedia.org/wiki/Slerp (at the time, now no longer available).
+ Most likely taken from Jonathan Blow's code in C++:
+ http://number-none.com/product/Understanding%20Slerp,%20Then%20Not%20Using%20It
+ """
+ t, v0, v1 = interpolation_checks(t, v0, v1)
+ # Copy vectors to reuse them later
+ v0_copy = np.copy(v0)
+ v1_copy = np.copy(v1)
+ # Normalize the vectors to get the directions and angles
+ v0 = v0 / np.linalg.norm(v0)
+ v1 = v1 / np.linalg.norm(v1)
+ # Dot product with the normalized vectors (can't always use np.dot, so we use the definition)
+ dot = np.sum(v0 * v1)
+ # If it's ~1, vectors are ~colineal, so use lerp on the original vectors
+ if np.abs(dot) > dot_threshold:
+ return lerp(t, v0_copy, v1_copy)
+ # Stay within domain of arccos
+ dot = np.clip(dot, -1.0, 1.0)
+ # Calculate initial angle between v0 and v1
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ # Divide the angle into t steps
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ # Finish the slerp algorithm
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0_copy + s1 * v1_copy
+ return v2
+
+
+def interpolate(
+ v0: Union[float, list, tuple, np.ndarray],
+ v1: Union[float, list, tuple, np.ndarray],
+ n_steps: int,
+ interp_type: str = 'spherical',
+ smooth: bool = False) -> np.ndarray:
+ """
+ Interpolation function between two vectors, v0 and v1. We will either do a 'linear' or 'spherical' interpolation,
+ taking n_steps. The steps can be 'smooth'-ed out, so that the transition between vectors isn't too drastic.
+ """
+ t_array = np.linspace(0, 1, num=n_steps, endpoint=False)
+ # TODO: have a dictionary with easing functions that contains my 'smooth' one (might be useful for someone else)
+ if smooth:
+ # Smooth out the interpolation with a polynomial of order 3 (cubic function f)
+ # Constructed f by setting f'(0) = f'(1) = 0, and f(0) = 0, f(1) = 1 => f(t) = -2t^3+3t^2 = t^2 (3-2t)
+ # NOTE: I've merely rediscovered the Smoothstep function S_1(x): https://en.wikipedia.org/wiki/Smoothstep
+ t_array = t_array ** 2 * (3 - 2 * t_array) # One line thanks to NumPy arrays
+ # TODO: this might be possible to optimize by using the fact they're numpy arrays, but haven't found a nice way yet
+ funcs_dict = {'linear': lerp, 'spherical': slerp}
+ vectors = np.array([funcs_dict[interp_type](t, v0, v1) for t in t_array], dtype=np.float32)
+ return vectors
+
+
+# ----------------------------------------------------------------------------
+
+
+def double_slowdown(latents: np.ndarray, duration: float, frames: int) -> Tuple[np.ndarray, float, int]:
+ """
+ Auxiliary function to slow down the video by 2x. We return the new latents, duration, and frames of the video
+ """
+ # Make an empty latent vector with double the amount of frames, but keep the others the same
+ z = np.empty(np.multiply(latents.shape, [2, 1, 1]), dtype=np.float32)
+ # In the even frames, populate it with the latents
+ for i in range(len(latents)):
+ z[2 * i] = latents[i]
+ # Interpolate in the odd frames
+ for i in range(1, len(z), 2):
+ # slerp between (t=0.5) even frames; for the last frame, we loop to the first one (z[0])
+ z[i] = slerp(0.5, z[i - 1], z[i + 1]) if i != len(z) - 1 else slerp(0.5, z[0], z[i - 1])
+ # TODO: we could change this to any slowdown: slerp(1/slowdown, ...), and we return z, slowdown * duration, ...
+ # Return the new latents, and the respective new duration and number of frames
+ return z, 2 * duration, 2 * frames
+
+
+def global_pulsate_psi(psi_start: float, psi_end: float, n_steps: int, frequency: float = 1.0) -> torch.Tensor:
+ """
+ Pulsate the truncation psi parameter between start and end, taking n_steps on a sinusoidal wave.
+ """
+ alpha = (psi_start + psi_end) / (psi_start - psi_end)
+ beta = 2 / (psi_start - psi_end)
+
+ total_time = 2 * np.pi # This value doesn't matter in the end
+
+ timesteps = torch.arange(0, total_time, total_time / n_steps)
+ truncation_psi = (torch.cos(frequency * timesteps) + alpha) / beta
+
+ return truncation_psi
+
+
+def wave_pulse_truncation_psi(psi_start: float,
+ psi_end: float,
+ n_steps: int,
+ grid_shape: Tuple[int, int],
+ frequency: int,
+ time: int) -> torch.Tensor: # Output shape: [num_grid_cells, 1, 1]
+ """
+ Pulsate the truncation psi parameter between start and end, taking n_steps on a sinusoidal wave on a grid
+ Note: The output shape should be [math.prod(grid_shape), 1, 1]
+ """
+ # Let's save some headaches, shall we?
+ if psi_start == psi_end:
+ import math
+ return torch.ones(math.prod(grid_shape), 1, 1) * psi_start
+
+ # We define a total time, but note it's specific to our definition of the wave below (the 2*pi in the conditions)
+ total_time = 5 * torch.pi # T
+ timesteps = torch.arange(0, total_time, total_time / n_steps)
+
+ # Envolope function
+ def envelope(time):
+ """ Envelope function that will regulate the amplitude of the wave; usage: envelope(time) * wave(time) """
+ # Example: a 1D Gabor filter
+ # gaussian = torch.exp(-(time - total_time / 2) ** 2 / 16)
+ # sinusoid = torch.exp(1j * torch.pi(time - total_time / 2) / 2)
+ # return torch.sin(time * torch.pi / total_time) / 2 + 0.5
+ return torch.tensor(1.0)
+
+ # Define the grid itself as a 2D grid where we will evaluate our wave function/psi
+ width, height = grid_shape
+ xs = torch.arange(0, 2*torch.pi, 2*torch.pi/width)
+ ys = torch.arange(0, 2*torch.pi, 2*torch.pi/height)
+ x, y = torch.meshgrid(xs, ys, indexing='xy')
+
+ # Define the wave equation (go crazy here!)
+ # In my case, I will use a sinusoidal wave with source at the upper-left corner of the grid
+ # The wave will travel radially from the source, and will be truncated at the edges of the grid with the psi_start value
+ r = torch.sqrt(x ** 2 + y ** 2)
+
+ # The wave function is defined by parts, that is, keep it constant (psi_start) before and after the wave; its
+ # general shape in 1D will be psi(x, t) = (cos(f(x-t)) + alpha) / beta, where alpha and beta are defined so as to
+ # satisfy the boundary conditions (psi(x, 0) = psi_start, psi(x, T/2) = psi_end, psi(x, T) = psi_start))
+ alpha = (psi_start + psi_end) / (psi_start - psi_end)
+ beta = 2 / (psi_start - psi_end)
+
+ def truncate(value):
+ """
+ Auxiliary function to interpolate between your start and end psi. Use to translate from "value=0" (psi_start)
+ to "value=1" (psi_end) """
+ return psi_start + value * (psi_end - psi_start)
+
+ # Define the wave function by parts, that is, keep it constant (psi_start) before and after the wave
+ truncation_psi = torch.where(torch.gt(r, timesteps[time]) | torch.lt(r, timesteps[time] - 2 * torch.pi),
+ torch.tensor(psi_start),
+ (torch.cos(frequency * (r - timesteps[time])) + alpha) / beta)
+
+ # Make sure the output is of the right shape
+ truncation_psi = truncation_psi.view(width*height, 1, 1)
+ return truncation_psi
+
+
+# ----------------------------------------------------------------------------
+
+
+def make_affine_transform(m: Union[torch.Tensor, np.ndarray] = None,
+ angle: float = 0.0,
+ translate_x: float = 0.0,
+ translate_y: float = 0.0,
+ scale_x: float = 1.0,
+ scale_y: float = 1.0,
+ shear_x: float = 0.0,
+ shear_y: float = 0.0,
+ mirror_x: bool = False,
+ mirror_y: bool = False) -> np.array:
+ """Make affine transformation with the given parameters. If none are passed, will return the identity.
+ As a guide for affine transformations: https://en.wikipedia.org/wiki/Affine_transformation"""
+ # m is the starting affine transformation matrix (e.g., G.synthesis.input.transform)
+ if m is None:
+ m = np.eye(3, dtype=np.float64)
+ elif isinstance(m, torch.Tensor):
+ m = m.cpu().numpy()
+ elif isinstance(m, np.ndarray):
+ pass
+ # Remember these are the inverse transformations!
+ # Rotation matrix
+ rotation_matrix = np.array([[np.cos(angle), np.sin(angle), 0.0],
+ [-np.sin(angle), np.cos(angle), 0.0],
+ [0.0, 0.0, 1.0]], dtype=np.float64)
+ # Translation matrix
+ translation_matrix = np.array([[1.0, 0.0, -translate_x],
+ [0.0, 1.0, -translate_y],
+ [0.0, 0.0, 1.0]], dtype=np.float64)
+ # Scale matrix (don't let it go into negative or 0)
+ scale_matrix = np.array([[1. / max(scale_x, 1e-4), 0.0, 0.0],
+ [0.0, 1. / max(scale_y, 1e-4), 0.0],
+ [0.0, 0.0, 1.0]], dtype=np.float64)
+ # Shear matrix
+ shear_matrix = np.array([[1.0, -shear_x, 0.0],
+ [-shear_y, 1.0, 0.0],
+ [0.0, 0.0, 1.0]], dtype=np.float64)
+ # Mirror/reflection in x matrix
+ xmirror_matrix = np.array([[1.0 - 2 * mirror_x, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0]], dtype=np.float64)
+ # Mirror/reflection in y matrix
+ ymirror_matrix = np.array([[1.0, 0.0, 0.0],
+ [0.0, 1.0 - 2 * mirror_y, 0.0],
+ [0.0, 0.0, 1.0]], dtype=np.float64)
+
+ # Make the resulting affine transformation (note that these are non-commutative, so we *choose* this order)
+ m = m @ rotation_matrix @ translation_matrix @ scale_matrix @ shear_matrix @ xmirror_matrix @ ymirror_matrix
+ return m
+
+
+def anchor_latent_space(G) -> None:
+ # Thanks to @RiversHaveWings and @nshepperd1
+ if hasattr(G.synthesis, 'input'):
+ # Unconditional models differ by a bit
+ if G.c_dim == 0:
+ shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0)).squeeze(0)
+ else:
+ shift = G.synthesis.input.affine(G.mapping.w_avg).mean(0)
+ G.synthesis.input.affine.bias.data.add_(shift)
+ G.synthesis.input.affine.weight.data.zero_()
+
+
+def force_fp32(G) -> None:
+ """Force fp32 as in during training"""
+ G.synthesis.num_fp16_res = 0
+ for name, layer in G.synthesis.named_modules():
+ if hasattr(layer, 'conv_clamp'):
+ layer.conv_clamp = None
+ layer.use_fp16 = False
+
+
+def use_cpu(G) -> None:
+ """Use the CPU instead of the GPU; force_fp32 must be set to True, apart from the device setting"""
+ # @nurpax found this before: https://github.com/NVlabs/stylegan2-ada-pytorch/issues/54#issuecomment-793713965, but we
+ # will use @JCBrouwer's solution: https://github.com/NVlabs/stylegan2-ada-pytorch/issues/105#issuecomment-838577639
+ import functools
+ G.forward = functools.partial(G.forward, force_fp32=True)
+
+# ----------------------------------------------------------------------------
+
+resume_specs = {
+ # For StyleGAN2/ADA models; --cfg=stylegan2
+ 'stylegan2': {
+ # Official NVIDIA models
+ 'ffhq256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl',
+ 'ffhqu256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-256x256.pkl',
+ 'ffhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-512x512.pkl',
+ 'ffhq1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl',
+ 'ffhqu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-1024x1024.pkl',
+ 'celebahq256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-celebahq-256x256.pkl',
+ 'lsundog256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-lsundog-256x256.pkl',
+ 'afhqcat512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqcat-512x512.pkl',
+ 'afhqdog512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqdog-512x512.pkl',
+ 'afhqwild512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqwild-512x512.pkl',
+ 'afhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-afhqv2-512x512.pkl',
+ 'brecahad512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-brecahad-512x512.pkl',
+ 'cifar10': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-cifar10-32x32.pkl',
+ 'metfaces1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfaces-1024x1024.pkl',
+ 'metfacesu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-metfacesu-1024x1024.pkl',
+ # Other configs are available at: https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/, but I will list here the config-f only
+ 'lsuncar512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-f.pkl', # config-f
+ 'lsuncat256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl', # config-f
+ 'lsunchurch256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-church-config-f.pkl', # config-f
+ 'lsunhorse256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-horse-config-f.pkl', # config-f
+ # Community models. More can be found at: https://github.com/justinpinkney/awesome-pretrained-stylegan2 by @justinpinkney, but weren't added here
+ 'minecraft1024': 'https://github.com/jeffheaton/pretrained-gan-minecraft/releases/download/v1/minecraft-gan-2020-12-22.pkl', # Thanks to @jeffheaton
+ 'imagenet512': 'https://battle.shawwn.com/sdc/stylegan2-imagenet-512/model.ckpt-533504.pkl', # Thanks to @shawwn
+ 'wikiart1024-C': 'https://archive.org/download/wikiart-stylegan2-conditional-model/WikiArt5.pkl', # Thanks to @pbaylies; conditional (167 classes in total: --class=0 to 166)
+ 'wikiart1024-U': 'https://archive.org/download/wikiart-stylegan2-conditional-model/WikiArt_Uncond2.pkl', # Thanks to @pbaylies; unconditional
+ 'maps1024': 'https://archive.org/download/mapdreamer/mapdreamer.pkl', # Thanks to @tjukanov
+ 'fursona512': 'https://thisfursonadoesnotexist.com/model/network-e621-r-512-3194880.pkl', # Thanks to @arfafax
+ 'mlpony512': 'https://thisponydoesnotexist.net/model/network-ponies-1024-151552.pkl', # Thanks to @arfafax
+ 'lhq1024': 'https://huggingface.co/justinpinkney/lhq-sg2-1024/resolve/main/lhq-sg2-1024.pkl', # Thanks to @justinpinkney
+ # Deceive-D/APA models (ignoring the faces models): https://github.com/EndlessSora/DeceiveD
+ 'afhqcat256': 'https://drive.google.com/u/0/uc?export=download&confirm=zFoN&id=1P9ouHIK-W8JTb6bvecfBe4c_3w6gmMJK',
+ 'anime256': 'https://drive.google.com/u/0/uc?export=download&confirm=6Uie&id=1EWOdieqELYmd2xRxUR4gnx7G10YI5dyP',
+ 'cub256': 'https://drive.google.com/u/0/uc?export=download&confirm=KwZS&id=1J0qactT55ofAvzddDE_xnJEY8s3vbo1_',
+ # Self-Distilled StyleGAN (full body representation of each class): https://github.com/self-distilled-stylegan/self-distilled-internet-photos
+ 'sddogs1024': 'https://storage.googleapis.com/self-distilled-stylegan/dogs_1024_pytorch.pkl',
+ 'sdelephant512': 'https://storage.googleapis.com/self-distilled-stylegan/elephants_512_pytorch.pkl',
+ 'sdhorses256': 'https://storage.googleapis.com/self-distilled-stylegan/horses_256_pytorch.pkl',
+ 'sdbicycles256': 'https://storage.googleapis.com/self-distilled-stylegan/bicycles_256_pytorch.pkl',
+ 'sdlions512': 'https://storage.googleapis.com/self-distilled-stylegan/lions_512_pytorch.pkl',
+ 'sdgiraffes512': 'https://storage.googleapis.com/self-distilled-stylegan/giraffes_512_pytorch.pkl',
+ 'sdparrots512': 'https://storage.googleapis.com/self-distilled-stylegan/parrots_512_pytorch.pkl'
+
+ },
+ # For StyleGAN2 extended (--cfg=styelgan2-ext)
+ 'stylegan2-ext': {
+ 'anime512': 'https://drive.google.com/u/0/uc?export=download&confirm=zFoN&id=1A-E_E32WAtTHRlOzjhhYhyyBDXLJN9_H' # Thanks to @aydao
+ },
+ # For StyleGAN3 config-r models (--cfg=stylegan3-r)
+ 'stylegan3-r': {
+ # Official NVIDIA models
+ 'afhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl',
+ 'ffhq1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl',
+ 'ffhqu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl',
+ 'ffhqu256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-256x256.pkl',
+ 'metfaces1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-metfaces-1024x1024.pkl',
+ 'metfacesu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-metfacesu-1024x1024.pkl',
+ },
+ # For StyleGAN3 config-t models (--cfg=stylegan3-t)
+ 'stylegan3-t': {
+ # Official NVIDIA models
+ 'afhq512': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-afhqv2-512x512.pkl',
+ 'ffhq1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl',
+ 'ffhqu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl',
+ 'ffhqu256': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-256x256.pkl',
+ 'metfaces1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfaces-1024x1024.pkl',
+ 'metfacesu1024': 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl',
+ # Community models, found in: https://github.com/justinpinkney/awesome-pretrained-stylegan3 by @justinpinkney
+ 'landscapes256': 'https://drive.google.com/u/0/uc?export=download&confirm=eJHe&id=14UGDDOusZ9TMb-pOrF0PAjMGVWLSAii1', # Thanks to @justinpinkney
+ 'wikiart1024': 'https://drive.google.com/u/0/uc?export=download&confirm=2tz5&id=18MOpwTMJsl_Z17q-wQVnaRLCUFZYSNkj', # Thanks to @justinpinkney
+ # -> Wombo Dream-based models found in: https://github.com/edstoica/lucid_stylegan3_datasets_models by @edstoica; TODO: more to come, update the list as they are released!
+ 'mechfuture256': 'https://www.dropbox.com/s/v2oie53cz62ozvu/network-snapshot-000029.pkl?dl=1', # Thanks to @edstoica; 29kimg tick
+ 'vivflowers256': 'https://www.dropbox.com/s/o33lhgnk91hstvx/network-snapshot-000069.pkl?dl=1', # Thanks to @edstoica; 68kimg tick
+ 'alienglass256': 'https://www.dropbox.com/s/gur14k0e7kspguy/network-snapshot-000038.pkl?dl=1', # Thanks to @edstoica; 38kimg tick
+ 'scificity256': 'https://www.dropbox.com/s/1kfsmlct4mriphc/network-snapshot-000210.pkl?dl=1', # Thanks to @edstoica; 210kimg tick
+ 'scifiship256': 'https://www.dropbox.com/s/02br3mjkma1hubc/network-snapshot-000162.pkl?dl=1', # Thanks to @edstoica; 168kimg tick
+ }
+}
+
+# ----------------------------------------------------------------------------
+
+
+# TODO: all of the following functions must work for RGBA images
+def w_to_img(G, dlatents: Union[List[torch.Tensor], torch.Tensor],
+ noise_mode: str = 'const',
+ new_w_avg: torch.Tensor = None,
+ truncation_psi: float = 1.0) -> np.ndarray:
+ """
+ Get an image/np.ndarray from a dlatent W using G and the selected noise_mode. The final shape of the
+ returned image will be [len(dlatents), G.img_resolution, G.img_resolution, G.img_channels].
+ Note: this function should be used after doing the truncation trick!
+ Note: Optionally, you can also pass a new_w_avg to use instead of the one in G, with a reverse
+ truncation trick
+ """
+ # If we have a single dlatent, we need to add a batch dimension
+ assert isinstance(dlatents, torch.Tensor), f'dlatents should be a torch.Tensor!: "{type(dlatents)}"'
+ if len(dlatents.shape) == 2:
+ dlatents = dlatents.unsqueeze(0) # An individual dlatent => [1, G.mapping.num_ws, G.mapping.w_dim]
+ if new_w_avg is not None:
+ new_w_avg = new_w_avg.to(next(G.parameters()).device)
+ dlatents = (dlatents - new_w_avg) * (1 - truncation_psi) + new_w_avg
+ synth_image = G.synthesis(dlatents, noise_mode=noise_mode)
+ synth_image = (synth_image + 1) * 255/2 # [-1.0, 1.0] -> [0.0, 255.0]
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8).cpu().numpy() # NCWH => NWHC
+ return synth_image
+
+
+def z_to_dlatent(G, latents: torch.Tensor, label: torch.Tensor, truncation_psi: float = 1.0) -> torch.Tensor:
+ """Get the dlatent from the given latent, class label and truncation psi"""
+ assert isinstance(latents, torch.Tensor), f'latents should be a torch.Tensor!: "{type(latents)}"'
+ assert isinstance(label, torch.Tensor), f'label should be a torch.Tensor!: "{type(label)}"'
+ if len(latents.shape) == 1:
+ latents = latents.unsqueeze(0) # An individual latent => [1, G.z_dim]
+ dlatents = G.mapping(z=latents, c=label, truncation_psi=truncation_psi)
+
+ return dlatents
+
+
+def z_to_img(G, latents: torch.Tensor, label: torch.Tensor, truncation_psi: float, noise_mode: str = 'const') -> np.ndarray:
+ """
+ Get an image/np.ndarray from a latent Z using G, the label, truncation_psi, and noise_mode. The shape
+ of the output image/np.ndarray will be [len(latents), G.img_resolution, G.img_resolution, G.img_channels]
+ """
+ dlatents = z_to_dlatent(G=G, latents=latents, label=label, truncation_psi=1.0)
+ dlatents = G.mapping.w_avg + (G.mapping.w_avg - dlatents) * truncation_psi
+ img = w_to_img(G=G, dlatents=dlatents, noise_mode=noise_mode) # Let's not redo code
+ return img
+
+
+def get_w_from_seed(G, device: torch.device, seed: int, truncation_psi: float, new_w_avg: torch.Tensor = None) -> torch.Tensor:
+ """Get the dlatent from a random seed, using the truncation trick (this could be optional)"""
+ z = np.random.RandomState(seed).randn(1, G.z_dim)
+ w = G.mapping(torch.from_numpy(z).to(device), None)
+ w_avg = G.mapping.w_avg if new_w_avg is None else new_w_avg.to(device)
+ w = w_avg + (w - w_avg) * truncation_psi
+
+ return w
+
+
+def get_latent_from_file(file: Union[str, os.PathLike],
+ return_ext: bool = False,
+ named_latent: str = 'w') -> Tuple[np.ndarray, Optional[str]]:
+ """Get dlatent (w) from a .npy or .npz file"""
+ filename, file_extension = os.path.splitext(file)
+ assert file_extension in ['.npy', '.npz'], f'"{file}" has wrong file format! Only ".npy" or ".npz" are allowed'
+ if file_extension == '.npy':
+ latent = np.load(file)
+ extension = '.npy'
+ else:
+ latent = np.load(file)[named_latent]
+ extension = '.npz'
+ if len(latent.shape) == 4:
+ latent = latent[0]
+ return (latent, extension) if return_ext else latent
+
+
+# ----------------------------------------------------------------------------
+
+
+def save_config(ctx: click.Context, run_dir: Union[str, os.PathLike], save_name: str = 'config.json') -> None:
+ """Save the configuration stored in ctx.obj into a JSON file at the output directory."""
+ with open(os.path.join(run_dir, save_name), 'w') as f:
+ json.dump(ctx.obj, f, indent=4, sort_keys=True)
+
+
+# ----------------------------------------------------------------------------
+
+
+def make_run_dir(outdir: Union[str, os.PathLike], desc: str, dry_run: bool = False) -> str:
+ """Reject modernity, return to automatically create the run dir."""
+ # Pick output directory.
+ prev_run_dirs = []
+ if os.path.isdir(outdir): # sanity check, but click.Path() should clear this one
+ prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
+ cur_run_id = max(prev_run_ids, default=-1) + 1 # start with 00000
+ run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
+ assert not os.path.exists(run_dir) # make sure it doesn't already exist
+
+ # Don't create the dir if it's a dry-run
+ if not dry_run:
+ print('Creating output directory...')
+ os.makedirs(run_dir)
+ return run_dir
diff --git a/stylegan3-fun/torch_utils/misc.py b/stylegan3-fun/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..335397dd1662d8f5bfd44e17899a00549867f4bc
--- /dev/null
+++ b/stylegan3-fun/torch_utils/misc.py
@@ -0,0 +1,266 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to temporarily suppress known warnings in torch.jit.trace().
+# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
+
+@contextlib.contextmanager
+def suppress_tracer_warnings():
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
+ warnings.filters.insert(0, flt)
+ yield
+ warnings.filters.remove(flt)
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = dict(named_params_and_buffers(src_module))
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/__init__.py b/stylegan3-fun/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..939e7c6c8f94c4ea1141885c3c3295fe083b06aa
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-311.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b85c2b9c15de012ca3c0cb2ec7d18d7c9861312e
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-38.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2381120b63c8d692b36af3c700db9c943e2828f
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-39.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07c1ea1b631fec5048fcd3dea9accd5422bc5145
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/__init__.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-311.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d292017c37285392f6ce01431c1e1f44f911336d
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00eec118f7033e9ef7ded26a84f040211bd2b4e7
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af9dfaccfd2713a9323e416e44358a154d331dd7
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-311.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b23e87f94cf16e18b164386cadb6d39147318b7
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a1677db317e46cb26bb99b8c6a12749f1cf351b
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd2262f1a91c00d092bb557ab63ce920f76c1aae
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-311.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e642ee6da590130adc6a4a83fa0af9629792e50
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74364f7c4229fbef658ab70ac8773541ffc4bc75
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b25a8f84a71d99c1303b41725c5e5b4b0739fb72
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-311.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e69bd8f0c1531f6f05a04e55768ce9befe9b91a
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-38.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..038e0e42d71f807b1450409d982b5b01c409b14e
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-39.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..228b63a6292253cf8d15827657a0347a485cac89
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/fma.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-311.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7eec3e721c9aa7ec86953924a50ff62cc4705547
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-38.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e71bd9c162969266e20cbe276d69d1bec367c069
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8637865ea3db81cd85611a473ff192748f1ac731
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-311.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d638b9fbade9632299155b939eb35b107dfbdea8
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-311.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c9344a0a6c05f93eae3f947fc5729e2a56ad887
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc b/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98bc69ad967139ea117f4537b2df16544d0f6dae
Binary files /dev/null and b/stylegan3-fun/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc differ
diff --git a/stylegan3-fun/torch_utils/ops/bias_act.cpp b/stylegan3-fun/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3adaeee2ae44e96655d354c2bdfb81de8ebfe6c6
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,99 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/bias_act.cu b/stylegan3-fun/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ed1d16f14eadd1344939e074ace1375cfd936cea
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/bias_act.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/bias_act.h b/stylegan3-fun/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..60b81c6058d54638a6d74a13046fa388442d767d
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/bias_act.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/bias_act.py b/stylegan3-fun/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2b53d7da34c76d53251bb9cbc2eb071c50af921
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/bias_act.py
@@ -0,0 +1,209 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import torch
+import dnnlib
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='bias_act_plugin',
+ sources=['bias_act.cpp', 'bias_act.cu'],
+ headers=['bias_act.h'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
+ )
+ return True
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/conv2d_gradfix.py b/stylegan3-fun/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..156b6b22503b3be31d134b48ed1725e911f0e7ce
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,203 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import contextlib
+import torch
+from pkg_resources import parse_version
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
+
+@contextlib.contextmanager
+def no_weight_gradients(disable=True):
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ if disable:
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if _use_pytorch_1_11_api:
+ # The work-around code doesn't work on PyTorch 1.11.0 onwards
+ return False
+ if input.device.type != 'cuda':
+ return False
+ return True
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+_null_tensor = torch.empty([0])
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ ctx.save_for_backward(
+ input if weight.requires_grad else _null_tensor,
+ weight if input.requires_grad else _null_tensor,
+ )
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ if transpose:
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ input_shape = ctx.input_shape
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad_input = op.apply(grad_output, weight, None)
+ assert grad_input.shape == input_shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input):
+ ctx.save_for_backward(
+ grad_output if input.requires_grad else _null_tensor,
+ input if grad_output.requires_grad else _null_tensor,
+ )
+ ctx.grad_output_shape = grad_output.shape
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
+ return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad_output_shape = ctx.grad_output_shape
+ input_shape = ctx.input_shape
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output_shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input_shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/conv2d_resample.py b/stylegan3-fun/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eb5877d7ffe4af74a2165f1d8d8c39dfac2476b
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ if not flip_weight and (kw > 1 or kh > 1):
+ w = w.flip([2, 3])
+
+ # Execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/filtered_lrelu.cpp b/stylegan3-fun/torch_utils/ops/filtered_lrelu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ff4149b8b46b54d2f400ae10e44d19f20503ba1f
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/filtered_lrelu.cpp
@@ -0,0 +1,300 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "filtered_lrelu.h"
+
+//------------------------------------------------------------------------
+
+static std::tuple filtered_lrelu(
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
+
+ // Figure out how much shared memory is available on the device.
+ int maxSharedBytes = 0;
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
+ int sharedKB = maxSharedBytes >> 10;
+
+ // Populate enough launch parameters to check if a CUDA kernel exists.
+ filtered_lrelu_kernel_params p;
+ p.up = up;
+ p.down = down;
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ if (!test_spec.exec)
+ {
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
+ }
+
+ // Input/output element size.
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
+
+ // Input sizes.
+ int64_t xw = (int)x.size(3);
+ int64_t xh = (int)x.size(2);
+ int64_t fut_w = (int)fu.size(-1) - 1;
+ int64_t fut_h = (int)fu.size(0) - 1;
+ int64_t fdt_w = (int)fd.size(-1) - 1;
+ int64_t fdt_h = (int)fd.size(0) - 1;
+
+ // Logical size of upsampled buffer.
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
+
+ // Compute output size and allocate.
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
+
+ // Allocate sign tensor.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ int64_t sw_active = 0; // Active width of sign tensor.
+ if (writeSigns)
+ {
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+ else if (readSigns)
+ sw_active = s.size(3) << 2;
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
+ }
+
+ // Populate rest of CUDA kernel parameters.
+ p.x = x.data_ptr();
+ p.y = y.data_ptr();
+ p.b = b.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.fu = fu.data_ptr();
+ p.fd = fd.data_ptr();
+ p.pad0 = make_int2(px0, py0);
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.flip = (flip_filters) ? 1 : 0;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
+
+ // x, y, b strides are in bytes.
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
+ p.bStride = sz * b.stride(0);
+
+ // fu, fd strides are in elements.
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
+
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
+ bool index64b = false;
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
+ if (s.numel() > INT_MAX) index64b = true;
+
+ // Choose CUDA kernel.
+ filtered_lrelu_kernel_spec spec = { 0 };
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
+ {
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
+ {
+ // Choose kernel based on index type, datatype and sign read/write modes.
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ }
+ });
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = spec.numWarps * 32;
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
+ int gz = p.yShape.z * p.yShape.w;
+
+ // Repeat multiple horizontal tiles in a CTA?
+ if (spec.xrep)
+ {
+ p.tilesXrep = spec.xrep;
+ p.tilesXdim = gx;
+
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
+ std::swap(gx, gy);
+ }
+ else
+ {
+ p.tilesXrep = 0;
+ p.tilesXdim = 0;
+ }
+
+ // Launch filter setup kernel.
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
+
+ // Copy kernels to constant memory.
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+
+ // Set cache and shared memory configurations for main kernel.
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
+
+ // Launch main kernel.
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
+ {
+ p.blockZofs = zofs;
+ int subGz = std::min(maxSubGz, gz - zofs);
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
+ }
+
+ // Done.
+ return std::make_tuple(y, so, 0);
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
+
+ // Output signs if we don't have sign input.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ if (writeSigns)
+ {
+ int64_t sw = x.size(3);
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
+ }
+
+ // Initialize CUDA kernel parameters.
+ filtered_lrelu_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+
+ // Choose CUDA kernel.
+ void* func = 0;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
+ {
+ if (writeSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else if (readSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else
+ func = choose_filtered_lrelu_act_kernel();
+ });
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = 128; // 4 warps per block.
+
+ // Logical size of launch = writeSigns ? p.s : p.x
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
+ gx = (gx - 1) / bx + 1;
+
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
+ const uint32_t gmax = 65535;
+ gy = std::min(gy, gmax);
+ gz = std::min(gz, gmax);
+
+ // Launch.
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
+ return so;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
+}
+
+//------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/filtered_lrelu.cu b/stylegan3-fun/torch_utils/ops/filtered_lrelu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8e6f47f873d42f7181a0faf64779377e70be3012
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/filtered_lrelu.cu
@@ -0,0 +1,1284 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "filtered_lrelu.h"
+#include
+
+//------------------------------------------------------------------------
+// Helpers.
+
+enum // Filter modes.
+{
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
+};
+
+template struct InternalType;
+template <> struct InternalType
+{
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+
+#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
+
+// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
+template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
+{
+ if ((N & (N-1)) && N <= 256)
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
+ else
+ y = i/N;
+
+ x = i - y*N;
+}
+
+// Type cast stride before reading it.
+template __device__ __forceinline__ T get_stride(const int64_t& x)
+{
+ return *reinterpret_cast(&x);
+}
+
+//------------------------------------------------------------------------
+// Filters, setup kernel, copying function.
+
+#define MAX_FILTER_SIZE 32
+
+// Combined up/down filter buffers so that transfer can be done with one copy.
+__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
+__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
+
+// Accessors to combined buffers to index up/down filters individually.
+#define c_fu (c_fbuf)
+#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+#define g_fu (g_fbuf)
+#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+
+// Set up filters into global memory buffer.
+static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
+{
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
+ {
+ int x, y;
+ fast_div_mod(x, y, idx);
+
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
+ if (p.fuShape.y > 0)
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
+ else
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
+
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
+ if (p.fdShape.y > 0)
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
+ else
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
+ }
+}
+
+// Host function to copy filters written by setup kernel into constant buffer for main kernel.
+template static cudaError_t copy_filters(cudaStream_t stream)
+{
+ void* src = 0;
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
+ if (err) return err;
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
+}
+
+//------------------------------------------------------------------------
+// Coordinate spaces:
+// - Relative to input tensor: inX, inY, tileInX, tileInY
+// - Relative to input tile: relInX, relInY, tileInW, tileInH
+// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
+// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
+// - Relative to output tensor: outX, outY, tileOutX, tileOutY
+//
+// Relationships between coordinate spaces:
+// - inX = tileInX + relInX
+// - inY = tileInY + relInY
+// - relUpX = relInX * up + phaseInX
+// - relUpY = relInY * up + phaseInY
+// - relUpX = relOutX * down
+// - relUpY = relOutY * down
+// - outX = tileOutX + relOutX
+// - outY = tileOutY + relOutY
+
+extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
+
+template
+static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
+{
+ // Check that we don't try to support non-existing filter modes.
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
+
+ // Static definitions.
+ typedef typename InternalType::scalar_t scalar_t;
+ typedef typename InternalType::vec2_t vec2_t;
+ typedef typename InternalType::vec4_t vec4_t;
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
+
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
+
+ // Sizes of logical buffers.
+ const int szIn = tileInH_up * tileInW;
+ const int szUpX = tileInH_up * tileUpW;
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
+ const int szDownX = tileUpH * tileOutW;
+
+ // Sizes for shared memory arrays.
+ const int s_buf0_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUFD) ? szIn :
+ -1;
+ const int s_buf1_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
+ (filterMode == MODE_FUSD) ? szUpXY :
+ (filterMode == MODE_SUFD) ? szUpX :
+ (filterMode == MODE_FUFD) ? szUpXY :
+ -1;
+
+ // Ensure U128 alignment.
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
+
+ // Check at compile time that we don't use too much shared memory.
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
+
+ // Declare shared memory arrays.
+ scalar_t* s_buf0;
+ scalar_t* s_buf1;
+ if (sharedKB <= 48)
+ {
+ // Allocate shared memory arrays here.
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
+ s_buf0 = s_buf0_st;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+ else
+ {
+ // Use the dynamically allocated shared memory array.
+ s_buf0 = (scalar_t*)s_buf_raw;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+
+ // Pointers to the buffers.
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
+ if (filterMode == MODE_SUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ s_tileDownX = s_buf1;
+ }
+ else if (filterMode == MODE_FUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ s_tileDownX = s_buf0;
+ }
+ else if (filterMode == MODE_SUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ }
+ else if (filterMode == MODE_FUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ }
+
+ // Allow large grids in z direction via per-launch offset.
+ int channelIdx = blockIdx.z + p.blockZofs;
+ int batchIdx = channelIdx / p.yShape.z;
+ channelIdx -= batchIdx * p.yShape.z;
+
+ // Offset to output feature map. In bytes.
+ index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w);
+
+ // Sign shift amount.
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
+
+ // Inner tile loop.
+ #pragma unroll 1
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
+ {
+ // Locate output tile.
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
+ int tileOutX = tileX * tileOutW;
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
+
+ // Locate input tile.
+ int tmpX = tileOutX * down - p.pad0.x;
+ int tmpY = tileOutY * down - p.pad0.y;
+ int tileInX = CEIL_DIV(tmpX, up);
+ int tileInY = CEIL_DIV(tmpY, up);
+ const int phaseInX = tileInX * up - tmpX;
+ const int phaseInY = tileInY * up - tmpY;
+
+ // Extra sync if input and output buffers are the same and we are not on first tile.
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
+ __syncthreads();
+
+ // Load input tile & apply bias. Unrolled.
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride)));
+ index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w);
+ int idx = threadIdx.x;
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
+ #pragma unroll
+ for (int loop = 0; loop < loopCountIN; loop++)
+ {
+ int relInX, relInY;
+ fast_div_mod(relInX, relInY, idx);
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b;
+
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
+ if (!skip)
+ s_tileIn[idx] = v;
+
+ idx += threadsPerBlock;
+ }
+
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
+ {
+ // Horizontal upsampling.
+ __syncthreads();
+ if (up == 4)
+ {
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ scalar_t a = s_tileIn[src0];
+ if (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInX == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInX == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ s_tileUpX[dst+2] = v.z;
+ s_tileUpX[dst+3] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ bool p0 = (phaseInX == 0);
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ scalar_t a = s_tileIn[src0];
+ if (p0) // (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ }
+ }
+
+ // Vertical upsampling & nonlinearity.
+
+ __syncthreads();
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
+ if (up == 4)
+ {
+ minY -= 3; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec4_t v = InternalType::zero_vec4();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInY == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInY == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+ index_t si2 = si0 + p.sShape.x * 2;
+ index_t si3 = si0 + p.sShape.x * 3;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ int ss = (signX & 3) << 1;
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ minY -= 1; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec2_t v = InternalType::zero_vec2();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+
+ if (!downInline)
+ {
+ // Write into temporary buffer.
+ s_tileUpXY[dst] = v.x;
+ if (relUpY0 < tileUpH - 1)
+ s_tileUpXY[dst + tileUpW] = v.y;
+ }
+ else
+ {
+ // Write directly into output buffer.
+ if ((uint32_t)x < p.yShape.x)
+ {
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
+ index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut;
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+ }
+ }
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
+ {
+ // Full upsampling filter.
+
+ if (up == 2)
+ {
+ // 2 x 2-wide.
+ __syncthreads();
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
+ int src0 = relInX0 + tileInW * relInY0;
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
+
+ #define X_LOOP(TAPY, PX) \
+ for (int sx = 0; sx < fuSize / up; sx++) \
+ { \
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ }
+
+ vec4_t v = InternalType::zero_vec4();
+ if (tap0y == 0 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 0) }
+ if (tap0y == 0 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 1) }
+ if (tap0y == 1 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 0) }
+ if (tap0y == 1 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 1) }
+
+ #undef X_LOOP
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read sign and apply.
+ {
+ if ((uint32_t)signY < p.sShape.y)
+ {
+ int s = 0;
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
+ s >>= (signX & 3) << 1;
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[idx + 0] = v.x;
+ s_tileUpXY[idx + 1] = v.y;
+ s_tileUpXY[idx + 2] = v.z;
+ s_tileUpXY[idx + 3] = v.w;
+ }
+ }
+ else if (up == 1)
+ {
+ __syncthreads();
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ v *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write sign.
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ }
+ else
+ {
+ // Determine and write sign.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ else
+ {
+ // Just compute the value.
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ }
+ }
+ else if (signRead)
+ {
+ // Read sign and apply if within sign tensor bounds.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
+ {
+ int s = p.s[si];
+ s >>= signXo;
+ if (s & 1) v *= p.slope;
+ if (s & 2) v = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+
+ if (!downInline) // Write into temporary buffer.
+ s_tileUpXY[idx] = v;
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
+ *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+
+ // Downsampling.
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
+ {
+ // Horizontal downsampling.
+ __syncthreads();
+ if (down == 4 && tileOutW % 4 == 0)
+ {
+ // Calculate 4 pixels at a time.
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ s_tileDownX[idx+2] = v.z;
+ s_tileDownX[idx+3] = v.w;
+ }
+ }
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
+ {
+ // Calculate 2 pixels at a time.
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ }
+ }
+ else
+ {
+ // Calculate 1 pixel at a time.
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src = relUpY * tileUpW + relUpX0;
+ scalar_t v = 0.f;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
+ s_tileDownX[idx] = v;
+ }
+ }
+
+ // Vertical downsampling & store output tile.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX, relOutY0;
+ fast_div_mod(relOutX, relOutY0, idx);
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileOutW + relOutX;
+ scalar_t v = 0;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
+
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY0;
+
+ if (outX < p.yShape.x & outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
+ {
+ // Full downsampling filter.
+ if (down == 2)
+ {
+ // 2-wide.
+ __syncthreads();
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ int relUpX0 = relOutX0 * down;
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int sy = 0; sy < fdSize; sy++)
+ #pragma unroll
+ for (int sx = 0; sx < fdSize; sx++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ }
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outY < p.yShape.y)
+ {
+ index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut;
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y;
+ }
+ }
+ }
+ else if (down == 1 && !downInline)
+ {
+ // Thread per pixel.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ }
+
+ if (!enableXrep)
+ break;
+ }
+}
+
+//------------------------------------------------------------------------
+// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
+// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
+
+template
+static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Indexing.
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
+
+ // Loop to accommodate oversized tensors.
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
+ {
+ // Extract z and w (channel, minibatch index).
+ int32_t w = q / p.xShape.z;
+ int32_t z = q - w * p.xShape.z;
+
+ // Choose behavior based on sign read/write mode.
+ if (signWrite)
+ {
+ // Process value if in p.x.
+ uint32_t s = 0;
+ if (x < p.xShape.x && y < p.xShape.y)
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+
+ // Gain, LReLU, clamp.
+ v *= p.gain;
+ if (v < 0.f)
+ {
+ v *= p.slope;
+ s = 1; // Sign.
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ v = InternalType::clamp(v, p.clamp);
+ s = 2; // Clamp.
+ }
+
+ *pv = (T)v; // Write value.
+ }
+
+ // Coalesce into threads 0 and 16 of warp.
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
+ s |= __shfl_xor_sync(m, s, 2);
+ s |= __shfl_xor_sync(m, s, 4);
+ s |= __shfl_xor_sync(m, s, 8);
+
+ // Write signs if leader and in p.s.
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
+ {
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
+ ((uint32_t*)p.s)[is >> 4] = s;
+ }
+ }
+ else if (signRead)
+ {
+ // Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+
+ // Apply sign buffer offset.
+ uint32_t sx = x + p.sOfs.x;
+ uint32_t sy = y + p.sOfs.y;
+
+ // Read and apply signs if we land inside valid region of sign buffer.
+ if (sx < p.sShape.x && sy < p.sShape.y)
+ {
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
+ unsigned char s = p.s[is];
+ s >>= (sx & 3) << 1; // Shift into place.
+ if (s & 1) // Sign?
+ v *= p.slope;
+ if (s & 2) // Clamp?
+ v = 0.f;
+ }
+
+ *pv = (T)v; // Write value.
+ }
+ }
+ else
+ {
+ // Forward pass with no sign write. Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+ if (v < 0.f)
+ v *= p.slope;
+ if (fabsf(v) > p.clamp)
+ v = InternalType::clamp(v, p.clamp);
+ *pv = (T)v; // Write value.
+ }
+ }
+ }
+}
+
+template void* choose_filtered_lrelu_act_kernel(void)
+{
+ return (void*)filtered_lrelu_act_kernel;
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
+{
+ filtered_lrelu_kernel_spec s = { 0 };
+
+ // Return the first matching kernel.
+#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
+ if (sharedKB >= SH) \
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
+ { \
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
+ s.setup = (void*)setup_filters_kernel; \
+ s.exec = (void*)filtered_lrelu_kernel; \
+ s.tileOut = make_int2(TW, TH); \
+ s.numWarps = W; \
+ s.xrep = XR; \
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
+ return s; \
+ }
+
+ // Launch parameters for various kernel specializations.
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
+
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
+
+ #undef CASE
+ return s; // No kernel found.
+}
+
+//------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/filtered_lrelu.h b/stylegan3-fun/torch_utils/ops/filtered_lrelu.h
new file mode 100644
index 0000000000000000000000000000000000000000..2c403e3f275f472315662321cad54dd0dbc56d00
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/filtered_lrelu.h
@@ -0,0 +1,90 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct filtered_lrelu_kernel_params
+{
+ // These parameters decide which kernel to use.
+ int up; // upsampling ratio (1, 2, 4)
+ int down; // downsampling ratio (1, 2, 4)
+ int2 fuShape; // [size, 1] | [size, size]
+ int2 fdShape; // [size, 1] | [size, size]
+
+ int _dummy; // Alignment.
+
+ // Rest of the parameters.
+ const void* x; // Input tensor.
+ void* y; // Output tensor.
+ const void* b; // Bias tensor.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+ const float* fu; // Upsampling filter.
+ const float* fd; // Downsampling filter.
+
+ int2 pad0; // Left/top padding.
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+ int flip; // Filter kernel flip for gradient computation.
+
+ int tilesXdim; // Original number of horizontal output tiles.
+ int tilesXrep; // Number of horizontal tiles per CTA.
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
+
+ int4 xShape; // [width, height, channel, batch]
+ int4 yShape; // [width, height, channel, batch]
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+ int swLimit; // Active width of sign tensor in bytes.
+
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
+ longlong4 yStride; //
+ int64_t bStride; //
+ longlong3 fuStride; //
+ longlong3 fdStride; //
+};
+
+struct filtered_lrelu_act_kernel_params
+{
+ void* x; // Input/output, modified in-place.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+
+ int4 xShape; // [width, height, channel, batch]
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct filtered_lrelu_kernel_spec
+{
+ void* setup; // Function for filter kernel setup.
+ void* exec; // Function for main operation.
+ int2 tileOut; // Width/height of launch tile.
+ int numWarps; // Number of warps per thread block, determines launch block size.
+ int xrep; // For processing multiple horizontal tiles per thread block.
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template void* choose_filtered_lrelu_act_kernel(void);
+template cudaError_t copy_filters(cudaStream_t stream);
+
+//------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/filtered_lrelu.py b/stylegan3-fun/torch_utils/ops/filtered_lrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d8865f6900df9bae8e716b4dbaa22c09475815b
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/filtered_lrelu.py
@@ -0,0 +1,279 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import numpy as np
+import torch
+import warnings
+
+from .. import custom_ops
+from .. import misc
+from . import upfirdn2d
+from . import bias_act
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+
+def _init():
+ global _plugin
+ # Bob Burrough's PR (#45) so that the plugins work in Windows: https://github.com/NVlabs/stylegan3/pull/45
+ extras = {}
+ if os.name == 'nt':
+ extras['extra_cflags'] = ['/std:c++17']
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='filtered_lrelu_plugin',
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
+ **extras,
+ )
+ return True
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor)
+ assert 1 <= f.ndim <= 2
+ return f.shape[-1], f.shape[0] # width, height
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
+ padding = [int(x) for x in padding]
+ if len(padding) == 2:
+ px, py = padding
+ padding = [px, px, py, py]
+ px0, px1, py0, py1 = padding
+ return px0, px1, py0, py1
+
+#----------------------------------------------------------------------------
+
+def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
+ r"""Filtered leaky ReLU for a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Add channel-specific bias if provided (`b`).
+
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 5. Multiply each value by the provided gain factor (`gain`).
+
+ 6. Apply leaky ReLU activation function to each value.
+
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
+
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
+ it so that the footprint of all output pixels lies within the input image.
+
+ 9. Downsample the image by keeping every Nth pixel (`down`).
+
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float16/float64 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ fu: Float32 upsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ fd: Float32 downsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The length of vector must must match the channel dimension of `x`.
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor. (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
+ flip_filter: False = convolution, True = correlation (default: False).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
+ existing `upfirdn2n()` and `bias_act()` ops.
+ """
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ fu_w, fu_h = _get_filter_size(fu)
+ fd_w, fd_h = _get_filter_size(fd)
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
+ misc.assert_shape(b, [x.shape[1]])
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ assert slope == float(slope) and slope >= 0
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+
+ # Calculate output size.
+ batch_size, channels, in_h, in_w = x.shape
+ in_dtype = x.dtype
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
+
+ # Compute using existing ops.
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Check output shape & dtype.
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
+ assert x.dtype == in_dtype
+ return x
+
+#----------------------------------------------------------------------------
+
+_filtered_lrelu_cuda_cache = dict()
+
+def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
+ """
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ gain = float(gain)
+ assert slope == float(slope) and slope >= 0
+ slope = float(slope)
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+ clamp = float(clamp if clamp is not None else 'inf')
+
+ # Lookup from cache.
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
+ if key in _filtered_lrelu_cuda_cache:
+ return _filtered_lrelu_cuda_cache[key]
+
+ # Forward op.
+ class FilteredLReluCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
+ if fu is None:
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ if fd is None:
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert 1 <= fu.ndim <= 2
+ assert 1 <= fd.ndim <= 2
+
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
+ fu = fu.square()[None]
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
+ fd = fd.square()[None]
+
+ # Missing sign input tensor.
+ if si is None:
+ si = torch.empty([0])
+
+ # Missing bias tensor.
+ if b is None:
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
+
+ # Construct internal sign tensor only if gradients are needed.
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
+
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
+
+ # Call C++/Cuda plugin if datatype is supported.
+ if x.dtype in [torch.float16, torch.float32]:
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
+ else:
+ return_code = -1
+
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
+ # only the bit-packed sign tensor is retained for gradient computation.
+ if return_code < 0:
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
+
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Prepare for gradient computation.
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
+ ctx.x_shape = x.shape
+ ctx.y_shape = y.shape
+ ctx.s_ofs = sx, sy
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ fu, fd, si = ctx.saved_tensors
+ _, _, xh, xw = ctx.x_shape
+ _, _, yh, yw = ctx.y_shape
+ sx, sy = ctx.s_ofs
+ dx = None # 0
+ dfu = None; assert not ctx.needs_input_grad[1]
+ dfd = None; assert not ctx.needs_input_grad[2]
+ db = None # 3
+ dsi = None; assert not ctx.needs_input_grad[4]
+ dsx = None; assert not ctx.needs_input_grad[5]
+ dsy = None; assert not ctx.needs_input_grad[6]
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
+ pp = [
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
+ xw * up - yw * down + px0 - (up - 1),
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
+ xh * up - yh * down + py0 - (up - 1),
+ ]
+ gg = gain * (up ** 2) / (down ** 2)
+ ff = (not flip_filter)
+ sx = sx - (fu.shape[-1] - 1) + px0
+ sy = sy - (fu.shape[0] - 1) + py0
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
+
+ if ctx.needs_input_grad[3]:
+ db = dx.sum([0, 2, 3])
+
+ return dx, dfu, dfd, db, dsi, dsx, dsy
+
+ # Add to cache.
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
+ return FilteredLReluCuda
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/filtered_lrelu_ns.cu b/stylegan3-fun/torch_utils/ops/filtered_lrelu_ns.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ef5d948c4fdf9cb0fe8a42f6268c61aeef6b2000
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/filtered_lrelu_ns.cu
@@ -0,0 +1,27 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for no signs mode (no gradients required).
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/stylegan3-fun/torch_utils/ops/filtered_lrelu_rd.cu b/stylegan3-fun/torch_utils/ops/filtered_lrelu_rd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..968347882e9aebd36204f67e201cd16226dd9132
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/filtered_lrelu_rd.cu
@@ -0,0 +1,27 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign read mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/stylegan3-fun/torch_utils/ops/filtered_lrelu_wr.cu b/stylegan3-fun/torch_utils/ops/filtered_lrelu_wr.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a4c6a24aae908bc07248f7ff710cbd1a11a38bb1
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/filtered_lrelu_wr.cu
@@ -0,0 +1,27 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign write mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/stylegan3-fun/torch_utils/ops/fma.py b/stylegan3-fun/torch_utils/ops/fma.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a45dfa0829987e8ee5214663e068cb3af2a8b9
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/fma.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+
+#----------------------------------------------------------------------------
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+#----------------------------------------------------------------------------
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+#----------------------------------------------------------------------------
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
+ assert x.shape == shape
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/grid_sample_gradfix.py b/stylegan3-fun/torch_utils/ops/grid_sample_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..441b37953c3392c3eb7dcf24c3d2891f17a3c18e
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import torch
+from pkg_resources import parse_version
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ return enabled
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ if _use_pytorch_1_11_api:
+ output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
+ else:
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/upfirdn2d.cpp b/stylegan3-fun/torch_utils/ops/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..44fa337d8d4c34dfa010a59cd27d86857db671aa
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,107 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/stylegan3-fun/torch_utils/ops/upfirdn2d.cu b/stylegan3-fun/torch_utils/ops/upfirdn2d.cu
new file mode 100644
index 0000000000000000000000000000000000000000..3a33e31bbb1bbc1cd02ee7d2ede3943917f3906e
--- /dev/null
+++ b/stylegan3-fun/torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,384 @@
+// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ // No up/downsampling.
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+
+ // 2x upsampling.
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ }
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ }
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
+ {
+ // contiguous
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ // channels_last
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small