rossellison commited on
Commit
e3a6a57
1 Parent(s): 689d7e7

Upload 194 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. app.py +87 -0
  3. stylegan3-fun/.github/FUNDING.yml +3 -0
  4. stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md +35 -0
  5. stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
  6. stylegan3-fun/.gitignore +173 -0
  7. stylegan3-fun/Dockerfile +19 -0
  8. stylegan3-fun/LICENSE.txt +97 -0
  9. stylegan3-fun/README.md +507 -0
  10. stylegan3-fun/__pycache__/legacy.cpython-311.pyc +0 -0
  11. stylegan3-fun/__pycache__/legacy.cpython-38.pyc +0 -0
  12. stylegan3-fun/__pycache__/legacy.cpython-39.pyc +0 -0
  13. stylegan3-fun/avg_spectra.py +276 -0
  14. stylegan3-fun/calc_metrics.py +188 -0
  15. stylegan3-fun/dataset_tool.py +547 -0
  16. stylegan3-fun/discriminator_synthesis.py +1007 -0
  17. stylegan3-fun/dnnlib/__init__.py +9 -0
  18. stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc +0 -0
  19. stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc +0 -0
  20. stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
  21. stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc +0 -0
  22. stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc +0 -0
  23. stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc +0 -0
  24. stylegan3-fun/dnnlib/util.py +491 -0
  25. stylegan3-fun/docs/avg_spectra_screen0.png +0 -0
  26. stylegan3-fun/docs/avg_spectra_screen0_half.png +0 -0
  27. stylegan3-fun/docs/configs.md +201 -0
  28. stylegan3-fun/docs/dataset-tool-help.txt +70 -0
  29. stylegan3-fun/docs/stylegan3-teaser-1920x1006.png +3 -0
  30. stylegan3-fun/docs/train-help.txt +53 -0
  31. stylegan3-fun/docs/troubleshooting.md +31 -0
  32. stylegan3-fun/docs/visualizer_screen0.png +3 -0
  33. stylegan3-fun/docs/visualizer_screen0_half.png +0 -0
  34. stylegan3-fun/environment.yml +35 -0
  35. stylegan3-fun/gen_images.py +145 -0
  36. stylegan3-fun/gen_video.py +281 -0
  37. stylegan3-fun/generate.py +838 -0
  38. stylegan3-fun/gui_utils/__init__.py +9 -0
  39. stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc +0 -0
  40. stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc +0 -0
  41. stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  42. stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc +0 -0
  43. stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc +0 -0
  44. stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc +0 -0
  45. stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc +0 -0
  46. stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc +0 -0
  47. stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc +0 -0
  48. stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc +0 -0
  49. stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc +0 -0
  50. stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stylegan3-fun/docs/stylegan3-teaser-1920x1006.png filter=lfs diff=lfs merge=lfs -text
37
+ stylegan3-fun/docs/visualizer_screen0.png filter=lfs diff=lfs merge=lfs -text
38
+ stylegan3-fun/out/seed0002.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import List, Optional, Tuple, Union
5
+ import random
6
+
7
+ sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located
8
+
9
+ import numpy as np
10
+ import PIL.Image
11
+ import torch
12
+ import streamlit as st
13
+ import dnnlib
14
+ import legacy
15
+
16
+
17
+ def parse_range(s: Union[str, List]) -> List[int]:
18
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
19
+
20
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
21
+ '''
22
+ if isinstance(s, list): return s
23
+ ranges = []
24
+ range_re = re.compile(r'^(\d+)-(\d+)$')
25
+ for p in s.split(','):
26
+ m = range_re.match(p)
27
+ if m:
28
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
29
+ else:
30
+ ranges.append(int(p))
31
+ return ranges
32
+
33
+ def make_transform(translate: Tuple[float,float], angle: float):
34
+ m = np.eye(3)
35
+ s = np.sin(angle/360.0*np.pi*2)
36
+ c = np.cos(angle/360.0*np.pi*2)
37
+ m[0][0] = c
38
+ m[0][1] = s
39
+ m[0][2] = translate[0]
40
+ m[1][0] = -s
41
+ m[1][1] = c
42
+ m[1][2] = translate[1]
43
+ return m
44
+
45
+ def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]):
46
+ print('Loading networks from "%s"...' % network_pkl)
47
+ device = torch.device('cuda')
48
+ with open(network_pkl, 'rb') as f:
49
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
50
+
51
+ # Labels.
52
+ label = torch.zeros([1, G.c_dim], device=device)
53
+ if G.c_dim != 0:
54
+ if class_idx is None:
55
+ raise Exception('Must specify class label when using a conditional network')
56
+ label[:, class_idx] = 1
57
+
58
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
59
+
60
+ if hasattr(G.synthesis, 'input'):
61
+ m = make_transform(translate, rotate)
62
+ m = np.linalg.inv(m)
63
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
64
+
65
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
66
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
67
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
68
+ return img
69
+
70
+ def main():
71
+ st.title('Kpop Face Generator')
72
+
73
+ st.write('Press the button below to generate a new image:')
74
+ if st.button('Generate'):
75
+ network_pkl = 'kpopGG.pkl'
76
+ seed = random.randint(0, 99999)
77
+ truncation_psi = 0.45
78
+ noise_mode = 'const'
79
+ translate = (0.0, 0.0)
80
+ rotate = 0.0
81
+ class_idx = None
82
+
83
+ image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx)
84
+ st.image(image)
85
+
86
+ if __name__ == "__main__":
87
+ main()
stylegan3-fun/.github/FUNDING.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: PDillis # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. In '...' directory, run command '...'
16
+ 2. See error (copy&paste full log, including exceptions and **stacktraces**).
17
+
18
+ Please copy&paste text instead of screenshots for better searchability.
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Desktop (please complete the following information):**
27
+ - OS: [e.g. Linux Ubuntu 20.04, Windows 10]
28
+ - PyTorch version (e.g., pytorch 1.9.0)
29
+ - CUDA toolkit version (e.g., CUDA 11.4)
30
+ - NVIDIA driver version
31
+ - GPU [e.g., Titan V, RTX 3090]
32
+ - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:21.08-py3)
33
+
34
+ **Additional context**
35
+ Add any other context about the problem here.
stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Describe alternatives you've considered**
17
+ A clear and concise description of any alternative solutions or features you've considered.
18
+
19
+ **Additional context**
20
+ Add any other context or screenshots about the feature request here.
stylegan3-fun/.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ */**/__pycache__/
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ .idea/
162
+
163
+ # Conda temp
164
+ .condatmp/
165
+
166
+ # SGAN specific folders
167
+ datasets/
168
+ dlatents/
169
+ out/
170
+ training-runs/
171
+ pretrained/
172
+ video/
173
+ _screenshots/
stylegan3-fun/Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ FROM nvcr.io/nvidia/pytorch:21.08-py3
10
+
11
+ ENV PYTHONDONTWRITEBYTECODE 1
12
+ ENV PYTHONUNBUFFERED 1
13
+
14
+ RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
15
+
16
+ WORKDIR /workspace
17
+
18
+ RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
19
+ ENTRYPOINT ["/entry.sh"]
stylegan3-fun/LICENSE.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for StyleGAN3
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
stylegan3-fun/README.md ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StyleGAN3-Fun<br><sub>Let's have fun with StyleGAN2/ADA/3!</sub>
2
+
3
+ SOTA GANs are hard to train and to explore, and StyleGAN2/ADA/3 are no different. The point of this repository is to allow
4
+ the user to both easily train and explore the trained models without unnecessary headaches.
5
+
6
+ As [before](https://github.com/PDillis/stylegan2-fun), we will build upon the official repository, which has the advantage
7
+ of being backwards-compatible. As such, we can use our previously-trained models from StyleGAN2 and StyleGAN2-ADA. Please
8
+ get acquainted with the official repository and its codebase, as we will be building upon it and as such, increase its
9
+ capabilities (but hopefully not its complexity!).
10
+
11
+ ## Additions
12
+
13
+ This repository adds/has the following changes (not yet the complete list):
14
+
15
+ * ***Dataset Setup (`dataset_tool.py`)***
16
+ * **RGBA support**, so revert saving images to `.png` ([Issue #156](https://github.com/NVlabs/stylegan3/issues/156) by @1378dm). Training can use RGBA and images can be generated.
17
+ * ***TODO:*** ~~Check that training code is correct for normalizing the alpha channel~~, as well as making the
18
+ interpolation code work with this new format (look into [`moviepy.editor.VideoClip`](https://zulko.github.io/moviepy/getting_started/videoclips.html?highlight=mask#mask-clips))
19
+ * For now, interpolation videos will only be saved in RGB format, e.g., discarding the alpha channel.
20
+ * **`--center-crop-tall`**: add vertical black bars to the sides of each image in the dataset (rectangular images, with height > width),
21
+ and you wish to train a square model, in the same vein as the horizontal bars added when using `--center-crop-wide` (where width > height).
22
+ * This is useful when you don't want to lose information from the left and right side of the image by only using the center
23
+ crop (ibidem for `--center-crop-wide`, but for the top and bottom of the image)
24
+ * Note that each image doesn't have to be of the same size, and the added bars will only ensure you get a square image, which will then be
25
+ resized to the model's desired resolution (set by `--resolution`).
26
+ * Grayscale images in the dataset are converted to `RGB`
27
+ * If you want to turn this off, remove the respective line in `dataset_tool.py`, e.g., if your dataset is made of images in a folder, then the function to be used is
28
+ `open_image_folder` in `dataset_tool.py`, and the line to be removed is `img = img.convert('RGB')` in the `iterate_images` inner function.
29
+ * The dataset can be forced to be of a specific number of channels, that is, grayscale, RGB or RGBA.
30
+ * To use this, set `--force-channels=1` for grayscale, `--force-channels=3` for RGB, and `--force-channels=4` for RGBA.
31
+ * If the dataset tool encounters an error, print it along the offending image, but continue with the rest of the dataset
32
+ ([PR #39](https://github.com/NVlabs/stylegan3/pull/39) from [Andreas Jansson](https://github.com/andreasjansson)).
33
+ * For conditional models, we can use the subdirectories as the classes by adding `--subfolders-as-labels`. This will
34
+ generate the `dataset.json` file automatically as done by @pbaylies [here](https://github.com/pbaylies/stylegan2-ada/blob/a8f0b1c891312631f870c94f996bcd65e0b8aeef/dataset_tool.py#L772)
35
+ * Additionally, in the `--source` folder, we will save a `class_labels.txt` file, to further know which classes correspond to each subdirectory.
36
+
37
+ * ***Training***
38
+ * Add `--cfg=stylegan2-ext`, which uses @aydao's extended modifications for handling large and diverse datasets.
39
+ * A good explanation is found in Gwern's blog [here](https://gwern.net/face#extended-stylegan2-danbooru2019-aydao)
40
+ * If you wish to fine-tune from @aydao's Anime model, use `--cfg=stylegan2-ext --resume=anime512` when running `train.py`
41
+ * Note: ***This is an extremely experimental configuration!*** The `.pkl` files will be ~1.1Gb each and training will slow down
42
+ significantly. Use at your own risk!
43
+ * `--blur-percent`: Blur both real and generated images before passing them to the Discriminator.
44
+ * The blur (`blur_init_sigma=10.0`) will completely fade after the selected percentage of the training is completed (using a linear ramp).
45
+ * Another experimental feature, should help with datasets that have a lot of variation, and you wish the model to slowly
46
+ learn to generate the objects and then its details.
47
+ * `--mirrory`: Added vertical mirroring for doubling the dataset size (quadrupling if `--mirror` is used; make sure your dataset has either or both
48
+ of these symmetries in order for it to make sense to use them)
49
+ * `--gamma`: If no R1 regularization is provided, the heuristic formula from [StyleGAN](https://github.com/NVlabs/stylegan2) will be used.
50
+ * Specifically, we will set `gamma=0.0002 * resolution ** 2 / batch_size`
51
+ * `--aug`: ***TODO:*** add [Deceive-D/APA](https://github.com/EndlessSora/DeceiveD) as an option.
52
+ * `--augpipe`: Now available to use is [StyleGAN2-ADA's](https://github.com/NVlabs/stylegan2-ada-pytorch) full list of augpipe, i.e., individual augmentations (`blit`, `geom`, `color`, `filter`, `noise`, `cutout`) or their combinations (`bg`, `bgc`, `bgcf`, `bgcfn`, `bgcfnc`).
53
+ * `--img-snap`: Set when to save snapshot images, so now it's independent of when the model is saved (e.g., save image snapshots more often to know how the model is training without saving the model itself, to save space).
54
+ * `--snap-res`: The resolution of the snapshots, depending on how many images you wish to see per snapshot. Available resolutions: `1080p`, `4k`, and `8k`.
55
+ * `--resume-kimg`: Starting number of `kimg`, useful when continuing training a previous run
56
+ * `--outdir`: Automatically set as `training-runs`, so no need to set beforehand (in general this is true throughout the repository)
57
+ * `--metrics`: Now set by default to `None`, so there's no need to worry about this one
58
+ * `--freezeD`: Renamed `--freezed` for better readability
59
+ * `--freezeM`: Freeze the first layers of the Mapping Network Gm (`G.mapping`)
60
+ * `--freezeE`: Freeze the embedding layer of the Generator (for class-conditional models)
61
+ * `--freezeG`: ***TODO:*** Freeze the first layers of the Synthesis Network (`G.synthesis`; less cost to transfer learn, focus on high layers?)
62
+ * `--resume`: All available pre-trained models from NVIDIA (and more) can be used with a simple dictionary, depending on the `--cfg` used.
63
+ For example, if you wish to use StyleGAN3's `config-r`, then set `--cfg=stylegan3-r`. In addition, if you wish to transfer learn from FFHQU at 1024 resolution, set `--resume=ffhqu1024`.
64
+ * The full list of currently available models to transfer learn from (or synthesize new images with) is the following (***TODO:*** add small description of each model,
65
+ so the user can better know which to use for their particular use-case; proper citation to original authors as well):
66
+
67
+ <details>
68
+ <summary>StyleGAN2 models</summary>
69
+
70
+ 1. Majority, if not all, are `config-f`: set `--cfg=stylegan2`
71
+ * `ffhq256`
72
+ * `ffhqu256`
73
+ * `ffhq512`
74
+ * `ffhq1024`
75
+ * `ffhqu1024`
76
+ * `celebahq256`
77
+ * `lsundog256`
78
+ * `afhqcat512`
79
+ * `afhqdog512`
80
+ * `afhqwild512`
81
+ * `afhq512`
82
+ * `brecahad512`
83
+ * `cifar10` (conditional, 10 classes)
84
+ * `metfaces1024`
85
+ * `metfacesu1024`
86
+ * `lsuncar512` (config-f)
87
+ * `lsuncat256` (config-f)
88
+ * `lsunchurch256` (config-f)
89
+ * `lsunhorse256` (config-f)
90
+ * `minecraft1024` (thanks to @jeffheaton)
91
+ * `imagenet512` (thanks to @shawwn)
92
+ * `wikiart1024-C` (conditional, 167 classes; thanks to @pbaylies)
93
+ * `wikiart1024-U` (thanks to @pbaylies)
94
+ * `maps1024` (thanks to @tjukanov)
95
+ * `fursona512` (thanks to @arfafax)
96
+ * `mlpony512` (thanks to @arfafax)
97
+ * `lhq1024` (thanks to @justinpinkney)
98
+ * `afhqcat256` (Deceive-D/APA models)
99
+ * `anime256` (Deceive-D/APA models)
100
+ * `cub256` (Deceive-D/APA models)
101
+ * `sddogs1024` (Self-Distilled StyleGAN models)
102
+ * `sdelephant512` (Self-Distilled StyleGAN models)
103
+ * `sdhorses512` (Self-Distilled StyleGAN models)
104
+ * `sdbicycles256` (Self-Distilled StyleGAN models)
105
+ * `sdlions512` (Self-Distilled StyleGAN models)
106
+ * `sdgiraffes512` (Self-Distilled StyleGAN models)
107
+ * `sdparrots512` (Self-Distilled StyleGAN models)
108
+ 2. Extended StyleGAN2 config from @aydao: set `--cfg=stylegan2-ext`
109
+ * `anime512` (thanks to @aydao; writeup by @gwern: https://gwern.net/Faces#extended-stylegan2-danbooru2019-aydao)
110
+ </details>
111
+
112
+ <details>
113
+ <summary>StyleGAN3 models</summary>
114
+
115
+ 1. `config-t`: set `--cfg=stylegan3-t`
116
+ * `afhq512`
117
+ * `ffhqu256`
118
+ * `ffhq1024`
119
+ * `ffhqu1024`
120
+ * `metfaces1024`
121
+ * `metfacesu1024`
122
+ * `landscapes256` (thanks to @justinpinkney)
123
+ * `wikiart1024` (thanks to @justinpinkney)
124
+ * `mechfuture256` (thanks to @edstoica; 29 kimg tick)
125
+ * `vivflowers256` (thanks to @edstoica; 68 kimg tick)
126
+ * `alienglass256` (thanks to @edstoica; 38 kimg tick)
127
+ * `scificity256` (thanks to @edstoica; 210 kimg tick)
128
+ * `scifiship256` (thanks to @edstoica; 168 kimg tick)
129
+ 2. `config-r`: set `--cfg=stylegan3-r`
130
+ * `afhq512`
131
+ * `ffhq1024`
132
+ * `ffhqu1024`
133
+ * `ffhqu256`
134
+ * `metfaces1024`
135
+ * `metfacesu1024`
136
+ </details>
137
+
138
+ * The main sources of these pretrained models are both the [official NVIDIA repository](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/research/models/stylegan3),
139
+ as well as other community repositories, such as [Justin Pinkney](https://github.com/justinpinkney) 's [Awesome Pretrained StyleGAN2](https://github.com/justinpinkney/awesome-pretrained-stylegan2)
140
+ and [Awesome Pretrained StyleGAN3](https://github.com/justinpinkney/awesome-pretrained-stylegan3), [Deceive-D/APA](https://github.com/EndlessSora/DeceiveD),
141
+ [Self-Distilled StyleGAN/Internet Photos](https://github.com/self-distilled-stylegan/self-distilled-internet-photos), and [edstoica](https://github.com/edstoica) 's
142
+ [Wombo Dream](https://www.wombo.art/) [-based models](https://github.com/edstoica/lucid_stylegan3_datasets_models). Others can be found around the net and are properly credited in this repository,
143
+ so long as they can be easily downloaded with [`dnnlib.util.open_url`](https://github.com/PDillis/stylegan3-fun/blob/4ce9d6f7601641ba1e2906ed97f2739a63fb96e2/dnnlib/util.py#L396).
144
+
145
+ * ***Interpolation videos***
146
+ * [Random interpolation](https://youtu.be/DNfocO1IOUE)
147
+ * [Generate images/interpolations with the internal representations of the model](https://nvlabs-fi-cdn.nvidia.com/_web/stylegan3/videos/video_8_internal_activations.mp4)
148
+ * Usage: Add `--layer=<layer_name>` to specify which layer to use for interpolation.
149
+ * If you don't know the names of the layers available for your model, add the flag `--available-layers` and the
150
+ layers will be printed to the console, along their names, number of channels, and sizes.
151
+ * Use one of `--grayscale` or `--rgb` to specify whether to save the images as grayscale or RGB during the interpolation.
152
+ * For `--rgb`, three consecutive channels (starting at `--starting-channel=0`) will be used to create the RGB image. For `--grayscale`, only the first channel will be used.
153
+ * Style-mixing
154
+ * [Sightseeding](https://twitter.com/PDillis/status/1270341433401249793?s=20&t=yLueNkagqsidZFqZ2jNPAw) (jumpiness has been fixed)
155
+ * [Circular interpolation](https://youtu.be/4nktYGjSVHg)
156
+ * [Visual-reactive interpolation](https://youtu.be/KoEAkPnE-zA) (Beta)
157
+ * Audiovisual-reactive interpolation (TODO)
158
+ * ***TODO:*** Give support to RGBA models!
159
+ * ***Projection into the latent space***
160
+ * [Project into $\mathcal{W}+$](https://arxiv.org/abs/1904.03189)
161
+ * Additional losses to use for better projection (e.g., using VGG16 or [CLIP](https://github.com/openai/CLIP))
162
+ * ***[Discriminator Synthesis](https://arxiv.org/abs/2111.02175)*** (official code)
163
+ * Generate a static image (`python discriminator_synthesis.py dream --help`) or a [video](https://youtu.be/hEJKWL2VQTE) with a feedback loop (`python discriminator_synthesis.py dream-zoom --help`,
164
+ `python discriminator_synthesis.py channel-zoom --help`, or `python discriminator_synthesis.py interp --help`)
165
+ * Start from a random image (`random` for noise or `perlin` for 2D fractal Perlin noise, using
166
+ [Mathieu Duchesneau's implementation](https://github.com/duchesneaumathieu/pyperlin)) or from an existing one
167
+ * ***Expansion on GUI/`visualizer.py`***
168
+ * Added the rest of the affine transformations
169
+ * Added widget for class-conditional models (***TODO:*** mix classes with continuous values for `cls`!)
170
+ * ***General model and code additions***
171
+ * [Multi-modal truncation trick](https://arxiv.org/abs/2202.12211): find the different clusters in your model and use the closest one to your dlatent, in order to increase the fidelity
172
+ * Usage: Run `python multimodal_truncation.py get-centroids --network=<path_to_model>` to use default values; for extra options, run `python multimodal_truncation.py get-centroids --help`
173
+ * StyleGAN3: anchor the latent space for easier to follow interpolations (thanks to [Rivers Have Wings](https://github.com/crowsonkb) and [nshepperd](https://github.com/nshepperd)).
174
+ * Use CPU instead of GPU if desired (not recommended, but perfectly fine for generating images, whenever the custom CUDA kernels fail to compile).
175
+ * Add missing dependencies and channels so that the [`conda`](https://docs.conda.io/en/latest/) environment is correctly setup in Windows
176
+ (PR's [#111](https://github.com/NVlabs/stylegan3/pull/111)/[#125](https://github.com/NVlabs/stylegan3/pull/125) and [#80](https://github.com/NVlabs/stylegan3/pull/80) /[#143](https://github.com/NVlabs/stylegan3/pull/143) from the base, respectively)
177
+ * Use [StyleGAN-NADA](https://github.com/rinongal/StyleGAN-nada) models with any part of the code (Issue [#9](https://github.com/PDillis/stylegan3-fun/issues/9))
178
+ * The StyleGAN-NADA models must first be converted via [Vadim Epstein](https://github.com/eps696) 's conversion code found [here](https://github.com/eps696/stylegan2ada#tweaking-models).
179
+ * Add PR [#173](https://github.com/NVlabs/stylegan3/pull/173) for adding the last remaining unknown kwarg for using StyleGAN2 models using TF 1.15.
180
+ * ***TODO*** list (this is a long one with more to come, so any help is appreciated):
181
+ * Add `--device={cuda, ref}` option to pass to each of the custom operations in order to (theoretically) be able to use AMD GPUs, as explained in
182
+ @l4rz's post [here](https://github.com/l4rz/practical-aspects-of-stylegan2-training#hardware)
183
+ * Define a [custom Generator](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/59e05bb115c1c7d0de56be0523754076c2b7ee83/legacy.py#L131) in `legacy.py` to modify the output size
184
+ * Related: the [multi-latent](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/main/training/stylegan2_multi.py), i.e., the one from [@eps696](https://github.com/eps696/stylegan2)
185
+ * Add [Top-K training](https://arxiv.org/abs/2002.06224) as done [here](https://github.com/dvschultz/stylegan2-ada/blob/8f4ab24f494483542d31bf10f4fdb0005dc62739/train.py#L272) and [here](https://github.com/dvschultz/stylegan2-ada-pytorch/blob/59e05bb115c1c7d0de56be0523754076c2b7ee83/training/loss.py#L79)
186
+ * Add panorama/SinGAN/feature interpolation from [StyleGAN of All Trades](https://arxiv.org/abs/2111.01619)
187
+ * [PTI](https://github.com/danielroich/PTI) for better inversion
188
+ * [Better sampling](https://arxiv.org/abs/2110.08009)
189
+ * [Add cross-model interpolation](https://twitter.com/arfafax/status/1297681537337446400?s=20&t=xspnTaLFTvd7y4krg8tkxA)
190
+ * Blend different models (average checkpoints, copy weights, create initial network), as in @aydao's [StyleGAN2-Surgery](https://github.com/aydao/stylegan2-surgery)
191
+ * Add multi-crop for the dataset creation, as used in [Earth View](https://github.com/PDillis/earthview#multi-crop---data_augmentpy).
192
+ * Make it easy to download pretrained models from Drive, otherwise a lot of models can't be used with `dnnlib.util.open_url`
193
+ (e.g., [StyleGAN-Human](https://github.com/stylegan-human/StyleGAN-Human) models)
194
+ * Finish documentation for better user experience, add videos/images, code samples, visuals...
195
+ * Add [Ensembling Off-the-shelf Models for GAN Training](https://arxiv.org/abs/2112.09130) and [Any-resolution Training for High-resolution Image Synthesis](https://chail.github.io/anyres-gan/)
196
+
197
+ ## Notebooks (Coming soon!)
198
+
199
+ ## Sponsors ![GitHub Sponsor](https://img.shields.io/github/sponsors/PDillis?label=Sponsor&logo=GitHub)
200
+
201
+ This repository has been sponsored by:
202
+
203
+ [isosceles](https://www.jasonfletcher.info/vjloops/)
204
+
205
+ Thank you so much!
206
+
207
+ If you wish to sponsor me, click here: [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/PDillis)
208
+
209
+
210
+ ---
211
+
212
+ ## Alias-Free Generative Adversarial Networks (StyleGAN3)<br><sub>Official PyTorch implementation of the NeurIPS 2021 paper</sub>
213
+
214
+ ![Teaser image](./docs/stylegan3-teaser-1920x1006.png)
215
+
216
+ **Alias-Free Generative Adversarial Networks**<br>
217
+ Tero Karras, Miika Aittala, Samuli Laine, Erik H&auml;rk&ouml;nen, Janne Hellsten, Jaakko Lehtinen, Timo Aila<br>
218
+ https://nvlabs.github.io/stylegan3<br>
219
+
220
+ Abstract: *We observe that despite their hierarchical convolutional nature, the synthesis process of typical generative adversarial networks depends on absolute pixel coordinates in an unhealthy manner. This manifests itself as, e.g., detail appearing to be glued to image coordinates instead of the surfaces of depicted objects. We trace the root cause to careless signal processing that causes aliasing in the generator network. Interpreting all signals in the network as continuous, we derive generally applicable, small architectural changes that guarantee that unwanted information cannot leak into the hierarchical synthesis process. The resulting networks match the FID of StyleGAN2 but differ dramatically in their internal representations, and they are fully equivariant to translation and rotation even at subpixel scales. Our results pave the way for generative models better suited for video and animation.*
221
+
222
+ For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
223
+
224
+ ## Release notes
225
+
226
+ This repository is an updated version of [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch), with several new features:
227
+ - Alias-free generator architecture and training configurations (`stylegan3-t`, `stylegan3-r`).
228
+ - Tools for interactive visualization (`visualizer.py`), spectral analysis (`avg_spectra.py`), and video generation (`gen_video.py`).
229
+ - Equivariance metrics (`eqt50k_int`, `eqt50k_frac`, `eqr50k`).
230
+ - General improvements: reduced memory usage, slightly faster training, bug fixes.
231
+
232
+ Compatibility:
233
+ - Compatible with old network pickles created using [stylegan2-ada](https://github.com/NVlabs/stylegan2-ada) and [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). (Note: running old StyleGAN2 models on StyleGAN3 code will produce the same results as running them on stylegan2-ada/stylegan2-ada-pytorch. To benefit from the StyleGAN3 architecture, you need to retrain.)
234
+ - Supports old StyleGAN2 training configurations, including ADA and transfer learning. See [Training configurations](./docs/configs.md) for details.
235
+ - Improved compatibility with Ampere GPUs and newer versions of PyTorch, CuDNN, etc.
236
+
237
+ ## Synthetic image detection
238
+
239
+ While new generator approaches enable new media synthesis capabilities, they may also present a new challenge for AI forensics algorithms for detection and attribution of synthetic media. In collaboration with digital forensic researchers participating in DARPA's SemaFor program, we curated a synthetic image dataset that allowed the researchers to test and validate the performance of their image detectors in advance of the public release. Please see [here](https://github.com/NVlabs/stylegan3-detector) for more details.
240
+
241
+ ## Additional material
242
+
243
+ - [Result videos](https://nvlabs-fi-cdn.nvidia.com/stylegan3/videos/)
244
+ - [Curated example images](https://nvlabs-fi-cdn.nvidia.com/stylegan3/images/)
245
+ - [StyleGAN3 pre-trained models](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan3) for config T (translation equiv.) and config R (translation and rotation equiv.)
246
+ > <sub>Access individual networks via `https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/<MODEL>`, where `<MODEL>` is one of:</sub><br>
247
+ > <sub>`stylegan3-t-ffhq-1024x1024.pkl`, `stylegan3-t-ffhqu-1024x1024.pkl`, `stylegan3-t-ffhqu-256x256.pkl`</sub><br>
248
+ > <sub>`stylegan3-r-ffhq-1024x1024.pkl`, `stylegan3-r-ffhqu-1024x1024.pkl`, `stylegan3-r-ffhqu-256x256.pkl`</sub><br>
249
+ > <sub>`stylegan3-t-metfaces-1024x1024.pkl`, `stylegan3-t-metfacesu-1024x1024.pkl`</sub><br>
250
+ > <sub>`stylegan3-r-metfaces-1024x1024.pkl`, `stylegan3-r-metfacesu-1024x1024.pkl`</sub><br>
251
+ > <sub>`stylegan3-t-afhqv2-512x512.pkl`</sub><br>
252
+ > <sub>`stylegan3-r-afhqv2-512x512.pkl`</sub><br>
253
+ - [StyleGAN2 pre-trained models](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan2) compatible with this codebase
254
+ > <sub>Access individual networks via `https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/<MODEL>`, where `<MODEL>` is one of:</sub><br>
255
+ > <sub>`stylegan2-ffhq-1024x1024.pkl`, `stylegan2-ffhq-512x512.pkl`, `stylegan2-ffhq-256x256.pkl`</sub><br>
256
+ > <sub>`stylegan2-ffhqu-1024x1024.pkl`, `stylegan2-ffhqu-256x256.pkl`</sub><br>
257
+ > <sub>`stylegan2-metfaces-1024x1024.pkl`, `stylegan2-metfacesu-1024x1024.pkl`</sub><br>
258
+ > <sub>`stylegan2-afhqv2-512x512.pkl`</sub><br>
259
+ > <sub>`stylegan2-afhqcat-512x512.pkl`, `stylegan2-afhqdog-512x512.pkl`, `stylegan2-afhqwild-512x512.pkl`</sub><br>
260
+ > <sub>`stylegan2-brecahad-512x512.pkl`, `stylegan2-cifar10-32x32.pkl`</sub><br>
261
+ > <sub>`stylegan2-celebahq-256x256.pkl`, `stylegan2-lsundog-256x256.pkl`</sub><br>
262
+
263
+ ## Requirements
264
+
265
+ * Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
266
+ * 1&ndash;8 high-end NVIDIA GPUs with at least 12 GB of memory. We have done all testing and development using Tesla V100 and A100 GPUs.
267
+ * 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
268
+ * CUDA toolkit 11.1 or later. (Why is a separate CUDA toolkit installation required? See [Troubleshooting](./docs/troubleshooting.md#why-is-cuda-toolkit-installation-necessary)).
269
+ * GCC 7 or later (Linux) or Visual Studio (Windows) compilers. Recommended GCC version depends on CUDA version, see for example [CUDA 11.4 system requirements](https://docs.nvidia.com/cuda/archive/11.4.1/cuda-installation-guide-linux/index.html#system-requirements).
270
+ * Python libraries: see [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your StyleGAN3 Python environment:
271
+ - `conda env create -f environment.yml`
272
+ - `conda activate stylegan3`
273
+ * Docker users:
274
+ - Ensure you have correctly installed the [NVIDIA container runtime](https://docs.docker.com/config/containers/resource_constraints/#gpu).
275
+ - Use the [provided Dockerfile](./Dockerfile) to build an image with the required library dependencies.
276
+
277
+ The code relies heavily on custom PyTorch extensions that are compiled on the fly using NVCC. On Windows, the compilation requires Microsoft Visual Studio. We recommend installing [Visual Studio Community Edition](https://visualstudio.microsoft.com/vs/) and adding it into `PATH` using `"C:\Program Files (x86)\Microsoft Visual Studio\<VERSION>\Community\VC\Auxiliary\Build\vcvars64.bat"`.
278
+
279
+ See [Troubleshooting](./docs/troubleshooting.md) for help on common installation and run-time problems.
280
+
281
+ ## Getting started
282
+
283
+ Pre-trained networks are stored as `*.pkl` files that can be referenced using local filenames or URLs:
284
+
285
+ ```bash
286
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
287
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \
288
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
289
+
290
+ # Render a 4x2 grid of interpolations for seeds 0 through 31.
291
+ python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \
292
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
293
+ ```
294
+
295
+ Outputs from the above commands are placed under `out/*.png`, controlled by `--outdir`. Downloaded network pickles are cached under `$HOME/.cache/dnnlib`, which can be overridden by setting the `DNNLIB_CACHE_DIR` environment variable. The default PyTorch extension build directory is `$HOME/.cache/torch_extensions`, which can be overridden by setting `TORCH_EXTENSIONS_DIR`.
296
+
297
+ **Docker**: You can run the above curated image example using Docker as follows:
298
+
299
+ ```bash
300
+ # Build the stylegan3:latest image
301
+ docker build --tag stylegan3 .
302
+
303
+ # Run the gen_images.py script using Docker:
304
+ docker run --gpus all -it --rm --user $(id -u):$(id -g) \
305
+ -v `pwd`:/scratch --workdir /scratch -e HOME=/scratch \
306
+ stylegan3 \
307
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \
308
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
309
+ ```
310
+
311
+ Note: The Docker image requires NVIDIA driver release `r470` or later.
312
+
313
+ The `docker run` invocation may look daunting, so let's unpack its contents here:
314
+
315
+ - `--gpus all -it --rm --user $(id -u):$(id -g)`: with all GPUs enabled, run an interactive session with current user's UID/GID to avoid Docker writing files as root.
316
+ - ``-v `pwd`:/scratch --workdir /scratch``: mount current running dir (e.g., the top of this git repo on your host machine) to `/scratch` in the container and use that as the current working dir.
317
+ - `-e HOME=/scratch`: let PyTorch and StyleGAN3 code know where to cache temporary files such as pre-trained models and custom PyTorch extension build results. Note: if you want more fine-grained control, you can instead set `TORCH_EXTENSIONS_DIR` (for custom extensions build dir) and `DNNLIB_CACHE_DIR` (for pre-trained model download cache). You want these cache dirs to reside on persistent volumes so that their contents are retained across multiple `docker run` invocations.
318
+
319
+ ## Interactive visualization
320
+
321
+ This release contains an interactive model visualization tool that can be used to explore various characteristics of a trained model. To start it, run:
322
+
323
+ ```bash
324
+ python visualizer.py
325
+ ```
326
+
327
+ <a href="./docs/visualizer_screen0.png"><img alt="Visualizer screenshot" src="./docs/visualizer_screen0_half.png"></img></a>
328
+
329
+ ## Using networks from Python
330
+
331
+ You can use pre-trained networks in your own Python code as follows:
332
+
333
+ ```python
334
+ with open('ffhq.pkl', 'rb') as f:
335
+ G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
336
+ z = torch.randn([1, G.z_dim]).cuda() # latent codes
337
+ c = None # class labels (not used in this example)
338
+ img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation
339
+ ```
340
+
341
+ The above code requires `torch_utils` and `dnnlib` to be accessible via `PYTHONPATH`. It does not need source code for the networks themselves &mdash; their class definitions are loaded from the pickle via `torch_utils.persistence`.
342
+
343
+ The pickle contains three networks. `'G'` and `'D'` are instantaneous snapshots taken during training, and `'G_ema'` represents a moving average of the generator weights over several training steps. The networks are regular instances of `torch.nn.Module`, with all of their parameters and buffers placed on the CPU at import and gradient computation disabled by default.
344
+
345
+ The generator consists of two submodules, `G.mapping` and `G.synthesis`, that can be executed separately. They also support various additional options:
346
+
347
+ ```python
348
+ w = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8)
349
+ img = G.synthesis(w, noise_mode='const', force_fp32=True)
350
+ ```
351
+
352
+ Please refer to [`gen_images.py`](./gen_images.py) for complete code example.
353
+
354
+ ## Preparing datasets
355
+
356
+ Datasets are stored as uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information. Alternatively, the folder can also be used directly as a dataset, without running it through `dataset_tool.py` first, but doing so may lead to suboptimal performance.
357
+
358
+ **FFHQ**: Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and create a zip archive using `dataset_tool.py`:
359
+
360
+ ```bash
361
+ # Original 1024x1024 resolution.
362
+ python dataset_tool.py --source=/tmp/images1024x1024 --dest=~/datasets/ffhq-1024x1024.zip
363
+
364
+ # Scaled down 256x256 resolution.
365
+ python dataset_tool.py --source=/tmp/images1024x1024 --dest=~/datasets/ffhq-256x256.zip \
366
+ --resolution=256x256
367
+ ```
368
+
369
+ See the [FFHQ README](https://github.com/NVlabs/ffhq-dataset) for information on how to obtain the unaligned FFHQ dataset images. Use the same steps as above to create a ZIP archive for training and validation.
370
+
371
+ **MetFaces**: Download the [MetFaces dataset](https://github.com/NVlabs/metfaces-dataset) and create a ZIP archive:
372
+
373
+ ```bash
374
+ python dataset_tool.py --source=~/downloads/metfaces/images --dest=~/datasets/metfaces-1024x1024.zip
375
+ ```
376
+
377
+ See the [MetFaces README](https://github.com/NVlabs/metfaces-dataset) for information on how to obtain the unaligned MetFaces dataset images. Use the same steps as above to create a ZIP archive for training and validation.
378
+
379
+ **AFHQv2**: Download the [AFHQv2 dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) and create a ZIP archive:
380
+
381
+ ```bash
382
+ python dataset_tool.py --source=~/downloads/afhqv2 --dest=~/datasets/afhqv2-512x512.zip
383
+ ```
384
+
385
+ Note that the above command creates a single combined dataset using all images of all three classes (cats, dogs, and wild animals), matching the setup used in the StyleGAN3 paper. Alternatively, you can also create a separate dataset for each class:
386
+
387
+ ```bash
388
+ python dataset_tool.py --source=~/downloads/afhqv2/train/cat --dest=~/datasets/afhqv2cat-512x512.zip
389
+ python dataset_tool.py --source=~/downloads/afhqv2/train/dog --dest=~/datasets/afhqv2dog-512x512.zip
390
+ python dataset_tool.py --source=~/downloads/afhqv2/train/wild --dest=~/datasets/afhqv2wild-512x512.zip
391
+ ```
392
+
393
+ ## Training
394
+
395
+ You can train new networks using `train.py`. For example:
396
+
397
+ ```bash
398
+ # Train StyleGAN3-T for AFHQv2 using 8 GPUs.
399
+ python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
400
+ --gpus=8 --batch=32 --gamma=8.2 --mirror=1
401
+
402
+ # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
403
+ python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
404
+ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
405
+ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
406
+
407
+ # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs.
408
+ python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \
409
+ --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug
410
+ ```
411
+
412
+ Note that the result quality and training time depend heavily on the exact set of options. The most important ones (`--gpus`, `--batch`, and `--gamma`) must be specified explicitly, and they should be selected with care. See [`python train.py --help`](./docs/train-help.txt) for the full list of options and [Training configurations](./docs/configs.md) for general guidelines &amp; recommendations, along with the expected training speed &amp; memory usage in different scenarios.
413
+
414
+ The results of each training run are saved to a newly created directory, for example `~/training-runs/00000-stylegan3-t-afhqv2-512x512-gpus8-batch32-gamma8.2`. The training loop exports network pickles (`network-snapshot-<KIMG>.pkl`) and random image grids (`fakes<KIMG>.png`) at regular intervals (controlled by `--snap`). For each exported pickle, it evaluates FID (controlled by `--metrics`) and logs the result in `metric-fid50k_full.jsonl`. It also records various statistics in `training_stats.jsonl`, as well as `*.tfevents` if TensorBoard is installed.
415
+
416
+ ## Quality metrics
417
+
418
+ By default, `train.py` automatically computes FID for each network pickle exported during training. We recommend inspecting `metric-fid50k_full.jsonl` (or TensorBoard) at regular intervals to monitor the training progress. When desired, the automatic computation can be disabled with `--metrics=none` to speed up the training slightly.
419
+
420
+ Additional quality metrics can also be computed after the training:
421
+
422
+ ```bash
423
+ # Previous training run: look up options automatically, save result to JSONL file.
424
+ python calc_metrics.py --metrics=eqt50k_int,eqr50k \
425
+ --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
426
+
427
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
428
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \
429
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
430
+ ```
431
+
432
+ The first example looks up the training configuration and performs the same operation as if `--metrics=eqt50k_int,eqr50k` had been specified during training. The second example downloads a pre-trained network pickle, in which case the values of `--data` and `--mirror` must be specified explicitly.
433
+
434
+ Note that the metrics can be quite expensive to compute (up to 1h), and many of them have an additional one-off cost for each new dataset (up to 30min). Also note that the evaluation is done using a different random seed each time, so the results will vary if the same metric is computed multiple times.
435
+
436
+ Recommended metrics:
437
+ * `fid50k_full`: Fr&eacute;chet inception distance<sup>[1]</sup> against the full dataset.
438
+ * `kid50k_full`: Kernel inception distance<sup>[2]</sup> against the full dataset.
439
+ * `pr50k3_full`: Precision and recall<sup>[3]</sup> againt the full dataset.
440
+ * `ppl2_wend`: Perceptual path length<sup>[4]</sup> in W, endpoints, full image.
441
+ * `eqt50k_int`: Equivariance<sup>[5]</sup> w.r.t. integer translation (EQ-T).
442
+ * `eqt50k_frac`: Equivariance w.r.t. fractional translation (EQ-T<sub>frac</sub>).
443
+ * `eqr50k`: Equivariance w.r.t. rotation (EQ-R).
444
+
445
+ Legacy metrics:
446
+ * `fid50k`: Fr&eacute;chet inception distance against 50k real images.
447
+ * `kid50k`: Kernel inception distance against 50k real images.
448
+ * `pr50k3`: Precision and recall against 50k real images.
449
+ * `is50k`: Inception score<sup>[6]</sup> for CIFAR-10.
450
+
451
+ References:
452
+ 1. [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500), Heusel et al. 2017
453
+ 2. [Demystifying MMD GANs](https://arxiv.org/abs/1801.01401), Bi&nacute;kowski et al. 2018
454
+ 3. [Improved Precision and Recall Metric for Assessing Generative Models](https://arxiv.org/abs/1904.06991), Kynk&auml;&auml;nniemi et al. 2019
455
+ 4. [A Style-Based Generator Architecture for Generative Adversarial Networks](https://arxiv.org/abs/1812.04948), Karras et al. 2018
456
+ 5. [Alias-Free Generative Adversarial Networks](https://nvlabs.github.io/stylegan3), Karras et al. 2021
457
+ 6. [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498), Salimans et al. 2016
458
+
459
+ ## Spectral analysis
460
+
461
+ The easiest way to inspect the spectral properties of a given generator is to use the built-in FFT mode in `visualizer.py`. In addition, you can visualize average 2D power spectra (Appendix A, Figure 15) as follows:
462
+
463
+ ```bash
464
+ # Calculate dataset mean and std, needed in subsequent steps.
465
+ python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip
466
+
467
+ # Calculate average spectrum for the training data.
468
+ python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \
469
+ --dest=tmp/training-data.npz --mean=112.684 --std=69.509
470
+
471
+ # Calculate average spectrum for a pre-trained generator.
472
+ python avg_spectra.py calc \
473
+ --source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \
474
+ --dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000
475
+
476
+ # Display results.
477
+ python avg_spectra.py heatmap tmp/training-data.npz
478
+ python avg_spectra.py heatmap tmp/stylegan3-r.npz
479
+ python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz
480
+ ```
481
+
482
+ <a href="./docs/avg_spectra_screen0.png"><img alt="Average spectra screenshot" src="./docs/avg_spectra_screen0_half.png"></img></a>
483
+
484
+ ## License
485
+
486
+ Copyright &copy; 2021, NVIDIA Corporation & affiliates. All rights reserved.
487
+
488
+ This work is made available under the [Nvidia Source Code License](https://github.com/NVlabs/stylegan3/blob/main/LICENSE.txt).
489
+
490
+ ## Citation
491
+
492
+ ```
493
+ @inproceedings{Karras2021,
494
+ author = {Tero Karras and Miika Aittala and Samuli Laine and Erik H\"ark\"onen and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
495
+ title = {Alias-Free Generative Adversarial Networks},
496
+ booktitle = {Proc. NeurIPS},
497
+ year = {2021}
498
+ }
499
+ ```
500
+
501
+ ## Development
502
+
503
+ This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.
504
+
505
+ ## Acknowledgements
506
+
507
+ We thank David Luebke, Ming-Yu Liu, Koki Nagano, Tuomas Kynk&auml;&auml;nniemi, and Timo Viitanen for reviewing early drafts and helpful suggestions. Fr&eacute;do Durand for early discussions. Tero Kuosmanen for maintaining our compute infrastructure. AFHQ authors for an updated version of their dataset. Getty Images for the training images in the Beaches dataset. We did not receive external funding or additional revenues for this project.
stylegan3-fun/__pycache__/legacy.cpython-311.pyc ADDED
Binary file (29 kB). View file
 
stylegan3-fun/__pycache__/legacy.cpython-38.pyc ADDED
Binary file (15.4 kB). View file
 
stylegan3-fun/__pycache__/legacy.cpython-39.pyc ADDED
Binary file (15.3 kB). View file
 
stylegan3-fun/avg_spectra.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Compare average power spectra between real and generated images,
10
+ or between multiple generators."""
11
+
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ import torch.fft
16
+ import scipy.ndimage
17
+ import matplotlib.pyplot as plt
18
+ import click
19
+ import tqdm
20
+ import dnnlib
21
+
22
+ import legacy
23
+ from training import dataset
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Setup an iterator for streaming images, in uint8 NCHW format, based on the
27
+ # respective command line options.
28
+
29
+ def stream_source_images(source, num, seed, device, data_loader_kwargs=None): # => num_images, image_size, image_iter
30
+ ext = source.split('.')[-1].lower()
31
+ if data_loader_kwargs is None:
32
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
33
+
34
+ if ext == 'pkl':
35
+ if num is None:
36
+ raise click.ClickException('--num is required when --source points to network pickle')
37
+ with dnnlib.util.open_url(source) as f:
38
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
39
+ def generate_image(seed):
40
+ rnd = np.random.RandomState(seed)
41
+ z = torch.from_numpy(rnd.randn(1, G.z_dim)).to(device)
42
+ c = torch.zeros([1, G.c_dim], device=device)
43
+ if G.c_dim > 0:
44
+ c[:, rnd.randint(G.c_dim)] = 1
45
+ return (G(z=z, c=c) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
46
+ _ = generate_image(seed) # warm up
47
+ image_iter = (generate_image(seed + idx) for idx in range(num))
48
+ return num, G.img_resolution, image_iter
49
+
50
+ elif ext == 'zip' or os.path.isdir(source):
51
+ dataset_obj = dataset.ImageFolderDataset(path=source, max_size=num, random_seed=seed)
52
+ if num is not None and num != len(dataset_obj):
53
+ raise click.ClickException(f'--source contains fewer than {num} images')
54
+ data_loader = torch.utils.data.DataLoader(dataset_obj, batch_size=1, **data_loader_kwargs)
55
+ image_iter = (image.to(device) for image, _label in data_loader)
56
+ return len(dataset_obj), dataset_obj.resolution, image_iter
57
+
58
+ else:
59
+ raise click.ClickException('--source must point to network pickle, dataset zip, or directory')
60
+
61
+ #----------------------------------------------------------------------------
62
+ # Load average power spectrum from the specified .npz file and construct
63
+ # the corresponding heatmap for visualization.
64
+
65
+ def construct_heatmap(npz_file, smooth):
66
+ npz_data = np.load(npz_file)
67
+ spectrum = npz_data['spectrum']
68
+ image_size = npz_data['image_size']
69
+ hmap = np.log10(spectrum) * 10 # dB
70
+ hmap = np.fft.fftshift(hmap)
71
+ hmap = np.concatenate([hmap, hmap[:1, :]], axis=0)
72
+ hmap = np.concatenate([hmap, hmap[:, :1]], axis=1)
73
+ if smooth > 0:
74
+ sigma = spectrum.shape[0] / image_size * smooth
75
+ hmap = scipy.ndimage.gaussian_filter(hmap, sigma=sigma, mode='nearest')
76
+ return hmap, image_size
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ @click.group()
81
+ def main():
82
+ """Compare average power spectra between real and generated images,
83
+ or between multiple generators.
84
+
85
+ Example:
86
+
87
+ \b
88
+ # Calculate dataset mean and std, needed in subsequent steps.
89
+ python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip
90
+
91
+ \b
92
+ # Calculate average spectrum for the training data.
93
+ python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \\
94
+ --dest=tmp/training-data.npz --mean=112.684 --std=69.509
95
+
96
+ \b
97
+ # Calculate average spectrum for a pre-trained generator.
98
+ python avg_spectra.py calc \\
99
+ --source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \\
100
+ --dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000
101
+
102
+ \b
103
+ # Display results.
104
+ python avg_spectra.py heatmap tmp/training-data.npz
105
+ python avg_spectra.py heatmap tmp/stylegan3-r.npz
106
+ python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz
107
+
108
+ \b
109
+ # Save as PNG.
110
+ python avg_spectra.py heatmap tmp/training-data.npz --save=tmp/training-data.png --dpi=300
111
+ python avg_spectra.py heatmap tmp/stylegan3-r.npz --save=tmp/stylegan3-r.png --dpi=300
112
+ python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz --save=tmp/slices.png --dpi=300
113
+ """
114
+
115
+ #----------------------------------------------------------------------------
116
+
117
+ @main.command()
118
+ @click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
119
+ @click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
120
+ @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
121
+ def stats(source, num, seed, device=torch.device('cuda')):
122
+ """Calculate dataset mean and standard deviation needed by 'calc'."""
123
+ torch.multiprocessing.set_start_method('spawn')
124
+ num_images, _image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
125
+
126
+ # Accumulate moments.
127
+ moments = torch.zeros([3], dtype=torch.float64, device=device)
128
+ for image in tqdm.tqdm(image_iter, total=num_images):
129
+ image = image.to(torch.float64)
130
+ moments += torch.stack([torch.ones_like(image).sum(), image.sum(), image.square().sum()])
131
+ moments = moments / moments[0]
132
+
133
+ # Compute mean and standard deviation.
134
+ mean = moments[1]
135
+ std = (moments[2] - moments[1].square()).sqrt()
136
+ print(f'--mean={mean:g} --std={std:g}')
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ @main.command()
141
+ @click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
142
+ @click.option('--dest', help='Where to store the result', metavar='NPZ', required=True)
143
+ @click.option('--mean', help='Dataset mean for whitening', metavar='FLOAT', type=float, required=True)
144
+ @click.option('--std', help='Dataset standard deviation for whitening', metavar='FLOAT', type=click.FloatRange(min=0), required=True)
145
+ @click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
146
+ @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
147
+ @click.option('--beta', help='Shape parameter for the Kaiser window', metavar='FLOAT', type=click.FloatRange(min=0), default=8, show_default=True)
148
+ @click.option('--interp', help='Frequency-domain interpolation factor', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True)
149
+ def calc(source, dest, mean, std, num, seed, beta, interp, device=torch.device('cuda')):
150
+ """Calculate average power spectrum and store it in .npz file."""
151
+ torch.multiprocessing.set_start_method('spawn')
152
+ num_images, image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
153
+ spectrum_size = image_size * interp
154
+ padding = spectrum_size - image_size
155
+
156
+ # Setup window function.
157
+ window = torch.kaiser_window(image_size, periodic=False, beta=beta, device=device)
158
+ window *= window.square().sum().rsqrt()
159
+ window = window.ger(window).unsqueeze(0).unsqueeze(1)
160
+
161
+ # Accumulate power spectrum.
162
+ spectrum = torch.zeros([spectrum_size, spectrum_size], dtype=torch.float64, device=device)
163
+ for image in tqdm.tqdm(image_iter, total=num_images):
164
+ image = (image.to(torch.float64) - mean) / std
165
+ image = torch.nn.functional.pad(image * window, [0, padding, 0, padding])
166
+ spectrum += torch.fft.fftn(image, dim=[2,3]).abs().square().mean(dim=[0,1])
167
+ spectrum /= num_images
168
+
169
+ # Save result.
170
+ if os.path.dirname(dest):
171
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
172
+ np.savez(dest, spectrum=spectrum.cpu().numpy(), image_size=image_size)
173
+
174
+ #----------------------------------------------------------------------------
175
+
176
+ @main.command()
177
+ @click.argument('npz-file', nargs=1)
178
+ @click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
179
+ @click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
180
+ @click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=1.25, show_default=True)
181
+ def heatmap(npz_file, save, smooth, dpi):
182
+ """Visualize 2D heatmap based on the given .npz file."""
183
+ hmap, image_size = construct_heatmap(npz_file=npz_file, smooth=smooth)
184
+
185
+ # Setup plot.
186
+ plt.figure(figsize=[6, 4.8], dpi=dpi, tight_layout=True)
187
+ freqs = np.linspace(-0.5, 0.5, num=hmap.shape[0], endpoint=True) * image_size
188
+ ticks = np.linspace(freqs[0], freqs[-1], num=5, endpoint=True)
189
+ levels = np.linspace(-40, 20, num=13, endpoint=True)
190
+
191
+ # Draw heatmap.
192
+ plt.xlim(ticks[0], ticks[-1])
193
+ plt.ylim(ticks[0], ticks[-1])
194
+ plt.xticks(ticks)
195
+ plt.yticks(ticks)
196
+ plt.contourf(freqs, freqs, hmap, levels=levels, extend='both', cmap='Blues')
197
+ plt.gca().set_aspect('equal')
198
+ plt.colorbar(ticks=levels)
199
+ plt.contour(freqs, freqs, hmap, levels=levels, extend='both', linestyles='solid', linewidths=1, colors='midnightblue', alpha=0.2)
200
+
201
+ # Display or save.
202
+ if save is None:
203
+ plt.show()
204
+ else:
205
+ if os.path.dirname(save):
206
+ os.makedirs(os.path.dirname(save), exist_ok=True)
207
+ plt.savefig(save)
208
+
209
+ #----------------------------------------------------------------------------
210
+
211
+ @main.command()
212
+ @click.argument('npz-files', nargs=-1, required=True)
213
+ @click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
214
+ @click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
215
+ @click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
216
+ def slices(npz_files, save, dpi, smooth):
217
+ """Visualize 1D slices based on the given .npz files."""
218
+ cases = [dnnlib.EasyDict(npz_file=npz_file) for npz_file in npz_files]
219
+ for c in cases:
220
+ c.hmap, c.image_size = construct_heatmap(npz_file=c.npz_file, smooth=smooth)
221
+ c.label = os.path.splitext(os.path.basename(c.npz_file))[0]
222
+
223
+ # Check consistency.
224
+ image_size = cases[0].image_size
225
+ hmap_size = cases[0].hmap.shape[0]
226
+ if any(c.image_size != image_size or c.hmap.shape[0] != hmap_size for c in cases):
227
+ raise click.ClickException('All .npz must have the same resolution')
228
+
229
+ # Setup plot.
230
+ plt.figure(figsize=[12, 4.6], dpi=dpi, tight_layout=True)
231
+ hmap_center = hmap_size // 2
232
+ hmap_range = np.arange(hmap_center, hmap_size)
233
+ freqs0 = np.linspace(0, image_size / 2, num=(hmap_size // 2 + 1), endpoint=True)
234
+ freqs45 = np.linspace(0, image_size / np.sqrt(2), num=(hmap_size // 2 + 1), endpoint=True)
235
+ xticks0 = np.linspace(freqs0[0], freqs0[-1], num=9, endpoint=True)
236
+ xticks45 = np.round(np.linspace(freqs45[0], freqs45[-1], num=9, endpoint=True))
237
+ yticks = np.linspace(-50, 30, num=9, endpoint=True)
238
+
239
+ # Draw 0 degree slice.
240
+ plt.subplot(1, 2, 1)
241
+ plt.title('0\u00b0 slice')
242
+ plt.xlim(xticks0[0], xticks0[-1])
243
+ plt.ylim(yticks[0], yticks[-1])
244
+ plt.xticks(xticks0)
245
+ plt.yticks(yticks)
246
+ for c in cases:
247
+ plt.plot(freqs0, c.hmap[hmap_center, hmap_range], label=c.label)
248
+ plt.grid()
249
+ plt.legend(loc='upper right')
250
+
251
+ # Draw 45 degree slice.
252
+ plt.subplot(1, 2, 2)
253
+ plt.title('45\u00b0 slice')
254
+ plt.xlim(xticks45[0], xticks45[-1])
255
+ plt.ylim(yticks[0], yticks[-1])
256
+ plt.xticks(xticks45)
257
+ plt.yticks(yticks)
258
+ for c in cases:
259
+ plt.plot(freqs45, c.hmap[hmap_range, hmap_range], label=c.label)
260
+ plt.grid()
261
+ plt.legend(loc='upper right')
262
+
263
+ # Display or save.
264
+ if save is None:
265
+ plt.show()
266
+ else:
267
+ if os.path.dirname(save):
268
+ os.makedirs(os.path.dirname(save), exist_ok=True)
269
+ plt.savefig(save)
270
+
271
+ #----------------------------------------------------------------------------
272
+
273
+ if __name__ == "__main__":
274
+ main() # pylint: disable=no-value-for-parameter
275
+
276
+ #----------------------------------------------------------------------------
stylegan3-fun/calc_metrics.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Calculate quality metrics for previous training run or pretrained network pickle."""
10
+
11
+ import os
12
+ import click
13
+ import json
14
+ import tempfile
15
+ import copy
16
+ import torch
17
+
18
+ import dnnlib
19
+ import legacy
20
+ from metrics import metric_main
21
+ from metrics import metric_utils
22
+ from torch_utils import training_stats
23
+ from torch_utils import custom_ops
24
+ from torch_utils import misc
25
+ from torch_utils.ops import conv2d_gradfix
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def subprocess_fn(rank, args, temp_dir):
30
+ dnnlib.util.Logger(should_flush=True)
31
+
32
+ # Init torch.distributed.
33
+ if args.num_gpus > 1:
34
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
35
+ if os.name == 'nt':
36
+ init_method = 'file:///' + init_file.replace('\\', '/')
37
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
38
+ else:
39
+ init_method = f'file://{init_file}'
40
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
41
+
42
+ # Init torch_utils.
43
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
44
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
45
+ if rank != 0 or not args.verbose:
46
+ custom_ops.verbosity = 'none'
47
+
48
+ # Configure torch.
49
+ device = torch.device('cuda', rank)
50
+ torch.backends.cuda.matmul.allow_tf32 = False
51
+ torch.backends.cudnn.allow_tf32 = False
52
+ conv2d_gradfix.enabled = True
53
+
54
+ # Print network summary.
55
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
56
+ if rank == 0 and args.verbose:
57
+ z = torch.empty([1, G.z_dim], device=device)
58
+ c = torch.empty([1, G.c_dim], device=device)
59
+ misc.print_module_summary(G, [z, c])
60
+
61
+ # Calculate each metric.
62
+ for metric in args.metrics:
63
+ if rank == 0 and args.verbose:
64
+ print(f'Calculating {metric}...')
65
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
66
+ result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
67
+ num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
68
+ if rank == 0:
69
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
70
+ if rank == 0 and args.verbose:
71
+ print()
72
+
73
+ # Done.
74
+ if rank == 0 and args.verbose:
75
+ print('Exiting...')
76
+
77
+ #----------------------------------------------------------------------------
78
+
79
+ def parse_comma_separated_list(s):
80
+ if isinstance(s, list):
81
+ return s
82
+ if s is None or s.lower() == 'none' or s == '':
83
+ return []
84
+ return s.split(',')
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ @click.command()
89
+ @click.pass_context
90
+ @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
91
+ @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
92
+ @click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
93
+ @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
94
+ @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
95
+ @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
96
+
97
+ def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
98
+ """Calculate quality metrics for previous training run or pretrained network pickle.
99
+
100
+ Examples:
101
+
102
+ \b
103
+ # Previous training run: look up options automatically, save result to JSONL file.
104
+ python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
105
+ --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
106
+
107
+ \b
108
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
109
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
110
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
111
+
112
+ \b
113
+ Recommended metrics:
114
+ fid50k_full Frechet inception distance against the full dataset.
115
+ kid50k_full Kernel inception distance against the full dataset.
116
+ pr50k3_full Precision and recall againt the full dataset.
117
+ ppl2_wend Perceptual path length in W, endpoints, full image.
118
+ eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
119
+ eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
120
+ eqr50k Equivariance w.r.t. rotation (EQ-R).
121
+
122
+ \b
123
+ Legacy metrics:
124
+ fid50k Frechet inception distance against 50k real images.
125
+ kid50k Kernel inception distance against 50k real images.
126
+ pr50k3 Precision and recall against 50k real images.
127
+ is50k Inception score for CIFAR-10.
128
+ """
129
+ dnnlib.util.Logger(should_flush=True)
130
+
131
+ # Validate arguments.
132
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
133
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
134
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
135
+ if not args.num_gpus >= 1:
136
+ ctx.fail('--gpus must be at least 1')
137
+
138
+ # Load network.
139
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
140
+ ctx.fail('--network must point to a file or URL')
141
+ if args.verbose:
142
+ print(f'Loading network from "{network_pkl}"...')
143
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
144
+ network_dict = legacy.load_network_pkl(f)
145
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
146
+
147
+ # Initialize dataset options.
148
+ if data is not None:
149
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
150
+ elif network_dict['training_set_kwargs'] is not None:
151
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
152
+ else:
153
+ ctx.fail('Could not look up dataset options; please specify --data')
154
+
155
+ # Finalize dataset options.
156
+ args.dataset_kwargs.resolution = args.G.img_resolution
157
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
158
+ if mirror is not None:
159
+ args.dataset_kwargs.xflip = mirror
160
+
161
+ # Print dataset options.
162
+ if args.verbose:
163
+ print('Dataset options:')
164
+ print(json.dumps(args.dataset_kwargs, indent=2))
165
+
166
+ # Locate run dir.
167
+ args.run_dir = None
168
+ if os.path.isfile(network_pkl):
169
+ pkl_dir = os.path.dirname(network_pkl)
170
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
171
+ args.run_dir = pkl_dir
172
+
173
+ # Launch processes.
174
+ if args.verbose:
175
+ print('Launching processes...')
176
+ torch.multiprocessing.set_start_method('spawn')
177
+ with tempfile.TemporaryDirectory() as temp_dir:
178
+ if args.num_gpus == 1:
179
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
180
+ else:
181
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
182
+
183
+ #----------------------------------------------------------------------------
184
+
185
+ if __name__ == "__main__":
186
+ calc_metrics() # pylint: disable=no-value-for-parameter
187
+
188
+ #----------------------------------------------------------------------------
stylegan3-fun/dataset_tool.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Tool for creating ZIP/PNG based datasets."""
10
+
11
+ import functools
12
+ import gzip
13
+ import io
14
+ import json
15
+ import os
16
+ import pickle
17
+ import re
18
+ import sys
19
+ import tarfile
20
+ import zipfile
21
+ from pathlib import Path
22
+ from typing import Callable, Optional, Tuple, Union
23
+
24
+ import click
25
+ import numpy as np
26
+ import PIL.Image
27
+ from tqdm import tqdm
28
+ from torch_utils import gen_utils
29
+
30
+
31
+ # ----------------------------------------------------------------------------
32
+
33
+
34
+ def error(msg):
35
+ print('Error: ' + msg)
36
+ sys.exit(1)
37
+
38
+
39
+ # ----------------------------------------------------------------------------
40
+
41
+
42
+ def parse_tuple(s: str) -> Tuple[int, int]:
43
+ """Parse a 'M,N' or 'MxN' integer tuple.
44
+
45
+ Example:
46
+ '4x2' returns (4,2)
47
+ '0,1' returns (0,1)
48
+ """
49
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
50
+ if m:
51
+ return int(m.group(1)), int(m.group(2))
52
+ raise ValueError(f'cannot parse tuple {s}')
53
+
54
+
55
+ # ----------------------------------------------------------------------------
56
+
57
+
58
+ def maybe_min(a: int, b: Optional[int]) -> int:
59
+ if b is not None:
60
+ return min(a, b)
61
+ return a
62
+
63
+
64
+ # ----------------------------------------------------------------------------
65
+
66
+
67
+ def file_ext(name: Union[str, Path]) -> str:
68
+ return str(name).split('.')[-1]
69
+
70
+
71
+ # ----------------------------------------------------------------------------
72
+
73
+
74
+ def is_image_ext(fname: Union[str, Path]) -> bool:
75
+ ext = file_ext(fname).lower()
76
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
77
+
78
+
79
+ # ----------------------------------------------------------------------------
80
+
81
+
82
+ def open_image_folder(source_dir, force_channels: int = None, *, max_images: Optional[int], subfolders_as_labels: Optional[bool] = False):
83
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
84
+
85
+ # Load labels.
86
+ labels = {}
87
+ meta_fname = os.path.join(source_dir, 'dataset.json')
88
+ if os.path.isfile(meta_fname) and not subfolders_as_labels:
89
+ # The `dataset.json` file exists and will be used (whether empty or not)
90
+ with open(meta_fname, 'r') as file:
91
+ labels = json.load(file)['labels']
92
+ if labels is not None:
93
+ labels = {x[0]: x[1] for x in labels}
94
+ else:
95
+ labels = {}
96
+ elif subfolders_as_labels:
97
+ # Use the folders in the directory as the labels themselves
98
+ # Get the subfolder names from the input_images names
99
+ labels = {os.path.relpath(fname, source_dir).replace('\\', '/'): os.path.basename(os.path.dirname(fname)) for fname in input_images}
100
+ # Change folder name (value) to a number (from 0 to n-1)
101
+ label_names = list(set(labels.values()))
102
+ label_names.sort()
103
+ labels = {fname: label_names.index(label) for fname, label in labels.items()}
104
+ print(f'Conditional dataset has {len(label_names)} labels! Saving to `class_labels.txt` in the source directory...')
105
+ with open(os.path.join(source_dir, 'class_labels.txt'), 'w') as f:
106
+ # Write, one per line, the index and the label name
107
+ for i, label in enumerate(label_names):
108
+ f.write(f'{i}: {label}\n')
109
+
110
+ max_idx = maybe_min(len(input_images), max_images)
111
+
112
+ def iterate_images():
113
+ for idx, fname in enumerate(input_images):
114
+ arch_fname = os.path.relpath(fname, source_dir)
115
+ arch_fname = arch_fname.replace('\\', '/')
116
+ # Adding Pull #39 from Andreas Jansson: https://github.com/NVlabs/stylegan3/pull/39
117
+ try:
118
+ img = PIL.Image.open(fname) # Let PIL handle the mode
119
+ # Convert grayscale image to RGB
120
+ if img.mode == 'L':
121
+ img = img.convert('RGB')
122
+ # Force the number of channels if so requested
123
+ if force_channels is not None:
124
+ img = img.convert(gen_utils.channels_dict[int(force_channels)])
125
+ img = np.array(img)
126
+ except Exception as e:
127
+ sys.stderr.write(f'Failed to read {fname}: {e}')
128
+ continue
129
+ yield dict(img=img, label=labels.get(arch_fname))
130
+ if idx >= max_idx-1:
131
+ break
132
+ return max_idx, iterate_images()
133
+
134
+
135
+ # ----------------------------------------------------------------------------
136
+
137
+
138
+ def open_image_zip(source, force_channels: int = None, *, max_images: Optional[int]):
139
+ with zipfile.ZipFile(source, mode='r') as z:
140
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
141
+
142
+ # Load labels.
143
+ labels = {}
144
+ if 'dataset.json' in z.namelist():
145
+ with z.open('dataset.json', 'r') as file:
146
+ labels = json.load(file)['labels']
147
+ if labels is not None:
148
+ labels = {x[0]: x[1] for x in labels}
149
+ else:
150
+ labels = {}
151
+
152
+ max_idx = maybe_min(len(input_images), max_images)
153
+
154
+ def iterate_images():
155
+ with zipfile.ZipFile(source, mode='r') as z:
156
+ for idx, fname in enumerate(input_images):
157
+ with z.open(fname, 'r') as file:
158
+ # Same as above: PR #39 by Andreas Jansson and turn Grayscale to RGB
159
+ try:
160
+ img = PIL.Image.open(file) # type: ignore
161
+ if img.mode == 'L':
162
+ img = img.convert('RGB')
163
+ # Force the number of channels if so requested
164
+ if force_channels is not None:
165
+ img = img.convert(gen_utils.channels_dict[int(force_channels)])
166
+ img = np.array(img)
167
+ except Exception as e:
168
+ sys.stderr.write(f'Failed to read {fname}: {e}')
169
+ continue
170
+ yield dict(img=img, label=labels.get(fname))
171
+ if idx >= max_idx-1:
172
+ break
173
+ return max_idx, iterate_images()
174
+
175
+
176
+ # ----------------------------------------------------------------------------
177
+
178
+
179
+ def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
180
+ import cv2 # pip install opencv-python # pylint: disable=import-error
181
+ import lmdb # pip install lmdb # pylint: disable=import-error
182
+
183
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
184
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
185
+
186
+ def iterate_images():
187
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
188
+ for idx, (_key, value) in enumerate(txn.cursor()):
189
+ try:
190
+ try:
191
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
192
+ if img is None:
193
+ raise IOError('cv2.imdecode failed')
194
+ img = img[:, :, ::-1] # BGR => RGB
195
+ except IOError:
196
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
197
+ yield dict(img=img, label=None)
198
+ if idx >= max_idx-1:
199
+ break
200
+ except:
201
+ print(sys.exc_info()[1])
202
+
203
+ return max_idx, iterate_images()
204
+
205
+
206
+ # ----------------------------------------------------------------------------
207
+
208
+
209
+ def open_cifar10(tarball: str, *, max_images: Optional[int]):
210
+ images = []
211
+ labels = []
212
+
213
+ with tarfile.open(tarball, 'r:gz') as tar:
214
+ for batch in range(1, 6):
215
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
216
+ with tar.extractfile(member) as file:
217
+ data = pickle.load(file, encoding='latin1')
218
+ images.append(data['data'].reshape(-1, 3, 32, 32))
219
+ labels.append(data['labels'])
220
+
221
+ images = np.concatenate(images)
222
+ labels = np.concatenate(labels)
223
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
224
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
225
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
226
+ assert np.min(images) == 0 and np.max(images) == 255
227
+ assert np.min(labels) == 0 and np.max(labels) == 9
228
+
229
+ max_idx = maybe_min(len(images), max_images)
230
+
231
+ def iterate_images():
232
+ for idx, img in enumerate(images):
233
+ yield dict(img=img, label=int(labels[idx]))
234
+ if idx >= max_idx-1:
235
+ break
236
+
237
+ return max_idx, iterate_images()
238
+
239
+
240
+ # ----------------------------------------------------------------------------
241
+
242
+
243
+ def open_mnist(images_gz: str, *, max_images: Optional[int]):
244
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
245
+ assert labels_gz != images_gz
246
+ images = []
247
+ labels = []
248
+
249
+ with gzip.open(images_gz, 'rb') as f:
250
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
251
+ with gzip.open(labels_gz, 'rb') as f:
252
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
253
+
254
+ images = images.reshape(-1, 28, 28)
255
+ images = np.pad(images, [(0, 0), (2, 2), (2, 2)], 'constant', constant_values=0)
256
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
257
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
258
+ assert np.min(images) == 0 and np.max(images) == 255
259
+ assert np.min(labels) == 0 and np.max(labels) == 9
260
+
261
+ max_idx = maybe_min(len(images), max_images)
262
+
263
+ def iterate_images():
264
+ for idx, img in enumerate(images):
265
+ yield dict(img=img, label=int(labels[idx]))
266
+ if idx >= max_idx-1:
267
+ break
268
+
269
+ return max_idx, iterate_images()
270
+
271
+
272
+ # ----------------------------------------------------------------------------
273
+
274
+
275
+ def make_transform(
276
+ transform: Optional[str],
277
+ output_width: Optional[int],
278
+ output_height: Optional[int]
279
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
280
+ def scale(width, height, img):
281
+ w = img.shape[1]
282
+ h = img.shape[0]
283
+ if width == w and height == h:
284
+ return img
285
+ img = PIL.Image.fromarray(img)
286
+ ww = width if width is not None else w
287
+ hh = height if height is not None else h
288
+ img = img.resize((ww, hh), PIL.Image.LANCZOS)
289
+ return np.array(img)
290
+
291
+ def center_crop(width, height, img):
292
+ crop = np.min(img.shape[:2])
293
+ img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
294
+ (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
295
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[img.shape[2]])
296
+ img = img.resize((width, height), PIL.Image.LANCZOS)
297
+ return np.array(img)
298
+
299
+ def center_crop_wide(width, height, img):
300
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
301
+ if img.shape[1] < width or ch < height:
302
+ return None
303
+
304
+ img = img[(img.shape[0] - ch) // 2: (img.shape[0] + ch) // 2]
305
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[img.shape[2]])
306
+ img = img.resize((width, height), PIL.Image.LANCZOS)
307
+ img = np.array(img)
308
+
309
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
310
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
311
+ return canvas
312
+
313
+ def center_crop_tall(width, height, img):
314
+ ch = int(np.round(height * img.shape[1] / img.shape[0]))
315
+ if img.shape[0] < height or ch < width:
316
+ return None
317
+
318
+ img = img[:, (img.shape[1] - ch) // 2: (img.shape[1] + ch) // 2] # center-crop: [width0, height0, C] -> [width0, height, C]
319
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[img.shape[2]])
320
+ img = img.resize((width, height), PIL.Image.LANCZOS) # resize: [width0, height, 3] -> [width, height, 3]
321
+ img = np.array(img)
322
+
323
+ canvas = np.zeros([height, height, 3], dtype=np.uint8) # square canvas
324
+ canvas[:, (height - width) // 2: (height + width) // 2] = img # replace the middle with img
325
+ return canvas
326
+
327
+ if transform is None:
328
+ return functools.partial(scale, output_width, output_height)
329
+ if transform == 'center-crop':
330
+ if (output_width is None) or (output_height is None):
331
+ error(f'must specify --resolution=WxH when using {transform} transform')
332
+ return functools.partial(center_crop, output_width, output_height)
333
+ if transform == 'center-crop-wide':
334
+ if (output_width is None) or (output_height is None):
335
+ error(f'must specify --resolution=WxH when using {transform} transform')
336
+ return functools.partial(center_crop_wide, output_width, output_height)
337
+ if transform == 'center-crop-tall':
338
+ if (output_width is None) or (output_height is None):
339
+ error(f'must specify --resolution=WxH when using {transform} transform')
340
+ return functools.partial(center_crop_tall, output_width, output_height)
341
+ assert False, 'unknown transform'
342
+
343
+
344
+ # ----------------------------------------------------------------------------
345
+
346
+
347
+ def open_dataset(source, force_channels, *, max_images: Optional[int], subfolders_as_labels: Optional[bool] = False):
348
+ if os.path.isdir(source):
349
+ if source.rstrip('/').endswith('_lmdb'):
350
+ return open_lmdb(source, max_images=max_images)
351
+ else:
352
+ return open_image_folder(source, force_channels, max_images=max_images, subfolders_as_labels=subfolders_as_labels)
353
+ elif os.path.isfile(source):
354
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
355
+ return open_cifar10(source, max_images=max_images)
356
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
357
+ return open_mnist(source, max_images=max_images)
358
+ elif file_ext(source) == 'zip':
359
+ return open_image_zip(source, force_channels, max_images=max_images)
360
+ else:
361
+ assert False, 'unknown archive type'
362
+ else:
363
+ error(f'Missing input file or directory: {source}')
364
+
365
+
366
+ # ----------------------------------------------------------------------------
367
+
368
+
369
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
370
+ dest_ext = file_ext(dest)
371
+
372
+ if dest_ext == 'zip':
373
+ if os.path.dirname(dest) != '':
374
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
375
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
376
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
377
+ zf.writestr(fname, data)
378
+ return '', zip_write_bytes, zf.close
379
+ else:
380
+ # If the output folder already exists, check that it is
381
+ # empty.
382
+ #
383
+ # Note: creating the output directory is not strictly
384
+ # necessary as folder_write_bytes() also mkdirs, but it's better
385
+ # to give an error message earlier in case the dest folder
386
+ # somehow cannot be created.
387
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
388
+ error('--dest folder must be empty')
389
+ os.makedirs(dest, exist_ok=True)
390
+
391
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
392
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
393
+ with open(fname, 'wb') as fout:
394
+ if isinstance(data, str):
395
+ data = data.encode('utf8')
396
+ fout.write(data)
397
+ return dest, folder_write_bytes, lambda: None
398
+
399
+
400
+ # ----------------------------------------------------------------------------
401
+
402
+
403
+ @click.command()
404
+ @click.pass_context
405
+ @click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
406
+ @click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
407
+ @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
408
+ @click.option('--force-channels', help='Force the number of channels in the image (1: grayscale, 3: RGB, 4: RGBA)', type=click.Choice(['1', '3', '4']), default=None)
409
+ @click.option('--subfolders-as-labels', help='Use the folder names as the labels, to avoid setting up `dataset.json`', is_flag=True)
410
+ @click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-tall']))
411
+ @click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
412
+ def convert_dataset(
413
+ ctx: click.Context,
414
+ source: str,
415
+ dest: str,
416
+ max_images: Optional[int],
417
+ force_channels: Optional[int],
418
+ subfolders_as_labels: Optional[bool],
419
+ transform: Optional[str],
420
+ resolution: Optional[Tuple[int, int]]
421
+ ):
422
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
423
+
424
+ The input dataset format is guessed from the --source argument:
425
+
426
+ \b
427
+ --source *_lmdb/ Load LSUN dataset
428
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
429
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
430
+ --source path/ Recursively load all images from path/
431
+ --source dataset.zip Recursively load all images from dataset.zip
432
+
433
+ Specifying the output format and path:
434
+
435
+ \b
436
+ --dest /path/to/dir Save output files under /path/to/dir
437
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
438
+
439
+ The output dataset format can be either an image folder or an uncompressed zip archive.
440
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
441
+ offer better training performance on network file systems.
442
+
443
+ Images within the dataset archive will be stored as uncompressed PNG.
444
+ Uncompresed PNGs can be efficiently decoded in the training loop.
445
+
446
+ Class labels are stored in a file called 'dataset.json' that is stored at the
447
+ dataset root folder. This file has the following structure:
448
+
449
+ \b
450
+ {
451
+ "labels": [
452
+ ["00000/img00000000.png",6],
453
+ ["00000/img00000001.png",9],
454
+ ... repeated for every image in the dataset
455
+ ["00049/img00049999.png",1]
456
+ ]
457
+ }
458
+
459
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
460
+ not containing class labels.
461
+
462
+ Image scale/crop and resolution requirements:
463
+
464
+ Output images must be square-shaped and they must all have the same power-of-two
465
+ dimensions.
466
+
467
+ To scale arbitrary input image size to a specific width and height, use the
468
+ --resolution option. Output resolution will be either the original
469
+ input resolution (if resolution was not specified) or the one specified with
470
+ --resolution option.
471
+
472
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
473
+ center crop transform on the input image. These options should be used with the
474
+ --resolution option. For example:
475
+
476
+ \b
477
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
478
+ --transform=center-crop-wide --resolution=512x384
479
+ """
480
+
481
+ PIL.Image.init() # type: ignore
482
+
483
+ if dest == '':
484
+ ctx.fail('--dest output filename or directory must not be an empty string')
485
+
486
+ num_files, input_iter = open_dataset(source, force_channels, max_images=max_images, subfolders_as_labels=subfolders_as_labels)
487
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
488
+
489
+ if resolution is None: resolution = (None, None)
490
+ transform_image = make_transform(transform, *resolution)
491
+
492
+ dataset_attrs = None
493
+
494
+ labels = []
495
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
496
+ idx_str = f'{idx:08d}'
497
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
498
+
499
+ # Apply crop and resize.
500
+ img = transform_image(image['img'])
501
+
502
+ # Transform may drop images.
503
+ if img is None:
504
+ continue
505
+
506
+ # Error check to require uniform image attributes across
507
+ # the whole dataset.
508
+ channels = img.shape[2] if img.ndim == 3 else 1
509
+ cur_image_attrs = {
510
+ 'width': img.shape[1],
511
+ 'height': img.shape[0],
512
+ 'channels': channels
513
+ }
514
+ if dataset_attrs is None:
515
+ dataset_attrs = cur_image_attrs
516
+ width = dataset_attrs['width']
517
+ height = dataset_attrs['height']
518
+ if width != height:
519
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
520
+ if dataset_attrs['channels'] not in [1, 3, 4]:
521
+ error('Input images must be stored as grayscale, RGB, or RGBA')
522
+ if width != 2 ** int(np.floor(np.log2(width))):
523
+ error('Image width/height after scale and crop are required to be power-of-two')
524
+ elif dataset_attrs != cur_image_attrs:
525
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object
526
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
527
+
528
+ # Save the image as an uncompressed PNG.
529
+ img = PIL.Image.fromarray(img, gen_utils.channels_dict[channels])
530
+ image_bits = io.BytesIO()
531
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
532
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
533
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
534
+
535
+ metadata = {
536
+ 'labels': labels if all(x is not None for x in labels) else None
537
+ }
538
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
539
+ close_dest()
540
+
541
+ # ----------------------------------------------------------------------------
542
+
543
+
544
+ if __name__ == "__main__":
545
+ convert_dataset() # pylint: disable=no-value-for-parameter
546
+
547
+ # ----------------------------------------------------------------------------
stylegan3-fun/discriminator_synthesis.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from torchvision import transforms
6
+
7
+ import PIL
8
+ from PIL import Image
9
+
10
+ try:
11
+ import ffmpeg
12
+ except ImportError:
13
+ raise ImportError('ffmpeg-python not found! Install it via "pip install ffmpeg-python"')
14
+
15
+ import scipy.ndimage as nd
16
+ import numpy as np
17
+
18
+ import os
19
+ import click
20
+ from typing import Union, Tuple, Optional, List, Type
21
+ from tqdm import tqdm
22
+ import re
23
+
24
+ from torch_utils import gen_utils
25
+ from network_features import DiscriminatorFeatures
26
+
27
+
28
+ # ----------------------------------------------------------------------------
29
+
30
+
31
+ @click.group()
32
+ def main():
33
+ pass
34
+
35
+
36
+ # ----------------------------------------------------------------------------
37
+
38
+
39
+ def get_available_layers(max_resolution: int) -> List[str]:
40
+ """Helper function to get the available layers given a max resolution (first block in the Discriminator)"""
41
+ max_res_log2 = int(np.log2(max_resolution))
42
+ block_resolutions = [2**i for i in range(max_res_log2, 2, -1)]
43
+
44
+ available_layers = ['from_rgb']
45
+ for block_res in block_resolutions:
46
+ # We don't add the skip layer, as it's the same as conv1 (due to in-place addition; could be changed)
47
+ available_layers.extend([f'b{block_res}_conv0', f'b{block_res}_conv1'])
48
+ # We also skip 'b4_mbstd', as it doesn't add any new information compared to b8_conv1
49
+ available_layers.extend(['b4_conv', 'fc', 'out'])
50
+ return available_layers
51
+
52
+
53
+ # ----------------------------------------------------------------------------
54
+ # DeepDream code; modified from Erik Linder-Norén's repository: https://github.com/eriklindernoren/PyTorch-Deep-Dream
55
+
56
+ def get_image(seed: int = 0,
57
+ image_noise: str = 'random',
58
+ starting_image: Union[str, os.PathLike] = None,
59
+ image_size: int = 1024,
60
+ convert_to_grayscale: bool = False,
61
+ device: torch.device = torch.device('cpu')) -> Tuple[PIL.Image.Image, str]:
62
+ """Set the random seed (NumPy + PyTorch), as well as get an image from a path or generate a random one with the seed"""
63
+ torch.manual_seed(seed)
64
+ rnd = np.random.RandomState(seed)
65
+
66
+ # Load image or generate a random one if none is provided
67
+ if starting_image is not None:
68
+ image = Image.open(starting_image).convert('RGB').resize((image_size, image_size), Image.LANCZOS)
69
+ else:
70
+ if image_noise == 'random':
71
+ starting_image = f'random_image-seed_{seed:08d}.jpg'
72
+ image = Image.fromarray(rnd.randint(0, 255, (image_size, image_size, 3), dtype='uint8'))
73
+ elif image_noise == 'perlin':
74
+ try:
75
+ # Graciously using Mathieu Duchesneau's implementation: https://github.com/duchesneaumathieu/pyperlin
76
+ from pyperlin import FractalPerlin2D
77
+ starting_image = f'perlin_image-seed_{seed:08d}.jpg'
78
+ shape = (3, image_size, image_size)
79
+ resolutions = [(2**i, 2**i) for i in range(1, 6+1)] # for lacunarity = 2.0 # TODO: set as cli variable
80
+ factors = [0.5**i for i in range(6)] # for persistence = 0.5 TODO: set as cli variables
81
+ g_cuda = torch.Generator(device=device).manual_seed(seed)
82
+ rgb = FractalPerlin2D(shape, resolutions, factors, generator=g_cuda)().cpu().numpy()
83
+ rgb = (255 * (rgb + 1) / 2).astype(np.uint8) # [-1.0, 1.0] => [0, 255]
84
+ image = Image.fromarray(rgb.transpose(1, 2, 0), 'RGB') # Reshape leads us to weird tiling
85
+
86
+ except ImportError:
87
+ raise ImportError('pyperlin not found! Install it via "pip install pyperlin"')
88
+
89
+ if convert_to_grayscale:
90
+ image = image.convert('L').convert('RGB') # We do a little trolling to Pillow (so we have a 3-channel image)
91
+
92
+ return image, starting_image
93
+
94
+
95
+ def crop_resize_rotate(img: PIL.Image.Image,
96
+ crop_size: int = None,
97
+ new_size: int = None,
98
+ rotation_deg: float = None,
99
+ translate_x: float = 0.0,
100
+ translate_y: float = 0.0) -> PIL.Image.Image:
101
+ """Center-crop the input image into a square of sides crop_size; can be resized to new_size; rotated rotation_deg counter-clockwise"""
102
+ # Center-crop the input image
103
+ if crop_size is not None:
104
+ w, h = img.size # Input image width and height
105
+ img = img.crop(box=((w - crop_size) // 2, # Left pixel coordinate
106
+ (h - crop_size) // 2, # Upper pixel coordinate
107
+ (w + crop_size) // 2, # Right pixel coordinate
108
+ (h + crop_size) // 2)) # Lower pixel coordinate
109
+ # Resize
110
+ if new_size is not None:
111
+ img = img.resize(size=(new_size, new_size), # Requested size of the image in pixels; (width, height)
112
+ resample=Image.LANCZOS) # Resampling filter
113
+ # Rotation and translation
114
+ if rotation_deg is not None:
115
+ img = img.rotate(angle=rotation_deg, # Angle to rotate image, counter-clockwise
116
+ resample=Image.BICUBIC, # Resampling filter; options: Image.Resampling.{NEAREST, BILINEAR, BICUBIC}
117
+ expand=False, # If True, the whole rotated image will be shown
118
+ translate=(translate_x, translate_y), # Translate the image, from top-left corner (post-rotation)
119
+ fillcolor=(0, 0, 0)) # Black background
120
+ # TODO: tile the background
121
+ return img
122
+
123
+
124
+ mean = np.array([0.485, 0.456, 0.406])
125
+ std = np.array([0.229, 0.224, 0.225])
126
+
127
+ preprocess = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
128
+
129
+
130
+ def deprocess(image_np: torch.Tensor) -> np.ndarray:
131
+ image_np = image_np.squeeze().transpose(1, 2, 0)
132
+ image_np = image_np * std.reshape((1, 1, 3)) + mean.reshape((1, 1, 3))
133
+ # image_np = (image_np + 1.0) / 2.0
134
+ image_np = np.clip(image_np, 0.0, 1.0)
135
+ image_np = (255 * image_np).astype('uint8')
136
+ return image_np
137
+
138
+
139
+ def clip(image_tensor: torch.Tensor) -> torch.Tensor:
140
+ """Clamp per channel"""
141
+ for c in range(3):
142
+ m, s = mean[c], std[c]
143
+ image_tensor[0, c] = torch.clamp(image_tensor[0, c], -m / s, (1 - m) / s)
144
+ return image_tensor
145
+
146
+
147
+ def dream(image: PIL.Image.Image,
148
+ model: torch.nn.Module,
149
+ layers: List[str],
150
+ channels: List[int] = None,
151
+ normed: bool = False,
152
+ sqrt_normed: bool = False,
153
+ iterations: int = 20,
154
+ lr: float = 1e-2) -> np.ndarray:
155
+ """ Updates the image to maximize outputs for n iterations """
156
+ Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
157
+ image = Variable(Tensor(image), requires_grad=True)
158
+ for i in range(iterations):
159
+ model.zero_grad()
160
+ out = model.get_layers_features(image, layers=layers, channels=channels, normed=normed, sqrt_normed=sqrt_normed)
161
+ loss = sum(layer.norm() for layer in out) # More than one layer may be used
162
+ loss.backward()
163
+ avg_grad = np.abs(image.grad.data.cpu().numpy()).mean()
164
+ norm_lr = lr / avg_grad
165
+ image.data += norm_lr * image.grad.data
166
+ image.data = clip(image.data)
167
+ # image.data = torch.clamp(image.data, -1.0, 1.0)
168
+ image.grad.data.zero_()
169
+ return image.cpu().data.numpy()
170
+
171
+
172
+ def deep_dream(image: PIL.Image.Image,
173
+ model: torch.nn.Module,
174
+ model_resolution: int,
175
+ layers: List[str],
176
+ channels: List[int],
177
+ seed: Union[int, Type[None]],
178
+ normed: bool,
179
+ sqrt_normed: bool,
180
+ iterations: int,
181
+ lr: float,
182
+ octave_scale: float,
183
+ num_octaves: int,
184
+ unzoom_octave: bool = False,
185
+ disable_inner_tqdm: bool = False,
186
+ ignore_initial_transform: bool = False) -> np.ndarray:
187
+ """ Main deep dream method """
188
+ # Center-crop and resize
189
+ if not ignore_initial_transform:
190
+ image = crop_resize_rotate(img=image, crop_size=min(image.size), new_size=model_resolution)
191
+ # Preprocess image
192
+ image = preprocess(image)
193
+ # image = torch.from_numpy(np.array(image)).permute(-1, 0, 1) / 127.5 - 1.0 # alternative
194
+ image = image.unsqueeze(0).cpu().data.numpy()
195
+ # Extract image representations for each octave
196
+ octaves = [image]
197
+ for _ in range(num_octaves - 1):
198
+ # Alternatively, see if we get better results with: https://www.tensorflow.org/tutorials/generative/deepdream#taking_it_up_an_octave
199
+ octave = nd.zoom(octaves[-1], (1, 1, 1 / octave_scale, 1 / octave_scale), order=1)
200
+ # Necessary for StyleGAN's Discriminator, as it cannot handle any image size
201
+ if unzoom_octave:
202
+ octave = nd.zoom(octave, np.array(octaves[-1].shape) / np.array(octave.shape), order=1)
203
+ octaves.append(octave)
204
+
205
+ detail = np.zeros_like(octaves[-1])
206
+ tqdm_desc = f'Dreaming w/layers {"|".join(x for x in layers)}'
207
+ tqdm_desc = f'Seed: {seed} - {tqdm_desc}' if seed is not None else tqdm_desc
208
+ for octave, octave_base in enumerate(tqdm(octaves[::-1], desc=tqdm_desc, disable=disable_inner_tqdm)):
209
+ if octave > 0:
210
+ # Upsample detail to new octave dimension
211
+ detail = nd.zoom(detail, np.array(octave_base.shape) / np.array(detail.shape), order=1)
212
+ # Add deep dream detail from previous octave to new base
213
+ input_image = octave_base + detail
214
+ # Get new deep dream image
215
+ dreamed_image = dream(input_image, model, layers, channels, normed, sqrt_normed, iterations, lr)
216
+ # Extract deep dream details
217
+ detail = dreamed_image - octave_base
218
+
219
+ return deprocess(dreamed_image)
220
+
221
+
222
+ # ----------------------------------------------------------------------------
223
+
224
+ # Helper functions (all base code taken from: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html)
225
+
226
+
227
+ class ContentLoss(nn.Module):
228
+
229
+ def __init__(self, target,):
230
+ super(ContentLoss, self).__init__()
231
+ # we 'detach' the target content from the tree used
232
+ # to dynamically compute the gradient: this is a stated value,
233
+ # not a variable. Otherwise the forward method of the criterion
234
+ # will throw an error.
235
+ self.target = target.detach()
236
+
237
+ def forward(self, input):
238
+ self.loss = F.mse_loss(input, self.target)
239
+ return input
240
+
241
+
242
+ def gram_matrix(input):
243
+ a, b, c, d = input.size() # (batch_size, no. feature maps, dims of a f. map (N=c*d))
244
+
245
+ features = input.view(a * b, c * d) # resize F_XL into \hat F_XL
246
+
247
+ G = torch.mm(features, features.t()) # compute the gram product
248
+
249
+ # 'Normalize' the values of the gram matrix by dividing by the number of element in each feature maps.
250
+ return G.div(a * b * c * d) # can also do torch.numel(input) to get the number of elements
251
+
252
+
253
+ class StyleLoss(nn.Module):
254
+ def __init__(self, target_feature):
255
+ super(StyleLoss, self).__init__()
256
+ self.target = gram_matrix(target_feature).detach()
257
+
258
+ def forward(self, input):
259
+ G = gram_matrix(input)
260
+ self.loss = F.mse_loss(G, self.target)
261
+ return input
262
+
263
+
264
+ @main.command(name='style-transfer', help='Use the StyleGAN2/3 Discriminator to perform style transfer')
265
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
266
+ @click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
267
+ @click.option('--content', type=str, help='Content image filename (url or local path)', required=True)
268
+ @click.option('--style', type=str, help='Style image filename (url or local path)', required=True)
269
+ def style_transfer_discriminator(
270
+ ctx: click.Context,
271
+ network_pkl: str,
272
+ cfg: str,
273
+ content: str,
274
+ style: str,
275
+ ):
276
+ print('Coming soon!')
277
+ # Reference: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html
278
+
279
+ # Set up device
280
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
281
+
282
+ imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
283
+
284
+ loader = transforms.Compose([transforms.Resize(imsize), # scale imported image
285
+ transforms.ToTensor()]) # transform it into a torch tensor
286
+
287
+ # Helper function
288
+ def image_loader(image_name):
289
+ image = Image.open(image_name)
290
+ # fake batch dimension required to fit network's input dimensions
291
+ image = loader(image).unsqueeze(0)
292
+ return image.to(device, torch.float)
293
+
294
+ style_img = image_loader(style)
295
+ content_img = image_loader(content)
296
+
297
+ # This shouldn't really happen, but just in case
298
+ assert style_img.size() == content_img.size(), 'Style and content images must be the same size'
299
+
300
+ unloader = transforms.ToPILImage() # reconvert into PIL image
301
+
302
+ # Load Discriminator
303
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
304
+ # TODO: finish this!
305
+
306
+
307
+ # ----------------------------------------------------------------------------
308
+
309
+
310
+ @main.command(name='dream', help='Discriminator Dreaming with the StyleGAN2/3 Discriminator and the chosen layers')
311
+ @click.pass_context
312
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
313
+ @click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
314
+ # Synthesis options
315
+ @click.option('--seeds', type=gen_utils.num_range, help='Random seeds to use. Accepted comma-separated values, ranges, or combinations: "a,b,c", "a-c", "a,b-d,e".', default='0')
316
+ @click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='perlin', show_default=True)
317
+ @click.option('--starting-image', type=str, help='Path to image to start from', default=None)
318
+ @click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
319
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
320
+ @click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=1e-2, show_default=True)
321
+ @click.option('--iterations', '-it', type=int, help='Number of gradient ascent steps per octave', default=20, show_default=True)
322
+ # Layer options
323
+ @click.option('--layers', type=str, help='Layers of the Discriminator to use as the features. If "all", will generate a dream image per available layer in the loaded model. If "use_all", will use all available layers.', default='b16_conv1', show_default=True)
324
+ @click.option('--channels', type=gen_utils.num_range, help='Comma-separated list and/or range of the channels of the Discriminator to use as the features. If "None", will use all channels in each specified layer.', default=None, show_default=True)
325
+ @click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
326
+ @click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
327
+ # Octaves options
328
+ @click.option('--num-octaves', type=int, help='Number of octaves', default=5, show_default=True)
329
+ @click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
330
+ @click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=True, show_default=True)
331
+ # Extra parameters for saving the results
332
+ @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
333
+ @click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
334
+ def discriminator_dream(
335
+ ctx: click.Context,
336
+ network_pkl: Union[str, os.PathLike],
337
+ cfg: Optional[str],
338
+ seeds: List[int],
339
+ image_noise: str,
340
+ starting_image: Union[str, os.PathLike],
341
+ convert_to_grayscale: bool,
342
+ class_idx: Optional[int], # TODO: conditional model
343
+ learning_rate: float,
344
+ iterations: int,
345
+ layers: str,
346
+ channels: Optional[List[int]],
347
+ norm_model_layers: bool,
348
+ sqrt_norm_model_layers: bool,
349
+ num_octaves: int,
350
+ octave_scale: float,
351
+ unzoom_octave: bool,
352
+ outdir: Union[str, os.PathLike],
353
+ description: str,
354
+ ):
355
+ # Set up device
356
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
357
+
358
+ # Load Discriminator
359
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
360
+
361
+ # Get the model resolution (image resizing and getting available layers)
362
+ model_resolution = D.img_resolution
363
+
364
+ # TODO: do this better, as we can combine these conditions later
365
+ layers = layers.split(',')
366
+
367
+ # We will use the features of the Discriminator, on the layer specified by the user
368
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
369
+
370
+ if 'all' in layers:
371
+ # Get all the available layers in a list
372
+ layers = get_available_layers(max_resolution=model_resolution)
373
+
374
+ for seed in seeds:
375
+ # Get the image and image name
376
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
377
+ starting_image=starting_image,
378
+ image_size=model_resolution,
379
+ convert_to_grayscale=convert_to_grayscale)
380
+
381
+ # Make the run dir in the specified output directory
382
+ desc = f'discriminator-dream-all_layers-seed_{seed}'
383
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
384
+ run_dir = gen_utils.make_run_dir(outdir, desc)
385
+
386
+ # Save starting image
387
+ image.save(os.path.join(run_dir, f'{os.path.basename(starting_image).split(".")[0]}.jpg'))
388
+
389
+ # Save the configuration used
390
+ ctx.obj = {
391
+ 'network_pkl': network_pkl,
392
+ 'synthesis_options': {
393
+ 'seed': seed,
394
+ 'random_image_noise': image_noise,
395
+ 'starting_image': starting_image,
396
+ 'class_idx': class_idx,
397
+ 'learning_rate': learning_rate,
398
+ 'iterations': iterations},
399
+ 'layer_options': {
400
+ 'layer': layers,
401
+ 'channels': channels,
402
+ 'norm_model_layers': norm_model_layers,
403
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers},
404
+ 'octaves_options': {
405
+ 'num_octaves': num_octaves,
406
+ 'octave_scale': octave_scale,
407
+ 'unzoom_octave': unzoom_octave},
408
+ 'extra_parameters': {
409
+ 'outdir': run_dir,
410
+ 'description': description}
411
+ }
412
+ # Save the run configuration
413
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
414
+
415
+ # For each layer:
416
+ for layer in layers:
417
+ # Extract deep dream image
418
+ dreamed_image = deep_dream(image, model, model_resolution, layers=[layer], channels=channels, seed=seed, normed=norm_model_layers,
419
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, lr=learning_rate,
420
+ octave_scale=octave_scale, num_octaves=num_octaves, unzoom_octave=unzoom_octave)
421
+
422
+ # Save the resulting dreamed image
423
+ filename = f'layer-{layer}_dreamed_{os.path.basename(starting_image).split(".")[0]}.jpg'
424
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
425
+
426
+ else:
427
+ if 'use_all' in layers:
428
+ # Get all available layers
429
+ layers = get_available_layers(max_resolution=model_resolution)
430
+ else:
431
+ # Parse the layers given by the user and leave only those available by the model
432
+ available_layers = get_available_layers(max_resolution=model_resolution)
433
+ layers = [layer for layer in layers if layer in available_layers]
434
+
435
+ # Make the run dir in the specified output directory
436
+ desc = f'discriminator-dream-layers_{"-".join(x for x in layers)}'
437
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
438
+ run_dir = gen_utils.make_run_dir(outdir, desc)
439
+
440
+ starting_images, used_seeds = [], []
441
+ for seed in seeds:
442
+ # Get the image and image name
443
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
444
+ starting_image=starting_image,
445
+ image_size=model_resolution,
446
+ convert_to_grayscale=convert_to_grayscale)
447
+
448
+ # Extract deep dream image
449
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, channels=channels, seed=seed, normed=norm_model_layers,
450
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, lr=learning_rate,
451
+ octave_scale=octave_scale, num_octaves=num_octaves, unzoom_octave=unzoom_octave)
452
+
453
+ # For logging later
454
+ starting_images.append(starting_image)
455
+ used_seeds.append(seed)
456
+
457
+ # Save the resulting image and initial image
458
+ filename = f'dreamed_{os.path.basename(starting_image)}'
459
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
460
+ image.save(os.path.join(run_dir, os.path.basename(starting_image)))
461
+ starting_image = None
462
+
463
+ # Save the configuration used
464
+ ctx.obj = {
465
+ 'network_pkl': network_pkl,
466
+ 'synthesis_options': {
467
+ 'seeds': used_seeds,
468
+ 'starting_image': starting_images,
469
+ 'class_idx': class_idx,
470
+ 'learning_rate': learning_rate,
471
+ 'iterations': iterations},
472
+ 'layer_options': {
473
+ 'layer': layers,
474
+ 'channels': channels,
475
+ 'norm_model_layers': norm_model_layers,
476
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers},
477
+ 'octaves_options': {
478
+ 'octave_scale': octave_scale,
479
+ 'num_octaves': num_octaves,
480
+ 'unzoom_octave': unzoom_octave},
481
+ 'extra_parameters': {
482
+ 'outdir': run_dir,
483
+ 'description': description}
484
+ }
485
+ # Save the run configuration
486
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
487
+
488
+
489
+ # ----------------------------------------------------------------------------
490
+
491
+
492
+ @main.command(name='dream-zoom',
493
+ help='Zoom/rotate/translate after each Discriminator Dreaming iteration. A video will be saved.')
494
+ @click.pass_context
495
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
496
+ @click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
497
+ # Synthesis options
498
+ @click.option('--seed', type=int, help='Random seed to use', default=0, show_default=True)
499
+ @click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
500
+ @click.option('--starting-image', type=str, help='Path to image to start from', default=None)
501
+ @click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
502
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
503
+ @click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=5e-3, show_default=True)
504
+ @click.option('--iterations', '-it', type=click.IntRange(min=1), help='Number of gradient ascent steps per octave', default=10, show_default=True)
505
+ # Layer options
506
+ @click.option('--layers', type=str, help='Comma-separated list of the layers of the Discriminator to use as the features. If "use_all", will use all available layers.', default='b16_conv0', show_default=True)
507
+ @click.option('--channels', type=gen_utils.num_range, help='Comma-separated list and/or range of the channels of the Discriminator to use as the features. If "None", will use all channels in each specified layer.', default=None, show_default=True)
508
+ @click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
509
+ @click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
510
+ # Octaves options
511
+ @click.option('--num-octaves', type=click.IntRange(min=1), help='Number of octaves', default=5, show_default=True)
512
+ @click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
513
+ @click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=False, show_default=True)
514
+ # Individual frame manipulation options
515
+ @click.option('--pixel-zoom', '-zoom', type=int, help='How many pixels to zoom per step (positive for zoom in, negative for zoom out, padded with black)', default=2, show_default=True)
516
+ @click.option('--rotation-deg', '-rot', type=float, help='Rotate image counter-clockwise per frame (padded with black)', default=0.0, show_default=True)
517
+ @click.option('--translate-x', '-tx', type=float, help='Translate the image in the horizontal axis per frame (from left to right, padded with black)', default=0.0, show_default=True)
518
+ @click.option('--translate-y', '-ty', type=float, help='Translate the image in the vertical axis per frame (from top to bottom, padded with black)', default=0.0, show_default=True)
519
+ # Video options
520
+ @click.option('--fps', type=gen_utils.parse_fps, help='FPS for the mp4 video of optimization progress (if saved)', default=25, show_default=True)
521
+ @click.option('--duration-sec', type=float, help='Duration length of the video', default=15.0, show_default=True)
522
+ @click.option('--reverse-video', is_flag=True, help='Add flag to reverse the generated video')
523
+ @click.option('--include-starting-image', type=bool, help='Include the starting image in the final video', default=True, show_default=True)
524
+ # Extra parameters for saving the results
525
+ @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
526
+ @click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
527
+ def discriminator_dream_zoom(
528
+ ctx: click.Context,
529
+ network_pkl: Union[str, os.PathLike],
530
+ cfg: Optional[str],
531
+ seed: int,
532
+ image_noise: Optional[str],
533
+ starting_image: Optional[Union[str, os.PathLike]],
534
+ convert_to_grayscale: bool,
535
+ class_idx: Optional[int], # TODO: conditional model
536
+ learning_rate: float,
537
+ iterations: int,
538
+ layers: str,
539
+ channels: List[int],
540
+ norm_model_layers: Optional[bool],
541
+ sqrt_norm_model_layers: Optional[bool],
542
+ num_octaves: int,
543
+ octave_scale: float,
544
+ unzoom_octave: Optional[bool],
545
+ pixel_zoom: int,
546
+ rotation_deg: float,
547
+ translate_x: int,
548
+ translate_y: int,
549
+ fps: int,
550
+ duration_sec: float,
551
+ reverse_video: bool,
552
+ include_starting_image: bool,
553
+ outdir: Union[str, os.PathLike],
554
+ description: str,
555
+ ):
556
+ # Set up device
557
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
558
+
559
+ # Load Discriminator
560
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
561
+
562
+ # Get the model resolution (for resizing the starting image if needed)
563
+ model_resolution = D.img_resolution
564
+ zoom_size = model_resolution - 2 * pixel_zoom
565
+
566
+ layers = layers.split(',')
567
+ if 'use_all' in layers:
568
+ # Get all available layers
569
+ layers = get_available_layers(max_resolution=model_resolution)
570
+ else:
571
+ # Parse the layers given by the user and leave only those available by the model
572
+ available_layers = get_available_layers(max_resolution=model_resolution)
573
+ layers = [layer for layer in layers if layer in available_layers]
574
+
575
+ # We will use the features of the Discriminator, on the layer specified by the user
576
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
577
+
578
+ # Get the image and image name
579
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
580
+ starting_image=starting_image,
581
+ image_size=model_resolution,
582
+ convert_to_grayscale=convert_to_grayscale)
583
+
584
+ # Make the run dir in the specified output directory
585
+ desc = 'discriminator-dream-zoom'
586
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
587
+ run_dir = gen_utils.make_run_dir(outdir, desc)
588
+
589
+ # Save the configuration used
590
+ ctx.obj = {
591
+ 'network_pkl': network_pkl,
592
+ 'synthesis_options': {
593
+ 'seed': seed,
594
+ 'random_image_noise': image_noise,
595
+ 'starting_image': starting_image,
596
+ 'class_idx': class_idx,
597
+ 'learning_rate': learning_rate,
598
+ 'iterations': iterations
599
+ },
600
+ 'layer_options': {
601
+ 'layers': layers,
602
+ 'channels': channels,
603
+ 'norm_model_layers': norm_model_layers,
604
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers
605
+ },
606
+ 'octaves_options': {
607
+ 'num_octaves': num_octaves,
608
+ 'octave_scale': octave_scale,
609
+ 'unzoom_octave': unzoom_octave
610
+ },
611
+ 'frame_manipulation_options': {
612
+ 'pixel_zoom': pixel_zoom,
613
+ 'rotation_deg': rotation_deg,
614
+ 'translate_x': translate_x,
615
+ 'translate_y': translate_y,
616
+ },
617
+ 'video_options': {
618
+ 'fps': fps,
619
+ 'duration_sec': duration_sec,
620
+ 'reverse_video': reverse_video,
621
+ 'include_starting_image': include_starting_image,
622
+ },
623
+ 'extra_parameters': {
624
+ 'outdir': run_dir,
625
+ 'description': description
626
+ }
627
+ }
628
+ # Save the run configuration
629
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
630
+
631
+ num_frames = int(np.rint(duration_sec * fps)) # Number of frames for the video
632
+ n_digits = int(np.log10(num_frames)) + 1 # Number of digits for naming each frame
633
+
634
+ # Save the starting image
635
+ starting_image_name = f'dreamed_{0:0{n_digits}d}.jpg' if include_starting_image else 'starting_image.jpg'
636
+ image.save(os.path.join(run_dir, starting_image_name))
637
+
638
+ for idx, frame in enumerate(tqdm(range(num_frames), desc='Dreaming...', unit='frame')):
639
+ # Zoom in after the first frame
640
+ if idx > 0:
641
+ image = crop_resize_rotate(image, crop_size=zoom_size, new_size=model_resolution,
642
+ rotation_deg=rotation_deg, translate_x=translate_x, translate_y=translate_y)
643
+ # Extract deep dream image
644
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, seed=seed, normed=norm_model_layers,
645
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, channels=channels,
646
+ lr=learning_rate, octave_scale=octave_scale, num_octaves=num_octaves,
647
+ unzoom_octave=unzoom_octave, disable_inner_tqdm=True)
648
+
649
+ # Save the resulting image and initial image
650
+ filename = f'dreamed_{idx + 1:0{n_digits}d}.jpg'
651
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
652
+
653
+ # Now, the dreamed image is the starting image
654
+ image = Image.fromarray(dreamed_image, 'RGB')
655
+
656
+ # Save the final video
657
+ gen_utils.save_video_from_images(run_dir=run_dir, image_names=f'dreamed_%0{n_digits}d.jpg',
658
+ video_name='dream-zoom', fps=fps, reverse_video=reverse_video)
659
+
660
+
661
+ # ----------------------------------------------------------------------------
662
+
663
+ @main.command(name='channel-zoom', help='Dream zoom using only the specified channels in the selected layer')
664
+ @click.pass_context
665
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
666
+ @click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
667
+ # Synthesis options
668
+ @click.option('--seed', type=int, help='Random seed to use', default=0, show_default=True)
669
+ @click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
670
+ @click.option('--starting-image', type=str, help='Path to image to start from', default=None)
671
+ @click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
672
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
673
+ @click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=5e-3, show_default=True)
674
+ @click.option('--iterations', '-it', type=click.IntRange(min=1), help='Number of gradient ascent steps per octave', default=10, show_default=True)
675
+ # Layer options
676
+ @click.option('--layer', type=str, help='Layers of the Discriminator to use as the features.', default='b8_conv0', show_default=True)
677
+ @click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
678
+ @click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
679
+ # Octaves options
680
+ @click.option('--num-octaves', type=click.IntRange(min=1), help='Number of octaves', default=5, show_default=True)
681
+ @click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
682
+ @click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=False, show_default=True)
683
+ # Individual frame manipulation options
684
+ @click.option('--pixel-zoom', '-zoom', type=int, help='How many pixels to zoom per step (positive for zoom in, negative for zoom out, padded with black)', default=2, show_default=True)
685
+ @click.option('--rotation-deg', '-rot', type=float, help='Rotate image counter-clockwise per frame (padded with black)', default=0.0, show_default=True)
686
+ @click.option('--translate-x', '-tx', type=float, help='Translate the image in the horizontal axis per frame (from left to right, padded with black)', default=0.0, show_default=True)
687
+ @click.option('--translate-y', '-ty', type=float, help='Translate the image in the vertical axis per frame (from top to bottom, padded with black)', default=0.0, show_default=True)
688
+ # Video options
689
+ @click.option('--frames-per-channel', type=click.IntRange(min=1), help='Number of frames per channel', default=1, show_default=True)
690
+ @click.option('--fps', type=gen_utils.parse_fps, help='FPS for the mp4 video of optimization progress (if saved)', default=25, show_default=True)
691
+ @click.option('--reverse-video', is_flag=True, help='Add flag to reverse the generated video')
692
+ @click.option('--include-starting-image', type=bool, help='Include the starting image in the final video', default=True, show_default=True)
693
+ # Extra parameters for saving the results
694
+ @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
695
+ @click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
696
+ def channel_zoom(
697
+ ctx: click.Context,
698
+ network_pkl: Union[str, os.PathLike],
699
+ cfg: Optional[str],
700
+ seed: int,
701
+ image_noise: Optional[str],
702
+ starting_image: Optional[Union[str, os.PathLike]],
703
+ convert_to_grayscale: bool,
704
+ class_idx: Optional[int], # TODO: conditional model
705
+ learning_rate: float,
706
+ iterations: int,
707
+ layer: str,
708
+ norm_model_layers: Optional[bool],
709
+ sqrt_norm_model_layers: Optional[bool],
710
+ num_octaves: int,
711
+ octave_scale: float,
712
+ unzoom_octave: Optional[bool],
713
+ pixel_zoom: int,
714
+ rotation_deg: float,
715
+ translate_x: int,
716
+ translate_y: int,
717
+ frames_per_channel: int,
718
+ fps: int,
719
+ reverse_video: bool,
720
+ include_starting_image: bool,
721
+ outdir: Union[str, os.PathLike],
722
+ description: str,
723
+ ):
724
+ """Zoom in using all the channels of a network (or a specified layer)"""
725
+ # Set up device
726
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
727
+
728
+ # Load Discriminator
729
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
730
+
731
+ # Get the model resolution (for resizing the starting image if needed)
732
+ model_resolution = D.img_resolution
733
+ zoom_size = model_resolution - 2 * pixel_zoom
734
+
735
+ if 'use_all' in layer:
736
+ ctx.fail('Cannot use "use_all" with this command. Please specify the layers you want to use.')
737
+ else:
738
+ # Parse the layers given by the user and leave only those available by the model
739
+ available_layers = get_available_layers(max_resolution=model_resolution)
740
+ assert layer in available_layers, f'Layer {layer} not available. Available layers: {available_layers}'
741
+ layers = [layer]
742
+
743
+ # We will use the features of the Discriminator, on the layer specified by the user
744
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
745
+
746
+ # Get the image and image name
747
+ image, starting_image = get_image(seed=seed, image_noise=image_noise,
748
+ starting_image=starting_image,
749
+ image_size=model_resolution,
750
+ convert_to_grayscale=convert_to_grayscale)
751
+
752
+ # Make the run dir in the specified output directory
753
+ desc = 'discriminator-channel-zoom'
754
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
755
+ run_dir = gen_utils.make_run_dir(outdir, desc)
756
+
757
+ # Finally, let's get the number of channels in the selected layer
758
+ channels_dict = {res: D.get_submodule(f'b{res}.conv0').out_channels for res in D.block_resolutions}
759
+ channels_dict[4] = D.get_submodule('b4.conv').out_channels # Last block has a different name
760
+ # Get the dimension of the block from the selected layer (e.g., from 'b128_conv0' get '128')
761
+ block_resolution = re.search(r'b(\d+)_', layer).group(1)
762
+ total_channels = channels_dict[int(block_resolution)]
763
+ # Make a list of all the channels, each repeated frames_per_channel
764
+ channels = np.repeat(np.arange(total_channels), frames_per_channel)
765
+
766
+ num_frames = int(np.rint(total_channels * frames_per_channel)) # Number of frames for the video
767
+ n_digits = int(np.log10(num_frames)) + 1 # Number of digits for naming each frame
768
+
769
+ # Save the starting image
770
+ starting_image_name = f'dreamed_{0:0{n_digits}d}.jpg' if include_starting_image else 'starting_image.jpg'
771
+ image.save(os.path.join(run_dir, starting_image_name))
772
+
773
+ for idx, frame in enumerate(tqdm(range(num_frames), desc='Dreaming...', unit='frame')):
774
+ # Zoom in after the first frame
775
+ if idx > 0:
776
+ image = crop_resize_rotate(image, crop_size=zoom_size, new_size=model_resolution,
777
+ rotation_deg=rotation_deg, translate_x=translate_x, translate_y=translate_y)
778
+ # Extract deep dream image
779
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, seed=seed, normed=norm_model_layers,
780
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, channels=channels[idx:idx + 1],
781
+ lr=learning_rate, octave_scale=octave_scale, num_octaves=num_octaves,
782
+ unzoom_octave=unzoom_octave, disable_inner_tqdm=True)
783
+
784
+ # Save the resulting image and initial image
785
+ filename = f'dreamed_{idx + 1:0{n_digits}d}.jpg'
786
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
787
+
788
+ # Now, the dreamed image is the starting image
789
+ image = Image.fromarray(dreamed_image, 'RGB')
790
+
791
+ # Save the final video
792
+ gen_utils.save_video_from_images(run_dir=run_dir, image_names=f'dreamed_%0{n_digits}d.jpg', video_name='channel-zoom',
793
+ fps=fps, reverse_video=reverse_video)
794
+
795
+ # Save the configuration used
796
+ ctx.obj = {
797
+ 'network_pkl': network_pkl,
798
+ 'synthesis_options': {
799
+ 'seed': seed,
800
+ 'random_image_noise': image_noise,
801
+ 'starting_image': starting_image,
802
+ 'class_idx': class_idx,
803
+ 'learning_rate': learning_rate,
804
+ 'iterations': iterations
805
+ },
806
+ 'layer_options': {
807
+ 'layer': layer,
808
+ 'channels': 'all',
809
+ 'total_channels': total_channels,
810
+ 'norm_model_layers': norm_model_layers,
811
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers
812
+ },
813
+ 'octaves_options': {
814
+ 'num_octaves': num_octaves,
815
+ 'octave_scale': octave_scale,
816
+ 'unzoom_octave': unzoom_octave
817
+ },
818
+ 'frame_manipulation_options': {
819
+ 'pixel_zoom': pixel_zoom,
820
+ 'rotation_deg': rotation_deg,
821
+ 'translate_x': translate_x,
822
+ 'translate_y': translate_y,
823
+ },
824
+ 'video_options': {
825
+ 'fps': fps,
826
+ 'frames_per_channel': frames_per_channel,
827
+ 'reverse_video': reverse_video,
828
+ 'include_starting_image': include_starting_image,
829
+ },
830
+ 'extra_parameters': {
831
+ 'outdir': run_dir,
832
+ 'description': description
833
+ }
834
+ }
835
+ # Save the run configuration
836
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
837
+
838
+
839
+ # ----------------------------------------------------------------------------
840
+
841
+
842
+ @main.command(name='interp', help='Interpolate between two or more seeds')
843
+ @click.pass_context
844
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
845
+ @click.option('--cfg', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), help='Model base configuration', default=None)
846
+ # Synthesis options
847
+ @click.option('--seeds', type=gen_utils.num_range, help='Random seeds to generate the Perlin noise from', required=True)
848
+ @click.option('--interp-type', '-interp', type=click.Choice(['linear', 'spherical']), help='Type of interpolation in Z or W', default='spherical', show_default=True)
849
+ @click.option('--smooth', is_flag=True, help='Add flag to smooth the interpolation between the seeds')
850
+ @click.option('--random-image-noise', '-noise', 'image_noise', type=click.Choice(['random', 'perlin']), default='random', show_default=True)
851
+ @click.option('--starting-image', type=str, help='Path to image to start from', default=None)
852
+ @click.option('--convert-to-grayscale', '-grayscale', is_flag=True, help='Add flag to grayscale the initial image')
853
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None)
854
+ @click.option('--lr', 'learning_rate', type=float, help='Learning rate', default=5e-3, show_default=True)
855
+ @click.option('--iterations', '-it', type=click.IntRange(min=1), help='Number of gradient ascent steps per octave', default=10, show_default=True)
856
+ # Layer options
857
+ @click.option('--layers', type=str, help='Comma-separated list of the layers of the Discriminator to use as the features. If "use_all", will use all available layers.', default='b16_conv0', show_default=True)
858
+ @click.option('--channels', type=gen_utils.num_range, help='Comma-separated list and/or range of the channels of the Discriminator to use as the features. If "None", will use all channels in each specified layer.', default=None, show_default=True)
859
+ @click.option('--normed', 'norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by its number of elements')
860
+ @click.option('--sqrt-normed', 'sqrt_norm_model_layers', is_flag=True, help='Add flag to divide the features of each layer of D by the square root of its number of elements')
861
+ # Octaves options
862
+ @click.option('--num-octaves', type=click.IntRange(min=1), help='Number of octaves', default=5, show_default=True)
863
+ @click.option('--octave-scale', type=float, help='Image scale between octaves', default=1.4, show_default=True)
864
+ @click.option('--unzoom-octave', type=bool, help='Set to True for the octaves to be unzoomed (this will be slower)', default=False, show_default=True)
865
+ # TODO: Individual frame manipulation options
866
+ # Video options
867
+ @click.option('--seed-sec', '-sec', type=float, help='Number of seconds between each seed transition', default=5.0, show_default=True)
868
+ @click.option('--fps', type=gen_utils.parse_fps, help='FPS for the mp4 video of optimization progress (if saved)', default=25, show_default=True)
869
+ # Extra parameters for saving the results
870
+ @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'discriminator_synthesis'), show_default=True, metavar='DIR')
871
+ @click.option('--description', '-desc', type=str, help='Additional description name for the directory path to save results', default='', show_default=True)
872
+ def random_interpolation(
873
+ ctx: click.Context,
874
+ network_pkl: Union[str, os.PathLike],
875
+ cfg: Optional[str],
876
+ seeds: List[int],
877
+ interp_type: Optional[str],
878
+ smooth: Optional[bool],
879
+ image_noise: Optional[str],
880
+ starting_image: Optional[Union[str, os.PathLike]],
881
+ convert_to_grayscale: bool,
882
+ class_idx: Optional[int], # TODO: conditional model
883
+ learning_rate: float,
884
+ iterations: int,
885
+ layers: str,
886
+ channels: List[int],
887
+ norm_model_layers: Optional[bool],
888
+ sqrt_norm_model_layers: Optional[bool],
889
+ num_octaves: int,
890
+ octave_scale: float,
891
+ unzoom_octave: Optional[bool],
892
+ seed_sec: float,
893
+ fps: int,
894
+ outdir: Union[str, os.PathLike],
895
+ description: str,
896
+ ):
897
+ """Do a latent walk between random Perlin images (given the seeds) and generate a video with these frames."""
898
+ # TODO: To make this better and more stable, we generate Perlin noise animations, not interpolations
899
+ # Set up device
900
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
901
+
902
+ # Load Discriminator
903
+ D = gen_utils.load_network('D', network_pkl, cfg, device)
904
+
905
+ # Get model resolution
906
+ model_resolution = D.img_resolution
907
+ model = DiscriminatorFeatures(D).requires_grad_(False).to(device)
908
+
909
+ layers = layers.split(',')
910
+ # Get all available layers
911
+ if 'use_all' in layers:
912
+ layers = get_available_layers(max_resolution=model_resolution)
913
+ else:
914
+ # Parse the layers given by the user and leave only those available by the model
915
+ available_layers = get_available_layers(max_resolution=model_resolution)
916
+ layers = [layer for layer in layers if layer in available_layers]
917
+
918
+ # Make the run dir in the specified output directory
919
+ desc = f'random-interp-layers_{"-".join(x for x in layers)}'
920
+ desc = f'{desc}-{description}' if len(description) != 0 else desc
921
+ run_dir = gen_utils.make_run_dir(outdir, desc)
922
+
923
+ # Number of steps to take between each random image
924
+ n_steps = int(np.rint(seed_sec * fps))
925
+ # Total number of frames
926
+ num_frames = int(n_steps * (len(seeds) - 1))
927
+ # Total video length in seconds
928
+ duration_sec = num_frames / fps
929
+
930
+ # Number of digits for naming purposes
931
+ n_digits = int(np.log10(num_frames)) + 1
932
+
933
+ # Create interpolation of noises
934
+ random_images = []
935
+ for seed in seeds:
936
+ # Get the starting seed and image
937
+ image, _ = get_image(seed=seed, image_noise=image_noise, starting_image=starting_image,
938
+ image_size=model_resolution, convert_to_grayscale=convert_to_grayscale)
939
+ image = np.array(image) / 255.0
940
+ random_images.append(image)
941
+ random_images = np.stack(random_images)
942
+
943
+ all_images = np.empty([0] + list(random_images.shape[1:]), dtype=np.float32)
944
+ # Do interpolation
945
+ for i in range(len(random_images) - 1):
946
+ # Interpolate between each pair of images
947
+ interp = gen_utils.interpolate(random_images[i], random_images[i + 1], n_steps, interp_type, smooth)
948
+ # Append it to the list of all images
949
+ all_images = np.append(all_images, interp, axis=0)
950
+
951
+ # DeepDream expects a list of PIL.Image objects
952
+ pil_images = []
953
+ for idx in range(len(all_images)):
954
+ im = (255 * all_images[idx]).astype(dtype=np.uint8)
955
+ pil_images.append(Image.fromarray(im))
956
+
957
+ for idx, image in enumerate(tqdm(pil_images, desc='Interpolating...', unit='frame', total=num_frames)):
958
+ # Extract deep dream image
959
+ dreamed_image = deep_dream(image, model, model_resolution, layers=layers, channels=channels, seed=None,
960
+ normed=norm_model_layers, disable_inner_tqdm=True, ignore_initial_transform=True,
961
+ sqrt_normed=sqrt_norm_model_layers, iterations=iterations, lr=learning_rate,
962
+ octave_scale=octave_scale, num_octaves=num_octaves, unzoom_octave=unzoom_octave)
963
+
964
+ # Save the resulting image and initial image
965
+ filename = f'{image_noise}-interpolation_frame_{idx:0{n_digits}d}.jpg'
966
+ Image.fromarray(dreamed_image, 'RGB').save(os.path.join(run_dir, filename))
967
+
968
+ # Save the configuration used
969
+ ctx.obj = {
970
+ 'network_pkl': network_pkl,
971
+ 'synthesis_options': {
972
+ 'seeds': seeds,
973
+ 'starting_image': starting_image,
974
+ 'class_idx': class_idx,
975
+ 'learning_rate': learning_rate,
976
+ 'iterations': iterations},
977
+ 'layer_options': {
978
+ 'layer': layers,
979
+ 'channels': channels,
980
+ 'norm_model_layers': norm_model_layers,
981
+ 'sqrt_norm_model_layers': sqrt_norm_model_layers},
982
+ 'octaves_options': {
983
+ 'octave_scale': octave_scale,
984
+ 'num_octaves': num_octaves,
985
+ 'unzoom_octave': unzoom_octave},
986
+ 'extra_parameters': {
987
+ 'outdir': run_dir,
988
+ 'description': description}
989
+ }
990
+ # Save the run configuration
991
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
992
+
993
+ # Generate video
994
+ print('Saving video...')
995
+ ffmpeg_command = r'/usr/bin/ffmpeg' if os.name != 'nt' else r'C:\\Ffmpeg\\bin\\ffmpeg.exe'
996
+ stream = ffmpeg.input(os.path.join(run_dir, f'{image_noise}-interpolation_frame_%0{n_digits}d.jpg'), framerate=fps)
997
+ stream = ffmpeg.output(stream, os.path.join(run_dir, f'{image_noise}-interpolation.mp4'), crf=20, pix_fmt='yuv420p')
998
+ ffmpeg.run(stream, capture_stdout=True, capture_stderr=True, cmd=ffmpeg_command)
999
+
1000
+ # ----------------------------------------------------------------------------
1001
+
1002
+
1003
+ if __name__ == '__main__':
1004
+ main()
1005
+
1006
+
1007
+ # ----------------------------------------------------------------------------
stylegan3-fun/dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (285 Bytes). View file
 
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (245 Bytes). View file
 
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (245 Bytes). View file
 
stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc ADDED
Binary file (26.1 kB). View file
 
stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc ADDED
Binary file (14.1 kB). View file
 
stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
stylegan3-fun/dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
stylegan3-fun/docs/avg_spectra_screen0.png ADDED
stylegan3-fun/docs/avg_spectra_screen0_half.png ADDED
stylegan3-fun/docs/configs.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training configurations
2
+
3
+ This document provides guidelines for selecting appropriate training options for various scenarios, as well as an extensive list of recommended configurations.
4
+
5
+ #### Example
6
+
7
+ In the remainder of this document, we summarize each configuration as follows:
8
+
9
+ | <sub>Config</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
10
+ | :--------------------------- | :--------------: | :--------------: | :------------: | :--
11
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>18.47</sub> | <sub>12.29</sub> | <sub>4.3</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=8.2 --mirror=1`</sub>
12
+
13
+ This corresponds to the following command line:
14
+
15
+ ```.bash
16
+ # Train StyleGAN3-T for AFHQv2 using 8 GPUs.
17
+ python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
18
+ --gpus=8 --batch=32 --gamma=8.2 --mirror=1
19
+ ```
20
+
21
+ Explanation of the columns:
22
+ - **Config**: StyleGAN3-T (translation equiv.), StyleGAN3-R (translation and rotation equiv.), or StyleGAN2. Reflects the value of `--cfg`.
23
+ - **s/kimg**: Raw training speed, measured separately on Tesla V100 and A100 using our recommended Docker image. The number indicates how many seconds, on average, it takes to process 1000 images from the training set. The number tends to vary slightly over the course of training; typically by no more than &plusmn;20%.
24
+ - **GPU mem**: Maximum GPU memory usage observed during training, reported in gigabytes per GPU. The above example uses 8 GPUs, which means that the total GPU memory usage is around 34.4 GB.
25
+ - **Options**: Command line options for `train.py`, excluding `--outdir` and `--data`.
26
+
27
+ #### Total training time
28
+
29
+ In addition the raw s/kimg number, the training time also depends on the `--kimg` and `--metric` options. `--kimg` controls the total number of training iterations and is set to 25000 by default. This is long enough to reach convergence in typical cases, but in practice the results should already look quite reasonable around 5000 kimg. `--metrics` determines which quality metrics are computed periodically during training. The default is `fid50k_full`, which increases the training time slightly; typically by no more than 5%. The automatic computation can be disabled by specifying `--metrics=none`.
30
+
31
+ In the above example, the total training time on V100 is approximately 18.47 s/kimg * 25000 kimg * 1.05 &thickapprox; 485,000 seconds &thickapprox; 5 days and 14 hours. Disabling metric computation (`--metrics=none`) reduces this to approximately 5 days and 8 hours.
32
+
33
+ ## General guidelines
34
+
35
+ The most important hyperparameter that needs to be tuned on a per-dataset basis is the R<sub>1</sub> regularization weight, `--gamma`, that must be specified explicitly for `train.py`. As a rule of thumb, the value of `--gamma` scales quadratically with respect to the training set resolution: doubling the resolution (e.g., 256x256 &rarr; 512x512) means that `--gamma` should be multiplied by 4 (e.g., 2 &rarr; 8). The optimal value is usually the same for `--cfg=stylegan3-t` and `--cfg=stylegan3-r`, but considerably lower for `--cfg=stylegan2`.
36
+
37
+ In practice, we recommend selecting the value of `--gamma` as follows:
38
+ - Find the closest match for your specific case in this document (config, resolution, and GPU count).
39
+ - Try training with the same `--gamma` first.
40
+ - Then, try increasing the value by 2x and 4x, and also decreasing it by 2x and 4x.
41
+ - Pick the value that yields the lowest FID.
42
+
43
+ The results may also be improved by adjusting `--mirror` and `--aug`, depending on the training data. Specifying `--mirror=1` augments the dataset with random *x*-flips, which effectively doubles the number of images. This is generally beneficial with datasets that are horizontally symmetric (e.g., FFHQ), but it can be harmful if the images contain noticeable asymmetric features (e.g., text or letters). Specifying `--aug=noaug` disables adaptive discriminator augmentation (ADA), which may improve the results slightly if the training set is large enough (at least 100k images when accounting for *x*-flips). With small datasets (less than 30k images), it is generally a good idea to leave the augmentations enabled.
44
+
45
+ It is possible to speed up the training by decreasing network capacity, i.e., `--cbase=16384`. This typically leads to lower quality results, but the difference is less pronounced with low-resolution datasets (e.g., 256x256).
46
+
47
+ #### Scaling to different number of GPUs
48
+
49
+ You can select the number of GPUs by changing the value of `--gpu`; this does not affect the convergence curves or training dynamics in any way. By default, the total batch size (`--batch`) is divided evenly among the GPUs, which means that decreasing the number of GPUs yields higher per-GPU memory usage. To avoid running out of memory, you can decrease the per-GPU batch size by specifying `--batch-gpu`, which performs the same computation in multiple passes using gradient accumulation.
50
+
51
+ By default, `train.py` exports network snapshots once every 200 kimg, i.e., the product of `--snap=50` and `--tick=4`. When using few GPUs (e.g., 1&ndash;2), this means that it may take a very long time for the first snapshot to appear. We recommend increasing the snapshot frequency in such cases by specifying `--snap=20`, `--snap=10`, or `--snap=5`.
52
+
53
+ Note that the configurations listed in this document have been specifically tuned for 8 GPUs. The safest way to scale them to different GPU counts is to adjust `--gpu`, `--batch-gpu`, and `--snap` as described above, but it may be possible to reach faster convergence by adjusting some of the other hyperparameters as well. Note, however, that adjusting the total batch size (`--batch`) requires some experimentation; decreasing `--batch` usually necessitates increasing regularization (`--gamma`) and/or decreasing the learning rates (most importantly `--dlr`).
54
+
55
+ #### Transfer learning
56
+
57
+ Transfer learning makes it possible to reach very good results very quickly, especially when the training set is small and/or the images resemble the ones produced by a pre-trained model. To enable transfer learning, you can point `--resume` to one of the pre-trained models that we provide for [StyleGAN3](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan3) and [StyleGAN2](https://ngc.nvidia.com/catalog/models/nvidia:research:stylegan2). For example:
58
+
59
+ ```.bash
60
+ # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
61
+ python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
62
+ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
63
+ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
64
+ ```
65
+
66
+ The pre-trained model should be selected to match the specified config, resolution, and architecture-related hyperparameters (e.g., `--cbase`, `--map-depth`, and `--mbstd-group`). You check this by looking at the `fakes_init.png` exported by `train.py` at the beginning; if the configuration is correct, the images should look reasonable.
67
+
68
+ With transfer learning, the results may be improved slightly by adjusting `--freezed`, in addition to the above guidelines for `--gamma`, `--mirror`, and `--aug`. In our experience, `--freezed=10` and `--freezed=13` tend to work reasonably well.
69
+
70
+ ## Recommended configurations
71
+
72
+ This section lists recommended settings for StyleGAN3-T and StyleGAN3-R for different resolutions and GPU counts, selected according to the above guidelines. These are intended to provide a good starting point when experimenting with a new dataset. Please note that many of the options (e.g., `--gamma`, `--mirror`, and `--aug`) are still worth adjusting on a case-by-case basis.
73
+
74
+ #### 128x128 resolution
75
+
76
+ | <sub>Config</sub><br><br> | <sub>GPUs</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
77
+ | :--------------------------- | :----------: | :--------------: | :--------------: | :------------: | :--
78
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>1</sub> | <sub>73.68</sub> | <sub>27.20</sub> | <sub>7.2</sub> | <sub>`--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=0.5 --batch-gpu=16 --snap=10`</sub>
79
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>2</sub> | <sub>37.30</sub> | <sub>13.74</sub> | <sub>7.1</sub> | <sub>`--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=0.5 --snap=20`</sub>
80
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>4</sub> | <sub>20.66</sub> | <sub>7.52</sub> | <sub>4.1</sub> | <sub>`--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=0.5`</sub>
81
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>8</sub> | <sub>11.31</sub> | <sub>4.40</sub> | <sub>2.6</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=0.5`</sub>
82
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>1</sub> | <sub>58.44</sub> | <sub>34.23</sub> | <sub>8.3</sub> | <sub>`--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=0.5 --batch-gpu=16 --snap=10`</sub>
83
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>2</sub> | <sub>29.92</sub> | <sub>17.29</sub> | <sub>8.2</sub> | <sub>`--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=0.5 --snap=20`</sub>
84
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>4</sub> | <sub>15.49</sub> | <sub>9.53</sub> | <sub>4.5</sub> | <sub>`--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=0.5`</sub>
85
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>8</sub> | <sub>8.43</sub> | <sub>5.69</sub> | <sub>2.7</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=0.5`</sub>
86
+
87
+ #### 256x256 resolution
88
+
89
+ | <sub>Config</sub><br><br> | <sub>GPUs</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
90
+ | :--------------------------- | :----------: | :--------------: | :--------------: | :------------: | :--
91
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>1</sub> | <sub>89.15</sub> | <sub>49.81</sub> | <sub>9.5</sub> | <sub>`--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=2 --batch-gpu=16 --snap=10`</sub>
92
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>2</sub> | <sub>45.45</sub> | <sub>25.05</sub> | <sub>9.3</sub> | <sub>`--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=2 --snap=20`</sub>
93
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>4</sub> | <sub>23.94</sub> | <sub>13.26</sub> | <sub>5.2</sub> | <sub>`--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=2`</sub>
94
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>8</sub> | <sub>13.04</sub> | <sub>7.32</sub> | <sub>3.1</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=2`</sub>
95
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>1</sub> | <sub>87.37</sub> | <sub>56.73</sub> | <sub>6.7</sub> | <sub>`--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=2 --batch-gpu=8 --snap=10`</sub>
96
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>2</sub> | <sub>44.12</sub> | <sub>28.60</sub> | <sub>6.7</sub> | <sub>`--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=2 --batch-gpu=8 --snap=20`</sub>
97
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>4</sub> | <sub>22.42</sub> | <sub>14.39</sub> | <sub>6.6</sub> | <sub>`--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=2`</sub>
98
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>8</sub> | <sub>11.88</sub> | <sub>8.03</sub> | <sub>3.7</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=2`</sub>
99
+
100
+ #### 512x512 resolution
101
+
102
+ | <sub>Config</sub><br><br> | <sub>GPUs</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
103
+ | :--------------------------- | :----------: | :---------------: | :---------------: | :------------: | :--
104
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>1</sub> | <sub>137.33</sub> | <sub>90.25</sub> | <sub>7.8</sub> | <sub>`--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=8 --batch-gpu=8 --snap=10`</sub>
105
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>2</sub> | <sub>69.65</sub> | <sub>45.42</sub> | <sub>7.7</sub> | <sub>`--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=8 --batch-gpu=8 --snap=20`</sub>
106
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>4</sub> | <sub>34.88</sub> | <sub>22.81</sub> | <sub>7.6</sub> | <sub>`--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=8`</sub>
107
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>8</sub> | <sub>18.47</sub> | <sub>12.29</sub> | <sub>4.3</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=8`</sub>
108
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>1</sub> | <sub>158.91</sub> | <sub>110.13</sub> | <sub>6.0</sub> | <sub>`--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=8 --batch-gpu=4 --snap=10`</sub>
109
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>2</sub> | <sub>79.96</sub> | <sub>55.18</sub> | <sub>6.0</sub> | <sub>`--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=8 --batch-gpu=4 --snap=20`</sub>
110
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>4</sub> | <sub>40.86</sub> | <sub>27.99</sub> | <sub>5.9</sub> | <sub>`--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=8 --batch-gpu=4`</sub>
111
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>8</sub> | <sub>20.44</sub> | <sub>14.04</sub> | <sub>5.9</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=8`</sub>
112
+
113
+ #### 1024x1024 resolution
114
+
115
+ | <sub>Config</sub><br><br> | <sub>GPUs</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
116
+ | :--------------------------- | :----------: | :---------------: | :---------------: | :-------------: | :--
117
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>1</sub> | <sub>221.85</sub> | <sub>156.91</sub> | <sub>7.0</sub> | <sub>`--cfg=stylegan3-t --gpus=1 --batch=32 --gamma=32 --batch-gpu=4 --snap=5`</sub>
118
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>2</sub> | <sub>113.44</sub> | <sub>79.16</sub> | <sub>6.8</sub> | <sub>`--cfg=stylegan3-t --gpus=2 --batch=32 --gamma=32 --batch-gpu=4 --snap=10`</sub>
119
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>4</sub> | <sub>57.04</sub> | <sub>39.62</sub> | <sub>6.7</sub> | <sub>`--cfg=stylegan3-t --gpus=4 --batch=32 --gamma=32 --batch-gpu=4 --snap=20`</sub>
120
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>8</sub> | <sub>28.71</sub> | <sub>20.01</sub> | <sub>6.6</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=32`</sub>
121
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>1</sub> | <sub>263.44</sub> | <sub>184.81</sub> | <sub>10.2</sub> | <sub>`--cfg=stylegan3-r --gpus=1 --batch=32 --gamma=32 --batch-gpu=4 --snap=5`</sub>
122
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>2</sub> | <sub>134.22</sub> | <sub>92.58</sub> | <sub>10.1</sub> | <sub>`--cfg=stylegan3-r --gpus=2 --batch=32 --gamma=32 --batch-gpu=4 --snap=10`</sub>
123
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>4</sub> | <sub>67.33</sub> | <sub>46.53</sub> | <sub>10.0</sub> | <sub>`--cfg=stylegan3-r --gpus=4 --batch=32 --gamma=32 --batch-gpu=4 --snap=20`</sub>
124
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>8</sub> | <sub>34.12</sub> | <sub>23.42</sub> | <sub>9.9</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=32`</sub>
125
+
126
+ ## Configurations used in StyleGAN3 paper
127
+
128
+ This section lists the exact settings that we used in the "Alias-Free Generative Adversarial Networks" paper.
129
+
130
+ #### FFHQ-U and FFHQ at 1024x1024 resolution
131
+
132
+ | <sub>Config</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
133
+ | :--------------------------- | :--------------: | :--------------: | :------------: | :--
134
+ | <sub>StyleGAN2</sub> | <sub>17.55</sub> | <sub>14.57</sub> | <sub>6.2</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug`</sub>
135
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>28.71</sub> | <sub>20.01</sub> | <sub>6.6</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=32.8 --mirror=1 --aug=noaug`</sub>
136
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>34.12</sub> | <sub>23.42</sub> | <sub>9.9</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=32.8 --mirror=1 --aug=noaug`</sub>
137
+
138
+ #### MetFaces-U at 1024x1024 resolution
139
+
140
+ | <sub>Config</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
141
+ | :--------------------------- | :--------------: | :--------------: | :-------------: | :--
142
+ | <sub>StyleGAN2</sub> | <sub>18.74</sub> | <sub>11.80</sub> | <sub>7.4</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=32 --gamma=10 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhqu-1024x1024.pkl`</sub>
143
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>29.84</sub> | <sub>21.06</sub> | <sub>7.7</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=16.4 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl`</sub>
144
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>35.10</sub> | <sub>24.32</sub> | <sub>10.9</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl`</sub>
145
+
146
+ #### MetFaces at 1024x1024 resolution
147
+
148
+ | <sub>Config</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
149
+ | :--------------------------- | :--------------: | :--------------: | :-------------: | :--
150
+ | <sub>StyleGAN2</sub> | <sub>18.74</sub> | <sub>11.80</sub> | <sub>7.4</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=32 --gamma=5 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl`</sub>
151
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>29.84</sub> | <sub>21.06</sub> | <sub>7.7</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl`</sub>
152
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>35.10</sub> | <sub>24.32</sub> | <sub>10.9</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=3.3 --mirror=1 --kimg=5000 --snap=10 --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl`</sub>
153
+
154
+ #### AFHQv2 at 512x512 resolution
155
+
156
+ | <sub>Config</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
157
+ | :--------------------------- | :--------------: | :--------------: | :------------: | :--
158
+ | <sub>StyleGAN2</sub> | <sub>10.90</sub> | <sub>6.60</sub> | <sub>3.9</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=32 --gamma=5 --mirror=1`</sub>
159
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>18.47</sub> | <sub>12.29</sub> | <sub>4.3</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=32 --gamma=8.2 --mirror=1`</sub>
160
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>20.44</sub> | <sub>14.04</sub> | <sub>5.9</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=32 --gamma=16.4 --mirror=1`</sub>
161
+
162
+ #### FFHQ-U ablations at 256x256 resolution
163
+
164
+ | <sub>Config</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
165
+ | :--------------------------- | :-------------: | :-------------: | :------------: | :--
166
+ | <sub>StyleGAN2</sub> | <sub>3.61</sub> | <sub>2.19</sub> | <sub>2.7</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=64 --gamma=1 --mirror=1 --aug=noaug --cbase=16384 --glr=0.0025 --dlr=0.0025 --mbstd-group=8`</sub>
167
+ | <sub>StyleGAN3&#8209;T</sub> | <sub>7.40</sub> | <sub>3.74</sub> | <sub>3.5</sub> | <sub>`--cfg=stylegan3-t --gpus=8 --batch=64 --gamma=1 --mirror=1 --aug=noaug --cbase=16384 --dlr=0.0025`</sub>
168
+ | <sub>StyleGAN3&#8209;R</sub> | <sub>6.71</sub> | <sub>4.81</sub> | <sub>4.2</sub> | <sub>`--cfg=stylegan3-r --gpus=8 --batch=64 --gamma=1 --mirror=1 --aug=noaug --cbase=16384 --dlr=0.0025`</sub>
169
+
170
+ ## Old StyleGAN2-ADA configurations
171
+
172
+ This section lists command lines that can be used to match the configurations provided by our previous [StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada-pytorch) codebase. The first table corresponds to `--cfg=auto` (default) for different resolutions and GPU counts, while the second table lists the remaining alternatives.
173
+
174
+ #### Default configuration
175
+
176
+ | <sub>Res.</sub><br><br> | <sub>GPUs</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
177
+ | :---------------------- | :----------: | :---------------: | :--------------: | :------------: | :--
178
+ | <sub>128&sup2;</sub> | <sub>1</sub> | <sub>12.51</sub> | <sub>6.79</sub> | <sub>6.2</sub> | <sub>`--cfg=stylegan2 --gpus=1 --batch=32 --gamma=0.1024 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`</sub>
179
+ | <sub>128&sup2;</sub> | <sub>2</sub> | <sub>6.43</sub> | <sub>3.45</sub> | <sub>6.2</sub> | <sub>`--cfg=stylegan2 --gpus=2 --batch=64 --gamma=0.0512 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`</sub>
180
+ | <sub>128&sup2;</sub> | <sub>4</sub> | <sub>3.82</sub> | <sub>2.23</sub> | <sub>3.5</sub> | <sub>`--cfg=stylegan2 --gpus=4 --batch=64 --gamma=0.0512 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`</sub>
181
+ | <sub>256&sup2;</sub> | <sub>1</sub> | <sub>20.84</sub> | <sub>12.53</sub> | <sub>4.5</sub> | <sub>`--cfg=stylegan2 --gpus=1 --batch=16 --gamma=0.8192 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`</sub>
182
+ | <sub>256&sup2;</sub> | <sub>2</sub> | <sub>10.93</sub> | <sub>6.36</sub> | <sub>4.5</sub> | <sub>`--cfg=stylegan2 --gpus=2 --batch=32 --gamma=0.4096 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`</sub>
183
+ | <sub>256&sup2;</sub> | <sub>4</sub> | <sub>5.39</sub> | <sub>3.20</sub> | <sub>4.5</sub> | <sub>`--cfg=stylegan2 --gpus=4 --batch=64 --gamma=0.2048 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`</sub>
184
+ | <sub>256&sup2;</sub> | <sub>8</sub> | <sub>3.89</sub> | <sub>2.38</sub> | <sub>2.6</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=64 --gamma=0.2048 --map-depth=2 --glr=0.0025 --dlr=0.0025 --cbase=16384`</sub>
185
+ | <sub>512&sup2;</sub> | <sub>1</sub> | <sub>71.59</sub> | <sub>41.06</sub> | <sub>6.8</sub> | <sub>`--cfg=stylegan2 --gpus=1 --batch=8 --gamma=6.5536 --map-depth=2 --glr=0.0025 --dlr=0.0025`</sub>
186
+ | <sub>512&sup2;</sub> | <sub>2</sub> | <sub>36.79</sub> | <sub>20.83</sub> | <sub>6.8</sub> | <sub>`--cfg=stylegan2 --gpus=2 --batch=16 --gamma=3.2768 --map-depth=2 --glr=0.0025 --dlr=0.0025`</sub>
187
+ | <sub>512&sup2;</sub> | <sub>4</sub> | <sub>18.12</sub> | <sub>10.45</sub> | <sub>6.7</sub> | <sub>`--cfg=stylegan2 --gpus=4 --batch=32 --gamma=1.6384 --map-depth=2 --glr=0.0025 --dlr=0.0025`</sub>
188
+ | <sub>512&sup2;</sub> | <sub>8</sub> | <sub>9.09</sub> | <sub>5.24</sub> | <sub>6.8</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=64 --gamma=0.8192 --map-depth=2 --glr=0.0025 --dlr=0.0025`</sub>
189
+ | <sub>1024&sup2;</sub> | <sub>1</sub> | <sub>141.83</sub> | <sub>90.39</sub> | <sub>7.2</sub> | <sub>`--cfg=stylegan2 --gpus=1 --batch=4 --gamma=52.4288 --map-depth=2`</sub>
190
+ | <sub>1024&sup2;</sub> | <sub>2</sub> | <sub>73.13</sub> | <sub>46.04</sub> | <sub>7.2</sub> | <sub>`--cfg=stylegan2 --gpus=2 --batch=8 --gamma=26.2144 --map-depth=2`</sub>
191
+ | <sub>1024&sup2;</sub> | <sub>4</sub> | <sub>36.95</sub> | <sub>23.15</sub> | <sub>7.0</sub> | <sub>`--cfg=stylegan2 --gpus=4 --batch=16 --gamma=13.1072 --map-depth=2`</sub>
192
+ | <sub>1024&sup2;</sub> | <sub>8</sub> | <sub>18.47</sub> | <sub>11.66</sub> | <sub>7.3</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=32 --gamma=6.5536 --map-depth=2`</sub>
193
+
194
+ #### Repro configurations
195
+
196
+ | <sub>Name</sub><br><br> | <sub>s/kimg</sub><br><sup>(V100)</sup> | <sub>s/kimg</sub><br><sup>(A100)</sup> | <sub>GPU</sub><br><sup>mem</sup> | <sub>Options</sub><br><br>
197
+ | :---------------------- | :--------------: | :--------------: | :------------: | :--
198
+ | <sub>`stylegan2`</sub> | <sub>17.55</sub> | <sub>14.57</sub> | <sub>6.2</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=32 --gamma=10`</sub>
199
+ | <sub>`paper256`</sub> | <sub>4.01</sub> | <sub>2.47</sub> | <sub>2.7</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=64 --gamma=1 --cbase=16384 --glr=0.0025 --dlr=0.0025 --mbstd-group=8`</sub>
200
+ | <sub>`paper512`</sub> | <sub>9.11</sub> | <sub>5.28</sub> | <sub>6.7</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=64 --gamma=0.5 --glr=0.0025 --dlr=0.0025 --mbstd-group=8`</sub>
201
+ | <sub>`paper1024`</sub> | <sub>18.56</sub> | <sub>11.75</sub> | <sub>6.9</sub> | <sub>`--cfg=stylegan2 --gpus=8 --batch=32 --gamma=2`</sub>
stylegan3-fun/docs/dataset-tool-help.txt ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage: dataset_tool.py [OPTIONS]
2
+
3
+ Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
4
+ PyTorch.
5
+
6
+ The input dataset format is guessed from the --source argument:
7
+
8
+ --source *_lmdb/ Load LSUN dataset
9
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
10
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
11
+ --source path/ Recursively load all images from path/
12
+ --source dataset.zip Recursively load all images from dataset.zip
13
+
14
+ Specifying the output format and path:
15
+
16
+ --dest /path/to/dir Save output files under /path/to/dir
17
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
18
+
19
+ The output dataset format can be either an image folder or an uncompressed
20
+ zip archive. Zip archives makes it easier to move datasets around file
21
+ servers and clusters, and may offer better training performance on network
22
+ file systems.
23
+
24
+ Images within the dataset archive will be stored as uncompressed PNG.
25
+ Uncompresed PNGs can be efficiently decoded in the training loop.
26
+
27
+ Class labels are stored in a file called 'dataset.json' that is stored at
28
+ the dataset root folder. This file has the following structure:
29
+
30
+ {
31
+ "labels": [
32
+ ["00000/img00000000.png",6],
33
+ ["00000/img00000001.png",9],
34
+ ... repeated for every image in the datase
35
+ ["00049/img00049999.png",1]
36
+ ]
37
+ }
38
+
39
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
40
+ not containing class labels.
41
+
42
+ Image scale/crop and resolution requirements:
43
+
44
+ Output images must be square-shaped and they must all have the same power-
45
+ of-two dimensions.
46
+
47
+ To scale arbitrary input image size to a specific width and height, use
48
+ the --resolution option. Output resolution will be either the original
49
+ input resolution (if resolution was not specified) or the one specified
50
+ with --resolution option.
51
+
52
+ Use the --transform=center-crop or --transform=center-crop-wide options to
53
+ apply a center crop transform on the input image. These options should be
54
+ used with the --resolution option. For example:
55
+
56
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
57
+ --transform=center-crop-wide --resolution=512x384
58
+
59
+ Options:
60
+ --source PATH Directory or archive name for input dataset
61
+ [required]
62
+
63
+ --dest PATH Output directory or archive name for output
64
+ dataset [required]
65
+
66
+ --max-images INTEGER Output only up to `max-images` images
67
+ --transform [center-crop|center-crop-wide]
68
+ Input crop/resize mode
69
+ --resolution WxH Output resolution (e.g., '512x512')
70
+ --help Show this message and exit.
stylegan3-fun/docs/stylegan3-teaser-1920x1006.png ADDED

Git LFS Details

  • SHA256: 6d71fd93e8f9456e6a4d751649effb3096290f2a115d4d62b91de35458918ac2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.82 MB
stylegan3-fun/docs/train-help.txt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage: train.py [OPTIONS]
2
+
3
+ Train a GAN using the techniques described in the paper "Alias-Free
4
+ Generative Adversarial Networks".
5
+
6
+ Examples:
7
+
8
+ # Train StyleGAN3-T for AFHQv2 using 8 GPUs.
9
+ python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \
10
+ --gpus=8 --batch=32 --gamma=8.2 --mirror=1
11
+
12
+ # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle.
13
+ python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \
14
+ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \
15
+ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl
16
+
17
+ # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs.
18
+ python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \
19
+ --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug
20
+
21
+ Options:
22
+ --outdir DIR Where to save the results [required]
23
+ --cfg [stylegan3-t|stylegan3-r|stylegan2]
24
+ Base configuration [required]
25
+ --data [ZIP|DIR] Training data [required]
26
+ --gpus INT Number of GPUs to use [required]
27
+ --batch INT Total batch size [required]
28
+ --gamma FLOAT R1 regularization weight [required]
29
+ --cond BOOL Train conditional model [default: False]
30
+ --mirror BOOL Enable dataset x-flips [default: False]
31
+ --aug [noaug|ada|fixed] Augmentation mode [default: ada]
32
+ --resume [PATH|URL] Resume from given network pickle
33
+ --freezed INT Freeze first layers of D [default: 0]
34
+ --p FLOAT Probability for --aug=fixed [default: 0.2]
35
+ --target FLOAT Target value for --aug=ada [default: 0.6]
36
+ --batch-gpu INT Limit batch size per GPU
37
+ --cbase INT Capacity multiplier [default: 32768]
38
+ --cmax INT Max. feature maps [default: 512]
39
+ --glr FLOAT G learning rate [default: varies]
40
+ --dlr FLOAT D learning rate [default: 0.002]
41
+ --map-depth INT Mapping network depth [default: varies]
42
+ --mbstd-group INT Minibatch std group size [default: 4]
43
+ --desc STR String to include in result dir name
44
+ --metrics [NAME|A,B,C|none] Quality metrics [default: fid50k_full]
45
+ --kimg KIMG Total training duration [default: 25000]
46
+ --tick KIMG How often to print progress [default: 4]
47
+ --snap TICKS How often to save snapshots [default: 50]
48
+ --seed INT Random seed [default: 0]
49
+ --fp32 BOOL Disable mixed-precision [default: False]
50
+ --nobench BOOL Disable cuDNN benchmarking [default: False]
51
+ --workers INT DataLoader worker processes [default: 3]
52
+ -n, --dry-run Print training options and exit
53
+ --help Show this message and exit.
stylegan3-fun/docs/troubleshooting.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Troubleshooting
2
+
3
+ Our PyTorch code uses custom [CUDA extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) to speed up some of the network layers. Getting these to run can sometimes be a hassle.
4
+
5
+ This page aims to give guidance on how to diagnose and fix run-time problems related to these extensions.
6
+
7
+ ## Before you start
8
+
9
+ 1. Try Docker first! Ensure you can successfully run our models using the recommended Docker image. Follow the instructions in [README.md](/README.md) to get it running.
10
+ 2. Can't use Docker? Read on..
11
+
12
+ ## Installing dependencies
13
+
14
+ Make sure you've installed everything listed on the requirements section in the [README.md](/README.md). The key components w.r.t. custom extensions are:
15
+
16
+ - **[CUDA toolkit 11.1](https://developer.nvidia.com/cuda-toolkit)** or later (this is not the same as `cudatoolkit` from Conda).
17
+ - PyTorch invokes `nvcc` to compile our CUDA kernels.
18
+ - **ninja**
19
+ - PyTorch uses [Ninja](https://ninja-build.org/) as its build system.
20
+ - **GCC** (Linux) or **Visual Studio** (Windows)
21
+ - GCC 7.x or later is required. Earlier versions such as GCC 6.3 [are known not to work](https://github.com/NVlabs/stylegan3/issues/2).
22
+
23
+ #### Why is CUDA toolkit installation necessary?
24
+
25
+ The PyTorch package contains the required CUDA toolkit libraries needed to run PyTorch, so why is a separate CUDA toolkit installation required? Our models use custom CUDA kernels to implement operations such as efficient resampling of 2D images. PyTorch code invokes the CUDA compiler at run-time to compile these kernels on first-use. The tools and libraries required for this compilation are not bundled in PyTorch and thus a host CUDA toolkit installation is required.
26
+
27
+ ## Things to try
28
+
29
+ - Completely remove: `$HOME/.cache/torch_extensions` (Linux) or `C:\Users\<username>\AppData\Local\torch_extensions\torch_extensions\Cache` (Windows) and re-run StyleGAN3 python code.
30
+ - Run ninja in `$HOME/.cache/torch_extensions` to see that it builds.
31
+ - Inspect the `build.ninja` in the build directories under `$HOME/.cache/torch_extensions` and check CUDA tools and versions are consistent with what you intended to use.
stylegan3-fun/docs/visualizer_screen0.png ADDED

Git LFS Details

  • SHA256: fafce2c734f6eb2bed6467d3a1e83f9910e0bc4b374414c79da95c7c52d68ab9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
stylegan3-fun/docs/visualizer_screen0_half.png ADDED
stylegan3-fun/environment.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: stylegan3
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge # PR #80 by @SetZero / #143 by @coldwaterq
6
+ dependencies:
7
+ - python >= 3.8
8
+ - pip
9
+ - numpy>=1.20
10
+ - click>=8.0
11
+ - pillow=8.3.1
12
+ - scipy=1.7.1
13
+ - pytorch::pytorch=1.11 # We can now use Pytorch 1.11!
14
+ - nvidia::cudatoolkit=11.3 # Necessary to remove previous PR by @edstoica
15
+ - requests=2.26.0
16
+ - tqdm=4.62.2
17
+ - ninja=1.10.2
18
+ - matplotlib=3.4.2
19
+ - imageio=2.9.0
20
+ - pip:
21
+ - imgui==1.3.0
22
+ - glfw==2.2.0
23
+ - pyopengl==3.1.5
24
+ - imageio-ffmpeg==0.4.3
25
+ - pyspng
26
+ - pyperlin # for Discriminator Synthesis
27
+ - psutil # PR #125 by @fastflair / #111 by @siddharthksah
28
+ - tensorboard # PR #125 by @fastflair
29
+ - torchvision==0.12.0 # For "Discriminator Synthesis" / discriminator_synthesis.py
30
+ - pyperlin # For "Discriminator Synthesis" / discriminator_synthesis.py
31
+ - scikit-learn # For "Self-Distilled StyleGAN" / multimodal_truncation.py
32
+ - moviepy==1.0.3
33
+ - ffmpeg-python==0.2.0
34
+ - scikit-video==1.1.11
35
+ - setuptools==59.5.0 # PR #8 by @ZibbeZabbe
stylegan3-fun/gen_images.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def parse_range(s: Union[str, List]) -> List[int]:
26
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
27
+
28
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
29
+ '''
30
+ if isinstance(s, list): return s
31
+ ranges = []
32
+ range_re = re.compile(r'^(\d+)-(\d+)$')
33
+ for p in s.split(','):
34
+ m = range_re.match(p)
35
+ if m:
36
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
37
+ else:
38
+ ranges.append(int(p))
39
+ return ranges
40
+
41
+ #----------------------------------------------------------------------------
42
+
43
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
44
+ '''Parse a floating point 2-vector of syntax 'a,b'.
45
+
46
+ Example:
47
+ '0,1' returns (0,1)
48
+ '''
49
+ if isinstance(s, tuple): return s
50
+ parts = s.split(',')
51
+ if len(parts) == 2:
52
+ return (float(parts[0]), float(parts[1]))
53
+ raise ValueError(f'cannot parse 2-vector {s}')
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def make_transform(translate: Tuple[float,float], angle: float):
58
+ m = np.eye(3)
59
+ s = np.sin(angle/360.0*np.pi*2)
60
+ c = np.cos(angle/360.0*np.pi*2)
61
+ m[0][0] = c
62
+ m[0][1] = s
63
+ m[0][2] = translate[0]
64
+ m[1][0] = -s
65
+ m[1][1] = c
66
+ m[1][2] = translate[1]
67
+ return m
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ @click.command()
72
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
73
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
74
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
75
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
76
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
77
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
78
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
79
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
80
+ def generate_images(
81
+ network_pkl: str,
82
+ seeds: List[int],
83
+ truncation_psi: float,
84
+ noise_mode: str,
85
+ outdir: str,
86
+ translate: Tuple[float,float],
87
+ rotate: float,
88
+ class_idx: Optional[int]
89
+ ):
90
+ """Generate images using pretrained network pickle.
91
+
92
+ Examples:
93
+
94
+ \b
95
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
96
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
97
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
98
+
99
+ \b
100
+ # Generate uncurated images with truncation using the MetFaces-U dataset
101
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
102
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
103
+ """
104
+
105
+ print('Loading networks from "%s"...' % network_pkl)
106
+ device = torch.device('cuda')
107
+ with dnnlib.util.open_url(network_pkl) as f:
108
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
109
+
110
+ os.makedirs(outdir, exist_ok=True)
111
+
112
+ # Labels.
113
+ label = torch.zeros([1, G.c_dim], device=device)
114
+ if G.c_dim != 0:
115
+ if class_idx is None:
116
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
117
+ label[:, class_idx] = 1
118
+ else:
119
+ if class_idx is not None:
120
+ print ('warn: --class=lbl ignored when running on an unconditional network')
121
+
122
+ # Generate images.
123
+ for seed_idx, seed in enumerate(seeds):
124
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
125
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
126
+
127
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
128
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
129
+ # operations in the network.
130
+ if hasattr(G.synthesis, 'input'):
131
+ m = make_transform(translate, rotate)
132
+ m = np.linalg.inv(m)
133
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
134
+
135
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
136
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
137
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
138
+
139
+
140
+ #----------------------------------------------------------------------------
141
+
142
+ if __name__ == "__main__":
143
+ generate_images() # pylint: disable=no-value-for-parameter
144
+
145
+ #----------------------------------------------------------------------------
stylegan3-fun/gen_video.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate lerp videos using pretrained network pickle."""
10
+
11
+ import copy
12
+ import os
13
+ import re
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import click
17
+ import dnnlib
18
+ import imageio
19
+ import numpy as np
20
+ import scipy.interpolate
21
+ import torch
22
+ from tqdm import tqdm
23
+
24
+ import legacy
25
+ from torch_utils import gen_utils
26
+
27
+
28
+ # ----------------------------------------------------------------------------
29
+
30
+
31
+ def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
32
+ batch_size, channels, img_h, img_w = img.shape
33
+ if grid_w is None:
34
+ grid_w = batch_size // grid_h
35
+ assert batch_size == grid_w * grid_h
36
+ if float_to_uint8:
37
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
38
+ img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
39
+ img = img.permute(2, 0, 3, 1, 4)
40
+ img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
41
+ if chw_to_hwc:
42
+ img = img.permute(1, 2, 0)
43
+ if to_numpy:
44
+ img = img.cpu().numpy()
45
+ return img
46
+
47
+
48
+ # ----------------------------------------------------------------------------
49
+
50
+
51
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
52
+ """Parse a floating point 2-vector of syntax 'a,b'.
53
+
54
+ Example:
55
+ '0,1' returns (0,1)
56
+ """
57
+ if isinstance(s, tuple): return s
58
+ parts = s.split(',')
59
+ if len(parts) == 2:
60
+ return (float(parts[0]), float(parts[1]))
61
+ raise ValueError(f'cannot parse 2-vector {s}')
62
+
63
+
64
+ # ----------------------------------------------------------------------------
65
+
66
+
67
+ def make_transform(translate: Tuple[float,float], angle: float):
68
+ m = np.eye(3)
69
+ s = np.sin(angle/360.0*np.pi*2)
70
+ c = np.cos(angle/360.0*np.pi*2)
71
+ m[0][0] = c
72
+ m[0][1] = s
73
+ m[0][2] = translate[0]
74
+ m[1][0] = -s
75
+ m[1][1] = c
76
+ m[1][2] = translate[1]
77
+ return m
78
+
79
+
80
+ # ----------------------------------------------------------------------------
81
+
82
+
83
+ def gen_interp_video(G,
84
+ mp4: str,
85
+ seeds: List[int],
86
+ shuffle_seed: int = None,
87
+ w_frames: int = 60*4,
88
+ kind: str = 'cubic',
89
+ grid_dims: Tuple[int] = (1,1),
90
+ num_keyframes: int = None,
91
+ wraps: int = 2,
92
+ psi: float = 1.0,
93
+ device: torch.device = torch.device('cuda'),
94
+ stabilize_video: bool = True,
95
+ **video_kwargs):
96
+ grid_w = grid_dims[0]
97
+ grid_h = grid_dims[1]
98
+
99
+ if stabilize_video:
100
+ # Thanks to @RiversHaveWings and @nshepperd1
101
+ if hasattr(G.synthesis, 'input'):
102
+ shift = G.synthesis.input.affine(G.mapping.w_avg.unsqueeze(0))
103
+ G.synthesis.input.affine.bias.data.add_(shift.squeeze(0))
104
+ G.synthesis.input.affine.weight.data.zero_()
105
+
106
+ # Get the Generator's transform
107
+ m = G.synthesis.input.transform if hasattr(G.synthesis, 'input') else None
108
+
109
+ if num_keyframes is None:
110
+ if len(seeds) % (grid_w*grid_h) != 0:
111
+ raise ValueError('Number of input seeds must be divisible by grid W*H')
112
+ num_keyframes = len(seeds) // (grid_w*grid_h)
113
+
114
+ all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
115
+ for idx in range(num_keyframes*grid_h*grid_w):
116
+ all_seeds[idx] = seeds[idx % len(seeds)]
117
+
118
+ if shuffle_seed is not None:
119
+ rng = np.random.RandomState(seed=shuffle_seed)
120
+ rng.shuffle(all_seeds)
121
+
122
+ zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
123
+ ws = G.mapping(z=zs, c=None, truncation_psi=psi)
124
+ _ = G.synthesis(ws[:1]) # warm up
125
+ ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
126
+
127
+ # Interpolation.
128
+ grid = []
129
+ for yi in range(grid_h):
130
+ row = []
131
+ for xi in range(grid_w):
132
+ x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
133
+ y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
134
+ interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
135
+ row.append(interp)
136
+ grid.append(row)
137
+
138
+ # Render video.
139
+ video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
140
+ for frame_idx in tqdm(range(num_keyframes * w_frames)):
141
+ imgs = []
142
+ # Construct an inverse affine matrix and pass to the generator. The generator expects
143
+ # this matrix as an inverse to avoid potentially failing numerical operations in the network.
144
+ if hasattr(G.synthesis, 'input'):
145
+ # Set default values for each affine transformation
146
+ total_rotation = 0.0 # If >= 0.0, will rotate the pixels counter-clockwise w.r.t. the center; in radians
147
+ total_translation_x = 0.0 # If >= 0.0, will translate all pixels to the right; if <= 0.0, to the left
148
+ total_translation_y = 0.0 # If >= 0.0, will translate all pixels upwards; if <= 0.0, downwards
149
+ total_scale_x = 1.0 # If <= 1.0, will zoom in; else, will zoom out (x-axis)
150
+ total_scale_y = 1.0 # If <= 1.0, will zoom in; else, will zoom out (y-axis)
151
+ total_shear_x = 0.0 # If >= 0.0, will shear pixels to the right, keeping y fixed; if <= 0.0, to the left
152
+ total_shear_y = 0.0 # If >= 0.0, will shear pixels upwards, keeping x fixed; if <= 0.0, downwards
153
+ mirror_x = False # Mirror along the x-axis; if True, will flip the image horizontally (can't be a function of frame_idx)
154
+ mirror_y = False # Mirror along the y-axis; if True, will flip the image vertically (can't be a function of frame_idx)
155
+
156
+ # Go nuts with these. They can be constants as above to fix centering/rotation in your video,
157
+ # or you can make them functions of frame_idx to animate them, such as (uncomment as many as you want to try):
158
+ # total_scale_x = 1 + np.sin(np.pi*frame_idx/(num_keyframes * w_frames))/2 # will oscillate between 0.5 and 1.5
159
+ # total_rotation = 4*np.pi*frame_idx/(num_keyframes * w_frames) # 4 will dictate the number of rotations, so 1 full rotation
160
+ # total_shear_y = 2*np.sin(2*np.pi*frame_idx/(num_keyframes * w_frames)) # will oscillate between -2 and 2
161
+
162
+ # We then use these values to construct the affine matrix
163
+ m = gen_utils.make_affine_transform(m, angle=total_rotation, translate_x=total_translation_x,
164
+ translate_y=total_translation_y, scale_x=total_scale_x,
165
+ scale_y=total_scale_y, shear_x=total_shear_x, shear_y=total_shear_y,
166
+ mirror_x=mirror_x, mirror_y=mirror_y)
167
+ m = np.linalg.inv(m)
168
+ # Finally, we pass the matrix to the generator
169
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
170
+
171
+ # The rest stays the same, for all you gen_video.py lovers out there
172
+ for yi in range(grid_h):
173
+ for xi in range(grid_w):
174
+ interp = grid[yi][xi]
175
+ w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
176
+ img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
177
+ imgs.append(img)
178
+ video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
179
+ video_out.close()
180
+
181
+
182
+ # ----------------------------------------------------------------------------
183
+
184
+
185
+ def parse_range(s: Union[str, List[int]]) -> List[int]:
186
+ """Parse a comma separated list of numbers or ranges and return a list of ints.
187
+
188
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
189
+ """
190
+ if isinstance(s, list): return s
191
+ ranges = []
192
+ range_re = re.compile(r'^(\d+)-(\d+)$')
193
+ for p in s.split(','):
194
+ m = range_re.match(p)
195
+ if m:
196
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
197
+ else:
198
+ ranges.append(int(p))
199
+ return ranges
200
+
201
+
202
+ # ----------------------------------------------------------------------------
203
+
204
+
205
+ def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
206
+ """Parse a 'M,N' or 'MxN' integer tuple.
207
+
208
+ Example:
209
+ '4x2' returns (4,2)
210
+ '0,1' returns (0,1)
211
+ """
212
+ if isinstance(s, tuple): return s
213
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
214
+ if m:
215
+ return (int(m.group(1)), int(m.group(2)))
216
+ raise ValueError(f'cannot parse tuple {s}')
217
+
218
+
219
+ # ----------------------------------------------------------------------------
220
+
221
+
222
+ @click.command()
223
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
224
+ @click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
225
+ @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
226
+ @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
227
+ @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
228
+ @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
229
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
230
+ @click.option('--stabilize-video', is_flag=True, help='Stabilize the video by anchoring the mapping to w_avg')
231
+ @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
232
+ def generate_images(
233
+ network_pkl: str,
234
+ seeds: List[int],
235
+ shuffle_seed: Optional[int],
236
+ truncation_psi: float,
237
+ grid: Tuple[int,int],
238
+ num_keyframes: Optional[int],
239
+ stabilize_video: bool,
240
+ w_frames: int,
241
+ output: str
242
+ ):
243
+ """Render a latent vector interpolation video.
244
+
245
+ Examples:
246
+
247
+ \b
248
+ # Render a 4x2 grid of interpolations for seeds 0 through 31.
249
+ python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
250
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
251
+
252
+ Animation length and seed keyframes:
253
+
254
+ The animation length is either determined based on the --seeds value or explicitly
255
+ specified using the --num-keyframes option.
256
+
257
+ When num keyframes is specified with --num-keyframes, the output video length
258
+ will be 'num_keyframes*w_frames' frames.
259
+
260
+ If --num-keyframes is not specified, the number of seeds given with
261
+ --seeds must be divisible by grid size W*H (--grid). In this case the
262
+ output video length will be '# seeds/(w*h)*w_frames' frames.
263
+ """
264
+
265
+ print('Loading networks from "%s"...' % network_pkl)
266
+ device = torch.device('cuda')
267
+ with dnnlib.util.open_url(network_pkl) as f:
268
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
269
+
270
+ gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames,
271
+ seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi, stabilize_video=stabilize_video)
272
+
273
+
274
+ # ----------------------------------------------------------------------------
275
+
276
+
277
+ if __name__ == "__main__":
278
+ generate_images() # pylint: disable=no-value-for-parameter
279
+
280
+
281
+ # ----------------------------------------------------------------------------
stylegan3-fun/generate.py ADDED
@@ -0,0 +1,838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List, Optional, Union, Tuple
4
+ import click
5
+
6
+ import dnnlib
7
+ from torch_utils import gen_utils
8
+ import copy
9
+
10
+ import scipy
11
+ import numpy as np
12
+ import PIL.Image
13
+ import torch
14
+
15
+ import legacy
16
+ from viz.renderer import Renderer
17
+
18
+ os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = 'hide'
19
+ import moviepy.editor
20
+
21
+
22
+ # ----------------------------------------------------------------------------
23
+
24
+
25
+ # We group the different types of generation (images, grid, video, wacky stuff) into a main function
26
+ @click.group()
27
+ def main():
28
+ pass
29
+
30
+
31
+ # ----------------------------------------------------------------------------
32
+
33
+
34
+ @main.command(name='images')
35
+ @click.pass_context
36
+ @click.option('--network', 'network_pkl', help='Network pickle filename: can be URL, local file, or the name of the model in torch_utils.gen_utils.resume_specs', required=True)
37
+ @click.option('--device', help='Device to use for image generation; using the CPU is slower than the GPU', type=click.Choice(['cpu', 'cuda']), default='cuda', show_default=True)
38
+ @click.option('--cfg', type=click.Choice(gen_utils.available_cfgs), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
39
+ # Synthesis options (feed a list of seeds or give the projected w to synthesize)
40
+ @click.option('--seeds', type=gen_utils.num_range, help='List of random seeds')
41
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
42
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)', default=None, show_default=True)
43
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44
+ @click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
45
+ @click.option('--projected-w', help='Projection result file; can be either .npy or .npz files', type=click.Path(exists=True, dir_okay=False), metavar='FILE')
46
+ @click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
47
+ # Save the output of the intermediate layers
48
+ @click.option('--layer', 'layer_name', type=str, help='Layer name to extract; if unsure, use `--available-layers`', default=None, show_default=True)
49
+ @click.option('--available-layers', is_flag=True, help='List the available layers in the trained model and exit')
50
+ @click.option('--starting-channel', 'starting_channel', type=int, help='Starting channel for the layer extraction', default=0, show_default=True)
51
+ @click.option('--grayscale', 'save_grayscale', type=bool, help='Use the first channel starting from `--starting-channel` to generate a grayscale image.', default=False, show_default=True)
52
+ @click.option('--rgb', 'save_rgb', type=bool, help='Use 3 consecutive channels (if they exist) to generate a RGB image, starting from `--starting-channel`.', default=False, show_default=True)
53
+ @click.option('--rgba', 'save_rgba', type=bool, help='Use 4 consecutive channels (if they exist) to generate a RGBA image, starting from `--starting-channel`.', default=False, show_default=True)
54
+ @click.option('--img-scale-db', 'img_scale_db', type=click.FloatRange(min=-40, max=40), help='Scale the image pixel values, akin to "exposure" (lower, the image is grayer/, higher the more white/burnt regions)', default=0, show_default=True)
55
+ @click.option('--img-normalize', 'img_normalize', type=bool, help='Normalize images of the selected layer and channel', default=False, show_default=True)
56
+ # Grid options
57
+ @click.option('--save-grid', is_flag=True, help='Use flag to save image grid')
58
+ @click.option('--grid-width', '-gw', type=click.IntRange(min=1), help='Grid width (number of columns)', default=None)
59
+ @click.option('--grid-height', '-gh', type=click.IntRange(min=1), help='Grid height (number of rows)', default=None)
60
+ # Extra parameters for saving the results
61
+ @click.option('--save-dlatents', is_flag=True, help='Use flag to save individual dlatents (W) for each individual resulting image')
62
+ @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'images'), show_default=True, metavar='DIR')
63
+ @click.option('--description', '-desc', type=str, help='Description name for the directory path to save results', default='generate-images', show_default=True)
64
+ def generate_images(
65
+ ctx: click.Context,
66
+ network_pkl: str,
67
+ device: Optional[str],
68
+ cfg: Optional[str],
69
+ seeds: Optional[List[int]],
70
+ truncation_psi: Optional[float],
71
+ class_idx: Optional[int],
72
+ noise_mode: Optional[str],
73
+ anchor_latent_space: Optional[bool],
74
+ projected_w: Optional[Union[str, os.PathLike]],
75
+ new_center: Tuple[str, Union[int, np.ndarray]],
76
+ layer_name: Optional[str],
77
+ available_layers: Optional[bool],
78
+ starting_channel: Optional[int],
79
+ save_grayscale: Optional[bool],
80
+ save_rgb: Optional[bool],
81
+ save_rgba: Optional[bool],
82
+ img_scale_db: Optional[float],
83
+ img_normalize: Optional[bool],
84
+ save_grid: Optional[bool],
85
+ grid_width: int,
86
+ grid_height: int,
87
+ save_dlatents: Optional[bool],
88
+ outdir: Union[str, os.PathLike],
89
+ description: str,
90
+ ):
91
+ """Generate images using pretrained network pickle.
92
+
93
+ Examples:
94
+
95
+ \b
96
+ # Generate curated MetFaces images without truncation (Fig.10 left)
97
+ python generate.py images --trunc=1 --seeds=85,265,297,849 \\
98
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
99
+
100
+ \b
101
+ # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
102
+ python generate.py images --trunc=0.7 --seeds=600-605 \\
103
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
104
+
105
+ \b
106
+ # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
107
+ python generate.py images --seeds=0-35 --class=1 \\
108
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
109
+
110
+ \b
111
+ # Render an image from projected W
112
+ python generate.py images --projected_w=projected_w.npz \\
113
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
114
+
115
+ \b
116
+ Generate class conditional StyleGAN2 WikiArt images, save each individually, and save all of them as a grid
117
+ python generate.py images --cfg=stylegan2 --network=wikiart1024-C --class=155 \\
118
+ --trunc=0.7 --seeds=10-50 --save-grid
119
+ """
120
+ # Sanity check
121
+ if len(seeds) < 1:
122
+ ctx.fail('Use `--seeds` to specify at least one seed.')
123
+
124
+ device = torch.device('cuda') if torch.cuda.is_available() and device == 'cuda' else torch.device('cpu')
125
+
126
+ # Load the network
127
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
128
+
129
+ if available_layers:
130
+ click.secho(f'Printing available layers (name, channels and size) for "{network_pkl}"...', fg='blue')
131
+ _ = Renderer().render(G=G, available_layers=available_layers)
132
+ sys.exit(1)
133
+
134
+ # Setup for using CPU
135
+ if device.type == 'cpu':
136
+ gen_utils.use_cpu(G)
137
+
138
+ # Stabilize/anchor the latent space
139
+ if anchor_latent_space:
140
+ gen_utils.anchor_latent_space(G)
141
+
142
+ description = 'generate-images' if len(description) == 0 else description
143
+ description = f'{description}-{layer_name}_layer' if layer_name is not None else description
144
+ # Create the run dir with the given name description
145
+ run_dir = gen_utils.make_run_dir(outdir, description)
146
+
147
+ # Synthesize the result of a W projection.
148
+ if projected_w is not None:
149
+ if seeds is not None:
150
+ print('warn: --seeds is ignored when using --projected-w')
151
+ print(f'Generating images from projected W "{projected_w}"')
152
+ ws, ext = gen_utils.get_latent_from_file(projected_w, return_ext=True)
153
+ ws = torch.tensor(ws, device=device)
154
+ assert ws.shape[1:] == (G.num_ws, G.w_dim)
155
+ n_digits = int(np.log10(len(ws))) + 1 # number of digits for naming the images
156
+ if ext == '.npy':
157
+ img = gen_utils.w_to_img(G, ws, noise_mode)[0]
158
+ PIL.Image.fromarray(img, gen_utils.channels_dict[G.synthesis.img_channels]).save(f'{run_dir}/proj.png')
159
+ else:
160
+ for idx, w in enumerate(ws):
161
+ img = gen_utils.w_to_img(G, w, noise_mode)[0]
162
+ PIL.Image.fromarray(img,
163
+ gen_utils.channels_dict[G.synthesis.img_channels]).save(f'{run_dir}/proj{idx:0{n_digits}d}.png')
164
+ return
165
+
166
+ # Labels.
167
+ class_idx = gen_utils.parse_class(G, class_idx, ctx)
168
+ label = torch.zeros([1, G.c_dim], device=device)
169
+ if G.c_dim != 0:
170
+ label[:, class_idx] = 1
171
+ else:
172
+ if class_idx is not None:
173
+ print('warn: --class=lbl ignored when running on an unconditional network')
174
+
175
+ if seeds is None:
176
+ ctx.fail('--seeds option is required when not using --projected-w')
177
+
178
+ # Recenter the latent space, if specified
179
+ if new_center is None:
180
+ w_avg = G.mapping.w_avg
181
+ else:
182
+ new_center, new_center_value = new_center
183
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
184
+ if isinstance(new_center_value, int):
185
+ w_avg = gen_utils.get_w_from_seed(G, device, new_center_value,
186
+ truncation_psi=1.0) # We want the pure dlatent
187
+ elif isinstance(new_center_value, np.ndarray):
188
+ w_avg = torch.from_numpy(new_center_value).to(device)
189
+ else:
190
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
191
+
192
+ # Generate images.
193
+ images = []
194
+ for seed_idx, seed in enumerate(seeds):
195
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
196
+ dlatent = gen_utils.get_w_from_seed(G, device, seed, truncation_psi=1.0)
197
+ # Do truncation trick with center (new or global)
198
+ w = w_avg + (dlatent - w_avg) * truncation_psi
199
+
200
+ # TODO: this is starting to look like an auxiliary function!
201
+ # Save the intermediate layer output.
202
+ if layer_name is not None:
203
+ # Sanity check (meh, could be done better)
204
+ submodule_names = {name: mod for name, mod in G.synthesis.named_modules()}
205
+ assert layer_name in submodule_names, f'Layer "{layer_name}" not found in the network! Available layers: {", ".join(submodule_names)}'
206
+ assert True in (save_grayscale, save_rgb, save_rgba), 'You must select to save the image in at least one of the three possible formats! (L, RGB, RGBA)'
207
+
208
+ sel_channels = 3 if save_rgb else (1 if save_grayscale else 4)
209
+ res = Renderer().render(G=G, layer_name=layer_name, dlatent=w, sel_channels=sel_channels,
210
+ base_channel=starting_channel, img_scale_db=img_scale_db, img_normalize=img_normalize)
211
+ img = res.image
212
+ else:
213
+ img = gen_utils.w_to_img(G, w, noise_mode)[0]
214
+
215
+ if save_grid:
216
+ images.append(img)
217
+
218
+ # Get the image format, whether user-specified or the one from the model
219
+ try:
220
+ img_format = gen_utils.channels_dict[sel_channels]
221
+ except NameError:
222
+ img_format = gen_utils.channels_dict[G.synthesis.img_channels]
223
+
224
+ # Save image, avoiding grayscale errors in PIL
225
+ PIL.Image.fromarray(img[:, :, 0] if img.shape[-1] == 1 else img,
226
+ img_format).save(os.path.join(run_dir, f'seed{seed}.png'))
227
+ if save_dlatents:
228
+ np.save(os.path.join(run_dir, f'seed{seed}.npy'), w.unsqueeze(0).cpu().numpy())
229
+
230
+ if save_grid:
231
+ print('Saving image grid...')
232
+ images = np.array(images)
233
+
234
+ # We let the function infer the shape of the grid
235
+ if (grid_width, grid_height) == (None, None):
236
+ grid = gen_utils.create_image_grid(images)
237
+ # The user tells the specific shape of the grid, but one value may be None
238
+ else:
239
+ grid = gen_utils.create_image_grid(images, (grid_width, grid_height))
240
+
241
+ grid = grid[:, :, 0] if grid.shape[-1] == 1 else grid
242
+ PIL.Image.fromarray(grid, img_format).save(os.path.join(run_dir, 'grid.png'))
243
+
244
+ # Save the configuration used
245
+ ctx.obj = {
246
+ 'network_pkl': network_pkl,
247
+ 'device': device.type,
248
+ 'config': cfg,
249
+ 'synthesis': {
250
+ 'seeds': seeds,
251
+ 'truncation_psi': truncation_psi,
252
+ 'class_idx': class_idx,
253
+ 'noise_mode': noise_mode,
254
+ 'anchor_latent_space': anchor_latent_space,
255
+ 'projected_w': projected_w,
256
+ 'new_center': new_center
257
+ },
258
+ 'intermediate_representations': {
259
+ 'layer': layer_name,
260
+ 'starting_channel': starting_channel,
261
+ 'grayscale': save_grayscale,
262
+ 'rgb': save_rgb,
263
+ 'rgba': save_rgba,
264
+ 'img_scale_db': img_scale_db,
265
+ 'img_normalize': img_normalize
266
+ },
267
+ 'grid_options': {
268
+ 'save_grid': save_grid,
269
+ 'grid_width': grid_width,
270
+ 'grid_height': grid_height,
271
+ },
272
+ 'extra_parameters': {
273
+ 'save_dlatents': save_dlatents,
274
+ 'run_dir': run_dir,
275
+ 'description': description,
276
+ }
277
+ }
278
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
279
+
280
+
281
+ # ----------------------------------------------------------------------------
282
+
283
+
284
+ @main.command(name='random-video')
285
+ @click.pass_context
286
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
287
+ @click.option('--cfg', type=click.Choice(gen_utils.available_cfgs), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
288
+ # Synthesis options
289
+ @click.option('--seeds', type=gen_utils.num_range, help='List of random seeds', required=True)
290
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
291
+ @click.option('--new-center', type=gen_utils.parse_new_center, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
292
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
293
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
294
+ @click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
295
+ # Save the output of the intermediate layers
296
+ @click.option('--layer', 'layer_name', type=str, help='Layer name to extract; if unsure, use `--available-layers`', default=None, show_default=True)
297
+ @click.option('--available-layers', is_flag=True, help='List the available layers in the trained model and exit')
298
+ @click.option('--starting-channel', 'starting_channel', type=int, help='Starting channel for the layer extraction', default=0, show_default=True)
299
+ @click.option('--grayscale', 'save_grayscale', type=bool, help='Use the first channel starting from `--starting-channel` to generate a grayscale image.', default=False, show_default=True)
300
+ @click.option('--rgb', 'save_rgb', type=bool, help='Use 3 consecutive channels (if they exist) to generate a RGB image, starting from `--starting-channel`.', default=False, show_default=True)
301
+ @click.option('--img-scale-db', 'img_scale_db', type=click.FloatRange(min=-40, max=40), help='Scale the image pixel values, akin to "exposure" (lower, the image is grayer/, higher the more white/burnt regions)', default=0, show_default=True)
302
+ @click.option('--img-normalize', 'img_normalize', type=bool, help='Normalize images of the selected layer and channel', default=False, show_default=True)
303
+ # Video options
304
+ @click.option('--grid-width', '-gw', type=click.IntRange(min=1), help='Video grid width / number of columns', default=None, show_default=True)
305
+ @click.option('--grid-height', '-gh', type=click.IntRange(min=1), help='Video grid height / number of rows', default=None, show_default=True)
306
+ @click.option('--slowdown', type=gen_utils.parse_slowdown, help='Slow down the video by this amount; will be approximated to the nearest power of 2', default='1', show_default=True)
307
+ @click.option('--duration-sec', '-sec', type=float, help='Duration length of the video', default=30.0, show_default=True)
308
+ @click.option('--fps', type=click.IntRange(min=1), help='Video FPS.', default=30, show_default=True)
309
+ @click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
310
+ # Extra parameters for saving the results
311
+ @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'video'), show_default=True, metavar='DIR')
312
+ @click.option('--description', '-desc', type=str, help='Description name for the directory path to save results')
313
+ def random_interpolation_video(
314
+ ctx: click.Context,
315
+ network_pkl: Union[str, os.PathLike],
316
+ cfg: Optional[str],
317
+ seeds: List[int],
318
+ truncation_psi: Optional[float],
319
+ new_center: Tuple[str, Union[int, np.ndarray]],
320
+ class_idx: Optional[int],
321
+ noise_mode: Optional[str],
322
+ anchor_latent_space: Optional[bool],
323
+ layer_name: Optional[str],
324
+ available_layers: Optional[bool],
325
+ starting_channel: Optional[int],
326
+ save_grayscale: Optional[bool],
327
+ save_rgb: Optional[bool],
328
+ img_scale_db: Optional[float],
329
+ img_normalize: Optional[bool],
330
+ grid_width: int,
331
+ grid_height: int,
332
+ slowdown: Optional[int],
333
+ duration_sec: Optional[float],
334
+ fps: int,
335
+ outdir: Union[str, os.PathLike],
336
+ description: str,
337
+ compress: bool,
338
+ smoothing_sec: Optional[float] = 3.0 # for Gaussian blur; won't be a command-line parameter, change at own risk
339
+ ):
340
+ """
341
+ Generate a random interpolation video using a pretrained network.
342
+
343
+ Examples:
344
+
345
+ \b
346
+ # Generate a 30-second long, untruncated MetFaces video at 30 FPS (3 rows and 2 columns; horizontal):
347
+ python generate.py random-video --seeds=0-5 \\
348
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
349
+
350
+ \b
351
+ # Generate a 60-second long, truncated 1x2 MetFaces video at 60 FPS (2 rows and 1 column; vertical):
352
+ python generate.py random-video --trunc=0.7 --seeds=10,20 --grid-width=1 --grid-height=2 \\
353
+ --fps=60 -sec=60 --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
354
+
355
+ """
356
+ # Sanity check
357
+ if len(seeds) < 1:
358
+ ctx.fail('Use `--seeds` to specify at least one seed.')
359
+
360
+ device = torch.device('cuda')
361
+
362
+ # Load the network
363
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
364
+
365
+ # Print the available layers in the model
366
+ if available_layers:
367
+ click.secho(f'Printing available layers (name, channels and size) for "{network_pkl}"...', fg='blue')
368
+ _ = Renderer().render(G=G, available_layers=available_layers)
369
+ sys.exit(1)
370
+
371
+ # Sadly, render can only generate one image at a time, so for now we'll just use the first seed
372
+ if layer_name is not None and len(seeds) > 1:
373
+ print(f'Note: Only one seed is supported for layer extraction, using seed "{seeds[0]}"...')
374
+ seeds = seeds[:1]
375
+
376
+ # Stabilize/anchor the latent space
377
+ if anchor_latent_space:
378
+ gen_utils.anchor_latent_space(G)
379
+
380
+ # Create the run dir with the given name description; add slowdown if different from the default (1)
381
+ desc = 'random-video'
382
+ desc = f'random-video-{description}' if description is not None else desc
383
+ desc = f'{desc}-{slowdown}xslowdown' if slowdown != 1 else desc
384
+ desc = f'{desc}-{layer_name}_layer' if layer_name is not None else desc
385
+ run_dir = gen_utils.make_run_dir(outdir, desc)
386
+
387
+ # Number of frames in the video and its total duration in seconds
388
+ num_frames = int(np.rint(duration_sec * fps))
389
+ total_duration = duration_sec * slowdown
390
+
391
+ print('Generating latent vectors...')
392
+ # TODO: let another helper function handle each case, we will use it for the grid
393
+ # If there's more than one seed provided and the shape isn't specified by the user
394
+ if (grid_width is None and grid_height is None) and len(seeds) >= 1:
395
+ # TODO: this can be done by another function
396
+ # Number of images in the grid video according to the seeds provided
397
+ num_seeds = len(seeds)
398
+ # Get the grid width and height according to num, giving priority to the number of columns
399
+ grid_width = max(int(np.ceil(np.sqrt(num_seeds))), 1)
400
+ grid_height = max((num_seeds - 1) // grid_width + 1, 1)
401
+ grid_size = (grid_width, grid_height)
402
+ shape = [num_frames, G.z_dim] # This is per seed
403
+ # Get the z latents
404
+ all_latents = np.stack([np.random.RandomState(seed).randn(*shape).astype(np.float32) for seed in seeds], axis=1)
405
+
406
+ # If only one seed is provided, but the user specifies the grid shape:
407
+ elif None not in (grid_width, grid_height) and len(seeds) == 1:
408
+ grid_size = (grid_width, grid_height)
409
+ shape = [num_frames, np.prod(grid_size), G.z_dim]
410
+ # Since we have one seed, we use it to generate all latents
411
+ all_latents = np.random.RandomState(*seeds).randn(*shape).astype(np.float32)
412
+
413
+ # If one or more seeds are provided, and the user also specifies the grid shape:
414
+ elif None not in (grid_width, grid_height) and len(seeds) >= 1:
415
+ # Case is similar to the first one
416
+ num_seeds = len(seeds)
417
+ grid_size = (grid_width, grid_height)
418
+ available_slots = np.prod(grid_size)
419
+ if available_slots < num_seeds:
420
+ diff = num_seeds - available_slots
421
+ click.secho(f'More seeds were provided ({num_seeds}) than available spaces in the grid ({available_slots})',
422
+ fg='red')
423
+ click.secho(f'Removing the last {diff} seeds: {seeds[-diff:]}', fg='blue')
424
+ seeds = seeds[:available_slots]
425
+ shape = [num_frames, G.z_dim]
426
+ all_latents = np.stack([np.random.RandomState(seed).randn(*shape).astype(np.float32) for seed in seeds], axis=1)
427
+
428
+ else:
429
+ ctx.fail('Error: wrong combination of arguments! Please provide either a list of seeds, one seed and the grid '
430
+ 'width and height, or more than one seed and the grid width and height')
431
+
432
+ # Let's smooth out the random latents so that now they form a loop (and are correctly generated in a 512-dim space)
433
+ all_latents = scipy.ndimage.gaussian_filter(all_latents, sigma=[smoothing_sec * fps, 0, 0], mode='wrap')
434
+ all_latents /= np.sqrt(np.mean(np.square(all_latents)))
435
+
436
+ # Name of the video
437
+ mp4_name = f'{grid_width}x{grid_height}-slerp-{slowdown}xslowdown'
438
+
439
+ # Labels.
440
+ class_idx = gen_utils.parse_class(G, class_idx, ctx)
441
+ label = torch.zeros([1, G.c_dim], device=device)
442
+ if G.c_dim != 0:
443
+ label[:, class_idx] = 1
444
+ else:
445
+ if class_idx is not None:
446
+ print('warn: --class=lbl ignored when running on an unconditional network')
447
+
448
+ # Let's slowdown the video, if so desired
449
+ while slowdown > 1:
450
+ all_latents, duration_sec, num_frames = gen_utils.double_slowdown(latents=all_latents,
451
+ duration=duration_sec,
452
+ frames=num_frames)
453
+ slowdown //= 2
454
+
455
+ if new_center is None:
456
+ w_avg = G.mapping.w_avg
457
+ else:
458
+ new_center, new_center_value = new_center
459
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
460
+ if isinstance(new_center_value, int):
461
+ w_avg = gen_utils.get_w_from_seed(G, device, new_center_value,
462
+ truncation_psi=1.0) # We want the pure dlatent
463
+ elif isinstance(new_center_value, np.ndarray):
464
+ w_avg = torch.from_numpy(new_center_value).to(device)
465
+ else:
466
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
467
+
468
+ # Auxiliary function for moviepy
469
+ def make_frame(t):
470
+ frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
471
+ latents = torch.from_numpy(all_latents[frame_idx]).to(device)
472
+ # Do the truncation trick (with the global centroid or the new center provided by the user)
473
+ w = G.mapping(latents, None)
474
+ w = w_avg + (w - w_avg) * truncation_psi
475
+
476
+ # Get the images
477
+
478
+ # Save the intermediate layer output.
479
+ if layer_name is not None:
480
+ # Sanity check (again, could be done better)
481
+ submodule_names = {name: mod for name, mod in G.synthesis.named_modules()}
482
+ assert layer_name in submodule_names, f'Layer "{layer_name}" not found in the network! Available layers: {", ".join(submodule_names)}'
483
+ assert True in (save_grayscale, save_rgb), 'You must select to save the video in at least one of the two possible formats! (L, RGB)'
484
+
485
+ sel_channels = 3 if save_rgb else 1
486
+ res = Renderer().render(G=G, layer_name=layer_name, dlatent=w, sel_channels=sel_channels,
487
+ base_channel=starting_channel, img_scale_db=img_scale_db, img_normalize=img_normalize)
488
+ images = res.image
489
+ images = np.expand_dims(np.array(images), axis=0)
490
+ else:
491
+ images = gen_utils.w_to_img(G, w, noise_mode) # Remember, it can only be a single image
492
+ # RGBA -> RGB, if necessary
493
+ images = images[:, :, :, :3]
494
+
495
+ # Generate the grid for this timestamp
496
+ grid = gen_utils.create_image_grid(images, grid_size)
497
+ # moviepy.editor.VideoClip expects 3 channels
498
+ if grid.shape[2] == 1:
499
+ grid = grid.repeat(3, 2)
500
+ return grid
501
+
502
+ # Generate video using the respective make_frame function
503
+ videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
504
+ videoclip.set_duration(total_duration)
505
+
506
+ mp4_name = f'{mp4_name}_{layer_name}' if layer_name is not None else mp4_name
507
+
508
+ # Change the video parameters (codec, bitrate) if you so desire
509
+ final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
510
+ videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')
511
+
512
+ # Save the configuration used
513
+ new_center = 'w_avg' if new_center is None else new_center
514
+ ctx.obj = {
515
+ 'network_pkl': network_pkl,
516
+ 'config': cfg,
517
+ 'synthesis_options': {
518
+ 'seeds': seeds,
519
+ 'truncation_psi': truncation_psi,
520
+ 'new_center': new_center,
521
+ 'class_idx': class_idx,
522
+ 'noise_mode': noise_mode,
523
+ 'anchor_latent_space': anchor_latent_space
524
+ },
525
+ 'intermediate_representations': {
526
+ 'layer': layer_name,
527
+ 'starting_channel': starting_channel,
528
+ 'grayscale': save_grayscale,
529
+ 'rgb': save_rgb,
530
+ 'img_scale_db': img_scale_db,
531
+ 'img_normalize': img_normalize
532
+ },
533
+ 'video_options': {
534
+ 'grid_width': grid_width,
535
+ 'grid_height': grid_height,
536
+ 'slowdown': slowdown,
537
+ 'duration_sec': duration_sec,
538
+ 'video_fps': fps,
539
+ 'compress': compress,
540
+ 'smoothing_sec': smoothing_sec
541
+ },
542
+ 'extra_parameters': {
543
+ 'run_dir': run_dir,
544
+ 'description': desc
545
+ }
546
+ }
547
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
548
+
549
+ # Compress the video (lower file size, same resolution)
550
+ if compress:
551
+ gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)
552
+
553
+
554
+ # ----------------------------------------------------------------------------
555
+
556
+
557
+ @main.command('circular-video')
558
+ @click.pass_context
559
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
560
+ @click.option('--cfg', type=click.Choice(gen_utils.available_cfgs), help='Config of the network, used only if you want to use the pretrained models in torch_utils.gen_utils.resume_specs')
561
+ # Synthesis options
562
+ @click.option('--seed', type=int, help='Random seed', required=True)
563
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
564
+ @click.option('--trunc-start', 'truncation_psi_start', type=float, help='Initial value of pulsating truncation psi', default=None, show_default=True)
565
+ @click.option('--trunc-end', 'truncation_psi_end', type=float, help='Maximum/minimum value of pulsating truncation psi', default=None, show_default=True)
566
+ @click.option('--global-pulse', 'global_pulsation_trick', is_flag=True, help='If set, the truncation psi will pulsate globally (on all grid cells)')
567
+ @click.option('--wave-pulse', 'wave_pulsation_trick', is_flag=True, help='If set, the truncation psi will pulsate in a wave-like fashion from the upper left to the lower right in the grid')
568
+ @click.option('--frequency', 'pulsation_frequency', type=int, help='Frequency of the pulsation', default=1, show_default=True)
569
+ @click.option('--new-center', type=str, help='New center for the W latent space; a seed (int) or a path to a projected dlatent (.npy/.npz)', default=None)
570
+ @click.option('--new-w-avg', 'new_w_avg', type=gen_utils.parse_new_center, help='Path to a new "global" w_avg (seed or .npy/.npz file) to be used in the truncation trick', default=None)
571
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
572
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
573
+ @click.option('--anchor-latent-space', '-anchor', is_flag=True, help='Anchor the latent space to w_avg to stabilize the video')
574
+ @click.option('--flesh', 'aydao_flesh_digression', is_flag=True, help='If set, we will slowly modify the constant input to the network (based on @aydao\'s work')
575
+ # Video options
576
+ @click.option('--grid-width', '-gw', type=click.IntRange(min=1), help='Video grid width / number of columns', required=True)
577
+ @click.option('--grid-height', '-gh', type=click.IntRange(min=1), help='Video grid height / number of rows', required=True)
578
+ @click.option('--duration-sec', '-sec', type=float, help='Duration length of the video', default=10.0, show_default=True)
579
+ @click.option('--fps', type=click.IntRange(min=1), help='Video FPS.', default=30, show_default=True)
580
+ @click.option('--compress', is_flag=True, help='Add flag to compress the final mp4 file with ffmpeg-python (same resolution, lower file size)')
581
+ # Extra parameters for saving the results
582
+ @click.option('--outdir', type=click.Path(file_okay=False), help='Directory path to save the results', default=os.path.join(os.getcwd(), 'out', 'video'), show_default=True, metavar='DIR')
583
+ @click.option('--description', '-desc', type=str, help='Description name for the directory path to save results')
584
+ def circular_video(
585
+ ctx: click.Context,
586
+ network_pkl: Union[str, os.PathLike],
587
+ cfg: Optional[str],
588
+ seed: int,
589
+ truncation_psi: Optional[float],
590
+ truncation_psi_start: Optional[float],
591
+ truncation_psi_end: Optional[float],
592
+ global_pulsation_trick: Optional[bool],
593
+ wave_pulsation_trick: Optional[bool],
594
+ pulsation_frequency: Optional[int],
595
+ new_center: Tuple[str, Union[int, np.ndarray]],
596
+ new_w_avg: Optional[Union[str, os.PathLike]],
597
+ class_idx: Optional[int],
598
+ noise_mode: Optional[str],
599
+ anchor_latent_space: Optional[bool],
600
+ aydao_flesh_digression: Optional[bool],
601
+ grid_width: int,
602
+ grid_height: int,
603
+ duration_sec: float,
604
+ fps: int,
605
+ compress: Optional[bool],
606
+ outdir: Union[str, os.PathLike],
607
+ description: str
608
+ ):
609
+ """
610
+ Generate a circular interpolation video in two random axes of Z, given a seed
611
+ """
612
+
613
+ device = torch.device('cuda')
614
+
615
+ # Load the network
616
+ G = gen_utils.load_network('G_ema', network_pkl, cfg, device)
617
+
618
+ # Get the constant input
619
+ if aydao_flesh_digression:
620
+ if hasattr(G.synthesis, 'b4'):
621
+ model_type = 'stylegan2'
622
+ const_input = copy.deepcopy(G.synthesis.b4.const).cpu().numpy()
623
+ elif hasattr(G.synthesis, 'input'):
624
+ model_type = 'stylegan3'
625
+ input_frequencies = copy.deepcopy(G.synthesis.input.freqs).cpu().numpy()
626
+ input_phases = copy.deepcopy(G.synthesis.input.phases).cpu().numpy()
627
+ else:
628
+ ctx.fail('Error: This option is only available for StyleGAN2 and StyleGAN3 models!')
629
+
630
+ # Get the labels, if the model is conditional
631
+ class_idx = gen_utils.parse_class(G, class_idx, ctx)
632
+ label = torch.zeros([1, G.c_dim], device=device)
633
+ if G.c_dim != 0:
634
+ label[:, class_idx] = 1
635
+ else:
636
+ if class_idx is not None:
637
+ print('warn: --class=lbl ignored when running on an unconditional network')
638
+
639
+ # Get center of the latent space (global or user-indicated)
640
+ if new_center is None:
641
+ w_avg = G.mapping.w_avg
642
+ w_avg = w_avg.view(1, 1, -1) # [w_dim] => [1, 1, w_dim]
643
+ else:
644
+ # It's an int, so use as a seed
645
+ if new_center.isdigit():
646
+ w_avg = gen_utils.get_w_from_seed(G, device, int(new_center), truncation_psi=1.0).to(device)
647
+ # It's a file, so load it
648
+ elif os.path.isfile(new_center):
649
+ w_avg = gen_utils.get_latent_from_file(new_center, return_ext=False)
650
+ w_avg = torch.from_numpy(w_avg).to(device)
651
+ # It's a directory, so get all latents inside it (including subdirectories, so be careful)
652
+ elif os.path.isdir(new_center):
653
+ w_avg = gen_utils.parse_all_projected_dlatents(new_center)
654
+ w_avg = torch.tensor(w_avg).squeeze(1).to(device)
655
+ else:
656
+ message = 'Only seeds (int) or paths to latent files (.npy/.npz) or directories containing these are allowed for "--new-center"'
657
+ raise ctx.fail(message)
658
+
659
+ # Some sanity checks
660
+ num_centers = len(w_avg)
661
+ if num_centers == 0:
662
+ raise ctx.fail('No centers were found! If files, makes sure they are .npy or .npz files.')
663
+ # Just one is provided, so this will be a sort of 'global' center
664
+ elif num_centers == 1:
665
+ print(f'Using only one center (if more than one is desired, provide a directory with all of them)')
666
+ elif num_centers != grid_height * grid_width:
667
+ message = f"Number of centers ({num_centers}) doesn't match the grid size ({grid_height}x{grid_width})"
668
+ raise ctx.fail(message)
669
+
670
+ print('Using wave pulsation trick' if wave_pulsation_trick else 'Using global pulsation trick' if global_pulsation_trick else 'Using standard truncation trick...')
671
+ # Stabilize/anchor the latent space
672
+ if anchor_latent_space:
673
+ gen_utils.anchor_latent_space(G)
674
+
675
+ # Create the run dir with the given name description; add slowdown if different from the default (1)
676
+ desc = 'circular-video'
677
+ desc = f'circular-video-{description}' if description is not None else desc
678
+ desc = f'{desc}-aydao-flesh-digression' if aydao_flesh_digression else desc
679
+ run_dir = gen_utils.make_run_dir(outdir, desc)
680
+
681
+ # Calculate the total number of frames in the video
682
+ num_frames = int(np.rint(duration_sec * fps))
683
+
684
+ grid_size = (grid_width, grid_height)
685
+ # Get the latents with the random state
686
+ random_state = np.random.RandomState(seed)
687
+ # Choose two random dims on which to plot the circles (from 0 to G.z_dim-1),
688
+ # one pair for each element of the grid (2*grid_width*grid_height in total)
689
+ try:
690
+ z1, z2 = np.split(random_state.choice(G.z_dim, 2 * np.prod(grid_size), replace=False), 2)
691
+ except ValueError:
692
+ # Extreme case: G.z_dim < 2 * grid_width * grid_height (low G.z_dim most likely)
693
+ z1, z2 = np.split(random_state.choice(G.z_dim, 2 * np.prod(grid_size), replace=True), 2)
694
+
695
+ # We partition the circle in equal strides w.r.t. num_frames
696
+ get_angles = lambda num_frames: np.linspace(0, 2*np.pi, num_frames)
697
+ angles = get_angles(num_frames=num_frames)
698
+
699
+ # Basic Polar to Cartesian transformation
700
+ polar_to_cartesian = lambda radius, theta: (radius * np.cos(theta), radius * np.sin(theta))
701
+ # Using a fixed radius (this value is irrelevant), we generate the circles in each chosen grid
702
+ Z1, Z2 = polar_to_cartesian(radius=5.0, theta=angles)
703
+
704
+ # Our latents will be comprising mostly of zeros
705
+ all_latents = np.zeros([num_frames, np.prod(grid_size), G.z_dim]).astype(np.float32)
706
+ # Obtain all the frames belonging to the specific box in the grid,
707
+ # replacing the zero values with the circle perimeter values
708
+ for box in range(np.prod(grid_size)):
709
+ box_frames = all_latents[:, box]
710
+ box_frames[:, [z1[box], z2[box]]] = np.vstack((Z1, Z2)).T
711
+
712
+ if aydao_flesh_digression:
713
+ # We will modify the constant input to the network (for --cfg=stylegan2)
714
+ if model_type == 'stylegan2':
715
+ const_input_interpolation = np.random.randn(num_frames, *const_input.shape).astype(np.float32) / 4 # [num_frames, G.w_dim, 4, 4] ; "/ 4" is arbitrary
716
+ const_input_interpolation = scipy.ndimage.gaussian_filter(const_input_interpolation, sigma=[fps, 0, 0, 0], mode='wrap')
717
+ const_input_interpolation /= np.sqrt(np.mean(np.square(const_input_interpolation)))
718
+ elif model_type == 'stylegan3':
719
+ const_freq_interpolation = np.random.randn(num_frames, *input_frequencies.shape).astype(np.float32) / 32 # [num_frames, G.w_dim, 2]
720
+ const_freq_interpolation = scipy.ndimage.gaussian_filter(const_freq_interpolation, sigma=[5.0*fps, 0, 0], mode='wrap')
721
+ const_freq_interpolation /= np.sqrt(np.mean(np.square(const_freq_interpolation)))
722
+
723
+ const_phase_interpolation = np.random.randn(num_frames, *input_phases.shape).astype(np.float32) / 8 # [num_frames, G.w_dim, 2]
724
+ const_phase_interpolation = scipy.ndimage.gaussian_filter(const_phase_interpolation, sigma=[5.0*fps, 0], mode='wrap')
725
+ const_phase_interpolation /= np.sqrt(np.mean(np.square(const_phase_interpolation)))
726
+
727
+ # Convert to torch tensor
728
+ if new_w_avg is not None:
729
+ print("Moving all the latent space towards the new center...")
730
+ _, new_w_avg = new_w_avg
731
+ # We get the new center using the int (a seed) or recovered dlatent (an np.ndarray)
732
+ if isinstance(new_w_avg, int):
733
+ new_w_avg = gen_utils.get_w_from_seed(G, device, new_w_avg,
734
+ truncation_psi=1.0) # We want the pure dlatent
735
+ elif isinstance(new_w_avg, np.ndarray):
736
+ new_w_avg = torch.from_numpy(new_w_avg).to(device) # [1, num_ws, w_dim]
737
+ else:
738
+ ctx.fail('Error: New center has strange format! Only an int (seed) or a file (.npy/.npz) are accepted!')
739
+
740
+ # Auxiliary function for moviepy
741
+ def make_frame(t):
742
+ frame_idx = int(np.clip(np.round(t * fps), 0, num_frames - 1))
743
+ latents = torch.from_numpy(all_latents[frame_idx]).to(device)
744
+ # Get the images with the respective label
745
+ dlatents = gen_utils.z_to_dlatent(G, latents, label, truncation_psi=1.0) # Get the pure dlatent
746
+ # Do truncation trick
747
+ # For the truncation trick (supersedes any value chosen for truncation_psi)
748
+ if None not in (truncation_psi_start, truncation_psi_end):
749
+ # For both, truncation psi will have the general form of a sinusoid: psi = (cos(t) + alpha) / beta
750
+ if global_pulsation_trick:
751
+ tr = gen_utils.global_pulsate_psi(psi_start=truncation_psi_start,
752
+ psi_end=truncation_psi_end,
753
+ n_steps=num_frames)
754
+ elif wave_pulsation_trick:
755
+ tr = gen_utils.wave_pulse_truncation_psi(psi_start=truncation_psi_start,
756
+ psi_end=truncation_psi_end,
757
+ n_steps=num_frames,
758
+ grid_shape=grid_size,
759
+ frequency=pulsation_frequency,
760
+ time=frame_idx)
761
+ # Define how to use the truncation psi
762
+ if global_pulsation_trick:
763
+ tr = tr[frame_idx].to(device)
764
+ elif wave_pulsation_trick:
765
+ tr = tr.to(device)
766
+ else:
767
+ # It's a float, so we can just use it
768
+ tr = truncation_psi
769
+
770
+ w = w_avg + (dlatents - w_avg) * tr
771
+ # Modify the constant input
772
+ if aydao_flesh_digression:
773
+ if model_type == 'stylegan2':
774
+ G.synthesis.b4.const.copy_(torch.from_numpy(const_input_interpolation[frame_idx]))
775
+ elif model_type == 'stylegan3':
776
+ pass
777
+ # G.synthesis.input.freqs.copy_(torch.from_numpy(const_freq_interpolation[frame_idx]))
778
+ # G.synthesis.input.phases.copy_(torch.from_numpy(const_phase_interpolation[frame_idx]))
779
+ # G.synthesis.input.phases.copy_(torch.from_numpy(
780
+ # input_phases * np.cos(np.pi * frame_idx / num_frames) ** 2
781
+ # ))
782
+ # Get the images
783
+ images = gen_utils.w_to_img(G, w, noise_mode, new_w_avg, tr)
784
+ # RGBA -> RGB
785
+ images = images[:, :, :, :3]
786
+ # Generate the grid for this timestep
787
+ grid = gen_utils.create_image_grid(images, grid_size)
788
+ # Grayscale => RGB
789
+ if grid.shape[2] == 1:
790
+ grid = grid.repeat(3, 2)
791
+ return grid
792
+
793
+ # Generate video using the respective make_frame function
794
+ videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
795
+ videoclip.set_duration(duration_sec)
796
+
797
+ # Name of the video
798
+ mp4_name = f'{grid_width}x{grid_height}-circular'
799
+
800
+ # Change the video parameters (codec, bitrate) if you so desire
801
+ final_video = os.path.join(run_dir, f'{mp4_name}.mp4')
802
+ videoclip.write_videofile(final_video, fps=fps, codec='libx264', bitrate='16M')
803
+
804
+ # Save the configuration used
805
+ new_center = 'w_avg' if new_center is None else new_center
806
+ ctx.obj = {
807
+ 'network_pkl': network_pkl,
808
+ 'config': cfg,
809
+ 'seed': seed,
810
+ 'z1, z2': [[int(i), int(j)] for i, j in zip(z1, z2)],
811
+ 'truncation_psi': truncation_psi if isinstance(truncation_psi, float) else 'pulsating',
812
+ 'truncation_psi_start': truncation_psi_start,
813
+ 'truncation_psi_end': truncation_psi_end,
814
+ 'new_center': new_center,
815
+ 'class_idx': class_idx,
816
+ 'noise_mode': noise_mode,
817
+ 'grid_width': grid_width,
818
+ 'grid_height': grid_height,
819
+ 'duration_sec': duration_sec,
820
+ 'video_fps': fps,
821
+ 'run_dir': run_dir,
822
+ 'description': desc,
823
+ 'compress': compress
824
+ }
825
+ gen_utils.save_config(ctx=ctx, run_dir=run_dir)
826
+
827
+ # Compress the video (lower file size, same resolution)
828
+ if compress:
829
+ gen_utils.compress_video(original_video=final_video, original_video_name=mp4_name, outdir=run_dir, ctx=ctx)
830
+
831
+ # ----------------------------------------------------------------------------
832
+
833
+
834
+ if __name__ == "__main__":
835
+ main() # pylint: disable=no-value-for-parameter
836
+
837
+
838
+ # ----------------------------------------------------------------------------
stylegan3-fun/gui_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (181 Bytes). View file
 
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (181 Bytes). View file
 
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc ADDED
Binary file (7.8 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc ADDED
Binary file (7.78 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc ADDED
Binary file (9.73 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc ADDED
Binary file (5.1 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc ADDED
Binary file (5.19 kB). View file